a b/trainers/AnoVAEGAN.py
1
from collections import defaultdict
2
from math import inf
3
4
from tensorflow.python.ops.losses.losses_impl import Reduction
5
6
from trainers import trainer_utils
7
from trainers.AEMODEL import AEMODEL, Phase, indicate_early_stopping, update_log_dicts
8
from trainers.DLMODEL import *
9
10
11
class AnoVAEGAN(AEMODEL):
12
    class Config(AEMODEL.Config):
13
        def __init__(self):
14
            super().__init__('AnoVAEGAN')
15
            self.scale = 10.0
16
            self.kappa = 1.0
17
            self.kl_weight = 1.0
18
19
    def __init__(self, sess, config, network=None):
20
        super().__init__(sess, config, network)
21
        self.x = tf.placeholder(tf.float32, [None, self.config.outputHeight, self.config.outputWidth, self.config.numChannels], name='x')
22
        self.z = tf.placeholder(tf.float32, [None, self.config.zDim], name='z')
23
24
        self.outputs = self.network(self.x, dropout_rate=self.dropout_rate, dropout=self.dropout, config=self.config)
25
26
        self.reconstruction = self.outputs['out']
27
        self.z_mu = self.outputs['z_mu']
28
        self.z_sigma = self.outputs['z_sigma']
29
        self.d_fake_features = self.outputs['d_fake_features']
30
        self.d_ = self.outputs['d_']
31
        self.d_features = self.outputs['d_features']
32
        self.d = self.outputs['d']
33
        self.x_hat = self.outputs['x_hat']
34
        self.d_hat = self.outputs['d_hat']
35
36
        self.kappa = self.config.kappa
37
        self.kl_weight = self.config.kl_weight
38
        self.scale = self.config.scale
39
40
        # Print Stats
41
        self.get_number_of_trainable_params()
42
        # Instantiate Saver
43
        self.saver = tf.train.Saver()
44
45
    def train(self, dataset):
46
        # Determine trainable variables
47
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
48
49
        # Build losses
50
        self.losses['disc_fake'] = disc_fake = tf.reduce_mean(self.d_)
51
        self.losses['disc_real'] = disc_real = tf.reduce_mean(self.d)
52
        disc_loss = disc_fake - disc_real
53
54
        ddx = tf.gradients(self.d_hat, self.x_hat)[0]  # gradient
55
        ddx = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1))  # slopes
56
        ddx = tf.reduce_mean(tf.square(ddx - 1.0)) * self.scale  # gradient penalty
57
        self.losses['disc_loss'] = disc_loss = disc_loss + ddx
58
59
        # Build losses
60
        kl = 0.5 * tf.reduce_sum(tf.square(self.z_mu) + tf.square(self.z_sigma) - tf.log(tf.square(self.z_sigma)) - 1,
61
                                 axis=1)
62
        self.losses['kl'] = loss_kl = tf.reduce_mean(kl)
63
64
        self.losses['loss_img'] = tf.reduce_mean(
65
            tf.reduce_mean(tf.losses.mean_squared_error(self.x, self.reconstruction, reduction=Reduction.NONE), axis=[1, 2, 3]))
66
        self.losses['loss_fts'] = tf.reduce_mean(
67
            tf.reduce_mean(tf.losses.mean_squared_error(self.d_fake_features, self.d_features, reduction=Reduction.NONE), axis=[1, 2, 3]))
68
        self.losses['L1'] = tf.losses.absolute_difference(self.x, self.reconstruction, reduction=Reduction.NONE)
69
        self.losses['reconstructionLoss'] = self.losses['loss'] = tf.reduce_mean(tf.reduce_sum(self.losses['L1'], axis=[1, 2, 3]))
70
71
        self.losses['gen_loss'] = gen_loss = - disc_fake
72
        self.losses['enc_loss'] = enc_loss = self.losses['reconstructionLoss'] + self.kl_weight * loss_kl
73
74
        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
75
            # Set the optimizer
76
            t_vars = tf.trainable_variables()
77
            dis_vars = [var for var in t_vars if 'Discriminator' in var.name]
78
            gen_vars = [var for var in t_vars if 'Generator' in var.name]
79
            enc_vars = [var for var in t_vars if 'Encoder' in var.name]
80
81
            optim_dis = tf.train.AdamOptimizer(learning_rate=self.config.learningrate, beta1=0.5, beta2=0.9).minimize(disc_loss, var_list=dis_vars)
82
            optim_gen = tf.train.AdamOptimizer(learning_rate=self.config.learningrate, beta1=0.5, beta2=0.9).minimize(gen_loss, var_list=gen_vars)
83
            optim_vae = tf.train.AdamOptimizer(learning_rate=self.config.learningrate, beta1=0.5, beta2=0.9).minimize(enc_loss, var_list=enc_vars + gen_vars)
84
85
        # initialize all variables
86
        tf.global_variables_initializer().run(session=self.sess)
87
88
        best_cost = inf
89
        last_improvement = 0
90
        last_epoch = self.load_checkpoint()
91
92
        # Go go go!
93
        for epoch in range(last_epoch, self.config.numEpochs):
94
            #################
95
            # TRAINING WGAN #
96
            #################
97
            phase = Phase.TRAIN
98
            scalars = defaultdict(list)
99
            visuals = []
100
            d_iters = 5
101
            num_batches = dataset.num_batches(self.config.batchsize, set=phase.value)
102
            for idx in range(0, num_batches):
103
                batch, _, _ = dataset.next_batch(self.config.batchsize, set=phase.value)
104
105
                # Encoder optimization
106
                fetches = {
107
                    # 'generated': self.generated,
108
                    'reconstruction': self.reconstruction,
109
                    'reconstructionLoss': self.losses['reconstructionLoss'],
110
                    'L1': self.losses['L1'],
111
                    'enc_loss': self.losses['enc_loss'],
112
                    'optimizer_e': optim_vae,
113
                }
114
115
                feed_dict = {
116
                    self.x: batch,
117
                    self.dropout: phase == Phase.TRAIN,
118
                    self.dropout_rate: self.config.dropout_rate
119
                }
120
                run = self.sess.run(fetches, feed_dict=feed_dict)
121
122
                # Generator optimization
123
                fetches = {
124
                    'gen_loss': self.losses['gen_loss'],
125
                    'optimizer_g': optim_gen,
126
                }
127
128
                feed_dict = {
129
                    self.x: batch,
130
                    self.dropout: phase == Phase.TRAIN,
131
                    self.dropout_rate: self.config.dropout_rate
132
                }
133
                run = {**run, **self.sess.run(fetches, feed_dict=feed_dict)}
134
135
                for _ in range(0, d_iters):
136
                    # Discriminator optimization
137
                    fetches = {
138
                        'disc_loss': self.losses['disc_loss'],
139
                        'disc_fake': self.losses['disc_fake'],
140
                        'disc_real': self.losses['disc_real'],
141
                        'optimizer_d': optim_dis,
142
                    }
143
                    feed_dict = {
144
                        self.x: batch,
145
                        self.dropout: phase == Phase.TRAIN,
146
                        self.dropout_rate: self.config.dropout_rate
147
                    }
148
                    run = {**run, **self.sess.run(fetches, feed_dict=feed_dict)}
149
150
                # Print to console
151
                print(f'Epoch ({phase.value}): [{epoch:2d}] [{idx:4d}/{num_batches:4d}]'
152
                      f' gen_loss: {run["gen_loss"]:.8f}, disc_loss: {run["disc_loss"]:.8f}, reconstructionLoss: {run["reconstructionLoss"]:.8f}')
153
                update_log_dicts(*trainer_utils.get_summary_dict(batch, run), scalars, visuals)
154
155
            self.log_to_tensorboard(epoch, scalars, visuals, phase)
156
157
            # Increment last_epoch counter and save model
158
            last_epoch += 1
159
            self.save(self.checkpointDir, last_epoch)
160
161
            ##############
162
            # VALIDATION #
163
            ##############
164
            phase = Phase.VAL
165
            scalars = defaultdict(list)
166
            visuals = []
167
            num_batches = dataset.num_batches(self.config.batchsize, set=phase.value)
168
            for idx in range(0, num_batches):
169
                batch, _, _ = dataset.next_batch(self.config.batchsize, set=phase.value)
170
171
                # Encoder optimization
172
                fetches = {
173
                    'reconstruction': self.reconstruction,
174
                    'reconstructionLoss': self.losses['reconstructionLoss'],
175
                    'L1': self.losses['L1'],
176
                    'enc_loss': self.losses['enc_loss'],
177
                }
178
179
                feed_dict = {
180
                    self.x: batch,
181
                    self.dropout: phase == Phase.TRAIN,
182
                    self.dropout_rate: self.config.dropout_rate
183
                }
184
                run = self.sess.run(fetches, feed_dict=feed_dict)
185
                # Print to console
186
                print(f'Epoch ({phase.value}): [{epoch:2d}] [{idx:4d}/{num_batches:4d}] reconstructionLoss: {run["reconstructionLoss"]:.8f}')
187
                update_log_dicts(*trainer_utils.get_summary_dict(batch, run), scalars, visuals)
188
189
            self.log_to_tensorboard(epoch, scalars, visuals, phase)
190
191
            best_cost, last_improvement, stop = indicate_early_stopping(scalars['reconstructionLoss'], best_cost, last_improvement)
192
            if stop:
193
                print('Early stopping was triggered due to no improvement over the last 5 epochs')
194
                break
195
196
    def reconstruct(self, x, dropout=False):
197
        if x.ndim < 4:
198
            x = np.expand_dims(x, 0)
199
200
        fetches = {
201
            'reconstruction': self.reconstruction
202
        }
203
        feed_dict = {
204
            self.x: x,
205
            self.dropout: dropout,  # apply only during MC sampling.
206
            self.dropout_rate: self.config.dropout_rate
207
        }
208
        results = self.sess.run(fetches, feed_dict=feed_dict)
209
210
        results['l1err'] = np.sum(np.abs(x - results['reconstruction']))
211
        results['l2err'] = np.sum(np.sqrt((x - results['reconstruction']) ** 2))
212
213
        return results