Diff of /layers/upsample.py [000000] .. [408896]

Switch to side-by-side view

--- a
+++ b/layers/upsample.py
@@ -0,0 +1,79 @@
+"""Contains custom upsampling classes."""
+import tensorflow as tf
+
+from layers.group_norm import GroupNormalization
+
+
+def get_upsampling(upsampling):
+    if upsampling == 'linear':
+        return LinearUpsample
+    elif upsampling == 'conv':
+        return ConvUpsample
+
+
+class ConvUpsample(tf.keras.layers.Layer):
+    def __init__(self,
+                 filters,
+                 groups=8,
+                 data_format='channels_last',
+                 l2_scale=1e-5,
+                 **kwargs):
+        super(ConvUpsample, self).__init__()
+        self.config = super(ConvUpsample, self).get_config()
+        self.config.update({'filters': filters,
+                            'data_format': data_format,
+                            'groups': groups,
+                            'l2_scale': l2_scale})
+
+        self.conv = tf.keras.layers.Conv3DTranspose(
+                            filters=filters,
+                            kernel_size=3,
+                            strides=2,
+                            padding='same',
+                            data_format=data_format)
+        self.norm = GroupNormalization(
+                            groups=groups,
+                            axis=-1 if data_format == 'channels_last' else 1)
+        self.relu = tf.keras.layers.Activation('relu')
+
+    def __call__(self, inputs, training=None):
+        inputs = self.conv(inputs)
+        inputs = self.norm(inputs, training=training)
+        inputs = self.relu(inputs)
+        return inputs
+
+    def get_config(self):
+        return self.config
+
+
+class LinearUpsample(tf.keras.layers.Layer):
+    def __init__(self,
+                 filters,
+                 data_format='channels_last',
+                 l2_scale=1e-5,
+                 **kwargs):
+        super(LinearUpsample, self).__init__()
+        self.config = super(LinearUpsample, self).get_config()
+        self.config.update({'filters': filters,
+                            'data_format': data_format,
+                            'l2_scale': l2_scale})
+
+        self.ptwise = tf.keras.layers.Conv3D(
+                                filters=filters,
+                                kernel_size=1,
+                                strides=1,
+                                padding='same',
+                                data_format=data_format,
+                                kernel_regularizer=tf.keras.regularizers.l2(l=l2_scale),
+                                kernel_initializer='he_normal')
+        self.linear = tf.keras.layers.UpSampling3D(
+                                size=2,
+                                data_format=data_format)
+
+    def __call__(self, inputs, training=None):
+        inputs = self.ptwise(inputs)
+        inputs = self.linear(inputs)
+        return inputs
+
+    def get_config(self):
+        return self.config