--- a +++ b/trainers/DLMODEL.py @@ -0,0 +1,152 @@ +import json +import os +from abc import abstractmethod + +import matplotlib.pyplot +import numpy as np +import tensorflow as tf + + +# Baseline class for all your Deep Learning needs with TensorFlow # + +class DLMODEL(object): + class Config(object): + def __init__(self): + self.modelname = '' + self.model_config = {} + self.checkpointDir = None + self.description = '' + self.batchsize = 6 + self.useTensorboard = True + self.tensorboardPort = 8008 + self.useMatplotlib = False + self.debugGradients = False + self.tfSummaryAfter = 100 + self.dataset = '' + self.beta1 = 0.5 + + def __init__(self, sess, config=Config()): + """ + + Args: + sess: TensorFlow session + config: (optional) a DLMODEL.Config object with your options + """ + self.sess = sess + self.config = config + self.variables = {} + self.curves = {} # For plotting via matplotlib + self.phase = tf.placeholder(tf.bool, name='phase') + self.handles = {} + if self.config.useMatplotlib: + self.handles['curves'] = matplotlib.pyplot.figure() + self.handles['samples'] = matplotlib.pyplot.figure() + + self.saver = None + self.losses = None + + @abstractmethod + def train(self, dataset): + """Train a Deep Neural Network""" + + def initialize_variables(self): + uninitialized_var_names_raw = set(self.sess.run(tf.report_uninitialized_variables())) + uninitialized_var_names = [v.decode() for v in uninitialized_var_names_raw] + variables_to_initialize = [v for v in tf.global_variables() if v.name.split(':')[0] in uninitialized_var_names] + self.sess.run(tf.initialize_variables(variables_to_initialize)) + print("Initialized all unitialized variables.") + + @property + def model_dir(self): + return "{}_d{}_b{}_{}".format(self.config.modelname, self.config.dataset, self.config.batchsize, self.config.description) + + def save(self, checkpoint_dir, step): + model_name = self.config.modelname + ".model" + checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) + + # Create checkpoint directory, if it does not exist + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) + + # Save the current model state to the checkpoint directory + self.saver.save(self.sess, + os.path.join(checkpoint_dir, model_name), + global_step=step) + + # Save the config to a json file, such that you can inspect it later + with open(os.path.join(checkpoint_dir, 'Config-{}.json'.format(step)), 'w') as outfile: + try: + json.dump(self.config.__dict__, outfile) + except: + print("Failed to save config json") + + # Save the curves to a np file such that we can recover and monitor the entire training process + np.save(os.path.join(checkpoint_dir, 'Curves.npy'), self.curves) + + def load(self, checkpoint_dir, iteration=None): + import re + print(" [*] Reading checkpoints...") + checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) + + # Load training curves, if any + curves_file = os.path.join(checkpoint_dir, 'Curves.npy') + if os.path.isfile(curves_file): + self.curves = np.load(curves_file, allow_pickle=True).item() + + if iteration is not None: + self.saver.restore(self.sess, os.path.join(checkpoint_dir, self.config.modelname + '.model-' + str(iteration))) + counter = iteration + return True, counter + + ckpt = tf.train.get_checkpoint_state(checkpoint_dir) + if ckpt and ckpt.model_checkpoint_path: + ckpt_name = os.path.basename(ckpt.model_checkpoint_path) + self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) + counter = int(next(re.finditer('(\d+)(?!.*\d)', ckpt_name)).group(0)) + print(" [*] Success to read {}".format(ckpt_name)) + return True, counter + else: + print(" [*] Failed to find a checkpoint") + return False, 0 + + @staticmethod + def create_optimizer(loss, var_list=(), learningrate=0.001, type='ADAM', beta1=0.05, momentum=0.9, name='optimizer', minimize=True, scope=None): + if type == 'ADAM': + optim = tf.train.AdamOptimizer(learningrate, beta1=beta1, name=name) + elif type == 'SGD': + optim = tf.train.GradientDescentOptimizer(learningrate) + elif type == 'MOMENTUM': + optim = tf.train.MomentumOptimizer(learning_rate=learningrate, momentum=momentum) + elif type == 'RMS': + optim = tf.train.RMSPropOptimizer(learning_rate=learningrate, momentum=momentum) + else: + raise ValueError('Invalid optimizer type') + + if minimize: + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + with tf.control_dependencies(update_ops): + train_op = optim.minimize(loss, var_list=var_list) + return train_op + else: + return optim + + @staticmethod + def get_number_of_trainable_params(): + def inner_get_number_of_trainable_params(_scope): + total_parameters = 0 + _variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, _scope) + for variable in _variables: + # shape is an array of tf.Dimension + shape = variable.get_shape() + variable_parametes = 1 + for dim in shape: + variable_parametes *= dim.value + total_parameters += variable_parametes + return total_parameters + + variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "") + scopes = list(set(map(lambda variable: os.path.split(os.path.split(variable.name)[0])[0], variables))) + for scope in scopes: + if scope != '': + print(f'#Params in {scope}: {inner_get_number_of_trainable_params(scope)}') + print(f'#Params in total: {inner_get_number_of_trainable_params("")}')