Diff of /lungs/inflate.py [000000] .. [eac570]

Switch to unified view

a b/lungs/inflate.py
1
import tensorflow as tf
2
import numpy as np
3
# pylint: disable=no-name-in-module
4
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
5
6
from lungs import i3d
7
8
def assign(global_vars, model_path):
9
    reader = tf.pywrap_tensorflow.NewCheckpointReader(model_path)
10
    for var_3d in global_vars:
11
        if 'Logits' not in var_3d.op.name:
12
            var_2d_name = var_3d.op.name.replace('RGB/inception_i3d', 'InceptionV1')
13
            var_2d_name = var_2d_name.replace('Conv3d', 'Conv2d')
14
            var_2d_name = var_2d_name.replace('conv_3d/w', 'weights')
15
            var_2d_name = var_2d_name.replace('conv_3d/b', 'biases')
16
            var_2d_name = var_2d_name.replace('batch_norm', 'BatchNorm')
17
            
18
            var_value = reader.get_tensor(var_2d_name)
19
20
            if len(var_3d.get_shape()) - 1 == len(var_value.shape):
21
                n = var_value.shape[0]
22
                inflated = np.tile(np.expand_dims(var_value / n, 0), [n,1,1,1,1])
23
                var_3d.assign(tf.convert_to_tensor(inflated))
24
            else:
25
                var_3d.assign(tf.convert_to_tensor(var_value.reshape(var_3d.get_shape())))
26
27
def inflate_inception_v1_checkpoint_to_i3d(ckpt_2d, ckpt_3d, num_slices=145, volume_size=224, num_classes=2):
28
    '''
29
    Bootstrap the filters from a pre-trained two-dimensional 
30
    Inception-v1 checkpoint into a three-dimensional I3D checkpoint.
31
    
32
    Usage example:
33
    inflate_inception_v1_checkpoint_to_i3d('path/to/inception_v1.ckpt', 'path/to/new/i3d_model.ckpt')
34
    '''
35
36
    rgb_input = tf.placeholder(tf.float32,
37
        shape=(1, num_slices, volume_size, volume_size, 3))
38
    with tf.variable_scope('RGB'):
39
        i3d.InceptionI3d(num_classes, spatial_squeeze=True, final_endpoint='Logits')\
40
                        (rgb_input, is_training=False)
41
    
42
    init_op = tf.compat.v1.global_variables_initializer()
43
    with tf.Session() as sess:
44
        sess.run(init_op)
45
        assign(tf.global_variables(), ckpt_2d)
46
        tf.train.Saver().save(sess, ckpt_3d)
47
48
def inspect_checkpoint(ckpt):
49
    print_tensors_in_checkpoint_file(ckpt, all_tensors=False, tensor_name='')