[7e75f2]: / jz-char-rnn-tensorflow / train_multiple_hyperparams.py

Download this file

24 lines (18 with data), 728 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import os,itertools,sys
import multiprocessing as mp
text = sys.argv[1]
textdir = 'data/'+text
sd = 'save_'+text
os.system('mkdir -p '+sd)
os.system('mkdir -p '+sd+'/train_losses')
# combinations of hyperparameters to test
models = ['rnn','lstm','gru']
layers = [2,3]
seq_lengths = [50,100,500,1000]
learning_rates = [0.0001]
def train_char_rnn(inputs):
m,nl,sl,lr = inputs
os.system('python train.py --data_dir ' + textdir + ' --save_dir ' + sd + ' --num_layers ' + str(nl) + ' --model ' + m + ' --seq_length ' + str(sl) + ' --num_epochs 1 ' + ' --learning_rate ' + str(lr))
allinputs = list(itertools.product(models,layers,seq_lengths,learning_rates))
pool=mp.Pool(processes=8)
pool.map(train_char_rnn,allinputs)