[bc8010]: / SegNet / SegNetCMR / layers.py

Download this file

50 lines (40 with data), 2.0 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.framework import ops
def unpool_with_argmax(pool, ind, name = None, ksize=[1, 2, 2, 1]):
"""
Unpooling layer after max_pool_with_argmax.
Args:
pool: max pooled output tensor
ind: argmax indices
ksize: ksize is the same as for the pool
Return:
unpool: unpooling tensor
"""
with tf.variable_scope(name):
input_shape = pool.get_shape().as_list()
output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3])
flat_input_size = np.prod(input_shape)
flat_output_shape = [output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]]
pool_ = tf.reshape(pool, [flat_input_size])
batch_range = tf.reshape(tf.range(output_shape[0], dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1])
b = tf.ones_like(ind) * batch_range
b = tf.reshape(b, [flat_input_size, 1])
ind_ = tf.reshape(ind, [flat_input_size, 1])
ind_ = tf.concat([b, ind_], 1) #交换了两个参数--remove!!
ret = tf.scatter_nd(ind_, pool_, shape=flat_output_shape)
ret = tf.reshape(ret, output_shape)
return ret
try:
@ops.RegisterGradient("MaxPoolWithArgmax")
def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad):
return gen_nn_ops._max_pool_grad_with_argmax(op.inputs[0],
grad,
op.outputs[1],
op.get_attr("ksize"),
op.get_attr("strides"),
padding=op.get_attr("padding"))
except Exception as e:
print(f"Could not add gradient for MaxPoolWithArgMax, Likely installed already (tf 1.4)")
print(e)