--- a +++ b/trainers/AEMODEL.py @@ -0,0 +1,79 @@ +import os +from abc import ABC +from datetime import datetime + +import numpy as np +import tensorflow as tf + +from trainers.DLMODEL import DLMODEL +from utils.logger import Logger, Phase + + +class AEMODEL(DLMODEL, ABC): + class Config(DLMODEL.Config): + def __init__(self, modelname='AE'): + super().__init__() + self.modelname = modelname + self.intermediateResolutions = [8, 8] + self.outputWidth = 256 + self.outputHeight = 256 + self.numChannels = 3 + self.dropout = False + self.dropout_rate = 0.2 + self.zDim = 128 + + def __init__(self, sess, config=Config(), network=None): + super().__init__(sess, config) + self.losses = {} + self.dropout = tf.placeholder(tf.bool, name='dropout') + self.dropout_rate = tf.placeholder(tf.float32, name='dropout_rate') + + self.network = network + self.checkpointDir = os.path.join(self.config.checkpointDir, self.network.__name__) + self.logDir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'logs', self.network.__name__, self.model_dir, + datetime.now().strftime('%Y%m%d_%H%M%S')) + self.logger = Logger(self.sess, self.logDir) + + def log_to_tensorboard(self, epoch, scalars, visuals, phase: Phase, name='x'): + for key in scalars.keys(): + scalars[key] = np.mean(scalars[key]) + + if visuals: + self.logger.summarize(epoch, phase=phase, summaries_dict={**scalars, **{name: np.vstack(visuals)[:50]}}) + + def load_checkpoint(self): + could_load, checkpoint_counter = self.load(self.checkpointDir) + if could_load: + last_epoch = checkpoint_counter + print(" [*] Load SUCCESS") + else: + last_epoch = 0 + print(" [!] Load failed...") + return last_epoch + + @property + def model_dir(self): + return "{}_d{}_s{}x{}_{}_b{}_z{}_{}".format(self.config.modelname, self.config.dataset, + self.config.outputWidth, + self.config.outputHeight, + self.network.__name__, + self.config.batchsize, self.config.zDim, + self.config.description) + + +def update_log_dicts(scalars, visuals, train_scalars, train_visuals): + for k, v in list(scalars.items()): + train_scalars[k].append(v) + train_visuals.append(visuals) + + +def indicate_early_stopping(current_cost, best_cost, last_improvement): + if current_cost < best_cost: + best_cost = current_cost + last_improvement = 0 + return best_cost, last_improvement, False + else: + last_improvement += 1 + if last_improvement >= 5: + return best_cost, last_improvement, True + return best_cost, last_improvement, False