[58db57]: / ddc_pub / custom_callbacks.py

Download this file

179 lines (152 with data), 6.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from tensorflow.keras.callbacks import ModelCheckpoint
import tempfile, pickle, shutil
import numpy as np
class LearningRateSchedule:
"""
Class for custom learning rate schedules.
"""
def __init__(self, epoch_to_start=500, last_epoch=999, lr_init=1e-3, lr_final=1e-6):
self.epoch_to_start = epoch_to_start
self.last_epoch = last_epoch
self.lr_init = lr_init
self.lr_final = lr_final
def exp_decay(self, epoch, lr):
"""
Exponential decay.
"""
decay_duration = self.last_epoch - self.epoch_to_start
if epoch < self.epoch_to_start:
return lr
else:
# Slope of the decay
k = -(1 / decay_duration) * np.log(self.lr_final / self.lr_init)
lr = self.lr_init * np.exp(-k * (epoch - self.epoch_to_start))
return lr
class ModelAndHistoryCheckpoint(ModelCheckpoint):
"""
Callback to save all sub-models and training history.
"""
def __init__(
self,
filepath,
model_dict,
monitor="val_loss",
verbose=0,
save_best_only=False,
save_weights_only=False,
mode="auto",
period=1,
history={},
):
super().__init__(
filepath,
monitor="val_loss",
verbose=0,
save_best_only=False,
save_weights_only=False,
mode="auto",
period=1,
)
self.period = period
self.model_dict = model_dict
self.history = history
def save_models(self, filepath):
"""
Save everything in a zip file.
"""
with tempfile.TemporaryDirectory() as dirpath:
# Save the Keras models
try: # Original DDC
if self.model_dict["_DDC__mol_to_latent_model"] is not None:
self.model_dict["_DDC__mol_to_latent_model"].save(
dirpath + "/mol_to_latent_model.h5"
)
self.model_dict["_DDC__latent_to_states_model"].save(
dirpath + "/latent_to_states_model.h5"
)
self.model_dict["_DDC__batch_model"].save(dirpath + "/batch_model.h5")
except: # Unbiased DDC
self.model_dict["_DDC__model"].save(dirpath + "/model.h5")
print(dirpath)
# Exclude un-picklable and un-wanted attributes
excl_attr = [
"_DDC__mode", # mode is excluded because it is defined inside __init__
"_DDC__train_gen",
"_DDC__valid_gen",
"_DDC__mol_to_latent_model",
"_DDC__latent_to_states_model",
"_DDC__batch_model",
"_DDC__sample_model",
"_DDC__multi_sample_model",
"_DDC__model",
]
# Cannot deepcopy self.__dict__ because of Keras' thread lock so this is
# bypassed by popping, saving and re-inserting the un-picklable attributes
to_add = {}
# Remove un-picklable attributes
for attr in excl_attr:
to_add[attr] = self.model_dict.pop(attr, None)
# Pickle metadata
pickle.dump(self.model_dict, open(dirpath + "/metadata.pickle", "wb"))
# Zip directory with its contents
shutil.make_archive(filepath, "zip", dirpath)
# Finally, re-load the popped elements
for attr in excl_attr:
# if attr == "_DDC__mol_to_latent_model" and to_add[attr] is None:
# continue
self.model_dict[attr] = to_add[attr]
print("Model saved in %s." % filepath)
def on_train_begin(self, logs=None):
self.epoch = []
def on_epoch_end(self, epoch, logs=None):
# Save training history
logs = logs or {}
self.epoch.append(epoch)
for k, v in logs.items():
self.history.setdefault(k, []).append(v)
# Save model(s)
self.epochs_since_last_save += 1
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save = 0
filepath = self.filepath.format(epoch=epoch + 1, **logs)
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn(
"Can save best model only with %s available, "
"skipping." % (self.monitor),
RuntimeWarning,
)
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
print(
"\nEpoch %05d: %s improved from %0.5f to %0.5f,"
" saving model to %s"
% (
epoch + 1,
self.monitor,
self.best,
current,
filepath,
)
)
self.best = current
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
else:
if self.verbose > 0:
print(
"\nEpoch %05d: %s did not improve from %0.5f"
% (epoch + 1, self.monitor, self.best)
)
else:
if self.verbose > 0:
print("\nEpoch %05d: saving model to %s" % (epoch + 1, filepath))
if self.save_weights_only:
print("Saving weights of full model, ONLY.")
self.model.save_weights(filepath, overwrite=True)
else:
self.save_models(filepath)