a b/layers/downsample.py
1
"""Contains custom downsampling classes."""
2
import tensorflow as tf
3
4
from layers.group_norm import GroupNormalization
5
6
7
def get_downsampling(downsampling):
8
    if downsampling == 'max':
9
        return MaxDownsample
10
    elif downsampling == 'conv':
11
        return ConvDownsample
12
13
14
class ConvDownsample(tf.keras.layers.Layer):
15
    def __init__(self,
16
                 filters,
17
                 data_format='channels_last',
18
                 groups=8,
19
                 l2_scale=1e-5,
20
                 **kwargs):
21
        super(ConvDownsample, self).__init__()
22
        self.config = super(ConvDownsample, self).get_config()
23
        self.config.update({'filters': filters,
24
                            'data_format': data_format,
25
                            'groups': groups,
26
                            'l2_scale': l2_scale})
27
28
        self.conv = tf.keras.layers.Conv3D(
29
                            filters=filters,
30
                            kernel_size=3,
31
                            strides=2,
32
                            padding='same',
33
                            data_format=data_format,
34
                            kernel_regularizer=tf.keras.regularizers.l2(l=l2_scale),
35
                            kernel_initializer='he_normal')
36
        self.norm = GroupNormalization(
37
                            groups=groups,
38
                            axis=-1 if data_format == 'channels_last' else 1)
39
        self.relu = tf.keras.layers.Activation('relu')
40
41
    def __call__(self, inputs, training=None):
42
        inputs = self.conv(inputs)
43
        inputs = self.norm(inputs, training=training)
44
        inputs = self.relu(inputs)
45
        return inputs
46
47
    def get_config(self):
48
        return self.config
49
50
51
class MaxDownsample(tf.keras.layers.Layer):
52
    def __init__(self,
53
                 data_format='channels_last',
54
                 **kwargs):
55
        super(MaxDownsample, self).__init__()
56
        self.config = super(MaxDownsample, self).get_config()
57
        self.config.update({'data_format': data_format})
58
59
        self.maxpool = tf.keras.layers.MaxPooling3D(
60
                            pool_size=2,
61
                            strides=2,
62
                            padding='same',
63
                            data_format=data_format)
64
65
    def __call__(self, inputs, training=None):
66
        inputs = self.maxpool(inputs)
67
        return inputs
68
69
    def get_config(self):
70
        return self.config