--- a
+++ b/model.py
@@ -0,0 +1,71 @@
+import tensorflow as tf
+
+from layers.encoder import Encoder
+from layers.decoder import Decoder
+from layers.vae import VariationalAutoencoder
+
+
+class Model(tf.keras.models.Model):
+    def __init__(self,
+                 data_format='channels_last',
+                 groups=8,
+                 reduction=2,
+                 l2_scale=1e-5,
+                 dropout=0.2,
+                 downsampling='conv',
+                 upsampling='conv',
+                 base_filters=16,
+                 depth=4,
+                 in_ch=2,
+                 out_ch=3):
+        """ Initializes the model, a cross between the 3D U-net
+            and 2018 BraTS Challenge top model with VAE regularization.
+
+            References:
+                - [3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation](https://arxiv.org/pdf/1606.06650.pdf)
+                - [3D MRI brain tumor segmentation using autoencoder regularization](https://arxiv.org/pdf/1810.11654.pdf)
+        """
+        super(Model, self).__init__()
+        self.epoch = tf.Variable(0, name='epoch', trainable=False)
+        self.encoder = Encoder(
+                            data_format=data_format,
+                            groups=groups,
+                            reduction=reduction,
+                            l2_scale=l2_scale,
+                            dropout=dropout,
+                            downsampling=downsampling,
+                            base_filters=base_filters,
+                            depth=depth)
+        self.decoder = Decoder(
+                            data_format=data_format,
+                            groups=groups,
+                            reduction=reduction,
+                            l2_scale=l2_scale,
+                            upsampling=upsampling,
+                            base_filters=base_filters,
+                            depth=depth,
+                            out_ch=out_ch)
+        self.vae = VariationalAutoencoder(
+                            data_format=data_format,
+                            groups=groups,
+                            reduction=reduction,
+                            l2_scale=l2_scale,
+                            upsampling=upsampling,
+                            base_filters=base_filters,
+                            depth=depth,
+                            out_ch=in_ch)
+
+    def call(self, inputs, training=None, inference=None):
+        # Inference mode does not evaluate VAE branch.
+        assert (not inference or not training), \
+            'Cannot run training and inference modes simultaneously.'
+
+        inputs = self.encoder(inputs, training=training)
+
+        y_pred = self.decoder((inputs[-1], inputs[:-1]), training=training)
+
+        if inference:
+            return (y_pred, None, None, None)
+        y_vae, z_mean, z_logvar = self.vae(inputs[-1], training=training)
+
+        return (y_pred, y_vae, z_mean, z_logvar)