Diff of /gait_nn.py [000000] .. [968c76]

Switch to unified view

a b/gait_nn.py
1
import settings
2
import os
3
4
import tensorflow as tf
5
import tensorflow.contrib.layers as layers
6
import numpy as np
7
8
from abc import abstractmethod
9
10
slim = tf.contrib.slim
11
12
SUMMARY_PATH = settings.LOGDIR_GAIT_PATH
13
KEY_SUMMARIES = tf.GraphKeys.SUMMARIES
14
15
SEED = 0
16
np.random.seed(SEED)
17
18
19
class GaitNN(object):
20
    def __init__(self, name, input_tensor, features, num_of_persons, reuse = False, is_train = True,
21
                 count_of_training_examples = 1000):
22
        self.input_tensor = input_tensor
23
        self.is_train = is_train
24
        self.name = name
25
26
        self.FEATURES = features
27
28
        net = self.pre_process(input_tensor)
29
        net, gait_signature, state = self.get_network(net, is_train, reuse)
30
31
        self.network = net
32
        self.gait_signature = gait_signature
33
        self.state = state
34
35
        if is_train:
36
            # Initialize placeholders
37
            self.desired_person = tf.placeholder(
38
                dtype = tf.int32,
39
                shape = [],
40
                name = 'desired_person')
41
42
            self.desired_person_one_hot = tf.one_hot(self.desired_person, num_of_persons, dtype = tf.float32)
43
            self.loss = self._sigm_ce_loss()
44
45
            self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
46
47
            self.learning_rate = tf.placeholder(
48
                dtype = tf.float32,
49
                shape = [],
50
                name = 'learning_rate')
51
52
            def _learning_rate_decay_fn(learning_rate, global_step):
53
                return tf.train.exponential_decay(
54
                    learning_rate,
55
                    global_step,
56
                    decay_steps = count_of_training_examples * 2,
57
                    decay_rate = 0.96,
58
                    staircase = True)
59
60
            self.optimize = layers.optimize_loss(loss = self.loss,
61
                                                 global_step = self.global_step,
62
                                                 learning_rate = self.learning_rate,
63
                                                 summaries = layers.optimizers.OPTIMIZER_SUMMARIES,
64
                                                 optimizer = tf.train.RMSPropOptimizer,
65
                                                 learning_rate_decay_fn = _learning_rate_decay_fn,
66
                                                 clip_gradients = 0.1,
67
                                                 )
68
69
        self.sess = tf.Session()
70
        self.sess.run(tf.global_variables_initializer())
71
72
        # Initialize summaries
73
        if name is not None:
74
            if is_train:
75
                logdir = os.path.join(SUMMARY_PATH, self.name, 'train')
76
                self.summary_writer = tf.train.SummaryWriter(logdir)
77
78
                self.ALL_SUMMARIES = tf.merge_all_summaries(KEY_SUMMARIES)
79
            else:
80
                self.summary_writer_d = {}
81
82
                for t in ['avg', 'n', 'b', 's']:
83
                    logdir = os.path.join(SUMMARY_PATH, self.name, 'val_%s' % t)
84
                    self.summary_writer_d[t] = tf.train.SummaryWriter(logdir)
85
86
        tf.set_random_seed(SEED)
87
88
    @staticmethod
89
    def pre_process(inp):
90
        return inp / 100.0
91
92
    @staticmethod
93
    def get_arg_scope(is_training):
94
        weight_decay_l2 = 0.1
95
        batch_norm_decay = 0.999
96
        batch_norm_epsilon = 0.0001
97
98
        with slim.arg_scope([slim.conv2d, slim.fully_connected, layers.separable_convolution2d],
99
                            weights_regularizer = slim.l2_regularizer(weight_decay_l2),
100
                            biases_regularizer = slim.l2_regularizer(weight_decay_l2),
101
                            weights_initializer = layers.variance_scaling_initializer(),
102
                            ):
103
            batch_norm_params = {
104
                'decay': batch_norm_decay,
105
                'epsilon': batch_norm_epsilon
106
            }
107
            with slim.arg_scope([slim.batch_norm, slim.dropout],
108
                                is_training = is_training):
109
                with slim.arg_scope([slim.batch_norm],
110
                                    **batch_norm_params):
111
                    with slim.arg_scope([slim.conv2d, layers.separable_convolution2d, layers.fully_connected],
112
                                        activation_fn = tf.nn.elu,
113
                                        normalizer_fn = slim.batch_norm,
114
                                        normalizer_params = batch_norm_params) as scope:
115
                        return scope
116
117
    def _sigm_ce_loss(self):
118
        ce = tf.nn.softmax_cross_entropy_with_logits(logits = self.network, labels = self.desired_person_one_hot)
119
        loss = tf.reduce_mean(ce)
120
121
        return loss
122
123
    def train(self, input_tensor, desired_person, learning_rate):
124
        if not self.is_train:
125
            raise Exception('Network is not in training mode!')
126
127
        self.sess.run(self.optimize, feed_dict = {
128
            self.input_tensor: input_tensor,
129
            self.desired_person: desired_person,
130
            self.learning_rate: learning_rate
131
        })
132
133
    def feed_forward(self, x):
134
        out, states = self.sess.run([self.gait_signature, self.state], feed_dict = {self.input_tensor: x})
135
136
        return out, states
137
138
    def write_test_summary(self, err, epoch, t = 'all'):
139
        loss_summ = tf.Summary()
140
        loss_summ.value.add(
141
            tag = 'Classification in percent',
142
            simple_value = float(err))
143
144
        self.summary_writer_d[t].add_summary(loss_summ, epoch)
145
        self.summary_writer_d[t].flush()
146
147
    def write_summary(self, inputs, desired_person, learning_rate, write_frequency = 50):
148
        step = tf.train.global_step(self.sess, self.global_step)
149
150
        if step % write_frequency == 0:
151
            feed_dict = {
152
                self.input_tensor: inputs,
153
                self.desired_person: desired_person,
154
                self.learning_rate: learning_rate,
155
            }
156
157
            summary, loss = self.sess.run([self.ALL_SUMMARIES, self.loss], feed_dict = feed_dict)
158
            self.summary_writer.add_summary(summary, step)
159
            self.summary_writer.flush()
160
161
    def save(self, checkpoint_path, name):
162
        if not os.path.exists(checkpoint_path):
163
            os.mkdir(checkpoint_path)
164
165
        checkpoint_name_path = os.path.join(checkpoint_path, '%s.ckpt' % name)
166
        all_vars = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope = 'GaitNN')
167
168
        saver = tf.train.Saver(all_vars)
169
        saver.save(self.sess, checkpoint_name_path)
170
171
    def restore(self, checkpoint_path):
172
        all_vars = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope = 'GaitNN')
173
174
        saver = tf.train.Saver(all_vars)
175
        saver.restore(self.sess, checkpoint_path)
176
177
    @staticmethod
178
    def residual_block(net, ch = 256, ch_inner = 128, scope = None, reuse = None, stride = 1):
179
        """
180
        Bottleneck v2
181
        """
182
183
        with slim.arg_scope([layers.convolution2d],
184
                            activation_fn = None,
185
                            normalizer_fn = None):
186
            with tf.variable_scope(scope, 'ResidualBlock', reuse = reuse):
187
                in_net = net
188
189
                if stride > 1:
190
                    net = layers.convolution2d(net, ch, kernel_size = 1, stride = stride)
191
192
                in_net = layers.batch_norm(in_net)
193
                in_net = tf.nn.relu(in_net)
194
                in_net = layers.convolution2d(in_net, ch_inner, 1)
195
196
                in_net = layers.batch_norm(in_net)
197
                in_net = tf.nn.relu(in_net)
198
                in_net = layers.convolution2d(in_net, ch_inner, 3, stride = stride)
199
200
                in_net = layers.batch_norm(in_net)
201
                in_net = tf.nn.relu(in_net)
202
                in_net = layers.convolution2d(in_net, ch, 1, activation_fn = None)
203
204
                net = tf.nn.relu(in_net + net)
205
206
        return net
207
208
    @abstractmethod
209
    def get_network(self, input_tensor, is_training, reuse = False):
210
        pass
211
212
213
class GaitNetwork(GaitNN):
214
    FEATURES = 512
215
216
    def __init__(self, name = None, num_of_persons = 0, recurrent_unit = 'GRU', rnn_layers = 1,
217
                 reuse = False, is_training = False, input_net = None):
218
        tf.set_random_seed(SEED)
219
220
        if num_of_persons <= 0 and is_training:
221
            raise Exception('Parameter num_of_persons has to be greater than zero when thaining')
222
223
        self.num_of_persons = num_of_persons
224
        self.rnn_layers = rnn_layers
225
        self.recurrent_unit = recurrent_unit
226
227
        if input_net is None:
228
            input_tensor = tf.placeholder(
229
                dtype = tf.float32,
230
                shape = (None, 17, 17, 32),
231
                name = 'input_image')
232
        else:
233
            input_tensor = input_net
234
235
        super().__init__(name, input_tensor, self.FEATURES, num_of_persons, reuse, is_training)
236
237
    def get_network(self, input_tensor, is_training, reuse = False):
238
        net = input_tensor
239
240
        with tf.variable_scope('GaitNN', reuse = reuse):
241
            with slim.arg_scope(self.get_arg_scope(is_training)):
242
                with tf.variable_scope('DownSampling'):
243
                    with tf.variable_scope('17x17'):
244
                        net = layers.convolution2d(net, num_outputs = 256, kernel_size = 1)
245
                        slim.repeat(net, 3, self.residual_block, ch = 256, ch_inner = 64)
246
247
                    with tf.variable_scope('8x8'):
248
                        net = self.residual_block(net, ch = 512, ch_inner = 64, stride = 2)
249
                        slim.repeat(net, 2, self.residual_block, ch = 512, ch_inner = 128)
250
251
                    with tf.variable_scope('4x4'):
252
                        net = self.residual_block(net, ch = 512, ch_inner = 128, stride = 2)
253
                        slim.repeat(net, 1, self.residual_block, ch = 512, ch_inner = 256)
254
255
                        net = layers.convolution2d(net, num_outputs = 256, kernel_size = 1)
256
                        net = layers.convolution2d(net, num_outputs = 256, kernel_size = 3)
257
258
                with tf.variable_scope('FullyConnected'):
259
                    # net = tf.reduce_mean(net, [1, 2], name = 'GlobalPool')
260
                    net = layers.flatten(net)
261
                    net = layers.fully_connected(net, 512, activation_fn = None, normalizer_fn = None)
262
263
                with tf.variable_scope('Recurrent', initializer = tf.contrib.layers.xavier_initializer()):
264
                    cell_type = {
265
                        'GRU': tf.nn.rnn_cell.GRUCell,
266
                        'LSTM': tf.nn.rnn_cell.LSTMCell
267
                    }
268
269
                    cell = cell_type[self.recurrent_unit](self.FEATURES)
270
                    cell = tf.nn.rnn_cell.MultiRNNCell([cell] * self.rnn_layers, state_is_tuple = True)
271
272
                    net = tf.expand_dims(net, 0)
273
                    net, state = tf.nn.dynamic_rnn(cell, net, initial_state = cell.zero_state(1, dtype = tf.float32))
274
                    net = tf.reshape(net, [-1, self.FEATURES])
275
276
                    # Temporal Avg-Pooling
277
                    gait_signature = tf.reduce_mean(net, 0)
278
279
                if is_training:
280
                    net = tf.expand_dims(gait_signature, 0)
281
                    net = layers.dropout(net, 0.7)
282
283
                    with tf.variable_scope('Logits'):
284
                        net = layers.fully_connected(net, self.num_of_persons, activation_fn = None,
285
                                                     normalizer_fn = None)
286
287
                return net, gait_signature, state