Diff of /lungs/inflate.py [000000] .. [eac570]

Switch to side-by-side view

--- a
+++ b/lungs/inflate.py
@@ -0,0 +1,49 @@
+import tensorflow as tf
+import numpy as np
+# pylint: disable=no-name-in-module
+from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
+
+from lungs import i3d
+
+def assign(global_vars, model_path):
+    reader = tf.pywrap_tensorflow.NewCheckpointReader(model_path)
+    for var_3d in global_vars:
+        if 'Logits' not in var_3d.op.name:
+            var_2d_name = var_3d.op.name.replace('RGB/inception_i3d', 'InceptionV1')
+            var_2d_name = var_2d_name.replace('Conv3d', 'Conv2d')
+            var_2d_name = var_2d_name.replace('conv_3d/w', 'weights')
+            var_2d_name = var_2d_name.replace('conv_3d/b', 'biases')
+            var_2d_name = var_2d_name.replace('batch_norm', 'BatchNorm')
+            
+            var_value = reader.get_tensor(var_2d_name)
+
+            if len(var_3d.get_shape()) - 1 == len(var_value.shape):
+                n = var_value.shape[0]
+                inflated = np.tile(np.expand_dims(var_value / n, 0), [n,1,1,1,1])
+                var_3d.assign(tf.convert_to_tensor(inflated))
+            else:
+                var_3d.assign(tf.convert_to_tensor(var_value.reshape(var_3d.get_shape())))
+
+def inflate_inception_v1_checkpoint_to_i3d(ckpt_2d, ckpt_3d, num_slices=145, volume_size=224, num_classes=2):
+    '''
+    Bootstrap the filters from a pre-trained two-dimensional 
+    Inception-v1 checkpoint into a three-dimensional I3D checkpoint.
+    
+    Usage example:
+    inflate_inception_v1_checkpoint_to_i3d('path/to/inception_v1.ckpt', 'path/to/new/i3d_model.ckpt')
+    '''
+
+    rgb_input = tf.placeholder(tf.float32,
+        shape=(1, num_slices, volume_size, volume_size, 3))
+    with tf.variable_scope('RGB'):
+        i3d.InceptionI3d(num_classes, spatial_squeeze=True, final_endpoint='Logits')\
+                        (rgb_input, is_training=False)
+    
+    init_op = tf.compat.v1.global_variables_initializer()
+    with tf.Session() as sess:
+        sess.run(init_op)
+        assign(tf.global_variables(), ckpt_2d)
+        tf.train.Saver().save(sess, ckpt_3d)
+
+def inspect_checkpoint(ckpt):
+    print_tensors_in_checkpoint_file(ckpt, all_tensors=False, tensor_name='')