--- a
+++ b/train_rnn.py
@@ -0,0 +1,98 @@
+# -*- coding:utf-8 -*-
+import logging
+import time
+import tensorflow as tf
+from build_rnn import AFD_RNN
+from utils import parser_cfg_file
+from data_load import DataLoad
+
+class AFD_RNN_Train(object):
+
+    def __init__(self, train_config):
+
+        self.learing_rate = float(train_config['learning_rate'])
+        self.train_iterior = int(train_config['train_iteration'])
+        self._train_logger_init()
+
+        net_config = parser_cfg_file('./config/rnn_net.cfg')
+        self.rnn_net = AFD_RNN(net_config)
+        self.predict = self.rnn_net.build_net_graph()
+        self.label = tf.placeholder(tf.float32, [None, self.rnn_net.time_step, self.rnn_net.class_num])
+
+    def _compute_loss(self):
+        with tf.name_scope('loss'):
+            # [batchszie, time_step, class_num] ==> [time_step][batchsize, class_num]
+            predict = tf.unstack(self.predict, axis=0)
+            label = tf.unstack(self.label, axis=1)
+
+            loss = [tf.nn.softmax_cross_entropy_with_logits(labels=label[i], logits=predict[i]) for i in range(self.rnn_net.time_step) ]
+            loss = tf.reduce_mean(loss)
+            train_op = tf.train.AdamOptimizer(self.learing_rate).minimize(loss)
+        return loss, train_op
+
+    def train_rnn(self):
+
+        loss, train_op = self._compute_loss()
+
+        with tf.name_scope('accuracy'):
+            predict = tf.transpose(self.predict, [1,0,2])
+            correct_pred = tf.equal(tf.argmax(self.label, 2), tf.argmax(predict, axis=2))
+            accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
+
+        dataset = DataLoad('./dataset/train/', time_step=self.rnn_net.time_step, class_num= self.rnn_net.class_num)
+        saver = tf.train.Saver()
+
+        with tf.Session() as sess:
+            sess.run(tf.global_variables_initializer())
+
+            for step in range(1, self.train_iterior+1):
+                x, y = dataset.get_batch(self.rnn_net.batch_size)
+                if step == 1:
+                    feed_dict = {self.rnn_net.input_tensor: x, self.label: y}
+                else:
+                    feed_dict = {self.rnn_net.input_tensor: x, self.label: y, self.rnn_net.cell_state:state}
+                _, compute_loss, state = sess.run([train_op, loss, self.rnn_net.cell_state], feed_dict=feed_dict)
+
+                if step%10 == 0:
+                    compute_accuracy = sess.run(accuracy, feed_dict=feed_dict)
+                    self.train_logger.info('train step = %d,loss = %f,accuracy = %f'%(step, compute_loss, compute_accuracy))
+                if step%1000 == 0:
+                    save_path = saver.save(sess, './model/model.ckpt')
+                    self.train_logger.info("train step = %d ,model save to =%s" % (step, save_path))
+
+    def _train_logger_init(self):
+        """
+        初始化log日志
+        :return:
+        """
+        self.train_logger = logging.getLogger('train')
+        self.train_logger.setLevel(logging.DEBUG)
+
+        # 添加文件输出
+        log_file = './train_logs/' + time.strftime('%Y%m%d%H%M', time.localtime(time.time())) + '.logs'
+        file_handler = logging.FileHandler(log_file, mode='w')
+        file_handler.setLevel(logging.DEBUG)
+        file_formatter = logging.Formatter('%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
+        file_handler.setFormatter(file_formatter)
+        self.train_logger.addHandler(file_handler)
+
+        # 添加控制台输出
+        consol_handler = logging.StreamHandler()
+        consol_handler.setLevel(logging.DEBUG)
+        consol_formatter = logging.Formatter('%(message)s')
+        consol_handler.setFormatter(consol_formatter)
+        self.train_logger.addHandler(consol_handler)
+
+if __name__ == '__main__':
+    train_config = parser_cfg_file('./config/train.cfg')
+    train = AFD_RNN_Train(train_config)
+    train.train_rnn()
+
+    # a = tf.zeros([1,2,3])
+    # b = tf.unstack(a, axis=1)
+    # c = tf.zeros([2,1,3])
+    # sess = tf.Session()
+    # d = b[0]
+    # print(sess.run(b[0]))
+    #
+    # pass
\ No newline at end of file