Switch to side-by-side view

--- a
+++ b/jz-char-rnn-tensorflow/model.py
@@ -0,0 +1,115 @@
+import tensorflow as tf
+from tensorflow.models.rnn import rnn_cell,rnn
+from tensorflow.models.rnn import seq2seq
+from jz_rnn_cell import *
+
+import numpy as np
+
+class Model():
+    def __init__(self, args, infer=False):
+        self.args = args
+        if infer:
+            args.batch_size = 1
+            args.seq_length = 1
+
+        if args.model == 'rnn': cell_fn = jzRNNCell
+        elif args.model == 'gru': cell_fn = jzGRUCell
+        elif args.model == 'lstm': cell_fn = jzLSTMCell
+        else: raise Exception("model type not supported: {}".format(args.model))
+
+        if args.activation == 'tanh': cell_af = tf.tanh
+        elif args.activation == 'sigmoid': cell_af = tf.sigmoid
+        elif args.activation == 'relu': cell_af = tf.nn.relu
+        else: raise Exception("activation function not supported: {}".format(args.activation))
+
+        self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
+        self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
+
+        with tf.variable_scope('rnnlm'):
+            if not args.bidirectional:
+                softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])
+            else:
+                softmax_w = tf.get_variable("softmax_w", [args.rnn_size*2, args.vocab_size])
+            softmax_b = tf.get_variable("softmax_b", [args.vocab_size])
+            with tf.device("/cpu:0"):
+                embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
+                inputs = tf.split(1, args.seq_length, tf.nn.embedding_lookup(embedding, self.input_data))
+                inputs = [tf.nn.dropout(tf.squeeze(input_, [1]),args.dropout) for input_ in inputs]
+
+        # one-directional RNN (nothing changed here..)
+        if not args.bidirectional:
+            cell = cell_fn(args.rnn_size,activation=cell_af)
+            self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers)
+            self.initial_state = cell.zero_state(args.batch_size, tf.float32)
+            def loop(prev, _):
+                prev = tf.matmul(prev, softmax_w) + softmax_b
+                prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
+                return tf.nn.embedding_lookup(embedding, prev_symbol)
+            outputs, last_state = seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None, scope='rnnlm')
+            output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size])
+
+        # bi-directional RNN
+        else:
+            lstm_fw = cell_fn(args.rnn_size,activation=cell_af)
+            lstm_bw = cell_fn(args.rnn_size,activation=cell_af)
+            self.lstm_fw = lstm_fw = rnn_cell.MultiRNNCell([lstm_fw]*args.num_layers)
+            self.lstm_bw = lstm_bw = rnn_cell.MultiRNNCell([lstm_bw]*args.num_layers)
+            self.initial_state_fw = lstm_fw.zero_state(args.batch_size,tf.float32)
+            self.initial_state_bw = lstm_bw.zero_state(args.batch_size,tf.float32)
+            outputs,_,_ = rnn.bidirectional_rnn(lstm_fw, lstm_bw, inputs,
+                                            initial_state_fw=self.initial_state_fw,
+                                            initial_state_bw=self.initial_state_bw,
+                                                sequence_length=args.batch_size) 
+            output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size*2])
+
+        self.logits = tf.matmul(tf.nn.dropout(output,args.dropout), softmax_w) + softmax_b
+        self.probs = tf.nn.softmax(self.logits)
+        loss = seq2seq.sequence_loss_by_example([self.logits],
+                [tf.reshape(self.targets, [-1])],
+                [tf.ones([args.batch_size * args.seq_length])],
+                args.vocab_size)
+        self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length
+        self.final_state = last_state
+        self.lr = tf.Variable(0.0, trainable=False)
+        tvars = tf.trainable_variables()
+        grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),
+                args.grad_clip)
+        optimizer = tf.train.AdamOptimizer(self.lr)
+        self.train_op = optimizer.apply_gradients(zip(grads, tvars))
+
+    def sample(self, sess, chars, vocab, num=200, prime='The ', sampling_type=1):
+        state = self.cell.zero_state(1, tf.float32).eval()
+        for char in prime[:-1]:
+            x = np.zeros((1, 1))
+            x[0, 0] = vocab[char]
+            feed = {self.input_data: x, self.initial_state:state}
+            [state] = sess.run([self.final_state], feed)
+
+        def weighted_pick(weights):
+            t = np.cumsum(weights)
+            s = np.sum(weights)
+            return(int(np.searchsorted(t, np.random.rand(1)*s)))
+
+        ret = prime
+        char = prime[-1]
+        for n in range(num):
+            x = np.zeros((1, 1))
+            x[0, 0] = vocab[char]
+            feed = {self.input_data: x, self.initial_state:state}
+            [probs, state] = sess.run([self.probs, self.final_state], feed)
+            p = probs[0]
+
+            if sampling_type == 0:
+                sample = np.argmax(p)
+            elif sampling_type == 2:
+                if char == ' ':
+                    sample = weighted_pick(p)
+                else:
+                    sample = np.argmax(p)
+            else: # sampling_type == 1 default:
+                sample = weighted_pick(p)
+
+            pred = chars[sample]
+            ret += pred
+            char = pred
+        return ret