--- a +++ b/jz-char-rnn-tensorflow/train.py @@ -0,0 +1,127 @@ +from __future__ import print_function +import numpy as np +import tensorflow as tf + +import argparse +import time +import os +from six.moves import cPickle + +from utils import TextLoader +from model import Model + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--data_dir', type=str, default='data/tinyshakespeare', + help='data directory containing input.txt') + parser.add_argument('--save_dir', type=str, default='save', + help='directory to store checkpointed models') + parser.add_argument('--rnn_size', type=int, default=128, + help='size of RNN hidden state') + parser.add_argument('--num_layers', type=int, default=2, + help='number of layers in the RNN') + parser.add_argument('--model', type=str, default='lstm', + help='rnn, gru, or lstm') + parser.add_argument('--activation', type=str, default='tanh', + help='tanh, sigmoid, or relu') + parser.add_argument('--batch_size', type=int, default=50, + help='minibatch size') + parser.add_argument('--seq_length', type=int, default=50, + help='RNN sequence length') + parser.add_argument('--dropout', type=float, default=1.0, + help='keep rate') + parser.add_argument('--num_epochs', type=int, default=50, + help='number of epochs') + parser.add_argument('--save_every', type=int, default=1000, + help='save frequency') + parser.add_argument('--grad_clip', type=float, default=5., + help='clip gradients at this value') + parser.add_argument('--learning_rate', type=float, default=0.002, + help='learning rate') + parser.add_argument('--bidirectional', type=bool, default=False, + help='1 or 0') + parser.add_argument('--decay_rate', type=float, default=0.97, + help='decay rate for rmsprop') + parser.add_argument('--init_from', type=str, default=None, + help="""continue training from saved model at this path. Path must contain files saved by previous training process: + 'config.pkl' : configuration; + 'chars_vocab.pkl' : vocabulary definitions; + 'checkpoint' : paths to model file(s) (created by tf). + Note: this file contains absolute paths, be careful when moving files around; + 'model.ckpt-*' : file(s) with model definition (created by tf) + """) + args = parser.parse_args() + train(args) + +def train(args): + data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length) + args.vocab_size = data_loader.vocab_size + + # check compatibility if training is continued from previously saved model + if args.init_from is not None: + # check if all necessary files exist + assert os.path.isdir(args.init_from)," %s must be a a path" % args.init_from + assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_from + assert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),"chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from + ckpt = tf.train.get_checkpoint_state(args.init_from) + assert ckpt,"No checkpoint found" + assert ckpt.model_checkpoint_path,"No model path found in checkpoint" + + # open old config and check if models are compatible + with open(os.path.join(args.init_from, 'config.pkl')) as f: + saved_model_args = cPickle.load(f) + need_be_same=["model","rnn_size","num_layers","seq_length"] + for checkme in need_be_same: + assert vars(saved_model_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s' "%checkme + + # open saved vocab/dict and check if vocabs/dicts are compatible + with open(os.path.join(args.init_from, 'chars_vocab.pkl')) as f: + saved_chars, saved_vocab = cPickle.load(f) + assert saved_chars==data_loader.chars, "Data and loaded model disagreee on character set!" + assert saved_vocab==data_loader.vocab, "Data and loaded model disagreee on dictionary mappings!" + + # JZ + os.system('mkdir -p ' + args.save_dir) + + with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f: + cPickle.dump(args, f) + with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f: + cPickle.dump((data_loader.chars, data_loader.vocab), f) + + model = Model(args) + + # JZ: +# f = open(args.save_dir+'/train_losses/'+args.model+'_'+str(args.seq_length)+'_'+str(args.num_layers)+'_'+str(args.learning_rate),'a') + + with tf.Session() as sess: + tf.initialize_all_variables().run() + saver = tf.train.Saver(tf.all_variables()) + # restore model + if args.init_from is not None: + saver.restore(sess, ckpt.model_checkpoint_path) + for e in range(args.num_epochs): + sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e))) + data_loader.reset_batch_pointer() + state = model.initial_state.eval() + for b in range(data_loader.num_batches): + start = time.time() + x, y = data_loader.next_batch() + feed = {model.input_data: x, model.targets: y, model.initial_state: state} + train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed) + end = time.time() + print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ + .format(e * data_loader.num_batches + b, + args.num_epochs * data_loader.num_batches, + e, train_loss, end - start)) + if (e * data_loader.num_batches + b) % args.save_every == 0\ + or (e==args.num_epochs-1 and b == data_loader.num_batches-1): # save for the last result + checkpoint_path = os.path.join(args.save_dir, 'model.ckpt') + saver.save(sess, checkpoint_path, global_step = e * data_loader.num_batches + b) + print("model saved to {}".format(checkpoint_path)) + + # JZ: +# f.write(str(train_loss)+'\n') +# f.close() + +if __name__ == '__main__': + main()