a b/trainers/VAE_You.py
1
from collections import defaultdict
2
from math import inf
3
4
from tensorflow.python.ops.losses.losses_impl import Reduction
5
6
from trainers import trainer_utils
7
from trainers.AEMODEL import Phase, update_log_dicts, indicate_early_stopping, AEMODEL
8
from trainers.DLMODEL import *
9
10
11
class VAE_You(AEMODEL):
12
    class Config(AEMODEL.Config):
13
        def __init__(self):
14
            super().__init__('VAE_You')
15
            self.restore_lr = 1e-3
16
            self.restore_steps = 150
17
            self.tv_lambda = 1.8
18
19
    def __init__(self, sess, config, network=None):
20
        super().__init__(sess, config, network)
21
        self.x = tf.placeholder(tf.float32, [None, self.config.outputHeight, self.config.outputWidth, self.config.numChannels], name='x')
22
        self.tv_lambda = tf.placeholder(tf.float32, shape=())
23
24
        # Additional Parameters
25
        self.restore_lr = self.config.restore_lr
26
        self.restore_steps = self.config.restore_steps
27
        self.tv_lambda_value = self.config.tv_lambda
28
29
        self.outputs = self.network(self.x, dropout_rate=self.dropout_rate, dropout=self.dropout, config=self.config)
30
        self.reconstruction = self.outputs['x_hat']
31
        self.z_mu = self.outputs['z_mu']
32
        self.z_sigma = self.outputs['z_sigma']
33
34
        # Print Stats
35
        self.get_number_of_trainable_params()
36
        # Instantiate Saver
37
        self.saver = tf.train.Saver()
38
39
    def train(self, dataset):
40
        # Determine trainable variables
41
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
42
43
        # Build losses
44
        self.losses['L1'] = tf.losses.absolute_difference(self.x, self.reconstruction, reduction=Reduction.NONE)
45
        rec = tf.reduce_sum(self.losses['L1'], axis=[1, 2, 3])
46
        kl = 0.5 * tf.reduce_sum(tf.square(self.z_mu) + tf.square(self.z_sigma) - tf.log(tf.square(self.z_sigma)) - 1, axis=1)
47
        self.losses['pixel_loss'] = rec + kl
48
        self.losses['reconstructionLoss'] = tf.reduce_mean(rec)
49
        self.losses['kl'] = tf.reduce_mean(kl)
50
        self.losses['loss'] = tf.reduce_mean(rec + kl)
51
52
        # for restoration
53
        self.losses['restore'] = self.tv_lambda * tf.image.total_variation(tf.subtract(self.x, self.reconstruction))
54
        self.losses['grads'] = tf.gradients(self.losses['pixel_loss'] + self.losses['restore'], self.x)[0]
55
56
        # Set the optimizer
57
        optim = self.create_optimizer(self.losses['loss'], var_list=self.variables, learningrate=self.config.learningrate,
58
                                      beta1=self.config.beta1, type=self.config.optimizer)
59
60
        # initialize all variables
61
        tf.global_variables_initializer().run(session=self.sess)
62
63
        best_cost = inf
64
        last_improvement = 0
65
        last_epoch = self.load_checkpoint()
66
67
        # Go go go!
68
        for epoch in range(last_epoch, self.config.numEpochs):
69
            ############
70
            # TRAINING #
71
            ############
72
            self.process(dataset, epoch, Phase.TRAIN, optim)
73
74
            # Increment last_epoch counter and save model
75
            last_epoch += 1
76
            self.save(self.checkpointDir, last_epoch)
77
78
            ##############
79
            # VALIDATION #
80
            ##############
81
            val_scalars = self.process(dataset, epoch, Phase.VAL)
82
83
            best_cost, last_improvement, stop = indicate_early_stopping(val_scalars['loss'], best_cost, last_improvement)
84
            if stop:
85
                print('Early stopping was triggered due to no improvement over the last 5 epochs')
86
                break
87
88
        if self.tv_lambda_value == -1 and self.restore_steps > 0:
89
            ##############
90
            # Determine lambda #
91
            ##############
92
            print('Determining best lambda')
93
            self.determine_best_lambda(dataset)
94
95
    def process(self, dataset, epoch, phase: Phase, optim=None):
96
        scalars = defaultdict(list)
97
        visuals = []
98
        num_batches = dataset.num_batches(self.config.batchsize, set=phase.value)
99
        for idx in range(0, num_batches):
100
            batch, _, _ = dataset.next_batch(self.config.batchsize, set=phase.value)
101
102
            fetches = {
103
                'reconstruction': self.reconstruction,
104
                **self.losses
105
            }
106
            if phase == Phase.TRAIN:
107
                fetches['optimizer'] = optim
108
109
            feed_dict = {
110
                self.x: batch,
111
                self.tv_lambda: self.tv_lambda_value,
112
                self.dropout: phase == Phase.TRAIN,
113
                self.dropout_rate: self.config.dropout_rate
114
            }
115
116
            run = self.sess.run(fetches, feed_dict=feed_dict)
117
118
            # Print to console
119
            print(f'Epoch ({phase.value}): [{epoch:2d}] [{idx:4d}/{num_batches:4d}] loss: {run["loss"]:.8f}')
120
            update_log_dicts(*trainer_utils.get_summary_dict(batch, run), scalars, visuals)
121
122
        self.log_to_tensorboard(epoch, scalars, visuals, phase)
123
        return scalars
124
125
    def reconstruct(self, x, dropout=False):
126
        if x.ndim < 4:
127
            x = np.expand_dims(x, 0)
128
129
        restored = x.copy()
130
        for step in range(self.restore_steps):
131
            feed_dict = {
132
                self.x: restored,
133
                self.tv_lambda: self.tv_lambda_value,
134
                self.dropout: dropout,  # apply only during MC sampling.
135
                self.dropout_rate: self.config.dropout_rate
136
            }
137
            run = self.sess.run({'grads': self.losses['grads']}, feed_dict=feed_dict)
138
            gradients = run['grads']
139
            restored -= self.restore_lr * gradients
140
141
        results = {
142
            'reconstruction': restored
143
        }
144
        results['l1err'] = np.sum(np.abs(x - results['reconstruction']))
145
        results['l2err'] = np.sum(np.sqrt((x - results['reconstruction']) ** 2))
146
147
        return results
148
149
    def determine_best_lambda(self, dataset):
150
        lambdas = np.arange(20) / 10.0
151
        mean_errors = []
152
        fetches = self.losses
153
154
        for tv_lambda in lambdas:
155
            errors = []
156
            for idx in range(int(dataset.num_batches(self.config.batchsize, set=Phase.VAL.value) * 0.2)):
157
                batch, _, _ = dataset.next_batch(self.config.batchsize, set=Phase.VAL.value)
158
                restored = batch.copy()
159
                for step in range(self.restore_steps):
160
                    feed_dict = {
161
                        self.x: restored,
162
                        self.tv_lambda: tv_lambda,
163
                        self.dropout: False,
164
                        self.dropout_rate: self.config.dropout_rate
165
                    }
166
                    run = self.sess.run(fetches, feed_dict=feed_dict)
167
                    restored -= self.restore_lr * run['grads']
168
                errors.append(np.sum(np.abs(batch - restored)))
169
            mean_error = np.mean(errors)
170
            mean_errors.append(mean_error)
171
            print(f'mean_error for lambda {tv_lambda}: {mean_error}')
172
        self.tv_lambda_value = lambdas[mean_errors.index(min(mean_errors))]
173
        print(f'Best lambda: {self.tv_lambda_value}')