--- a +++ b/ddc_pub/custom_callbacks.py @@ -0,0 +1,178 @@ +from tensorflow.keras.callbacks import ModelCheckpoint +import tempfile, pickle, shutil +import numpy as np + + +class LearningRateSchedule: + """ + Class for custom learning rate schedules. + """ + + def __init__(self, epoch_to_start=500, last_epoch=999, lr_init=1e-3, lr_final=1e-6): + self.epoch_to_start = epoch_to_start + self.last_epoch = last_epoch + self.lr_init = lr_init + self.lr_final = lr_final + + def exp_decay(self, epoch, lr): + """ + Exponential decay. + """ + + decay_duration = self.last_epoch - self.epoch_to_start + if epoch < self.epoch_to_start: + return lr + else: + # Slope of the decay + k = -(1 / decay_duration) * np.log(self.lr_final / self.lr_init) + + lr = self.lr_init * np.exp(-k * (epoch - self.epoch_to_start)) + return lr + + +class ModelAndHistoryCheckpoint(ModelCheckpoint): + """ + Callback to save all sub-models and training history. + """ + + def __init__( + self, + filepath, + model_dict, + monitor="val_loss", + verbose=0, + save_best_only=False, + save_weights_only=False, + mode="auto", + period=1, + history={}, + ): + + super().__init__( + filepath, + monitor="val_loss", + verbose=0, + save_best_only=False, + save_weights_only=False, + mode="auto", + period=1, + ) + + self.period = period + self.model_dict = model_dict + + self.history = history + + def save_models(self, filepath): + """ + Save everything in a zip file. + """ + + with tempfile.TemporaryDirectory() as dirpath: + + # Save the Keras models + try: # Original DDC + if self.model_dict["_DDC__mol_to_latent_model"] is not None: + self.model_dict["_DDC__mol_to_latent_model"].save( + dirpath + "/mol_to_latent_model.h5" + ) + + self.model_dict["_DDC__latent_to_states_model"].save( + dirpath + "/latent_to_states_model.h5" + ) + self.model_dict["_DDC__batch_model"].save(dirpath + "/batch_model.h5") + + except: # Unbiased DDC + self.model_dict["_DDC__model"].save(dirpath + "/model.h5") + print(dirpath) + + # Exclude un-picklable and un-wanted attributes + excl_attr = [ + "_DDC__mode", # mode is excluded because it is defined inside __init__ + "_DDC__train_gen", + "_DDC__valid_gen", + "_DDC__mol_to_latent_model", + "_DDC__latent_to_states_model", + "_DDC__batch_model", + "_DDC__sample_model", + "_DDC__multi_sample_model", + "_DDC__model", + ] + + # Cannot deepcopy self.__dict__ because of Keras' thread lock so this is + # bypassed by popping, saving and re-inserting the un-picklable attributes + to_add = {} + # Remove un-picklable attributes + for attr in excl_attr: + to_add[attr] = self.model_dict.pop(attr, None) + + # Pickle metadata + pickle.dump(self.model_dict, open(dirpath + "/metadata.pickle", "wb")) + + # Zip directory with its contents + shutil.make_archive(filepath, "zip", dirpath) + + # Finally, re-load the popped elements + for attr in excl_attr: + # if attr == "_DDC__mol_to_latent_model" and to_add[attr] is None: + # continue + self.model_dict[attr] = to_add[attr] + + print("Model saved in %s." % filepath) + + def on_train_begin(self, logs=None): + self.epoch = [] + + def on_epoch_end(self, epoch, logs=None): + # Save training history + logs = logs or {} + self.epoch.append(epoch) + for k, v in logs.items(): + self.history.setdefault(k, []).append(v) + + # Save model(s) + self.epochs_since_last_save += 1 + if self.epochs_since_last_save >= self.period: + self.epochs_since_last_save = 0 + filepath = self.filepath.format(epoch=epoch + 1, **logs) + if self.save_best_only: + current = logs.get(self.monitor) + if current is None: + warnings.warn( + "Can save best model only with %s available, " + "skipping." % (self.monitor), + RuntimeWarning, + ) + else: + if self.monitor_op(current, self.best): + if self.verbose > 0: + print( + "\nEpoch %05d: %s improved from %0.5f to %0.5f," + " saving model to %s" + % ( + epoch + 1, + self.monitor, + self.best, + current, + filepath, + ) + ) + self.best = current + if self.save_weights_only: + self.model.save_weights(filepath, overwrite=True) + else: + self.model.save(filepath, overwrite=True) + else: + if self.verbose > 0: + print( + "\nEpoch %05d: %s did not improve from %0.5f" + % (epoch + 1, self.monitor, self.best) + ) + else: + if self.verbose > 0: + print("\nEpoch %05d: saving model to %s" % (epoch + 1, filepath)) + if self.save_weights_only: + print("Saving weights of full model, ONLY.") + self.model.save_weights(filepath, overwrite=True) + else: + self.save_models(filepath)