[408896]: / layers / upsample.py

Download this file

80 lines (66 with data), 2.8 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""Contains custom upsampling classes."""
import tensorflow as tf
from layers.group_norm import GroupNormalization
def get_upsampling(upsampling):
if upsampling == 'linear':
return LinearUpsample
elif upsampling == 'conv':
return ConvUpsample
class ConvUpsample(tf.keras.layers.Layer):
def __init__(self,
filters,
groups=8,
data_format='channels_last',
l2_scale=1e-5,
**kwargs):
super(ConvUpsample, self).__init__()
self.config = super(ConvUpsample, self).get_config()
self.config.update({'filters': filters,
'data_format': data_format,
'groups': groups,
'l2_scale': l2_scale})
self.conv = tf.keras.layers.Conv3DTranspose(
filters=filters,
kernel_size=3,
strides=2,
padding='same',
data_format=data_format)
self.norm = GroupNormalization(
groups=groups,
axis=-1 if data_format == 'channels_last' else 1)
self.relu = tf.keras.layers.Activation('relu')
def __call__(self, inputs, training=None):
inputs = self.conv(inputs)
inputs = self.norm(inputs, training=training)
inputs = self.relu(inputs)
return inputs
def get_config(self):
return self.config
class LinearUpsample(tf.keras.layers.Layer):
def __init__(self,
filters,
data_format='channels_last',
l2_scale=1e-5,
**kwargs):
super(LinearUpsample, self).__init__()
self.config = super(LinearUpsample, self).get_config()
self.config.update({'filters': filters,
'data_format': data_format,
'l2_scale': l2_scale})
self.ptwise = tf.keras.layers.Conv3D(
filters=filters,
kernel_size=1,
strides=1,
padding='same',
data_format=data_format,
kernel_regularizer=tf.keras.regularizers.l2(l=l2_scale),
kernel_initializer='he_normal')
self.linear = tf.keras.layers.UpSampling3D(
size=2,
data_format=data_format)
def __call__(self, inputs, training=None):
inputs = self.ptwise(inputs)
inputs = self.linear(inputs)
return inputs
def get_config(self):
return self.config