Switch to unified view

a b/SegNet/SegNetCMR/layers.py
1
import numpy as np
2
3
import tensorflow as tf
4
5
from tensorflow.python.ops import gen_nn_ops
6
from tensorflow.python.framework import ops
7
8
9
def unpool_with_argmax(pool, ind, name = None, ksize=[1, 2, 2, 1]):
10
11
    """
12
       Unpooling layer after max_pool_with_argmax.
13
       Args:
14
           pool:   max pooled output tensor
15
           ind:      argmax indices
16
           ksize:     ksize is the same as for the pool
17
       Return:
18
           unpool:    unpooling tensor
19
    """
20
    with tf.variable_scope(name):
21
        input_shape = pool.get_shape().as_list()
22
        output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3])
23
24
        flat_input_size = np.prod(input_shape)
25
        flat_output_shape = [output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]]
26
27
        pool_ = tf.reshape(pool, [flat_input_size])
28
        batch_range = tf.reshape(tf.range(output_shape[0], dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1])
29
        b = tf.ones_like(ind) * batch_range
30
        b = tf.reshape(b, [flat_input_size, 1])
31
        ind_ = tf.reshape(ind, [flat_input_size, 1])
32
        ind_ = tf.concat([b, ind_], 1) #交换了两个参数--remove!!
33
34
        ret = tf.scatter_nd(ind_, pool_, shape=flat_output_shape)
35
        ret = tf.reshape(ret, output_shape)
36
        return ret
37
38
try:
39
    @ops.RegisterGradient("MaxPoolWithArgmax")
40
    def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad):
41
        return gen_nn_ops._max_pool_grad_with_argmax(op.inputs[0],
42
                                                     grad,
43
                                                     op.outputs[1],
44
                                                     op.get_attr("ksize"),
45
                                                     op.get_attr("strides"),
46
                                                     padding=op.get_attr("padding"))
47
except Exception as e:
48
    print(f"Could not add gradient for MaxPoolWithArgMax, Likely installed already (tf 1.4)")
49
    print(e)