[c0da92]: / callbacks / ensemble.py

Download this file

66 lines (55 with data), 2.9 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# -*- coding: utf-8 -*-
import os
from keras.callbacks import Callback
class SWA(Callback):
"""
This callback implements a stochastic weight averaging (SWA) method with constant lr
as presented in the paper -
"Izmailov et al. Averaging Weights Leads to Wider Optima and Better Generalization"
(https://arxiv.org/abs/1803.05407)
Author's implementation: https://github.com/timgaripov/swa
"""
def __init__(self, swa_model, checkpoint_dir, model_name, swa_start=1):
"""
:param swa_model: the model that we use to store the average of the weights once SWA begins
:param checkpoint_dir: the directory where the model will be saved in
:param model_name: the name of model we're training
:param swa_start: the epoch when averaging begins. We generally pre-train the network for
a certain amount of epochs to start (swa_start > 1), as opposed to
starting to track the average from the very beginning.
"""
super(SWA, self).__init__()
self.checkpoint_dir = checkpoint_dir
self.model_name = model_name
self.swa_start = swa_start
self.swa_model = swa_model
def on_train_begin(self, logs=None):
self.epoch = 0
self.swa_n = 0
'''Note: I found deep copy of a model with customized layer would give errors'''
# self.swa_model = copy.deepcopy(self.model) # make a copy of the model we're training
'''Note: Something wired still happen even though i use keras.models.clone_model method,
so I build swa_model outside this callback and pass it as an argument.
It's not fancy, but the best I can do :)
'''
# self.swa_model = keras.models.clone_model(self.model)
# see: https://github.com/keras-team/keras/issues/1765
self.swa_model.set_weights(self.model.get_weights())
def on_epoch_end(self, epoch, logs=None):
if (self.epoch + 1) >= self.swa_start:
self.update_average_model()
self.swa_n += 1
self.epoch += 1
def update_average_model(self):
# update running average of parameters
alpha = 1. / (self.swa_n + 1)
for layer, swa_layer in zip(self.model.layers, self.swa_model.layers):
weights = []
for w1, w2 in zip(swa_layer.get_weights(), layer.get_weights()):
weights.append((1 - alpha) * w1 + alpha * w2)
swa_layer.set_weights(weights)
def on_train_end(self, logs=None):
print(f'Logging Info - Saving SWA model checkpoint: {self.model_name}_swa.hdf5')
self.swa_model.save_weights(os.path.join(self.checkpoint_dir,
f'{self.model_name}_swa.hdf5'))
print('Logging Info - SWA model Saved')