--- a
+++ b/tf_layers.py
@@ -0,0 +1,280 @@
+import tensorflow as tf
+from tensorflow.contrib.keras.python.keras.layers import *
+
+
+def Conv3DWithBN(x, filters, ksize, strides, name, padding='same', dilation_rate=1, center=True, scale=True, decay=0.99):
+    x = Conv3D(filters=filters, kernel_size=ksize, strides=strides, padding=padding, dilation_rate=dilation_rate,
+                        use_bias=False, kernel_initializer='he_normal', name=name+'_conv')(x)
+    x = BatchNormalization(center=center, scale=scale, momentum=decay, name=name+'_bn')(x)
+    x = Activation('relu', name=name+'_relu')(x)
+    return x
+
+
+def Conv2DWithBN(x, filters, ksize, strides, name, padding='same', dilation_rate=1, center=True, scale=True, decay=0.99):
+    x = Conv2D(filters=filters, kernel_size=ksize, strides=strides, padding=padding, dilation_rate=dilation_rate,
+                        use_bias=False, kernel_initializer='he_normal', name=name+'_conv')(x)
+    x = BatchNormalization(center=center, scale=scale, momentum=decay, name=name+'_bn')(x)
+    x = Activation('relu', name=name+'_relu')(x)
+    return x
+
+
+def Conv1DWithBN(x, filters, ksize, strides, name, padding='same', dilation_rate=1, center=True, scale=True, decay=0.99):
+    x = Conv1D(filters=filters, kernel_size=ksize, strides=strides, padding=padding, dilation_rate=dilation_rate,
+               use_bias=False, kernel_initializer='he_normal', name=name+'_conv')(x)
+    x = BatchNormalization(center=center, scale=scale, momentum=decay, name=name+'_bn')(x)
+    x = Activation('relu', name=name+'_relu')(x)
+    return x
+
+
+def DenseWithBN(x, units, name, kernel_regularizer=None, center=True, scale=True, decay=0.99):
+    x = Dense(units=units, use_bias=False, kernel_regularizer=kernel_regularizer, name=name+'_weight')(x)
+    x = BatchNormalization(center=center, scale=scale, momentum=decay, name=name+'_bias')(x)
+    x = Activation('relu', name=name+'_relu')(x)
+    return x
+
+
+def ResNetUnit2D(x, filters, ksize, name, end=False, center=True, scale=True, decay=0.99):
+    identity = x
+
+    x = BatchNormalization(center=center, scale=scale, momentum=decay, name=name+'_bn_1')(x)
+    x = Activation('relu', name=name+'_relu_1')(x)
+    x = Conv2D(filters, kernel_size=ksize, strides=1, padding='same', kernel_initializer='he_normal', name=name+'_conv_1')(x)
+
+    x = BatchNormalization(center=center, scale=scale, momentum=decay, name=name+'_bn_2')(x)
+    x = Activation('relu', name=name+'_relu_2')(x)
+    x = Conv2D(filters, kernel_size=ksize, strides=1, padding='same', kernel_initializer='he_normal', name=name+'_conv2')(x)
+
+    x = add([x, identity])
+    if end:
+        x = BatchNormalization(center=center, scale=scale, momentum=decay)(x)
+        x = Activation('relu')(x)
+    return x
+
+
+def ResNetUnitIncreasingDims2D(x, filters, ksize, strides, name, begin=False, center=True, scale=True, decay=0.99):
+    identity = x
+
+    if not begin:
+        x = BatchNormalization(center=center, scale=scale, momentum=decay, name=name+'_bn_1')(x)
+        x = Activation('relu', name=name+'_relu_1')(x)
+    x = Conv2D(filters, kernel_size=ksize, strides=strides[0], padding='same', kernel_initializer='he_normal', use_bias=False, name=name+'_conv_1')(x)
+
+    x = BatchNormalization(center=center, scale=scale, momentum=decay, name=name+'_bn_2')(x)
+    x = Activation('relu', name=name+'_relu_2')(x)
+    x = Conv2D(filters, kernel_size=ksize, strides=strides[1], padding='same', kernel_initializer='he_normal', use_bias=False, name=name+'_conv_2')(x)
+
+    identity = Conv2D(filters, kernel_size=1, strides=strides[0], padding='same', kernel_initializer='he_normal', name=name+'_conv_identity')(identity)
+    x = add([x, identity])
+    return x
+
+
+def ResNetUnit1D(x, filters, ksize, name, end=False, center=True, scale=True, decay=0.99):
+    identity = x
+
+    x = BatchNormalization(center=center, scale=scale, momentum=decay, name=name+'_bn_1')(x)
+    x = Activation('relu', name=name+'_relu_1')(x)
+    x = Conv1D(filters, kernel_size=ksize, strides=1, padding='same', kernel_initializer='he_normal', use_bias=False, name=name+'_conv_1')(x)
+
+    x = BatchNormalization(center=center, scale=scale, momentum=decay, name=name+'_bn_2')(x)
+    x = Activation('relu', name=name+'_relu_2')(x)
+    x = Conv1D(filters, kernel_size=ksize, strides=1, padding='same', kernel_initializer='he_normal', use_bias=False, name=name+'_conv_2')(x)
+
+    x = add([x, identity])
+    if end:
+        x = BatchNormalization(center=center, scale=scale, momentum=decay, name=name+'_bn_3')(x)
+        x = Activation('relu', name=name+'_relu_3')(x)
+    return x
+
+
+def ResNetUnitIncreasingDims1D(x, filters, ksize, strides, name, begin=False, center=True, scale=True, decay=0.99):
+    '''
+    ResNet unit without BottleNeck. 2 layers
+    :param x:
+    :param filters:
+    :param ksize:
+    :param strides: list with 2 elements, stride for each layer
+    :param begin:
+    :return:
+    '''
+
+    identity = x
+
+    if not begin:
+        x = BatchNormalization(center=center, scale=scale, momentum=decay, name=name+'_bn_1')(x)
+        x = Activation('relu', name=name+'_relu_1')(x)
+    x = Conv1D(filters, kernel_size=ksize, strides=strides[0], padding='same', kernel_initializer='he_normal', use_bias=False, name=name+'_conv_1')(x)
+
+    x = BatchNormalization(center=center, scale=scale, momentum=decay, name=name+'_bn_2')(x)
+    x = Activation('relu', name=name+'_relu_2')(x)
+    x = Conv1D(filters, kernel_size=ksize, strides=strides[1], padding='same', kernel_initializer='he_normal', use_bias=False, name=name+'_conv_2')(x)
+
+    identity = Conv1D(filters, kernel_size=1, strides=strides[0], padding='same', kernel_initializer='he_normal', name=name+'_conv_identity')(identity)
+    x = add([x, identity])
+    return x
+
+
+def ContextualAvgPooling(x, ksizes, strides):
+    '''
+    Concatenate input with pooling result
+    :param x:
+    :param ksizes:
+    :param strides:
+    :return:
+    '''
+    out = None
+    for ks in ksizes:
+        x_pooled = AvgPool1D(pool_size=ks, strides=strides, padding='same')
+        if out is None:
+            out = x_pooled
+        else:
+            out = concatenate([out, x_pooled], axis=-1)
+    x = concatenate([x, out], axis=-1)
+    return x
+
+
+def ContextualAtrousConv1D(x, filters, ksize, strides, dilation_rates, name):
+    """
+    Retrieve contextual information with Atrous Convolution
+    :param x:
+    :param filters:
+    :param ksize:
+    :param strides:
+    :param dilation_rates:
+    :return:
+    """
+    concat = x
+    for dr in dilation_rates:
+        x_atrous = Conv1DWithBN(x, filters=filters, ksize=ksize, strides=strides, dilation_rate=dr, name=name+'a_conv_s'+str(dr))
+        concat = concatenate([concat, x_atrous], axis=-1)
+    concat = Conv1DWithBN(concat, filters=256, ksize=ksize, strides=strides, name=name+'atrou_post_conv1')
+    return concat
+
+
+def ContextualAtrousConv3D(x, filters, ksize, strides, dilation_rates, name):
+    """
+    Retrieve contextual information with Atrous Convolution
+    :param x:
+    :param filters:
+    :param ksize:
+    :param strides:
+    :param dilation_rates:
+    :return:
+    """
+    concat = None
+    for dr in dilation_rates:
+        x_atrous = Conv3DWithBN(x, filters=filters, ksize=ksize, strides=strides, dilation_rate=dr, name=name+'a_conv_s'+str(dr))
+        if concat is None:
+            concat = x_atrous
+        else:
+            concat = concatenate([concat, x_atrous], axis=-1)
+    concat = Conv3DWithBN(concat, filters=filters, ksize=ksize, strides=strides, name=name+'atrou_post_conv1')
+    return concat
+
+
+def SharedAtrousConv1D(x, SharedConvs, PostConv):
+    concat = x
+    for SharedConv in SharedConvs:
+        x_atrous = SharedConv(x)
+        x_atrous = BatchNormalization()(x_atrous)
+        x_atrous = Activation('relu')(x_atrous)
+        concat = concatenate([concat, x_atrous], axis=-1)
+    concat = PostConv(concat)
+    concat = BatchNormalization()(concat)
+    concat = Activation('relu')(concat)
+    return concat
+
+
+def densenet_block3d(x, k, rep):
+    dense_input = x
+    for i in range(rep):
+        x_dense = Conv3D(filters=k, kernel_size=3, strides=1, padding='same', activation='relu')(dense_input)
+        dense_input = concatenate([dense_input, x_dense])
+    return dense_input
+
+
+def DenseNetTransit(x, rate=1, name=None):
+    if rate != 1:
+        out_features = x.get_shape().as_list()[-1] * rate
+        x = BatchNormalization(center=True, scale=True, name=name + '_bn')(x)
+        x = Activation('relu', name=name + '_relu')(x)
+        x = Conv3D(filters=out_features, kernel_size=1, strides=1, padding='same', kernel_initializer='he_normal',
+                   use_bias=False, name=name + '_conv')(x)
+    x = AveragePooling3D(pool_size=2, strides=2, padding='same')(x)
+    return x
+
+
+def DenseNetUnit3D(x, growth_rate, ksize, rep, bn_decay=0.99, name=None):
+    for i in range(rep):
+        concat = x
+        x = BatchNormalization(center=True, scale=True, momentum=bn_decay, name=name+'_bn_rep_'+str(i))(x)
+        x = Activation('relu')(x)
+        x = Conv3D(filters=growth_rate, kernel_size=ksize, padding='same',
+                   kernel_initializer='glorot_normal', use_bias=False, name=name+'_conv_rep_'+str(i))(x)
+        x = concatenate([concat, x])
+    return x
+
+
+class BilinearUpsampling3D(Layer):
+    """
+    Wrapping 1D BilinearUpsamling as a Keras layer
+    Input: 3D Tensor (batch, dim, channels)
+    """
+    def __init__(self, size, **kwargs):
+        self.size = size
+        super(BilinearUpsampling3D, self).__init__(**kwargs)
+
+    def build(self, input_shape):
+        super(BilinearUpsampling3D,self).build(input_shape)
+
+    def call(self, x, mask=None):
+        x = tf.expand_dims(x, axis=2)
+        x = tf.image.resize_bilinear(x, [self.size, 1])
+        x = tf.squeeze(x, axis=2)
+        return x
+
+    def get_output_shape_for(self, input_shape):
+        return (input_shape[0], self.size, input_shape[2])
+
+
+# def SharedAtrousConv1D(x, SharedConvs, PostConv):
+#     concat = None
+#     for SharedConv in SharedConvs:
+#         x_atrous = SharedConv(x)
+#         x_atrous = BatchNormalization()(x_atrous)
+#         x_atrous = Activation('relu')(x_atrous)
+#         if concat is None:
+#             concat = x_atrous
+#         else:
+#             concat = concatenate([concat, x_atrous], axis=-1)
+#     concat = PostConv(concat)
+#     concat = BatchNormalization()(concat)
+#     concat = Activation('relu')(concat)
+#     return concat
+
+
+class BilinearUpsampling1D(Layer):
+    """
+    Wrapping 1D BilinearUpsamling as a Keras layer
+    Input: 3D Tensor (batch, dim, channels)
+    """
+    def __init__(self, size, **kwargs):
+        self.size = size
+        super(BilinearUpsampling1D, self).__init__(**kwargs)
+
+    def build(self, input_shape):
+        super(BilinearUpsampling1D,self).build(input_shape)
+
+    def call(self, x, mask=None):
+        x = tf.expand_dims(x, axis=2)
+        x = tf.image.resize_bilinear(x, [self.size, 1])
+        x = tf.squeeze(x, axis=2)
+        return x
+
+    def get_output_shape_for(self, input_shape):
+        return (input_shape[0], self.size, input_shape[2])
+
+
+
+
+
+