--- a +++ b/SegNet/SegNetCMR/layers.py @@ -0,0 +1,49 @@ +import numpy as np + +import tensorflow as tf + +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.framework import ops + + +def unpool_with_argmax(pool, ind, name = None, ksize=[1, 2, 2, 1]): + + """ + Unpooling layer after max_pool_with_argmax. + Args: + pool: max pooled output tensor + ind: argmax indices + ksize: ksize is the same as for the pool + Return: + unpool: unpooling tensor + """ + with tf.variable_scope(name): + input_shape = pool.get_shape().as_list() + output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]) + + flat_input_size = np.prod(input_shape) + flat_output_shape = [output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]] + + pool_ = tf.reshape(pool, [flat_input_size]) + batch_range = tf.reshape(tf.range(output_shape[0], dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1]) + b = tf.ones_like(ind) * batch_range + b = tf.reshape(b, [flat_input_size, 1]) + ind_ = tf.reshape(ind, [flat_input_size, 1]) + ind_ = tf.concat([b, ind_], 1) #交换了两个参数--remove!! + + ret = tf.scatter_nd(ind_, pool_, shape=flat_output_shape) + ret = tf.reshape(ret, output_shape) + return ret + +try: + @ops.RegisterGradient("MaxPoolWithArgmax") + def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad): + return gen_nn_ops._max_pool_grad_with_argmax(op.inputs[0], + grad, + op.outputs[1], + op.get_attr("ksize"), + op.get_attr("strides"), + padding=op.get_attr("padding")) +except Exception as e: + print(f"Could not add gradient for MaxPoolWithArgMax, Likely installed already (tf 1.4)") + print(e)