Diff of /callbacks/ensemble.py [000000] .. [c0da92]

Switch to unified view

a b/callbacks/ensemble.py
1
# -*- coding: utf-8 -*-
2
3
import os
4
from keras.callbacks import Callback
5
6
7
class SWA(Callback):
8
    """
9
    This callback implements a stochastic weight averaging (SWA) method with constant lr
10
    as presented in the paper -
11
        "Izmailov et al. Averaging Weights Leads to Wider Optima and Better Generalization"
12
        (https://arxiv.org/abs/1803.05407)
13
    Author's implementation: https://github.com/timgaripov/swa
14
    """
15
    def __init__(self, swa_model, checkpoint_dir, model_name, swa_start=1):
16
        """
17
        :param swa_model: the model that we use to store the average of the weights once SWA begins
18
        :param checkpoint_dir: the directory where the model will be saved in
19
        :param model_name: the name of model we're training
20
        :param swa_start: the epoch when averaging begins. We generally pre-train the network for
21
                          a certain amount of epochs to start (swa_start > 1), as opposed to
22
                          starting to track the average from the very beginning.
23
        """
24
        super(SWA, self).__init__()
25
        self.checkpoint_dir = checkpoint_dir
26
        self.model_name = model_name
27
        self.swa_start = swa_start
28
        self.swa_model = swa_model
29
30
    def on_train_begin(self, logs=None):
31
        self.epoch = 0
32
        self.swa_n = 0
33
34
        '''Note: I found deep copy of a model with customized layer would give errors'''
35
        # self.swa_model = copy.deepcopy(self.model)  # make a copy of the model we're training
36
37
        '''Note: Something wired still happen even though i use keras.models.clone_model method, 
38
                 so I build swa_model outside this callback and pass it as an argument. 
39
                 It's not fancy, but the best I can do :)
40
        '''
41
        # self.swa_model = keras.models.clone_model(self.model)
42
        # see: https://github.com/keras-team/keras/issues/1765
43
        self.swa_model.set_weights(self.model.get_weights())
44
45
    def on_epoch_end(self, epoch, logs=None):
46
        if (self.epoch + 1) >= self.swa_start:
47
            self.update_average_model()
48
            self.swa_n += 1
49
50
        self.epoch += 1
51
52
    def update_average_model(self):
53
        # update running average of parameters
54
        alpha = 1. / (self.swa_n + 1)
55
        for layer, swa_layer in zip(self.model.layers, self.swa_model.layers):
56
            weights = []
57
            for w1, w2 in zip(swa_layer.get_weights(), layer.get_weights()):
58
                weights.append((1 - alpha) * w1 + alpha * w2)
59
            swa_layer.set_weights(weights)
60
61
    def on_train_end(self, logs=None):
62
        print(f'Logging Info - Saving SWA model checkpoint: {self.model_name}_swa.hdf5')
63
        self.swa_model.save_weights(os.path.join(self.checkpoint_dir,
64
                                                 f'{self.model_name}_swa.hdf5'))
65
        print('Logging Info - SWA model Saved')