|
a |
|
b/lungs/inflate.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
import numpy as np |
|
|
3 |
# pylint: disable=no-name-in-module |
|
|
4 |
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file |
|
|
5 |
|
|
|
6 |
from lungs import i3d |
|
|
7 |
|
|
|
8 |
def assign(global_vars, model_path): |
|
|
9 |
reader = tf.pywrap_tensorflow.NewCheckpointReader(model_path) |
|
|
10 |
for var_3d in global_vars: |
|
|
11 |
if 'Logits' not in var_3d.op.name: |
|
|
12 |
var_2d_name = var_3d.op.name.replace('RGB/inception_i3d', 'InceptionV1') |
|
|
13 |
var_2d_name = var_2d_name.replace('Conv3d', 'Conv2d') |
|
|
14 |
var_2d_name = var_2d_name.replace('conv_3d/w', 'weights') |
|
|
15 |
var_2d_name = var_2d_name.replace('conv_3d/b', 'biases') |
|
|
16 |
var_2d_name = var_2d_name.replace('batch_norm', 'BatchNorm') |
|
|
17 |
|
|
|
18 |
var_value = reader.get_tensor(var_2d_name) |
|
|
19 |
|
|
|
20 |
if len(var_3d.get_shape()) - 1 == len(var_value.shape): |
|
|
21 |
n = var_value.shape[0] |
|
|
22 |
inflated = np.tile(np.expand_dims(var_value / n, 0), [n,1,1,1,1]) |
|
|
23 |
var_3d.assign(tf.convert_to_tensor(inflated)) |
|
|
24 |
else: |
|
|
25 |
var_3d.assign(tf.convert_to_tensor(var_value.reshape(var_3d.get_shape()))) |
|
|
26 |
|
|
|
27 |
def inflate_inception_v1_checkpoint_to_i3d(ckpt_2d, ckpt_3d, num_slices=145, volume_size=224, num_classes=2): |
|
|
28 |
''' |
|
|
29 |
Bootstrap the filters from a pre-trained two-dimensional |
|
|
30 |
Inception-v1 checkpoint into a three-dimensional I3D checkpoint. |
|
|
31 |
|
|
|
32 |
Usage example: |
|
|
33 |
inflate_inception_v1_checkpoint_to_i3d('path/to/inception_v1.ckpt', 'path/to/new/i3d_model.ckpt') |
|
|
34 |
''' |
|
|
35 |
|
|
|
36 |
rgb_input = tf.placeholder(tf.float32, |
|
|
37 |
shape=(1, num_slices, volume_size, volume_size, 3)) |
|
|
38 |
with tf.variable_scope('RGB'): |
|
|
39 |
i3d.InceptionI3d(num_classes, spatial_squeeze=True, final_endpoint='Logits')\ |
|
|
40 |
(rgb_input, is_training=False) |
|
|
41 |
|
|
|
42 |
init_op = tf.compat.v1.global_variables_initializer() |
|
|
43 |
with tf.Session() as sess: |
|
|
44 |
sess.run(init_op) |
|
|
45 |
assign(tf.global_variables(), ckpt_2d) |
|
|
46 |
tf.train.Saver().save(sess, ckpt_3d) |
|
|
47 |
|
|
|
48 |
def inspect_checkpoint(ckpt): |
|
|
49 |
print_tensors_in_checkpoint_file(ckpt, all_tensors=False, tensor_name='') |