Diff of /layers/vae.py [000000] .. [408896]

Switch to unified view

a b/layers/vae.py
1
"""Contains custom variational autoencoder class."""
2
import tensorflow as tf
3
4
from layers.downsample import get_downsampling
5
from layers.upsample import get_upsampling
6
from layers.resnet import ResnetBlock
7
8
9
def sample(inputs):
10
    """Samples from the Gaussian given by mean and variance."""
11
    z_mean, z_logvar = inputs
12
    eps = tf.random.normal(shape=z_mean.shape, dtype=tf.float32)
13
    return z_mean + tf.math.exp(0.5 * z_logvar) * eps
14
15
16
class VariationalAutoencoder(tf.keras.layers.Layer):
17
    def __init__(self,
18
                 data_format='channels_last',
19
                 groups=8,
20
                 reduction=2,
21
                 l2_scale=1e-5,
22
                 downsampling='conv',
23
                 upsampling='conv',
24
                 base_filters=16,
25
                 depth=4,
26
                 out_ch=2):
27
        """ Initializes the variational autoencoder: consists of sampling
28
            then an alternating series of SENet blocks and upsampling.
29
30
            References:
31
                - [3D MRI brain tumor segmentation using autoencoder regularization](https://arxiv.org/pdf/1810.11654.pdf)
32
        """
33
        super(VariationalAutoencoder, self).__init__()
34
        # Set up config for self.get_config() to serialize later.
35
        self.data_format = data_format
36
        self.l2_scale = l2_scale
37
        self.config = super(VariationalAutoencoder, self).get_config()
38
        self.config.update({'groups': groups,
39
                            'reduction': reduction,
40
                            'downsampling': downsampling,
41
                            'upsampling': upsampling,
42
                            'base_filters': base_filters,
43
                            'depth': depth,
44
                            'out_ch': out_ch})
45
46
        # Retrieve downsampling method.
47
        Downsample = get_downsampling(downsampling)
48
49
        # Retrieve upsampling method.
50
        Upsample = get_upsampling(upsampling)
51
52
        # Extra downsampling layer to reduce parameters.
53
        self.downsample = Downsample(
54
                            filters=base_filters//2,
55
                            groups=groups,
56
                            data_format=data_format,
57
                            kernel_regularizer=tf.keras.regularizers.l2(l=l2_scale))
58
59
        # Build sampling layers.
60
        self.flatten = tf.keras.layers.Flatten(data_format)
61
        self.proj = tf.keras.layers.Dense(
62
                            units=base_filters*(2**(depth-1)),
63
                            kernel_regularizer=tf.keras.regularizers.l2(l=l2_scale),
64
                            kernel_initializer='he_normal')
65
        self.latent_size = base_filters*(2**(depth-2))
66
        self.sample = tf.keras.layers.Lambda(sample)
67
68
        # Extra upsampling layer to counter extra downsampling layer.
69
        self.upsample = Upsample(
70
                            filters=base_filters*(2**(depth-1)),
71
                            groups=groups,
72
                            data_format=data_format,
73
                            l2_scale=l2_scale)
74
75
        # Build layers at all spatial levels.
76
        self.levels = []
77
        for i in range(depth-2, -1, -1):
78
            upsample = Upsample(
79
                        filters=base_filters*(2**i),
80
                        groups=groups,
81
                        data_format=data_format,
82
                        l2_scale=l2_scale)
83
            conv = ResnetBlock(
84
                        filters=base_filters*(2**i),
85
                        groups=groups,
86
                        reduction=reduction,
87
                        data_format=data_format,
88
                        l2_scale=l2_scale)
89
            self.levels.append([upsample, conv])
90
91
        # Output layer convolution.
92
        self.out = tf.keras.layers.Conv3D(
93
                            filters=out_ch,
94
                            kernel_size=3,
95
                            strides=1,
96
                            padding='same',
97
                            data_format=data_format,
98
                            kernel_regularizer=tf.keras.regularizers.l2(l=l2_scale),
99
                            kernel_initializer='he_normal')
100
101
    def build(self, input_shape):
102
        h, w, d = input_shape[1:-1] if self.data_format == 'channels_last' else input_shape[2:]
103
104
        # Build reshaping layers after sampling.
105
        self.unproj = tf.keras.layers.Dense(
106
                                units=h*w*d*1//8,
107
                                kernel_regularizer=tf.keras.regularizers.l2(l=self.l2_scale),
108
                                kernel_initializer='he_normal',
109
                                activation='relu')
110
        self.unflatten = tf.keras.layers.Reshape(
111
                                (h//2, w//2, d//2, 1) if self.data_format == 'channels_last' else (1, h//2, w//2, d//2))
112
113
114
    def call(self, inputs, training=None):
115
        # Downsample.
116
        inputs = self.downsample(inputs)
117
118
        # Flatten and project
119
        inputs = self.flatten(inputs)
120
        inputs = self.proj(inputs)
121
122
        # Sample.
123
        z_mean = inputs[:, :self.latent_size]
124
        z_logvar = inputs[:, self.latent_size:]
125
        inputs = self.sample([z_mean, z_logvar])
126
127
        # Restored projection and reshape
128
        inputs = self.unproj(inputs)
129
        inputs = self.unflatten(inputs)
130
131
        # Upsample.
132
        inputs = self.upsample(inputs)
133
134
        # Iterate through spatial levels.
135
        for level in self.levels:
136
            upsample, conv = level
137
            inputs = upsample(inputs, training=training)
138
            inputs = conv(inputs, training=training)
139
140
        # Map convolution to number of original input channels.
141
        inputs = self.out(inputs)
142
143
        return inputs, z_mean, z_logvar
144
145
    def get_config(self):
146
        self.config.update({'data_format': self.data_format,
147
                            'l2_scale': self.l2_scale})
148
        return self.config