Diff of /trainers/DLMODEL.py [000000] .. [978658]

Switch to unified view

a b/trainers/DLMODEL.py
1
import json
2
import os
3
from abc import abstractmethod
4
5
import matplotlib.pyplot
6
import numpy as np
7
import tensorflow as tf
8
9
10
# Baseline class for all your Deep Learning needs with TensorFlow #
11
12
class DLMODEL(object):
13
    class Config(object):
14
        def __init__(self):
15
            self.modelname = ''
16
            self.model_config = {}
17
            self.checkpointDir = None
18
            self.description = ''
19
            self.batchsize = 6
20
            self.useTensorboard = True
21
            self.tensorboardPort = 8008
22
            self.useMatplotlib = False
23
            self.debugGradients = False
24
            self.tfSummaryAfter = 100
25
            self.dataset = ''
26
            self.beta1 = 0.5
27
28
    def __init__(self, sess, config=Config()):
29
        """
30
31
        Args:
32
          sess: TensorFlow session
33
          config: (optional) a DLMODEL.Config object with your options
34
        """
35
        self.sess = sess
36
        self.config = config
37
        self.variables = {}
38
        self.curves = {}  # For plotting via matplotlib
39
        self.phase = tf.placeholder(tf.bool, name='phase')
40
        self.handles = {}
41
        if self.config.useMatplotlib:
42
            self.handles['curves'] = matplotlib.pyplot.figure()
43
            self.handles['samples'] = matplotlib.pyplot.figure()
44
45
        self.saver = None
46
        self.losses = None
47
48
    @abstractmethod
49
    def train(self, dataset):
50
        """Train a Deep Neural Network"""
51
52
    def initialize_variables(self):
53
        uninitialized_var_names_raw = set(self.sess.run(tf.report_uninitialized_variables()))
54
        uninitialized_var_names = [v.decode() for v in uninitialized_var_names_raw]
55
        variables_to_initialize = [v for v in tf.global_variables() if v.name.split(':')[0] in uninitialized_var_names]
56
        self.sess.run(tf.initialize_variables(variables_to_initialize))
57
        print("Initialized all unitialized variables.")
58
59
    @property
60
    def model_dir(self):
61
        return "{}_d{}_b{}_{}".format(self.config.modelname, self.config.dataset, self.config.batchsize, self.config.description)
62
63
    def save(self, checkpoint_dir, step):
64
        model_name = self.config.modelname + ".model"
65
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
66
67
        # Create checkpoint directory, if it does not exist
68
        if not os.path.exists(checkpoint_dir):
69
            os.makedirs(checkpoint_dir)
70
71
        # Save the current model state to the checkpoint directory
72
        self.saver.save(self.sess,
73
                        os.path.join(checkpoint_dir, model_name),
74
                        global_step=step)
75
76
        # Save the config to a json file, such that you can inspect it later
77
        with open(os.path.join(checkpoint_dir, 'Config-{}.json'.format(step)), 'w') as outfile:
78
            try:
79
                json.dump(self.config.__dict__, outfile)
80
            except:
81
                print("Failed to save config json")
82
83
        # Save the curves to a np file such that we can recover and monitor the entire training process
84
        np.save(os.path.join(checkpoint_dir, 'Curves.npy'), self.curves)
85
86
    def load(self, checkpoint_dir, iteration=None):
87
        import re
88
        print(" [*] Reading checkpoints...")
89
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
90
91
        # Load training curves, if any
92
        curves_file = os.path.join(checkpoint_dir, 'Curves.npy')
93
        if os.path.isfile(curves_file):
94
            self.curves = np.load(curves_file, allow_pickle=True).item()
95
96
        if iteration is not None:
97
            self.saver.restore(self.sess, os.path.join(checkpoint_dir, self.config.modelname + '.model-' + str(iteration)))
98
            counter = iteration
99
            return True, counter
100
101
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
102
        if ckpt and ckpt.model_checkpoint_path:
103
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
104
            self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
105
            counter = int(next(re.finditer('(\d+)(?!.*\d)', ckpt_name)).group(0))
106
            print(" [*] Success to read {}".format(ckpt_name))
107
            return True, counter
108
        else:
109
            print(" [*] Failed to find a checkpoint")
110
            return False, 0
111
112
    @staticmethod
113
    def create_optimizer(loss, var_list=(), learningrate=0.001, type='ADAM', beta1=0.05, momentum=0.9, name='optimizer', minimize=True, scope=None):
114
        if type == 'ADAM':
115
            optim = tf.train.AdamOptimizer(learningrate, beta1=beta1, name=name)
116
        elif type == 'SGD':
117
            optim = tf.train.GradientDescentOptimizer(learningrate)
118
        elif type == 'MOMENTUM':
119
            optim = tf.train.MomentumOptimizer(learning_rate=learningrate, momentum=momentum)
120
        elif type == 'RMS':
121
            optim = tf.train.RMSPropOptimizer(learning_rate=learningrate, momentum=momentum)
122
        else:
123
            raise ValueError('Invalid optimizer type')
124
125
        if minimize:
126
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
127
            with tf.control_dependencies(update_ops):
128
                train_op = optim.minimize(loss, var_list=var_list)
129
            return train_op
130
        else:
131
            return optim
132
133
    @staticmethod
134
    def get_number_of_trainable_params():
135
        def inner_get_number_of_trainable_params(_scope):
136
            total_parameters = 0
137
            _variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, _scope)
138
            for variable in _variables:
139
                # shape is an array of tf.Dimension
140
                shape = variable.get_shape()
141
                variable_parametes = 1
142
                for dim in shape:
143
                    variable_parametes *= dim.value
144
                total_parameters += variable_parametes
145
            return total_parameters
146
147
        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "")
148
        scopes = list(set(map(lambda variable: os.path.split(os.path.split(variable.name)[0])[0], variables)))
149
        for scope in scopes:
150
            if scope != '':
151
                print(f'#Params in {scope}: {inner_get_number_of_trainable_params(scope)}')
152
        print(f'#Params in total: {inner_get_number_of_trainable_params("")}')