--- a
+++ b/layers/vae.py
@@ -0,0 +1,148 @@
+"""Contains custom variational autoencoder class."""
+import tensorflow as tf
+
+from layers.downsample import get_downsampling
+from layers.upsample import get_upsampling
+from layers.resnet import ResnetBlock
+
+
+def sample(inputs):
+    """Samples from the Gaussian given by mean and variance."""
+    z_mean, z_logvar = inputs
+    eps = tf.random.normal(shape=z_mean.shape, dtype=tf.float32)
+    return z_mean + tf.math.exp(0.5 * z_logvar) * eps
+
+
+class VariationalAutoencoder(tf.keras.layers.Layer):
+    def __init__(self,
+                 data_format='channels_last',
+                 groups=8,
+                 reduction=2,
+                 l2_scale=1e-5,
+                 downsampling='conv',
+                 upsampling='conv',
+                 base_filters=16,
+                 depth=4,
+                 out_ch=2):
+        """ Initializes the variational autoencoder: consists of sampling
+            then an alternating series of SENet blocks and upsampling.
+
+            References:
+                - [3D MRI brain tumor segmentation using autoencoder regularization](https://arxiv.org/pdf/1810.11654.pdf)
+        """
+        super(VariationalAutoencoder, self).__init__()
+        # Set up config for self.get_config() to serialize later.
+        self.data_format = data_format
+        self.l2_scale = l2_scale
+        self.config = super(VariationalAutoencoder, self).get_config()
+        self.config.update({'groups': groups,
+                            'reduction': reduction,
+                            'downsampling': downsampling,
+                            'upsampling': upsampling,
+                            'base_filters': base_filters,
+                            'depth': depth,
+                            'out_ch': out_ch})
+
+        # Retrieve downsampling method.
+        Downsample = get_downsampling(downsampling)
+
+        # Retrieve upsampling method.
+        Upsample = get_upsampling(upsampling)
+
+        # Extra downsampling layer to reduce parameters.
+        self.downsample = Downsample(
+                            filters=base_filters//2,
+                            groups=groups,
+                            data_format=data_format,
+                            kernel_regularizer=tf.keras.regularizers.l2(l=l2_scale))
+
+        # Build sampling layers.
+        self.flatten = tf.keras.layers.Flatten(data_format)
+        self.proj = tf.keras.layers.Dense(
+                            units=base_filters*(2**(depth-1)),
+                            kernel_regularizer=tf.keras.regularizers.l2(l=l2_scale),
+                            kernel_initializer='he_normal')
+        self.latent_size = base_filters*(2**(depth-2))
+        self.sample = tf.keras.layers.Lambda(sample)
+
+        # Extra upsampling layer to counter extra downsampling layer.
+        self.upsample = Upsample(
+                            filters=base_filters*(2**(depth-1)),
+                            groups=groups,
+                            data_format=data_format,
+                            l2_scale=l2_scale)
+
+        # Build layers at all spatial levels.
+        self.levels = []
+        for i in range(depth-2, -1, -1):
+            upsample = Upsample(
+                        filters=base_filters*(2**i),
+                        groups=groups,
+                        data_format=data_format,
+                        l2_scale=l2_scale)
+            conv = ResnetBlock(
+                        filters=base_filters*(2**i),
+                        groups=groups,
+                        reduction=reduction,
+                        data_format=data_format,
+                        l2_scale=l2_scale)
+            self.levels.append([upsample, conv])
+
+        # Output layer convolution.
+        self.out = tf.keras.layers.Conv3D(
+                            filters=out_ch,
+                            kernel_size=3,
+                            strides=1,
+                            padding='same',
+                            data_format=data_format,
+                            kernel_regularizer=tf.keras.regularizers.l2(l=l2_scale),
+                            kernel_initializer='he_normal')
+
+    def build(self, input_shape):
+        h, w, d = input_shape[1:-1] if self.data_format == 'channels_last' else input_shape[2:]
+
+        # Build reshaping layers after sampling.
+        self.unproj = tf.keras.layers.Dense(
+                                units=h*w*d*1//8,
+                                kernel_regularizer=tf.keras.regularizers.l2(l=self.l2_scale),
+                                kernel_initializer='he_normal',
+                                activation='relu')
+        self.unflatten = tf.keras.layers.Reshape(
+                                (h//2, w//2, d//2, 1) if self.data_format == 'channels_last' else (1, h//2, w//2, d//2))
+
+
+    def call(self, inputs, training=None):
+        # Downsample.
+        inputs = self.downsample(inputs)
+
+        # Flatten and project
+        inputs = self.flatten(inputs)
+        inputs = self.proj(inputs)
+
+        # Sample.
+        z_mean = inputs[:, :self.latent_size]
+        z_logvar = inputs[:, self.latent_size:]
+        inputs = self.sample([z_mean, z_logvar])
+
+        # Restored projection and reshape
+        inputs = self.unproj(inputs)
+        inputs = self.unflatten(inputs)
+
+        # Upsample.
+        inputs = self.upsample(inputs)
+
+        # Iterate through spatial levels.
+        for level in self.levels:
+            upsample, conv = level
+            inputs = upsample(inputs, training=training)
+            inputs = conv(inputs, training=training)
+
+        # Map convolution to number of original input channels.
+        inputs = self.out(inputs)
+
+        return inputs, z_mean, z_logvar
+
+    def get_config(self):
+        self.config.update({'data_format': self.data_format,
+                            'l2_scale': self.l2_scale})
+        return self.config