|
a |
|
b/SegNet/SegNetCMR/layers.py |
|
|
1 |
import numpy as np |
|
|
2 |
|
|
|
3 |
import tensorflow as tf |
|
|
4 |
|
|
|
5 |
from tensorflow.python.ops import gen_nn_ops |
|
|
6 |
from tensorflow.python.framework import ops |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
def unpool_with_argmax(pool, ind, name = None, ksize=[1, 2, 2, 1]): |
|
|
10 |
|
|
|
11 |
""" |
|
|
12 |
Unpooling layer after max_pool_with_argmax. |
|
|
13 |
Args: |
|
|
14 |
pool: max pooled output tensor |
|
|
15 |
ind: argmax indices |
|
|
16 |
ksize: ksize is the same as for the pool |
|
|
17 |
Return: |
|
|
18 |
unpool: unpooling tensor |
|
|
19 |
""" |
|
|
20 |
with tf.variable_scope(name): |
|
|
21 |
input_shape = pool.get_shape().as_list() |
|
|
22 |
output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]) |
|
|
23 |
|
|
|
24 |
flat_input_size = np.prod(input_shape) |
|
|
25 |
flat_output_shape = [output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]] |
|
|
26 |
|
|
|
27 |
pool_ = tf.reshape(pool, [flat_input_size]) |
|
|
28 |
batch_range = tf.reshape(tf.range(output_shape[0], dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1]) |
|
|
29 |
b = tf.ones_like(ind) * batch_range |
|
|
30 |
b = tf.reshape(b, [flat_input_size, 1]) |
|
|
31 |
ind_ = tf.reshape(ind, [flat_input_size, 1]) |
|
|
32 |
ind_ = tf.concat([b, ind_], 1) #交换了两个参数--remove!! |
|
|
33 |
|
|
|
34 |
ret = tf.scatter_nd(ind_, pool_, shape=flat_output_shape) |
|
|
35 |
ret = tf.reshape(ret, output_shape) |
|
|
36 |
return ret |
|
|
37 |
|
|
|
38 |
try: |
|
|
39 |
@ops.RegisterGradient("MaxPoolWithArgmax") |
|
|
40 |
def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad): |
|
|
41 |
return gen_nn_ops._max_pool_grad_with_argmax(op.inputs[0], |
|
|
42 |
grad, |
|
|
43 |
op.outputs[1], |
|
|
44 |
op.get_attr("ksize"), |
|
|
45 |
op.get_attr("strides"), |
|
|
46 |
padding=op.get_attr("padding")) |
|
|
47 |
except Exception as e: |
|
|
48 |
print(f"Could not add gradient for MaxPoolWithArgMax, Likely installed already (tf 1.4)") |
|
|
49 |
print(e) |