a b/trainers/AEMODEL.py
1
import os
2
from abc import ABC
3
from datetime import datetime
4
5
import numpy as np
6
import tensorflow as tf
7
8
from trainers.DLMODEL import DLMODEL
9
from utils.logger import Logger, Phase
10
11
12
class AEMODEL(DLMODEL, ABC):
13
    class Config(DLMODEL.Config):
14
        def __init__(self, modelname='AE'):
15
            super().__init__()
16
            self.modelname = modelname
17
            self.intermediateResolutions = [8, 8]
18
            self.outputWidth = 256
19
            self.outputHeight = 256
20
            self.numChannels = 3
21
            self.dropout = False
22
            self.dropout_rate = 0.2
23
            self.zDim = 128
24
25
    def __init__(self, sess, config=Config(), network=None):
26
        super().__init__(sess, config)
27
        self.losses = {}
28
        self.dropout = tf.placeholder(tf.bool, name='dropout')
29
        self.dropout_rate = tf.placeholder(tf.float32, name='dropout_rate')
30
31
        self.network = network
32
        self.checkpointDir = os.path.join(self.config.checkpointDir, self.network.__name__)
33
        self.logDir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'logs', self.network.__name__, self.model_dir,
34
                                   datetime.now().strftime('%Y%m%d_%H%M%S'))
35
        self.logger = Logger(self.sess, self.logDir)
36
37
    def log_to_tensorboard(self, epoch, scalars, visuals, phase: Phase, name='x'):
38
        for key in scalars.keys():
39
            scalars[key] = np.mean(scalars[key])
40
41
        if visuals:
42
            self.logger.summarize(epoch, phase=phase, summaries_dict={**scalars, **{name: np.vstack(visuals)[:50]}})
43
44
    def load_checkpoint(self):
45
        could_load, checkpoint_counter = self.load(self.checkpointDir)
46
        if could_load:
47
            last_epoch = checkpoint_counter
48
            print(" [*] Load SUCCESS")
49
        else:
50
            last_epoch = 0
51
            print(" [!] Load failed...")
52
        return last_epoch
53
54
    @property
55
    def model_dir(self):
56
        return "{}_d{}_s{}x{}_{}_b{}_z{}_{}".format(self.config.modelname, self.config.dataset,
57
                                                    self.config.outputWidth,
58
                                                    self.config.outputHeight,
59
                                                    self.network.__name__,
60
                                                    self.config.batchsize, self.config.zDim,
61
                                                    self.config.description)
62
63
64
def update_log_dicts(scalars, visuals, train_scalars, train_visuals):
65
    for k, v in list(scalars.items()):
66
        train_scalars[k].append(v)
67
    train_visuals.append(visuals)
68
69
70
def indicate_early_stopping(current_cost, best_cost, last_improvement):
71
    if current_cost < best_cost:
72
        best_cost = current_cost
73
        last_improvement = 0
74
        return best_cost, last_improvement, False
75
    else:
76
        last_improvement += 1
77
        if last_improvement >= 5:
78
            return best_cost, last_improvement, True
79
        return best_cost, last_improvement, False