Switch to side-by-side view

--- a
+++ b/SegNet/SegNetCMR/layers.py
@@ -0,0 +1,49 @@
+import numpy as np
+
+import tensorflow as tf
+
+from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.framework import ops
+
+
+def unpool_with_argmax(pool, ind, name = None, ksize=[1, 2, 2, 1]):
+
+    """
+       Unpooling layer after max_pool_with_argmax.
+       Args:
+           pool:   max pooled output tensor
+           ind:      argmax indices
+           ksize:     ksize is the same as for the pool
+       Return:
+           unpool:    unpooling tensor
+    """
+    with tf.variable_scope(name):
+        input_shape = pool.get_shape().as_list()
+        output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3])
+
+        flat_input_size = np.prod(input_shape)
+        flat_output_shape = [output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]]
+
+        pool_ = tf.reshape(pool, [flat_input_size])
+        batch_range = tf.reshape(tf.range(output_shape[0], dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1])
+        b = tf.ones_like(ind) * batch_range
+        b = tf.reshape(b, [flat_input_size, 1])
+        ind_ = tf.reshape(ind, [flat_input_size, 1])
+        ind_ = tf.concat([b, ind_], 1) #交换了两个参数--remove!!
+
+        ret = tf.scatter_nd(ind_, pool_, shape=flat_output_shape)
+        ret = tf.reshape(ret, output_shape)
+        return ret
+
+try:
+    @ops.RegisterGradient("MaxPoolWithArgmax")
+    def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad):
+        return gen_nn_ops._max_pool_grad_with_argmax(op.inputs[0],
+                                                     grad,
+                                                     op.outputs[1],
+                                                     op.get_attr("ksize"),
+                                                     op.get_attr("strides"),
+                                                     padding=op.get_attr("padding"))
+except Exception as e:
+    print(f"Could not add gradient for MaxPoolWithArgMax, Likely installed already (tf 1.4)")
+    print(e)