Diff of /layers/upsample.py [000000] .. [408896]

Switch to unified view

a b/layers/upsample.py
1
"""Contains custom upsampling classes."""
2
import tensorflow as tf
3
4
from layers.group_norm import GroupNormalization
5
6
7
def get_upsampling(upsampling):
8
    if upsampling == 'linear':
9
        return LinearUpsample
10
    elif upsampling == 'conv':
11
        return ConvUpsample
12
13
14
class ConvUpsample(tf.keras.layers.Layer):
15
    def __init__(self,
16
                 filters,
17
                 groups=8,
18
                 data_format='channels_last',
19
                 l2_scale=1e-5,
20
                 **kwargs):
21
        super(ConvUpsample, self).__init__()
22
        self.config = super(ConvUpsample, self).get_config()
23
        self.config.update({'filters': filters,
24
                            'data_format': data_format,
25
                            'groups': groups,
26
                            'l2_scale': l2_scale})
27
28
        self.conv = tf.keras.layers.Conv3DTranspose(
29
                            filters=filters,
30
                            kernel_size=3,
31
                            strides=2,
32
                            padding='same',
33
                            data_format=data_format)
34
        self.norm = GroupNormalization(
35
                            groups=groups,
36
                            axis=-1 if data_format == 'channels_last' else 1)
37
        self.relu = tf.keras.layers.Activation('relu')
38
39
    def __call__(self, inputs, training=None):
40
        inputs = self.conv(inputs)
41
        inputs = self.norm(inputs, training=training)
42
        inputs = self.relu(inputs)
43
        return inputs
44
45
    def get_config(self):
46
        return self.config
47
48
49
class LinearUpsample(tf.keras.layers.Layer):
50
    def __init__(self,
51
                 filters,
52
                 data_format='channels_last',
53
                 l2_scale=1e-5,
54
                 **kwargs):
55
        super(LinearUpsample, self).__init__()
56
        self.config = super(LinearUpsample, self).get_config()
57
        self.config.update({'filters': filters,
58
                            'data_format': data_format,
59
                            'l2_scale': l2_scale})
60
61
        self.ptwise = tf.keras.layers.Conv3D(
62
                                filters=filters,
63
                                kernel_size=1,
64
                                strides=1,
65
                                padding='same',
66
                                data_format=data_format,
67
                                kernel_regularizer=tf.keras.regularizers.l2(l=l2_scale),
68
                                kernel_initializer='he_normal')
69
        self.linear = tf.keras.layers.UpSampling3D(
70
                                size=2,
71
                                data_format=data_format)
72
73
    def __call__(self, inputs, training=None):
74
        inputs = self.ptwise(inputs)
75
        inputs = self.linear(inputs)
76
        return inputs
77
78
    def get_config(self):
79
        return self.config