|
a |
|
b/callbacks/ensemble.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
|
|
|
3 |
import os |
|
|
4 |
from keras.callbacks import Callback |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
class SWA(Callback): |
|
|
8 |
""" |
|
|
9 |
This callback implements a stochastic weight averaging (SWA) method with constant lr |
|
|
10 |
as presented in the paper - |
|
|
11 |
"Izmailov et al. Averaging Weights Leads to Wider Optima and Better Generalization" |
|
|
12 |
(https://arxiv.org/abs/1803.05407) |
|
|
13 |
Author's implementation: https://github.com/timgaripov/swa |
|
|
14 |
""" |
|
|
15 |
def __init__(self, swa_model, checkpoint_dir, model_name, swa_start=1): |
|
|
16 |
""" |
|
|
17 |
:param swa_model: the model that we use to store the average of the weights once SWA begins |
|
|
18 |
:param checkpoint_dir: the directory where the model will be saved in |
|
|
19 |
:param model_name: the name of model we're training |
|
|
20 |
:param swa_start: the epoch when averaging begins. We generally pre-train the network for |
|
|
21 |
a certain amount of epochs to start (swa_start > 1), as opposed to |
|
|
22 |
starting to track the average from the very beginning. |
|
|
23 |
""" |
|
|
24 |
super(SWA, self).__init__() |
|
|
25 |
self.checkpoint_dir = checkpoint_dir |
|
|
26 |
self.model_name = model_name |
|
|
27 |
self.swa_start = swa_start |
|
|
28 |
self.swa_model = swa_model |
|
|
29 |
|
|
|
30 |
def on_train_begin(self, logs=None): |
|
|
31 |
self.epoch = 0 |
|
|
32 |
self.swa_n = 0 |
|
|
33 |
|
|
|
34 |
'''Note: I found deep copy of a model with customized layer would give errors''' |
|
|
35 |
# self.swa_model = copy.deepcopy(self.model) # make a copy of the model we're training |
|
|
36 |
|
|
|
37 |
'''Note: Something wired still happen even though i use keras.models.clone_model method, |
|
|
38 |
so I build swa_model outside this callback and pass it as an argument. |
|
|
39 |
It's not fancy, but the best I can do :) |
|
|
40 |
''' |
|
|
41 |
# self.swa_model = keras.models.clone_model(self.model) |
|
|
42 |
# see: https://github.com/keras-team/keras/issues/1765 |
|
|
43 |
self.swa_model.set_weights(self.model.get_weights()) |
|
|
44 |
|
|
|
45 |
def on_epoch_end(self, epoch, logs=None): |
|
|
46 |
if (self.epoch + 1) >= self.swa_start: |
|
|
47 |
self.update_average_model() |
|
|
48 |
self.swa_n += 1 |
|
|
49 |
|
|
|
50 |
self.epoch += 1 |
|
|
51 |
|
|
|
52 |
def update_average_model(self): |
|
|
53 |
# update running average of parameters |
|
|
54 |
alpha = 1. / (self.swa_n + 1) |
|
|
55 |
for layer, swa_layer in zip(self.model.layers, self.swa_model.layers): |
|
|
56 |
weights = [] |
|
|
57 |
for w1, w2 in zip(swa_layer.get_weights(), layer.get_weights()): |
|
|
58 |
weights.append((1 - alpha) * w1 + alpha * w2) |
|
|
59 |
swa_layer.set_weights(weights) |
|
|
60 |
|
|
|
61 |
def on_train_end(self, logs=None): |
|
|
62 |
print(f'Logging Info - Saving SWA model checkpoint: {self.model_name}_swa.hdf5') |
|
|
63 |
self.swa_model.save_weights(os.path.join(self.checkpoint_dir, |
|
|
64 |
f'{self.model_name}_swa.hdf5')) |
|
|
65 |
print('Logging Info - SWA model Saved') |