Diff of /human_pose_nn.py [000000] .. [968c76]

Switch to unified view

a b/human_pose_nn.py
1
import tensorflow as tf
2
import numpy as np
3
import part_detector
4
import settings
5
import utils
6
import os
7
8
from abc import abstractmethod
9
from functools import lru_cache
10
from scipy.stats import norm
11
12
from inception_resnet_v2 import inception_resnet_v2_arg_scope, inception_resnet_v2
13
14
import tensorflow.contrib.layers as layers
15
16
slim = tf.contrib.slim
17
18
SUMMARY_PATH = settings.LOGDIR_PATH
19
20
KEY_SUMMARIES = tf.GraphKeys.SUMMARIES
21
KEY_SUMMARIES_PER_JOINT = ['summary_joint_%02d' % i for i in range(16)]
22
23
24
class HumanPoseNN(object):
25
    """
26
    The neural network used for pose estimation.
27
    """
28
29
    def __init__(self, log_name, heatmap_size, image_size, loss_type = 'SCE', is_training = True):
30
        tf.set_random_seed(0)
31
32
        if loss_type not in { 'MSE', 'SCE' }:
33
            raise NotImplementedError('Loss function should be either MSE or SCE!')
34
35
        self.log_name = log_name
36
        self.heatmap_size = heatmap_size
37
        self.image_size = image_size
38
        self.is_train = is_training
39
        self.loss_type = loss_type
40
41
        # Initialize placeholders
42
        self.input_tensor = tf.placeholder(
43
            dtype = tf.float32,
44
            shape = (None, image_size, image_size, 3),
45
            name = 'input_image')
46
47
        self.present_joints = tf.placeholder(
48
            dtype = tf.float32,
49
            shape = (None, 16),
50
            name = 'present_joints')
51
52
        self.inside_box_joints = tf.placeholder(
53
            dtype = tf.float32,
54
            shape = (None, 16),
55
            name = 'inside_box_joints')
56
57
        self.desired_heatmap = tf.placeholder(
58
            dtype = tf.float32,
59
            shape = (None, heatmap_size, heatmap_size, 16),
60
            name = 'desired_heatmap')
61
62
        self.desired_points = tf.placeholder(
63
            dtype = tf.float32,
64
            shape = (None, 2, 16),
65
            name = 'desired_points')
66
67
        self.network = self.pre_process(self.input_tensor)
68
        self.network, self.feature_tensor = self.get_network(self.network, is_training)
69
70
        self.sigm_network = tf.sigmoid(self.network)
71
        self.smoothed_sigm_network = self._get_gauss_smoothing_net(self.sigm_network, std = 0.7)
72
73
        self.loss_err = self._get_loss_function(loss_type)
74
        self.euclidean_dist = self._euclidean_dist_err()
75
        self.euclidean_dist_per_joint = self._euclidean_dist_per_joint_err()
76
77
        if is_training:
78
            self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
79
80
            self.learning_rate = tf.placeholder(
81
                dtype = tf.float32,
82
                shape = [],
83
                name = 'learning_rate')
84
85
            self.optimize = layers.optimize_loss(loss = self.loss_err,
86
                                                 global_step = self.global_step,
87
                                                 learning_rate = self.learning_rate,
88
                                                 optimizer = tf.train.RMSPropOptimizer(self.learning_rate),
89
                                                 clip_gradients = 2.0
90
                                                 )
91
92
        self.sess = tf.Session()
93
        self.sess.run(tf.global_variables_initializer())
94
95
        if log_name is not None:
96
            self._init_summaries()
97
98
    def _init_summaries(self):
99
        if self.is_train:
100
            logdir = os.path.join(SUMMARY_PATH, self.log_name, 'train')
101
102
            self.summary_writer = tf.summary.FileWriter(logdir)
103
            self.summary_writer_by_points = [tf.summary.FileWriter(os.path.join(logdir, 'point_%02d' % i))
104
                                             for i in range(16)]
105
106
            tf.scalar_summary('Average euclidean distance', self.euclidean_dist, collections = [KEY_SUMMARIES])
107
108
            for i in range(16):
109
                tf.scalar_summary('Joint euclidean distance', self.euclidean_dist_per_joint[i],
110
                                  collections = [KEY_SUMMARIES_PER_JOINT[i]])
111
112
            self.create_summary_from_weights()
113
114
            self.ALL_SUMMARIES = tf.merge_all_summaries(KEY_SUMMARIES)
115
            self.SUMMARIES_PER_JOINT = [tf.merge_all_summaries(KEY_SUMMARIES_PER_JOINT[i]) for i in range(16)]
116
        else:
117
            logdir = os.path.join(SUMMARY_PATH, self.log_name, 'test')
118
            self.summary_writer = tf.summary.FileWriter(logdir)
119
120
    def _get_loss_function(self, loss_type):
121
        loss_dict = {
122
            'MSE': self._loss_mse(),
123
            'SCE': self._loss_cross_entropy()
124
        }
125
126
        return loss_dict[loss_type]
127
128
    @staticmethod
129
    @lru_cache()
130
    def _get_gauss_filter(size = 15, std = 1.0, kernel_sum = 1.0):
131
        samples = norm.pdf(np.linspace(-2, 2, size), 0, std)
132
        samples /= np.sum(samples)
133
        samples *= kernel_sum ** 0.5
134
135
        samples = np.expand_dims(samples, 0)
136
        weights = np.zeros(shape = (1, size, 16, 1), dtype = np.float32)
137
138
        for i in range(16):
139
            weights[:, :, i, 0] = samples
140
141
        return weights
142
143
    @staticmethod
144
    def _get_gauss_smoothing_net(net, size = 15, std = 1.0, kernel_sum = 1.0):
145
        filter_h = HumanPoseNN._get_gauss_filter(size, std, kernel_sum)
146
        filter_v = filter_h.swapaxes(0, 1)
147
148
        net = tf.nn.depthwise_conv2d(net, filter = filter_h, strides = [1, 1, 1, 1], padding = 'SAME',
149
                                     name = 'SmoothingHorizontal')
150
151
        net = tf.nn.depthwise_conv2d(net, filter = filter_v, strides = [1, 1, 1, 1], padding = 'SAME',
152
                                     name = 'SmoothingVertical')
153
154
        return net
155
156
    def generate_output(self, shape, presented_parts, labels, sigma):
157
        heatmap_dict = {
158
            'MSE': utils.get_gauss_heat_map(
159
                shape = shape, is_present = presented_parts,
160
                mean = labels, sigma = sigma),
161
            'SCE': utils.get_binary_heat_map(
162
                shape = shape, is_present = presented_parts,
163
                centers = labels, diameter = sigma)
164
        }
165
166
        return heatmap_dict[self.loss_type]
167
168
    def _adjust_loss(self, loss_err):
169
        # Shape: [batch, joints]
170
        loss = tf.reduce_sum(loss_err, [1, 2])
171
172
        # Stop error propagation of joints that are not presented
173
        loss = tf.multiply(loss, self.present_joints)
174
175
        # Compute average loss of presented joints
176
        num_of_visible_joints = tf.reduce_sum(self.present_joints)
177
        loss = tf.reduce_sum(loss) / num_of_visible_joints
178
179
        return loss
180
181
    def _loss_mse(self):
182
        sq = tf.squared_difference(self.sigm_network, self.desired_heatmap)
183
        loss = self._adjust_loss(sq)
184
185
        return loss
186
187
    def _loss_cross_entropy(self):
188
        ce = tf.nn.sigmoid_cross_entropy_with_logits(logits = self.network, labels = self.desired_heatmap)
189
        loss = self._adjust_loss(ce)
190
191
        return loss
192
193
    def _joint_highest_activations(self):
194
        highest_activation = tf.reduce_max(self.smoothed_sigm_network, [1, 2])
195
196
        return highest_activation
197
198
    def _joint_positions(self):
199
        highest_activation = tf.reduce_max(self.sigm_network, [1, 2])
200
        x = tf.argmax(tf.reduce_max(self.smoothed_sigm_network, 1), 1)
201
        y = tf.argmax(tf.reduce_max(self.smoothed_sigm_network, 2), 1)
202
203
        x = tf.cast(x, tf.float32)
204
        y = tf.cast(y, tf.float32)
205
        a = tf.cast(highest_activation, tf.float32)
206
207
        scale_coef = (self.image_size / self.heatmap_size)
208
        x *= scale_coef
209
        y *= scale_coef
210
211
        out = tf.stack([y, x, a])
212
213
        return out
214
215
    def _euclidean_dist_err(self):
216
        # Work only with joints that are presented inside frame
217
        l2_dist = tf.multiply(self.euclidean_distance(), self.inside_box_joints)
218
219
        # Compute average loss of presented joints
220
        num_of_visible_joints = tf.reduce_sum(self.inside_box_joints)
221
        l2_dist = tf.reduce_sum(l2_dist) / num_of_visible_joints
222
223
        return l2_dist
224
225
    def _euclidean_dist_per_joint_err(self):
226
        # Work only with joints that are presented inside frame
227
        l2_dist = tf.multiply(self.euclidean_distance(), self.inside_box_joints)
228
229
        # Average euclidean distance of presented joints
230
        present_joints = tf.reduce_sum(self.inside_box_joints, 0)
231
        err = tf.reduce_sum(l2_dist, 0) / present_joints
232
233
        return err
234
235
    def _restore(self, checkpoint_path, variables):
236
        saver = tf.train.Saver(variables)
237
        saver.restore(self.sess, checkpoint_path)
238
239
    def _save(self, checkpoint_path, name, variables):
240
        if not os.path.exists(checkpoint_path):
241
            os.mkdir(checkpoint_path)
242
243
        checkpoint_name_path = os.path.join(checkpoint_path, '%s.ckpt' % name)
244
245
        saver = tf.train.Saver(variables)
246
        saver.save(self.sess, checkpoint_name_path)
247
248
    def euclidean_distance(self):
249
        x = tf.argmax(tf.reduce_max(self.smoothed_sigm_network, 1), 1)
250
        y = tf.argmax(tf.reduce_max(self.smoothed_sigm_network, 2), 1)
251
252
        x = tf.cast(x, tf.float32)
253
        y = tf.cast(y, tf.float32)
254
255
        dy = tf.squeeze(self.desired_points[:, 0, :])
256
        dx = tf.squeeze(self.desired_points[:, 1, :])
257
258
        sx = tf.squared_difference(x, dx)
259
        sy = tf.squared_difference(y, dy)
260
261
        l2_dist = tf.sqrt(sx + sy)
262
263
        return l2_dist
264
265
    def feed_forward(self, x):
266
        out = self.sess.run(self.sigm_network, feed_dict = {
267
            self.input_tensor: x
268
        })
269
270
        return out
271
272
    def heat_maps(self, x):
273
        out = self.sess.run(self.smoothed_sigm_network, feed_dict = {
274
            self.input_tensor: x
275
        })
276
277
        return out
278
279
    def feed_forward_pure(self, x):
280
        out = self.sess.run(self.network, feed_dict = {
281
            self.input_tensor: x
282
        })
283
284
        return out
285
286
    def feed_forward_features(self, x):
287
        out = self.sess.run(self.feature_tensor, feed_dict = {
288
            self.input_tensor: x,
289
        })
290
291
        return out
292
293
    def test_euclidean_distance(self, x, points, present_joints, inside_box_joints):
294
        err = self.sess.run(self.euclidean_dist, feed_dict = {
295
            self.input_tensor: x,
296
            self.desired_points: points,
297
            self.present_joints: present_joints,
298
            self.inside_box_joints: inside_box_joints
299
        })
300
301
        return err
302
303
    def test_joint_distances(self, x, y):
304
        err = self.sess.run(self.euclidean_distance(), feed_dict = {
305
            self.input_tensor: x,
306
            self.desired_points: y
307
        })
308
309
        return err
310
311
    def test_joint_activations(self, x):
312
        err = self.sess.run(self._joint_highest_activations(), feed_dict = {
313
            self.input_tensor: x
314
        })
315
316
        return err
317
318
    def estimate_joints(self, x):
319
        out = self.sess.run(self._joint_positions(), feed_dict = {
320
            self.input_tensor: x
321
        })
322
323
        return out
324
325
    def train(self, x, heatmaps, present_joints, learning_rate, is_inside_box):
326
        if not self.is_train:
327
            raise Exception('Network is not in train mode!')
328
329
        self.sess.run(self.optimize, feed_dict = {
330
            self.input_tensor: x,
331
            self.desired_heatmap: heatmaps,
332
            self.present_joints: present_joints,
333
            self.learning_rate: learning_rate,
334
            self.inside_box_joints: is_inside_box
335
        })
336
337
    def write_test_summary(self, epoch, loss):
338
        loss_sum = tf.Summary()
339
        loss_sum.value.add(
340
            tag = 'Average Euclidean Distance',
341
            simple_value = float(loss))
342
        self.summary_writer.add_summary(loss_sum, epoch)
343
        self.summary_writer.flush()
344
345
    def write_summary(self, inp, desired_points, heatmaps, present_joints, learning_rate, is_inside_box,
346
                      write_frequency = 20, write_per_joint_frequency = 100):
347
        step = tf.train.global_step(self.sess, self.global_step)
348
349
        if step % write_frequency == 0:
350
            feed_dict = {
351
                self.input_tensor: inp,
352
                self.desired_points: desired_points,
353
                self.desired_heatmap: heatmaps,
354
                self.present_joints: present_joints,
355
                self.learning_rate: learning_rate,
356
                self.inside_box_joints: is_inside_box
357
            }
358
359
            summary, loss = self.sess.run([self.ALL_SUMMARIES, self.loss_err], feed_dict = feed_dict)
360
            self.summary_writer.add_summary(summary, step)
361
362
            if step % write_per_joint_frequency == 0:
363
                summaries = self.sess.run(self.SUMMARIES_PER_JOINT, feed_dict = feed_dict)
364
365
                for i in range(16):
366
                    self.summary_writer_by_points[i].add_summary(summaries[i], step)
367
368
                for i in range(16):
369
                    self.summary_writer_by_points[i].flush()
370
371
            self.summary_writer.flush()
372
373
    @abstractmethod
374
    def pre_process(self, inp):
375
        pass
376
377
    @abstractmethod
378
    def get_network(self, input_tensor, is_training):
379
        pass
380
381
    @abstractmethod
382
    def create_summary_from_weights(self):
383
        pass
384
385
386
class HumanPoseIRNetwork(HumanPoseNN):
387
    """
388
    The first part of our network that exposes as an extractor of spatial features. It s derived from
389
    Inception-Resnet-v2 architecture and modified for generating heatmaps - i.e. dense predictions of body joints.
390
    """
391
392
    FEATURES = 32
393
    IMAGE_SIZE = 299
394
    HEATMAP_SIZE = 289
395
    POINT_DIAMETER = 15
396
    SMOOTH_SIZE = 21
397
398
    def __init__(self, log_name = None, loss_type = 'SCE', is_training = False):
399
        super().__init__(log_name, self.HEATMAP_SIZE, self.IMAGE_SIZE, loss_type, is_training)
400
401
    def pre_process(self, inp):
402
        return ((inp / 255) - 0.5) * 2.0
403
404
    def get_network(self, input_tensor, is_training):
405
        # Load pre-trained inception-resnet model
406
        with slim.arg_scope(inception_resnet_v2_arg_scope(batch_norm_decay = 0.999, weight_decay = 0.0001)):
407
            net, end_points = inception_resnet_v2(input_tensor, is_training = is_training)
408
409
        # Adding some modification to original InceptionResnetV2 - changing scoring of AUXILIARY TOWER
410
        weight_decay = 0.0005
411
        with tf.variable_scope('NewInceptionResnetV2'):
412
            with tf.variable_scope('AuxiliaryScoring'):
413
                with slim.arg_scope([layers.convolution2d, layers.convolution2d_transpose],
414
                                    weights_regularizer = slim.l2_regularizer(weight_decay),
415
                                    biases_regularizer = slim.l2_regularizer(weight_decay),
416
                                    activation_fn = None):
417
                    tf.summary.histogram('Last_layer/activations', net, [KEY_SUMMARIES])
418
419
                    # Scoring
420
                    net = slim.dropout(net, 0.7, is_training = is_training, scope = 'Dropout')
421
                    net = layers.convolution2d(net, num_outputs = self.FEATURES, kernel_size = 1, stride = 1,
422
                                               scope = 'Scoring_layer')
423
                    feature = net
424
                    tf.summary.histogram('Scoring_layer/activations', net, [KEY_SUMMARIES])
425
426
                    # Upsampling
427
                    net = layers.convolution2d_transpose(net, num_outputs = 16, kernel_size = 17, stride = 17,
428
                                                         padding = 'VALID', scope = 'Upsampling_layer')
429
430
                    tf.summary.histogram('Upsampling_layer/activations', net, [KEY_SUMMARIES])
431
432
            # Smoothing layer - separable gaussian filters
433
            net = super()._get_gauss_smoothing_net(net, size = self.SMOOTH_SIZE, std = 1.0, kernel_sum = 0.2)
434
435
            return net, feature
436
437
    def restore(self, checkpoint_path, is_pre_trained_imagenet_checkpoint = False):
438
        all_vars = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope = 'InceptionResnetV2')
439
        if not is_pre_trained_imagenet_checkpoint:
440
            all_vars += tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope = 'NewInceptionResnetV2/AuxiliaryScoring')
441
442
        super()._restore(checkpoint_path, all_vars)
443
444
    def save(self, checkpoint_path, name):
445
        all_vars = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope = 'InceptionResnetV2')
446
        all_vars += tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope = 'NewInceptionResnetV2/AuxiliaryScoring')
447
448
        super()._save(checkpoint_path, name, all_vars)
449
450
    def create_summary_from_weights(self):
451
        with tf.variable_scope('NewInceptionResnetV2/AuxiliaryScoring', reuse = True):
452
            tf.summary.histogram('Scoring_layer/biases', tf.get_variable('Scoring_layer/biases'), [KEY_SUMMARIES])
453
            tf.summary.histogram('Upsampling_layer/biases', tf.get_variable('Upsampling_layer/biases'), [KEY_SUMMARIES])
454
            tf.summary.histogram('Scoring_layer/weights', tf.get_variable('Scoring_layer/weights'), [KEY_SUMMARIES])
455
            tf.summary.histogram('Upsampling_layer/weights', tf.get_variable('Upsampling_layer/weights'),
456
                                 [KEY_SUMMARIES])
457
458
        with tf.variable_scope('InceptionResnetV2/AuxLogits', reuse = True):
459
            tf.summary.histogram('Last_layer/weights', tf.get_variable('Conv2d_2a_5x5/weights'), [KEY_SUMMARIES])
460
            tf.summary.histogram('Last_layer/beta', tf.get_variable('Conv2d_2a_5x5/BatchNorm/beta'), [KEY_SUMMARIES])
461
            tf.summary.histogram('Last_layer/moving_mean', tf.get_variable('Conv2d_2a_5x5/BatchNorm/moving_mean'),
462
                                 [KEY_SUMMARIES])
463
464
465
class PartDetector(HumanPoseNN):
466
    """
467
    Architecture of Part Detector network, as was described in https://arxiv.org/abs/1609.01743
468
    """
469
470
    IMAGE_SIZE = 256
471
    HEATMAP_SIZE = 256
472
    POINT_DIAMETER = 11
473
474
    def __init__(self, log_name = None, init_from_checkpoint = None, loss_type = 'SCE', is_training = False):
475
        if init_from_checkpoint is not None:
476
            part_detector.init_model_variables(init_from_checkpoint, is_training)
477
            self.reuse = True
478
        else:
479
            self.reuse = False
480
481
        super().__init__(log_name, self.HEATMAP_SIZE, self.IMAGE_SIZE, loss_type, is_training)
482
483
    def pre_process(self, inp):
484
        return inp / 255
485
486
    def create_summary_from_weights(self):
487
        pass
488
489
    def restore(self, checkpoint_path):
490
        all_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope = 'HumanPoseResnet')
491
        all_vars += tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope = 'NewHumanPoseResnet/Scoring')
492
493
        super()._restore(checkpoint_path, all_vars)
494
495
    def save(self, checkpoint_path, name):
496
        all_vars = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope = 'HumanPoseResnet')
497
        all_vars += tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope = 'NewHumanPoseResnet/Scoring')
498
499
        super()._save(checkpoint_path, name, all_vars)
500
501
    def get_network(self, input_tensor, is_training):
502
        net_end, end_points = part_detector.human_pose_resnet(input_tensor, reuse = self.reuse, training = is_training)
503
504
        return net_end, end_points['features']