[978658]: / trainers / VAE.py

Download this file

124 lines (95 with data), 4.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from collections import defaultdict
from math import inf
from tensorflow.python.ops.losses.losses_impl import Reduction
from trainers import trainer_utils
from trainers.AEMODEL import AEMODEL, update_log_dicts, Phase, indicate_early_stopping
from trainers.DLMODEL import *
matplotlib.pyplot.ion()
class VAE(AEMODEL):
class Config(AEMODEL.Config):
def __init__(self):
super().__init__('VAE')
def __init__(self, sess, config, network=None):
super().__init__(sess, config, network)
self.x = tf.placeholder(tf.float32, [None, self.config.outputHeight, self.config.outputWidth, self.config.numChannels], name='x')
self.outputs = self.network(self.x, dropout_rate=self.dropout_rate, dropout=self.dropout, config=self.config)
self.reconstruction = self.outputs['x_hat']
self.z_mu = self.outputs['z_mu']
self.z_sigma = self.outputs['z_sigma']
# Print Stats
self.get_number_of_trainable_params()
# Instantiate Saver
self.saver = tf.train.Saver()
def train(self, dataset):
# Determine trainable variables
self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
# Build losses
self.losses['L1'] = tf.losses.absolute_difference(self.x, self.reconstruction, reduction=Reduction.NONE)
rec = tf.reduce_sum(self.losses['L1'], axis=[1, 2, 3])
kl = 0.5 * tf.reduce_sum(tf.square(self.z_mu) + tf.square(self.z_sigma) - tf.log(tf.square(self.z_sigma)) - 1, axis=1)
self.losses['reconstructionLoss'] = tf.reduce_mean(rec)
self.losses['kl'] = tf.reduce_mean(kl)
self.losses['loss'] = tf.reduce_mean(rec + kl)
# Set the optimizer
optim = self.create_optimizer(self.losses['loss'], var_list=self.variables, learningrate=self.config.learningrate,
beta1=self.config.beta1, type=self.config.optimizer)
# initialize all variables
tf.global_variables_initializer().run(session=self.sess)
best_cost = inf
last_improvement = 0
last_epoch = self.load_checkpoint()
# Go go go!
for epoch in range(last_epoch, self.config.numEpochs):
############
# TRAINING #
############
self.process(dataset, epoch, Phase.TRAIN, optim)
# Increment last_epoch counter and save model
last_epoch += 1
self.save(self.checkpointDir, last_epoch)
##############
# VALIDATION #
##############
val_scalars = self.process(dataset, epoch, Phase.VAL)
best_cost, last_improvement, stop = indicate_early_stopping(val_scalars['loss'], best_cost, last_improvement)
if stop:
print('Early stopping was triggered due to no improvement over the last 5 epochs')
break
def process(self, dataset, epoch, phase: Phase, optim=None):
scalars = defaultdict(list)
visuals = []
num_batches = dataset.num_batches(self.config.batchsize, set=phase.value)
for idx in range(0, num_batches):
batch, _, _ = dataset.next_batch(self.config.batchsize, set=phase.value)
fetches = {
'reconstruction': self.reconstruction,
**self.losses
}
if phase == Phase.TRAIN:
fetches['optimizer'] = optim
feed_dict = {
self.x: batch,
self.dropout: phase == Phase.TRAIN,
self.dropout_rate: self.config.dropout_rate
}
run = self.sess.run(fetches, feed_dict=feed_dict)
# Print to console
print(f'Epoch ({phase.value}): [{epoch:2d}] [{idx:4d}/{num_batches:4d}] loss: {run["loss"]:.8f}')
update_log_dicts(*trainer_utils.get_summary_dict(batch, run), scalars, visuals)
self.log_to_tensorboard(epoch, scalars, visuals, phase)
return scalars
def reconstruct(self, x, dropout=False):
if x.ndim < 4:
x = np.expand_dims(x, 0)
fetches = {
'reconstruction': self.reconstruction
}
feed_dict = {
self.x: x,
self.dropout: dropout, # apply only during MC sampling.
self.dropout_rate: self.config.dropout_rate
}
results = self.sess.run(fetches, feed_dict=feed_dict)
results['l1err'] = np.sum(np.abs(x - results['reconstruction']))
results['l2err'] = np.sum(np.sqrt((x - results['reconstruction']) ** 2))
return results