[978658]: / trainers / AEMODEL.py

Download this file

80 lines (65 with data), 3.0 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
from abc import ABC
from datetime import datetime
import numpy as np
import tensorflow as tf
from trainers.DLMODEL import DLMODEL
from utils.logger import Logger, Phase
class AEMODEL(DLMODEL, ABC):
class Config(DLMODEL.Config):
def __init__(self, modelname='AE'):
super().__init__()
self.modelname = modelname
self.intermediateResolutions = [8, 8]
self.outputWidth = 256
self.outputHeight = 256
self.numChannels = 3
self.dropout = False
self.dropout_rate = 0.2
self.zDim = 128
def __init__(self, sess, config=Config(), network=None):
super().__init__(sess, config)
self.losses = {}
self.dropout = tf.placeholder(tf.bool, name='dropout')
self.dropout_rate = tf.placeholder(tf.float32, name='dropout_rate')
self.network = network
self.checkpointDir = os.path.join(self.config.checkpointDir, self.network.__name__)
self.logDir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'logs', self.network.__name__, self.model_dir,
datetime.now().strftime('%Y%m%d_%H%M%S'))
self.logger = Logger(self.sess, self.logDir)
def log_to_tensorboard(self, epoch, scalars, visuals, phase: Phase, name='x'):
for key in scalars.keys():
scalars[key] = np.mean(scalars[key])
if visuals:
self.logger.summarize(epoch, phase=phase, summaries_dict={**scalars, **{name: np.vstack(visuals)[:50]}})
def load_checkpoint(self):
could_load, checkpoint_counter = self.load(self.checkpointDir)
if could_load:
last_epoch = checkpoint_counter
print(" [*] Load SUCCESS")
else:
last_epoch = 0
print(" [!] Load failed...")
return last_epoch
@property
def model_dir(self):
return "{}_d{}_s{}x{}_{}_b{}_z{}_{}".format(self.config.modelname, self.config.dataset,
self.config.outputWidth,
self.config.outputHeight,
self.network.__name__,
self.config.batchsize, self.config.zDim,
self.config.description)
def update_log_dicts(scalars, visuals, train_scalars, train_visuals):
for k, v in list(scalars.items()):
train_scalars[k].append(v)
train_visuals.append(visuals)
def indicate_early_stopping(current_cost, best_cost, last_improvement):
if current_cost < best_cost:
best_cost = current_cost
last_improvement = 0
return best_cost, last_improvement, False
else:
last_improvement += 1
if last_improvement >= 5:
return best_cost, last_improvement, True
return best_cost, last_improvement, False