Switch to side-by-side view

--- a
+++ b/ddc_pub/custom_callbacks.py
@@ -0,0 +1,178 @@
+from tensorflow.keras.callbacks import ModelCheckpoint
+import tempfile, pickle, shutil
+import numpy as np
+
+
+class LearningRateSchedule:
+    """
+    Class for custom learning rate schedules.
+    """
+
+    def __init__(self, epoch_to_start=500, last_epoch=999, lr_init=1e-3, lr_final=1e-6):
+        self.epoch_to_start = epoch_to_start
+        self.last_epoch = last_epoch
+        self.lr_init = lr_init
+        self.lr_final = lr_final
+
+    def exp_decay(self, epoch, lr):
+        """
+        Exponential decay.
+        """
+
+        decay_duration = self.last_epoch - self.epoch_to_start
+        if epoch < self.epoch_to_start:
+            return lr
+        else:
+            # Slope of the decay
+            k = -(1 / decay_duration) * np.log(self.lr_final / self.lr_init)
+
+            lr = self.lr_init * np.exp(-k * (epoch - self.epoch_to_start))
+            return lr
+
+
+class ModelAndHistoryCheckpoint(ModelCheckpoint):
+    """
+    Callback to save all sub-models and training history.
+    """
+
+    def __init__(
+        self,
+        filepath,
+        model_dict,
+        monitor="val_loss",
+        verbose=0,
+        save_best_only=False,
+        save_weights_only=False,
+        mode="auto",
+        period=1,
+        history={},
+    ):
+
+        super().__init__(
+            filepath,
+            monitor="val_loss",
+            verbose=0,
+            save_best_only=False,
+            save_weights_only=False,
+            mode="auto",
+            period=1,
+        )
+
+        self.period = period
+        self.model_dict = model_dict
+
+        self.history = history
+
+    def save_models(self, filepath):
+        """
+        Save everything in a zip file.
+        """
+
+        with tempfile.TemporaryDirectory() as dirpath:
+
+            # Save the Keras models
+            try:  # Original DDC
+                if self.model_dict["_DDC__mol_to_latent_model"] is not None:
+                    self.model_dict["_DDC__mol_to_latent_model"].save(
+                        dirpath + "/mol_to_latent_model.h5"
+                    )
+
+                self.model_dict["_DDC__latent_to_states_model"].save(
+                    dirpath + "/latent_to_states_model.h5"
+                )
+                self.model_dict["_DDC__batch_model"].save(dirpath + "/batch_model.h5")
+
+            except:  # Unbiased DDC
+                self.model_dict["_DDC__model"].save(dirpath + "/model.h5")
+                print(dirpath)
+
+            # Exclude un-picklable and un-wanted attributes
+            excl_attr = [
+                "_DDC__mode",  # mode is excluded because it is defined inside __init__
+                "_DDC__train_gen",
+                "_DDC__valid_gen",
+                "_DDC__mol_to_latent_model",
+                "_DDC__latent_to_states_model",
+                "_DDC__batch_model",
+                "_DDC__sample_model",
+                "_DDC__multi_sample_model",
+                "_DDC__model",
+            ]
+
+            # Cannot deepcopy self.__dict__ because of Keras' thread lock so this is
+            # bypassed by popping, saving and re-inserting the un-picklable attributes
+            to_add = {}
+            # Remove un-picklable attributes
+            for attr in excl_attr:
+                to_add[attr] = self.model_dict.pop(attr, None)
+
+            # Pickle metadata
+            pickle.dump(self.model_dict, open(dirpath + "/metadata.pickle", "wb"))
+
+            # Zip directory with its contents
+            shutil.make_archive(filepath, "zip", dirpath)
+
+            # Finally, re-load the popped elements
+            for attr in excl_attr:
+                # if attr == "_DDC__mol_to_latent_model" and to_add[attr] is None:
+                #    continue
+                self.model_dict[attr] = to_add[attr]
+
+            print("Model saved in %s." % filepath)
+
+    def on_train_begin(self, logs=None):
+        self.epoch = []
+
+    def on_epoch_end(self, epoch, logs=None):
+        # Save training history
+        logs = logs or {}
+        self.epoch.append(epoch)
+        for k, v in logs.items():
+            self.history.setdefault(k, []).append(v)
+
+        # Save model(s)
+        self.epochs_since_last_save += 1
+        if self.epochs_since_last_save >= self.period:
+            self.epochs_since_last_save = 0
+            filepath = self.filepath.format(epoch=epoch + 1, **logs)
+            if self.save_best_only:
+                current = logs.get(self.monitor)
+                if current is None:
+                    warnings.warn(
+                        "Can save best model only with %s available, "
+                        "skipping." % (self.monitor),
+                        RuntimeWarning,
+                    )
+                else:
+                    if self.monitor_op(current, self.best):
+                        if self.verbose > 0:
+                            print(
+                                "\nEpoch %05d: %s improved from %0.5f to %0.5f,"
+                                " saving model to %s"
+                                % (
+                                    epoch + 1,
+                                    self.monitor,
+                                    self.best,
+                                    current,
+                                    filepath,
+                                )
+                            )
+                        self.best = current
+                        if self.save_weights_only:
+                            self.model.save_weights(filepath, overwrite=True)
+                        else:
+                            self.model.save(filepath, overwrite=True)
+                    else:
+                        if self.verbose > 0:
+                            print(
+                                "\nEpoch %05d: %s did not improve from %0.5f"
+                                % (epoch + 1, self.monitor, self.best)
+                            )
+            else:
+                if self.verbose > 0:
+                    print("\nEpoch %05d: saving model to %s" % (epoch + 1, filepath))
+                if self.save_weights_only:
+                    print("Saving weights of full model, ONLY.")
+                    self.model.save_weights(filepath, overwrite=True)
+                else:
+                    self.save_models(filepath)