|
a |
|
b/train_rnn.py |
|
|
1 |
# -*- coding:utf-8 -*- |
|
|
2 |
import logging |
|
|
3 |
import time |
|
|
4 |
import tensorflow as tf |
|
|
5 |
from build_rnn import AFD_RNN |
|
|
6 |
from utils import parser_cfg_file |
|
|
7 |
from data_load import DataLoad |
|
|
8 |
|
|
|
9 |
class AFD_RNN_Train(object): |
|
|
10 |
|
|
|
11 |
def __init__(self, train_config): |
|
|
12 |
|
|
|
13 |
self.learing_rate = float(train_config['learning_rate']) |
|
|
14 |
self.train_iterior = int(train_config['train_iteration']) |
|
|
15 |
self._train_logger_init() |
|
|
16 |
|
|
|
17 |
net_config = parser_cfg_file('./config/rnn_net.cfg') |
|
|
18 |
self.rnn_net = AFD_RNN(net_config) |
|
|
19 |
self.predict = self.rnn_net.build_net_graph() |
|
|
20 |
self.label = tf.placeholder(tf.float32, [None, self.rnn_net.time_step, self.rnn_net.class_num]) |
|
|
21 |
|
|
|
22 |
def _compute_loss(self): |
|
|
23 |
with tf.name_scope('loss'): |
|
|
24 |
# [batchszie, time_step, class_num] ==> [time_step][batchsize, class_num] |
|
|
25 |
predict = tf.unstack(self.predict, axis=0) |
|
|
26 |
label = tf.unstack(self.label, axis=1) |
|
|
27 |
|
|
|
28 |
loss = [tf.nn.softmax_cross_entropy_with_logits(labels=label[i], logits=predict[i]) for i in range(self.rnn_net.time_step) ] |
|
|
29 |
loss = tf.reduce_mean(loss) |
|
|
30 |
train_op = tf.train.AdamOptimizer(self.learing_rate).minimize(loss) |
|
|
31 |
return loss, train_op |
|
|
32 |
|
|
|
33 |
def train_rnn(self): |
|
|
34 |
|
|
|
35 |
loss, train_op = self._compute_loss() |
|
|
36 |
|
|
|
37 |
with tf.name_scope('accuracy'): |
|
|
38 |
predict = tf.transpose(self.predict, [1,0,2]) |
|
|
39 |
correct_pred = tf.equal(tf.argmax(self.label, 2), tf.argmax(predict, axis=2)) |
|
|
40 |
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) |
|
|
41 |
|
|
|
42 |
dataset = DataLoad('./dataset/train/', time_step=self.rnn_net.time_step, class_num= self.rnn_net.class_num) |
|
|
43 |
saver = tf.train.Saver() |
|
|
44 |
|
|
|
45 |
with tf.Session() as sess: |
|
|
46 |
sess.run(tf.global_variables_initializer()) |
|
|
47 |
|
|
|
48 |
for step in range(1, self.train_iterior+1): |
|
|
49 |
x, y = dataset.get_batch(self.rnn_net.batch_size) |
|
|
50 |
if step == 1: |
|
|
51 |
feed_dict = {self.rnn_net.input_tensor: x, self.label: y} |
|
|
52 |
else: |
|
|
53 |
feed_dict = {self.rnn_net.input_tensor: x, self.label: y, self.rnn_net.cell_state:state} |
|
|
54 |
_, compute_loss, state = sess.run([train_op, loss, self.rnn_net.cell_state], feed_dict=feed_dict) |
|
|
55 |
|
|
|
56 |
if step%10 == 0: |
|
|
57 |
compute_accuracy = sess.run(accuracy, feed_dict=feed_dict) |
|
|
58 |
self.train_logger.info('train step = %d,loss = %f,accuracy = %f'%(step, compute_loss, compute_accuracy)) |
|
|
59 |
if step%1000 == 0: |
|
|
60 |
save_path = saver.save(sess, './model/model.ckpt') |
|
|
61 |
self.train_logger.info("train step = %d ,model save to =%s" % (step, save_path)) |
|
|
62 |
|
|
|
63 |
def _train_logger_init(self): |
|
|
64 |
""" |
|
|
65 |
初始化log日志 |
|
|
66 |
:return: |
|
|
67 |
""" |
|
|
68 |
self.train_logger = logging.getLogger('train') |
|
|
69 |
self.train_logger.setLevel(logging.DEBUG) |
|
|
70 |
|
|
|
71 |
# 添加文件输出 |
|
|
72 |
log_file = './train_logs/' + time.strftime('%Y%m%d%H%M', time.localtime(time.time())) + '.logs' |
|
|
73 |
file_handler = logging.FileHandler(log_file, mode='w') |
|
|
74 |
file_handler.setLevel(logging.DEBUG) |
|
|
75 |
file_formatter = logging.Formatter('%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') |
|
|
76 |
file_handler.setFormatter(file_formatter) |
|
|
77 |
self.train_logger.addHandler(file_handler) |
|
|
78 |
|
|
|
79 |
# 添加控制台输出 |
|
|
80 |
consol_handler = logging.StreamHandler() |
|
|
81 |
consol_handler.setLevel(logging.DEBUG) |
|
|
82 |
consol_formatter = logging.Formatter('%(message)s') |
|
|
83 |
consol_handler.setFormatter(consol_formatter) |
|
|
84 |
self.train_logger.addHandler(consol_handler) |
|
|
85 |
|
|
|
86 |
if __name__ == '__main__': |
|
|
87 |
train_config = parser_cfg_file('./config/train.cfg') |
|
|
88 |
train = AFD_RNN_Train(train_config) |
|
|
89 |
train.train_rnn() |
|
|
90 |
|
|
|
91 |
# a = tf.zeros([1,2,3]) |
|
|
92 |
# b = tf.unstack(a, axis=1) |
|
|
93 |
# c = tf.zeros([2,1,3]) |
|
|
94 |
# sess = tf.Session() |
|
|
95 |
# d = b[0] |
|
|
96 |
# print(sess.run(b[0])) |
|
|
97 |
# |
|
|
98 |
# pass |