Switch to side-by-side view

--- a
+++ b/RefineNet & SESNet/nets/convlstm_cell.py
@@ -0,0 +1,138 @@
+import tensorflow as tf
+
+
+class ConvLSTMCell(tf.nn.rnn_cell.RNNCell):
+	"""A LSTM cell with convolutions instead of multiplications.
+
+	Reference:
+	  Xingjian, S. H. I., et al. "Convolutional LSTM network: A machine learning approach for precipitation nowcasting." Advances in Neural Information Processing Systems. 2015.
+	"""
+
+	def __init__(self, shape, filters, kernel, forget_bias=1.0, activation=tf.tanh, normalize=True, peephole=True,
+	             data_format='channels_last', reuse=None):
+		super(ConvLSTMCell, self).__init__(_reuse=reuse)
+		self._kernel = kernel
+		self._filters = filters
+		self._forget_bias = forget_bias
+		self._activation = activation
+		self._normalize = normalize
+		self._peephole = peephole
+		if data_format == 'channels_last':
+			self._size = tf.TensorShape(shape + [self._filters])
+			self._feature_axis = self._size.ndims
+			self._data_format = None
+		elif data_format == 'channels_first':
+			self._size = tf.TensorShape([self._filters] + shape)
+			self._feature_axis = 0
+			self._data_format = 'NC'
+		else:
+			raise ValueError('Unknown data_format')
+
+	@property
+	def state_size(self):
+		return tf.nn.rnn_cell.LSTMStateTuple(self._size, self._size)
+
+	@property
+	def output_size(self):
+		return self._size
+
+	def call(self, x, state):
+		c, h = state
+
+		x = tf.concat([x, h], axis=self._feature_axis)
+		n = x.shape[-1].value
+		m = 4 * self._filters if self._filters > 1 else 4
+		W = tf.get_variable('kernel', self._kernel + [n, m])
+		y = tf.nn.convolution(x, W, 'SAME', data_format=self._data_format)
+		if not self._normalize:
+			y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer())
+		j, i, f, o = tf.split(y, 4, axis=self._feature_axis)
+
+		if self._peephole:
+			i += tf.get_variable('W_ci', c.shape[1:]) * c
+			f += tf.get_variable('W_cf', c.shape[1:]) * c
+
+		if self._normalize:
+			j = tf.contrib.layers.layer_norm(j)
+			i = tf.contrib.layers.layer_norm(i)
+			f = tf.contrib.layers.layer_norm(f)
+
+		f = tf.sigmoid(f + self._forget_bias)
+		i = tf.sigmoid(i)
+		c = c * f + i * self._activation(j)
+
+		if self._peephole:
+			o += tf.get_variable('W_co', c.shape[1:]) * c
+
+		if self._normalize:
+			o = tf.contrib.layers.layer_norm(o)
+			c = tf.contrib.layers.layer_norm(c)
+
+		o = tf.sigmoid(o)
+		h = o * self._activation(c)
+
+		state = tf.nn.rnn_cell.LSTMStateTuple(c, h)
+
+		return h, state
+
+
+class ConvGRUCell(tf.nn.rnn_cell.RNNCell):
+	"""A GRU cell with convolutions instead of multiplications."""
+
+	def __init__(self, shape, filters, kernel, activation=tf.tanh, normalize=True, data_format='channels_last',
+	             reuse=None):
+		super(ConvGRUCell, self).__init__(_reuse=reuse)
+		self._filters = filters
+		self._kernel = kernel
+		self._activation = activation
+		self._normalize = normalize
+		if data_format == 'channels_last':
+			self._size = tf.TensorShape(shape + [self._filters])
+			self._feature_axis = self._size.ndims
+			self._data_format = None
+		elif data_format == 'channels_first':
+			self._size = tf.TensorShape([self._filters] + shape)
+			self._feature_axis = 0
+			self._data_format = 'NC'
+		else:
+			raise ValueError('Unknown data_format')
+
+	@property
+	def state_size(self):
+		return self._size
+
+	@property
+	def output_size(self):
+		return self._size
+
+	def call(self, x, h):
+		channels = x.shape[self._feature_axis].value
+
+		with tf.variable_scope('gates'):
+			inputs = tf.concat([x, h], axis=self._feature_axis)
+			n = channels + self._filters
+			m = 2 * self._filters if self._filters > 1 else 2
+			W = tf.get_variable('kernel', self._kernel + [n, m])
+			y = tf.nn.convolution(inputs, W, 'SAME', data_format=self._data_format)
+			if self._normalize:
+				r, u = tf.split(y, 2, axis=self._feature_axis)
+				r = tf.contrib.layers.layer_norm(r)
+				u = tf.contrib.layers.layer_norm(u)
+			else:
+				y += tf.get_variable('bias', [m], initializer=tf.ones_initializer())
+				r, u = tf.split(y, 2, axis=self._feature_axis)
+			r, u = tf.sigmoid(r), tf.sigmoid(u)
+
+		with tf.variable_scope('candidate'):
+			inputs = tf.concat([x, r * h], axis=self._feature_axis)
+			n = channels + self._filters
+			m = self._filters
+			W = tf.get_variable('kernel', self._kernel + [n, m])
+			y = tf.nn.convolution(inputs, W, 'SAME', data_format=self._data_format)
+			if self._normalize:
+				y = tf.contrib.layers.layer_norm(y)
+			else:
+				y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer())
+			h = u * h + (1 - u) * self._activation(y)
+
+		return h, h