--- a
+++ b/callbacks/ensemble.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+
+import os
+from keras.callbacks import Callback
+
+
+class SWA(Callback):
+    """
+    This callback implements a stochastic weight averaging (SWA) method with constant lr
+    as presented in the paper -
+        "Izmailov et al. Averaging Weights Leads to Wider Optima and Better Generalization"
+        (https://arxiv.org/abs/1803.05407)
+    Author's implementation: https://github.com/timgaripov/swa
+    """
+    def __init__(self, swa_model, checkpoint_dir, model_name, swa_start=1):
+        """
+        :param swa_model: the model that we use to store the average of the weights once SWA begins
+        :param checkpoint_dir: the directory where the model will be saved in
+        :param model_name: the name of model we're training
+        :param swa_start: the epoch when averaging begins. We generally pre-train the network for
+                          a certain amount of epochs to start (swa_start > 1), as opposed to
+                          starting to track the average from the very beginning.
+        """
+        super(SWA, self).__init__()
+        self.checkpoint_dir = checkpoint_dir
+        self.model_name = model_name
+        self.swa_start = swa_start
+        self.swa_model = swa_model
+
+    def on_train_begin(self, logs=None):
+        self.epoch = 0
+        self.swa_n = 0
+
+        '''Note: I found deep copy of a model with customized layer would give errors'''
+        # self.swa_model = copy.deepcopy(self.model)  # make a copy of the model we're training
+
+        '''Note: Something wired still happen even though i use keras.models.clone_model method, 
+                 so I build swa_model outside this callback and pass it as an argument. 
+                 It's not fancy, but the best I can do :)
+        '''
+        # self.swa_model = keras.models.clone_model(self.model)
+        # see: https://github.com/keras-team/keras/issues/1765
+        self.swa_model.set_weights(self.model.get_weights())
+
+    def on_epoch_end(self, epoch, logs=None):
+        if (self.epoch + 1) >= self.swa_start:
+            self.update_average_model()
+            self.swa_n += 1
+
+        self.epoch += 1
+
+    def update_average_model(self):
+        # update running average of parameters
+        alpha = 1. / (self.swa_n + 1)
+        for layer, swa_layer in zip(self.model.layers, self.swa_model.layers):
+            weights = []
+            for w1, w2 in zip(swa_layer.get_weights(), layer.get_weights()):
+                weights.append((1 - alpha) * w1 + alpha * w2)
+            swa_layer.set_weights(weights)
+
+    def on_train_end(self, logs=None):
+        print(f'Logging Info - Saving SWA model checkpoint: {self.model_name}_swa.hdf5')
+        self.swa_model.save_weights(os.path.join(self.checkpoint_dir,
+                                                 f'{self.model_name}_swa.hdf5'))
+        print('Logging Info - SWA model Saved')