[408896]: / layers / vae.py

Download this file

149 lines (123 with data), 5.8 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""Contains custom variational autoencoder class."""
import tensorflow as tf
from layers.downsample import get_downsampling
from layers.upsample import get_upsampling
from layers.resnet import ResnetBlock
def sample(inputs):
"""Samples from the Gaussian given by mean and variance."""
z_mean, z_logvar = inputs
eps = tf.random.normal(shape=z_mean.shape, dtype=tf.float32)
return z_mean + tf.math.exp(0.5 * z_logvar) * eps
class VariationalAutoencoder(tf.keras.layers.Layer):
def __init__(self,
data_format='channels_last',
groups=8,
reduction=2,
l2_scale=1e-5,
downsampling='conv',
upsampling='conv',
base_filters=16,
depth=4,
out_ch=2):
""" Initializes the variational autoencoder: consists of sampling
then an alternating series of SENet blocks and upsampling.
References:
- [3D MRI brain tumor segmentation using autoencoder regularization](https://arxiv.org/pdf/1810.11654.pdf)
"""
super(VariationalAutoencoder, self).__init__()
# Set up config for self.get_config() to serialize later.
self.data_format = data_format
self.l2_scale = l2_scale
self.config = super(VariationalAutoencoder, self).get_config()
self.config.update({'groups': groups,
'reduction': reduction,
'downsampling': downsampling,
'upsampling': upsampling,
'base_filters': base_filters,
'depth': depth,
'out_ch': out_ch})
# Retrieve downsampling method.
Downsample = get_downsampling(downsampling)
# Retrieve upsampling method.
Upsample = get_upsampling(upsampling)
# Extra downsampling layer to reduce parameters.
self.downsample = Downsample(
filters=base_filters//2,
groups=groups,
data_format=data_format,
kernel_regularizer=tf.keras.regularizers.l2(l=l2_scale))
# Build sampling layers.
self.flatten = tf.keras.layers.Flatten(data_format)
self.proj = tf.keras.layers.Dense(
units=base_filters*(2**(depth-1)),
kernel_regularizer=tf.keras.regularizers.l2(l=l2_scale),
kernel_initializer='he_normal')
self.latent_size = base_filters*(2**(depth-2))
self.sample = tf.keras.layers.Lambda(sample)
# Extra upsampling layer to counter extra downsampling layer.
self.upsample = Upsample(
filters=base_filters*(2**(depth-1)),
groups=groups,
data_format=data_format,
l2_scale=l2_scale)
# Build layers at all spatial levels.
self.levels = []
for i in range(depth-2, -1, -1):
upsample = Upsample(
filters=base_filters*(2**i),
groups=groups,
data_format=data_format,
l2_scale=l2_scale)
conv = ResnetBlock(
filters=base_filters*(2**i),
groups=groups,
reduction=reduction,
data_format=data_format,
l2_scale=l2_scale)
self.levels.append([upsample, conv])
# Output layer convolution.
self.out = tf.keras.layers.Conv3D(
filters=out_ch,
kernel_size=3,
strides=1,
padding='same',
data_format=data_format,
kernel_regularizer=tf.keras.regularizers.l2(l=l2_scale),
kernel_initializer='he_normal')
def build(self, input_shape):
h, w, d = input_shape[1:-1] if self.data_format == 'channels_last' else input_shape[2:]
# Build reshaping layers after sampling.
self.unproj = tf.keras.layers.Dense(
units=h*w*d*1//8,
kernel_regularizer=tf.keras.regularizers.l2(l=self.l2_scale),
kernel_initializer='he_normal',
activation='relu')
self.unflatten = tf.keras.layers.Reshape(
(h//2, w//2, d//2, 1) if self.data_format == 'channels_last' else (1, h//2, w//2, d//2))
def call(self, inputs, training=None):
# Downsample.
inputs = self.downsample(inputs)
# Flatten and project
inputs = self.flatten(inputs)
inputs = self.proj(inputs)
# Sample.
z_mean = inputs[:, :self.latent_size]
z_logvar = inputs[:, self.latent_size:]
inputs = self.sample([z_mean, z_logvar])
# Restored projection and reshape
inputs = self.unproj(inputs)
inputs = self.unflatten(inputs)
# Upsample.
inputs = self.upsample(inputs)
# Iterate through spatial levels.
for level in self.levels:
upsample, conv = level
inputs = upsample(inputs, training=training)
inputs = conv(inputs, training=training)
# Map convolution to number of original input channels.
inputs = self.out(inputs)
return inputs, z_mean, z_logvar
def get_config(self):
self.config.update({'data_format': self.data_format,
'l2_scale': self.l2_scale})
return self.config