[978658]: / trainers / ceVAE.py

Download this file

145 lines (115 with data), 6.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from collections import defaultdict
from math import inf
from tensorflow.python.ops.losses.losses_impl import Reduction
from trainers import trainer_utils
from trainers.AEMODEL import AEMODEL, Phase, indicate_early_stopping, update_log_dicts
from trainers.CE import retrieve_masked_batch
from trainers.DLMODEL import *
class ceVAE(AEMODEL):
class Config(AEMODEL.Config):
def __init__(self):
super().__init__('ceVAE')
self.use_gradient_based_restoration = True
def __init__(self, sess, config, network=None):
super().__init__(sess, config, network)
self.x = tf.placeholder(tf.float32, [None, self.config.outputHeight, self.config.outputWidth, self.config.numChannels], name='x')
self.x_ce = tf.placeholder(tf.float32, [None, self.config.outputHeight, self.config.outputWidth, self.config.numChannels], name='x_ce')
self.outputs = self.network(self.x, self.x_ce, dropout_rate=self.dropout_rate, dropout=self.dropout, config=self.config)
self.reconstruction = self.outputs['x_hat']
self.reconstruction_ce = self.outputs['x_hat_ce']
self.z_mu = self.outputs['z_mu']
self.z_sigma = self.outputs['z_sigma']
# Print Stats
self.get_number_of_trainable_params()
# Instantiate Saver
self.saver = tf.train.Saver()
def train(self, dataset):
# Determine trainable variables
self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
# Build losses
self.losses['L1_vae'] = tf.losses.absolute_difference(self.x, self.reconstruction, reduction=Reduction.NONE)
self.losses['L1_ce'] = tf.losses.absolute_difference(self.x_ce, self.reconstruction_ce, reduction=Reduction.NONE)
self.losses['L1'] = 0.5 * (self.losses['L1_vae'] + self.losses['L1_ce'])
rec_vae = tf.reduce_sum(self.losses['L1_vae'], axis=[1, 2, 3])
rec_ce = tf.reduce_sum(self.losses['L1_ce'], axis=[1, 2, 3])
kl = 0.5 * tf.reduce_sum(tf.square(self.z_mu) + tf.square(self.z_sigma) - tf.log(tf.square(self.z_sigma)) - 1, axis=1)
self.losses['Rec_ce'] = tf.reduce_mean(rec_ce)
self.losses['Rec_vae'] = tf.reduce_mean(rec_vae)
self.losses['reconstructionLoss'] = 0.5 * tf.reduce_mean(rec_vae + rec_ce)
self.losses['kl'] = tf.reduce_mean(kl)
self.losses['loss'] = tf.reduce_mean(rec_vae + kl + rec_ce)
self.losses['loss_vae'] = tf.reduce_mean(rec_vae + kl)
self.losses['anomaly'] = self.losses['L1_vae'] * tf.abs(tf.gradients(self.losses['loss_vae'], self.x))[0]
# Set the optimizer
optim = self.create_optimizer(self.losses['loss'], var_list=self.variables, learningrate=self.config.learningrate,
beta1=self.config.beta1, type=self.config.optimizer)
# initialize all variables
tf.global_variables_initializer().run(session=self.sess)
best_cost = inf
last_improvement = 0
last_epoch = self.load_checkpoint()
visualization_keys = ['reconstruction', 'reconstruction_ce', 'anomaly']
# Go go go!
for epoch in range(last_epoch, self.config.numEpochs):
############
# TRAINING #
############
self.process(dataset, epoch, Phase.TRAIN, optim, visualization_keys=visualization_keys)
# Increment last_epoch counter and save model
last_epoch += 1
self.save(self.checkpointDir, last_epoch)
##############
# VALIDATION #
##############
val_scalars = self.process(dataset, epoch, Phase.VAL, visualization_keys=visualization_keys)
best_cost, last_improvement, stop = indicate_early_stopping(val_scalars['loss'], best_cost, last_improvement)
if stop:
print('Early stopping was triggered due to no improvement over the last 5 epochs')
break
def process(self, dataset, epoch, phase: Phase, optim=None, visualization_keys=None):
scalars = defaultdict(list)
visuals = []
num_batches = dataset.num_batches(self.config.batchsize, set=phase.value)
for idx in range(0, num_batches):
batch, _, brainmasks = dataset.next_batch(self.config.batchsize, return_brainmask=True, set=phase.value)
masked_batch = retrieve_masked_batch(batch, brainmasks)
fetches = {
'reconstruction': self.reconstruction,
'reconstruction_ce': self.reconstruction_ce,
**self.losses
}
if phase == Phase.TRAIN:
fetches['optimizer'] = optim
feed_dict = {
self.x: batch,
self.x_ce: masked_batch if phase == Phase.TRAIN else batch,
self.dropout: phase == Phase.TRAIN,
self.dropout_rate: self.config.dropout_rate
}
run = self.sess.run(fetches, feed_dict=feed_dict)
# Print to console
print(f'Epoch ({phase.value}): [{epoch:2d}] [{idx:4d}/{num_batches:4d}] loss: {run["loss"]:.8f}')
update_log_dicts(*trainer_utils.get_summary_dict(batch, run, visualization_keys), scalars, visuals)
self.log_to_tensorboard(epoch, scalars, visuals, phase)
return scalars
def reconstruct(self, x, dropout=False):
if x.ndim < 4:
x = np.expand_dims(x, 0)
fetches = {
'reconstruction': self.reconstruction,
**self.losses
}
feed_dict = {
self.x: x,
self.x_ce: x,
self.dropout: dropout,
self.dropout_rate: self.config.dropout_rate
}
results = self.sess.run(fetches, feed_dict=feed_dict)
if self.config.use_gradient_based_restoration:
# this is actually not the real 'reconstruction' but for convenience we treat it like it
# would be to prevent changes in our evaluation script
results['reconstruction'] = x - self.config.use_gradient_based_restoration * results['anomaly']
results['l1err'] = np.sum(np.abs(x - results['reconstruction']))
results['l2err'] = np.sum(np.sqrt((x - results['reconstruction']) ** 2))
return results