[408896]: / layers / encoder.py

Download this file

105 lines (84 with data), 3.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
"""Contains custom convolutional encoder class."""
import tensorflow as tf
from layers.resnet import ResnetBlock
from layers.downsample import get_downsampling
class Encoder(tf.keras.layers.Layer):
def __init__(self,
data_format='channels_last',
groups=8,
reduction=2,
l2_scale=1e-5,
dropout=0.2,
downsampling='conv',
base_filters=16,
depth=4):
""" Initializes the model encoder: consists of an alternating
series of ResNet blocks with DenseNet connections and downsampling layers.
References:
- [Densely Connected Residual Networks](https://arxiv.org/pdf/1608.06993.pdf)
"""
super(Encoder, self).__init__()
# Set up config for self.get_config() to serialize later.
self.config = super(Encoder, self).get_config()
self.config.update({'data_format': data_format,
'groups': groups,
'reduction': reduction,
'l2_scale': l2_scale,
'downsampling': downsampling,
'base_filters': base_filters,
'depth': depth})
# Retrieve downsampling method.
Downsample = get_downsampling(downsampling)
# Initial dropout layer (similar to denoised autoencoding).
self.dropout = tf.keras.layers.Dropout(rate=dropout)
# Build layers at all spatial levels.
self.levels = []
for i in range(depth):
convs = []
for j in range(i + 1):
conv = ResnetBlock(
filters=base_filters*(2**i),
groups=groups,
reduction=reduction,
data_format=data_format,
l2_scale=l2_scale)
dense = tf.keras.layers.Concatenate(
axis=-1 if data_format == 'channels_last' else 1) if j > 0 else None
convs.append([conv, dense])
# Concatenate before downsampling.
concat = tf.keras.layers.Concatenate(
axis=-1 if data_format == 'channels_last' else 1) if i > 0 else None
# No downsampling at deepest spatial level.
downsample = Downsample(
filters=base_filters*(2**i),
groups=groups,
data_format=data_format,
l2_scale=l2_scale) if i < depth - 1 else None
self.levels.append([convs, concat, downsample])
def call(self, inputs, training=None):
# Apply dropout.
inputs = self.dropout(inputs, training=training)
residuals = []
# Iterate through spatial levels.
for i, level in enumerate(self.levels):
convs, concat, downsample = level
# Cache intermediate activations for concatenation.
cache = []
# Iterate through convolutional blocks.
for conv, dense in convs:
if dense is not None:
inputs = dense([inputs] + cache)
inputs = conv(inputs, training=training)
cache.append(inputs)
# Concatenate all activations in the layer.
if concat is not None:
inputs = concat(cache)
# Store residuals for use in decoder.
residuals.append(inputs)
# No downsampling at bottom spatial level.
if downsample is not None:
inputs = downsample(inputs, training=training)
# Return values after each spatial level for decoder.
return residuals
def get_config(self):
return self.config