Diff of /model.py [000000] .. [408896]

Switch to unified view

a b/model.py
1
import tensorflow as tf
2
3
from layers.encoder import Encoder
4
from layers.decoder import Decoder
5
from layers.vae import VariationalAutoencoder
6
7
8
class Model(tf.keras.models.Model):
9
    def __init__(self,
10
                 data_format='channels_last',
11
                 groups=8,
12
                 reduction=2,
13
                 l2_scale=1e-5,
14
                 dropout=0.2,
15
                 downsampling='conv',
16
                 upsampling='conv',
17
                 base_filters=16,
18
                 depth=4,
19
                 in_ch=2,
20
                 out_ch=3):
21
        """ Initializes the model, a cross between the 3D U-net
22
            and 2018 BraTS Challenge top model with VAE regularization.
23
24
            References:
25
                - [3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation](https://arxiv.org/pdf/1606.06650.pdf)
26
                - [3D MRI brain tumor segmentation using autoencoder regularization](https://arxiv.org/pdf/1810.11654.pdf)
27
        """
28
        super(Model, self).__init__()
29
        self.epoch = tf.Variable(0, name='epoch', trainable=False)
30
        self.encoder = Encoder(
31
                            data_format=data_format,
32
                            groups=groups,
33
                            reduction=reduction,
34
                            l2_scale=l2_scale,
35
                            dropout=dropout,
36
                            downsampling=downsampling,
37
                            base_filters=base_filters,
38
                            depth=depth)
39
        self.decoder = Decoder(
40
                            data_format=data_format,
41
                            groups=groups,
42
                            reduction=reduction,
43
                            l2_scale=l2_scale,
44
                            upsampling=upsampling,
45
                            base_filters=base_filters,
46
                            depth=depth,
47
                            out_ch=out_ch)
48
        self.vae = VariationalAutoencoder(
49
                            data_format=data_format,
50
                            groups=groups,
51
                            reduction=reduction,
52
                            l2_scale=l2_scale,
53
                            upsampling=upsampling,
54
                            base_filters=base_filters,
55
                            depth=depth,
56
                            out_ch=in_ch)
57
58
    def call(self, inputs, training=None, inference=None):
59
        # Inference mode does not evaluate VAE branch.
60
        assert (not inference or not training), \
61
            'Cannot run training and inference modes simultaneously.'
62
63
        inputs = self.encoder(inputs, training=training)
64
65
        y_pred = self.decoder((inputs[-1], inputs[:-1]), training=training)
66
67
        if inference:
68
            return (y_pred, None, None, None)
69
        y_vae, z_mean, z_logvar = self.vae(inputs[-1], training=training)
70
71
        return (y_pred, y_vae, z_mean, z_logvar)