Diff of /train_rnn.py [000000] .. [857d1b]

Switch to unified view

a b/train_rnn.py
1
# -*- coding:utf-8 -*-
2
import logging
3
import time
4
import tensorflow as tf
5
from build_rnn import AFD_RNN
6
from utils import parser_cfg_file
7
from data_load import DataLoad
8
9
class AFD_RNN_Train(object):
10
11
    def __init__(self, train_config):
12
13
        self.learing_rate = float(train_config['learning_rate'])
14
        self.train_iterior = int(train_config['train_iteration'])
15
        self._train_logger_init()
16
17
        net_config = parser_cfg_file('./config/rnn_net.cfg')
18
        self.rnn_net = AFD_RNN(net_config)
19
        self.predict = self.rnn_net.build_net_graph()
20
        self.label = tf.placeholder(tf.float32, [None, self.rnn_net.time_step, self.rnn_net.class_num])
21
22
    def _compute_loss(self):
23
        with tf.name_scope('loss'):
24
            # [batchszie, time_step, class_num] ==> [time_step][batchsize, class_num]
25
            predict = tf.unstack(self.predict, axis=0)
26
            label = tf.unstack(self.label, axis=1)
27
28
            loss = [tf.nn.softmax_cross_entropy_with_logits(labels=label[i], logits=predict[i]) for i in range(self.rnn_net.time_step) ]
29
            loss = tf.reduce_mean(loss)
30
            train_op = tf.train.AdamOptimizer(self.learing_rate).minimize(loss)
31
        return loss, train_op
32
33
    def train_rnn(self):
34
35
        loss, train_op = self._compute_loss()
36
37
        with tf.name_scope('accuracy'):
38
            predict = tf.transpose(self.predict, [1,0,2])
39
            correct_pred = tf.equal(tf.argmax(self.label, 2), tf.argmax(predict, axis=2))
40
            accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
41
42
        dataset = DataLoad('./dataset/train/', time_step=self.rnn_net.time_step, class_num= self.rnn_net.class_num)
43
        saver = tf.train.Saver()
44
45
        with tf.Session() as sess:
46
            sess.run(tf.global_variables_initializer())
47
48
            for step in range(1, self.train_iterior+1):
49
                x, y = dataset.get_batch(self.rnn_net.batch_size)
50
                if step == 1:
51
                    feed_dict = {self.rnn_net.input_tensor: x, self.label: y}
52
                else:
53
                    feed_dict = {self.rnn_net.input_tensor: x, self.label: y, self.rnn_net.cell_state:state}
54
                _, compute_loss, state = sess.run([train_op, loss, self.rnn_net.cell_state], feed_dict=feed_dict)
55
56
                if step%10 == 0:
57
                    compute_accuracy = sess.run(accuracy, feed_dict=feed_dict)
58
                    self.train_logger.info('train step = %d,loss = %f,accuracy = %f'%(step, compute_loss, compute_accuracy))
59
                if step%1000 == 0:
60
                    save_path = saver.save(sess, './model/model.ckpt')
61
                    self.train_logger.info("train step = %d ,model save to =%s" % (step, save_path))
62
63
    def _train_logger_init(self):
64
        """
65
        初始化log日志
66
        :return:
67
        """
68
        self.train_logger = logging.getLogger('train')
69
        self.train_logger.setLevel(logging.DEBUG)
70
71
        # 添加文件输出
72
        log_file = './train_logs/' + time.strftime('%Y%m%d%H%M', time.localtime(time.time())) + '.logs'
73
        file_handler = logging.FileHandler(log_file, mode='w')
74
        file_handler.setLevel(logging.DEBUG)
75
        file_formatter = logging.Formatter('%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
76
        file_handler.setFormatter(file_formatter)
77
        self.train_logger.addHandler(file_handler)
78
79
        # 添加控制台输出
80
        consol_handler = logging.StreamHandler()
81
        consol_handler.setLevel(logging.DEBUG)
82
        consol_formatter = logging.Formatter('%(message)s')
83
        consol_handler.setFormatter(consol_formatter)
84
        self.train_logger.addHandler(consol_handler)
85
86
if __name__ == '__main__':
87
    train_config = parser_cfg_file('./config/train.cfg')
88
    train = AFD_RNN_Train(train_config)
89
    train.train_rnn()
90
91
    # a = tf.zeros([1,2,3])
92
    # b = tf.unstack(a, axis=1)
93
    # c = tf.zeros([2,1,3])
94
    # sess = tf.Session()
95
    # d = b[0]
96
    # print(sess.run(b[0]))
97
    #
98
    # pass