|
a |
|
b/ext/neuron/models.py |
|
|
1 |
""" |
|
|
2 |
tensorflow/keras utilities for the neuron project |
|
|
3 |
|
|
|
4 |
If you use this code, please cite |
|
|
5 |
Dalca AV, Guttag J, Sabuncu MR |
|
|
6 |
Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, |
|
|
7 |
CVPR 2018 |
|
|
8 |
|
|
|
9 |
Contact: adalca [at] csail [dot] mit [dot] edu |
|
|
10 |
License: GPLv3 |
|
|
11 |
""" |
|
|
12 |
|
|
|
13 |
import sys |
|
|
14 |
|
|
|
15 |
from ext.neuron import layers |
|
|
16 |
|
|
|
17 |
# third party |
|
|
18 |
import numpy as np |
|
|
19 |
import tensorflow as tf |
|
|
20 |
import keras |
|
|
21 |
import keras.layers as KL |
|
|
22 |
from keras.models import Model |
|
|
23 |
import keras.backend as K |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
def unet(nb_features, |
|
|
27 |
input_shape, |
|
|
28 |
nb_levels, |
|
|
29 |
conv_size, |
|
|
30 |
nb_labels, |
|
|
31 |
name='unet', |
|
|
32 |
prefix=None, |
|
|
33 |
feat_mult=1, |
|
|
34 |
pool_size=2, |
|
|
35 |
use_logp=True, |
|
|
36 |
padding='same', |
|
|
37 |
dilation_rate_mult=1, |
|
|
38 |
activation='elu', |
|
|
39 |
skip_n_concatenations=0, |
|
|
40 |
use_residuals=False, |
|
|
41 |
final_pred_activation='softmax', |
|
|
42 |
nb_conv_per_level=1, |
|
|
43 |
add_prior_layer=False, |
|
|
44 |
layer_nb_feats=None, |
|
|
45 |
conv_dropout=0, |
|
|
46 |
batch_norm=None, |
|
|
47 |
input_model=None): |
|
|
48 |
""" |
|
|
49 |
unet-style keras model with an overdose of parametrization. |
|
|
50 |
|
|
|
51 |
Parameters: |
|
|
52 |
nb_features: the number of features at each convolutional level |
|
|
53 |
see below for `feat_mult` and `layer_nb_feats` for modifiers to this number |
|
|
54 |
input_shape: input layer shape, vector of size ndims + 1 (nb_channels) |
|
|
55 |
conv_size: the convolution kernel size |
|
|
56 |
nb_levels: the number of Unet levels (number of downsamples) in the "encoder" |
|
|
57 |
(e.g. 4 would give you 4 levels in encoder, 4 in decoder) |
|
|
58 |
nb_labels: number of output channels |
|
|
59 |
name (default: 'unet'): the name of the network |
|
|
60 |
prefix (default: `name` value): prefix to be added to layer names |
|
|
61 |
feat_mult (default: 1) multiple for `nb_features` as we go down the encoder levels. |
|
|
62 |
e.g. feat_mult of 2 and nb_features of 16 would yield 32 features in the |
|
|
63 |
second layer, 64 features in the third layer, etc. |
|
|
64 |
pool_size (default: 2): max pooling size (integer or list if specifying per dimension) |
|
|
65 |
skip_n_concatenations=0: enabled to skip concatenation links between contracting and expanding paths for the n |
|
|
66 |
top levels. |
|
|
67 |
use_logp: |
|
|
68 |
padding: |
|
|
69 |
dilation_rate_mult: |
|
|
70 |
activation: |
|
|
71 |
use_residuals: |
|
|
72 |
final_pred_activation: |
|
|
73 |
nb_conv_per_level: |
|
|
74 |
add_prior_layer: |
|
|
75 |
skip_n_concatenations: |
|
|
76 |
layer_nb_feats: list of the number of features for each layer. Automatically used if specified |
|
|
77 |
conv_dropout: dropout probability |
|
|
78 |
batch_norm: |
|
|
79 |
input_model: concatenate the provided input_model to this current model. |
|
|
80 |
Only the first output of input_model is used. |
|
|
81 |
""" |
|
|
82 |
|
|
|
83 |
# naming |
|
|
84 |
model_name = name |
|
|
85 |
if prefix is None: |
|
|
86 |
prefix = model_name |
|
|
87 |
|
|
|
88 |
# volume size data |
|
|
89 |
ndims = len(input_shape) - 1 |
|
|
90 |
if isinstance(pool_size, int): |
|
|
91 |
pool_size = (pool_size,) * ndims |
|
|
92 |
|
|
|
93 |
# get encoding model |
|
|
94 |
enc_model = conv_enc(nb_features, |
|
|
95 |
input_shape, |
|
|
96 |
nb_levels, |
|
|
97 |
conv_size, |
|
|
98 |
name=model_name, |
|
|
99 |
prefix=prefix, |
|
|
100 |
feat_mult=feat_mult, |
|
|
101 |
pool_size=pool_size, |
|
|
102 |
padding=padding, |
|
|
103 |
dilation_rate_mult=dilation_rate_mult, |
|
|
104 |
activation=activation, |
|
|
105 |
use_residuals=use_residuals, |
|
|
106 |
nb_conv_per_level=nb_conv_per_level, |
|
|
107 |
layer_nb_feats=layer_nb_feats, |
|
|
108 |
conv_dropout=conv_dropout, |
|
|
109 |
batch_norm=batch_norm, |
|
|
110 |
input_model=input_model) |
|
|
111 |
|
|
|
112 |
# get decoder |
|
|
113 |
# use_skip_connections=True makes it a u-net |
|
|
114 |
lnf = layer_nb_feats[(nb_levels * nb_conv_per_level):] if layer_nb_feats is not None else None |
|
|
115 |
dec_model = conv_dec(nb_features, |
|
|
116 |
[], |
|
|
117 |
nb_levels, |
|
|
118 |
conv_size, |
|
|
119 |
nb_labels, |
|
|
120 |
name=model_name, |
|
|
121 |
prefix=prefix, |
|
|
122 |
feat_mult=feat_mult, |
|
|
123 |
pool_size=pool_size, |
|
|
124 |
use_skip_connections=True, |
|
|
125 |
skip_n_concatenations=skip_n_concatenations, |
|
|
126 |
padding=padding, |
|
|
127 |
dilation_rate_mult=dilation_rate_mult, |
|
|
128 |
activation=activation, |
|
|
129 |
use_residuals=use_residuals, |
|
|
130 |
final_pred_activation='linear' if add_prior_layer else final_pred_activation, |
|
|
131 |
nb_conv_per_level=nb_conv_per_level, |
|
|
132 |
batch_norm=batch_norm, |
|
|
133 |
layer_nb_feats=lnf, |
|
|
134 |
conv_dropout=conv_dropout, |
|
|
135 |
input_model=enc_model) |
|
|
136 |
final_model = dec_model |
|
|
137 |
|
|
|
138 |
if add_prior_layer: |
|
|
139 |
final_model = add_prior(dec_model, |
|
|
140 |
[*input_shape[:-1], nb_labels], |
|
|
141 |
name=model_name + '_prior', |
|
|
142 |
use_logp=use_logp, |
|
|
143 |
final_pred_activation=final_pred_activation) |
|
|
144 |
|
|
|
145 |
return final_model |
|
|
146 |
|
|
|
147 |
|
|
|
148 |
def ae(nb_features, |
|
|
149 |
input_shape, |
|
|
150 |
nb_levels, |
|
|
151 |
conv_size, |
|
|
152 |
nb_labels, |
|
|
153 |
enc_size, |
|
|
154 |
name='ae', |
|
|
155 |
feat_mult=1, |
|
|
156 |
pool_size=2, |
|
|
157 |
padding='same', |
|
|
158 |
activation='elu', |
|
|
159 |
use_residuals=False, |
|
|
160 |
nb_conv_per_level=1, |
|
|
161 |
batch_norm=None, |
|
|
162 |
enc_batch_norm=None, |
|
|
163 |
ae_type='conv', # 'dense', or 'conv' |
|
|
164 |
enc_lambda_layers=None, |
|
|
165 |
add_prior_layer=False, |
|
|
166 |
use_logp=True, |
|
|
167 |
conv_dropout=0, |
|
|
168 |
include_mu_shift_layer=False, |
|
|
169 |
single_model=False, # whether to return a single model, or a tuple of models that can be stacked. |
|
|
170 |
final_pred_activation='softmax', |
|
|
171 |
do_vae=False, |
|
|
172 |
input_model=None): |
|
|
173 |
"""Convolutional Auto-Encoder. Optionally Variational (if do_vae is set to True).""" |
|
|
174 |
|
|
|
175 |
# naming |
|
|
176 |
model_name = name |
|
|
177 |
|
|
|
178 |
# volume size data |
|
|
179 |
ndims = len(input_shape) - 1 |
|
|
180 |
if isinstance(pool_size, int): |
|
|
181 |
pool_size = (pool_size,) * ndims |
|
|
182 |
|
|
|
183 |
# get encoding model |
|
|
184 |
enc_model = conv_enc(nb_features, |
|
|
185 |
input_shape, |
|
|
186 |
nb_levels, |
|
|
187 |
conv_size, |
|
|
188 |
name=model_name, |
|
|
189 |
feat_mult=feat_mult, |
|
|
190 |
pool_size=pool_size, |
|
|
191 |
padding=padding, |
|
|
192 |
activation=activation, |
|
|
193 |
use_residuals=use_residuals, |
|
|
194 |
nb_conv_per_level=nb_conv_per_level, |
|
|
195 |
conv_dropout=conv_dropout, |
|
|
196 |
batch_norm=batch_norm, |
|
|
197 |
input_model=input_model) |
|
|
198 |
|
|
|
199 |
# middle AE structure |
|
|
200 |
if single_model: |
|
|
201 |
in_input_shape = None |
|
|
202 |
in_model = enc_model |
|
|
203 |
else: |
|
|
204 |
in_input_shape = enc_model.output.shape.as_list()[1:] |
|
|
205 |
in_model = None |
|
|
206 |
mid_ae_model = single_ae(enc_size, |
|
|
207 |
in_input_shape, |
|
|
208 |
conv_size=conv_size, |
|
|
209 |
name=model_name, |
|
|
210 |
ae_type=ae_type, |
|
|
211 |
input_model=in_model, |
|
|
212 |
batch_norm=enc_batch_norm, |
|
|
213 |
enc_lambda_layers=enc_lambda_layers, |
|
|
214 |
include_mu_shift_layer=include_mu_shift_layer, |
|
|
215 |
do_vae=do_vae) |
|
|
216 |
|
|
|
217 |
# decoder |
|
|
218 |
if single_model: |
|
|
219 |
in_input_shape = None |
|
|
220 |
in_model = mid_ae_model |
|
|
221 |
else: |
|
|
222 |
in_input_shape = mid_ae_model.output.shape.as_list()[1:] |
|
|
223 |
in_model = None |
|
|
224 |
dec_model = conv_dec(nb_features, |
|
|
225 |
in_input_shape, |
|
|
226 |
nb_levels, |
|
|
227 |
conv_size, |
|
|
228 |
nb_labels, |
|
|
229 |
name=model_name, |
|
|
230 |
feat_mult=feat_mult, |
|
|
231 |
pool_size=pool_size, |
|
|
232 |
use_skip_connections=False, |
|
|
233 |
padding=padding, |
|
|
234 |
activation=activation, |
|
|
235 |
use_residuals=use_residuals, |
|
|
236 |
final_pred_activation='linear', |
|
|
237 |
nb_conv_per_level=nb_conv_per_level, |
|
|
238 |
batch_norm=batch_norm, |
|
|
239 |
conv_dropout=conv_dropout, |
|
|
240 |
input_model=in_model) |
|
|
241 |
|
|
|
242 |
if add_prior_layer: |
|
|
243 |
dec_model = add_prior(dec_model, |
|
|
244 |
[*input_shape[:-1], nb_labels], |
|
|
245 |
name=model_name, |
|
|
246 |
prefix=model_name + '_prior', |
|
|
247 |
use_logp=use_logp, |
|
|
248 |
final_pred_activation=final_pred_activation) |
|
|
249 |
|
|
|
250 |
if single_model: |
|
|
251 |
return dec_model |
|
|
252 |
else: |
|
|
253 |
return dec_model, mid_ae_model, enc_model |
|
|
254 |
|
|
|
255 |
|
|
|
256 |
def conv_enc(nb_features, |
|
|
257 |
input_shape, |
|
|
258 |
nb_levels, |
|
|
259 |
conv_size, |
|
|
260 |
name=None, |
|
|
261 |
prefix=None, |
|
|
262 |
feat_mult=1, |
|
|
263 |
pool_size=2, |
|
|
264 |
dilation_rate_mult=1, |
|
|
265 |
padding='same', |
|
|
266 |
activation='elu', |
|
|
267 |
layer_nb_feats=None, |
|
|
268 |
use_residuals=False, |
|
|
269 |
nb_conv_per_level=2, |
|
|
270 |
conv_dropout=0, |
|
|
271 |
batch_norm=None, |
|
|
272 |
input_model=None): |
|
|
273 |
"""Fully Convolutional Encoder""" |
|
|
274 |
|
|
|
275 |
# naming |
|
|
276 |
model_name = name |
|
|
277 |
if prefix is None: |
|
|
278 |
prefix = model_name |
|
|
279 |
|
|
|
280 |
# first layer: input |
|
|
281 |
name = '%s_input' % prefix |
|
|
282 |
if input_model is None: |
|
|
283 |
input_tensor = KL.Input(shape=input_shape, name=name) |
|
|
284 |
last_tensor = input_tensor |
|
|
285 |
else: |
|
|
286 |
input_tensor = input_model.inputs |
|
|
287 |
last_tensor = input_model.outputs |
|
|
288 |
if isinstance(last_tensor, list): |
|
|
289 |
last_tensor = last_tensor[0] |
|
|
290 |
|
|
|
291 |
# volume size data |
|
|
292 |
ndims = len(input_shape) - 1 |
|
|
293 |
if isinstance(pool_size, int): |
|
|
294 |
pool_size = (pool_size,) * ndims |
|
|
295 |
|
|
|
296 |
# prepare layers |
|
|
297 |
convL = getattr(KL, 'Conv%dD' % ndims) |
|
|
298 |
conv_kwargs = {'padding': padding, 'activation': activation, 'data_format': 'channels_last'} |
|
|
299 |
maxpool = getattr(KL, 'MaxPooling%dD' % ndims) |
|
|
300 |
|
|
|
301 |
# down arm: |
|
|
302 |
# add nb_levels of conv + ReLu + conv + ReLu. Pool after each of first nb_levels - 1 layers |
|
|
303 |
lfidx = 0 # level feature index |
|
|
304 |
for level in range(nb_levels): |
|
|
305 |
lvl_first_tensor = last_tensor |
|
|
306 |
nb_lvl_feats = np.round(nb_features * feat_mult ** level).astype(int) |
|
|
307 |
conv_kwargs['dilation_rate'] = dilation_rate_mult ** level |
|
|
308 |
|
|
|
309 |
for conv in range(nb_conv_per_level): # does several conv per level, max pooling applied at the end |
|
|
310 |
if layer_nb_feats is not None: # None or List of all the feature numbers |
|
|
311 |
nb_lvl_feats = layer_nb_feats[lfidx] |
|
|
312 |
lfidx += 1 |
|
|
313 |
|
|
|
314 |
name = '%s_conv_downarm_%d_%d' % (prefix, level, conv) |
|
|
315 |
if conv < (nb_conv_per_level - 1) or (not use_residuals): |
|
|
316 |
last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor) |
|
|
317 |
else: # no activation |
|
|
318 |
last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor) |
|
|
319 |
|
|
|
320 |
if conv_dropout > 0: |
|
|
321 |
# conv dropout along feature space only |
|
|
322 |
name = '%s_dropout_downarm_%d_%d' % (prefix, level, conv) |
|
|
323 |
noise_shape = [None, *[1] * ndims, nb_lvl_feats] |
|
|
324 |
last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor) |
|
|
325 |
|
|
|
326 |
if use_residuals: |
|
|
327 |
convarm_layer = last_tensor |
|
|
328 |
|
|
|
329 |
# the "add" layer is the original input |
|
|
330 |
# However, it may not have the right number of features to be added |
|
|
331 |
nb_feats_in = lvl_first_tensor.get_shape()[-1] |
|
|
332 |
nb_feats_out = convarm_layer.get_shape()[-1] |
|
|
333 |
add_layer = lvl_first_tensor |
|
|
334 |
if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out): |
|
|
335 |
name = '%s_expand_down_merge_%d' % (prefix, level) |
|
|
336 |
last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(lvl_first_tensor) |
|
|
337 |
add_layer = last_tensor |
|
|
338 |
|
|
|
339 |
if conv_dropout > 0: |
|
|
340 |
noise_shape = [None, *[1] * ndims, nb_lvl_feats] |
|
|
341 |
convarm_layer = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor) |
|
|
342 |
|
|
|
343 |
name = '%s_res_down_merge_%d' % (prefix, level) |
|
|
344 |
last_tensor = KL.add([add_layer, convarm_layer], name=name) |
|
|
345 |
|
|
|
346 |
name = '%s_res_down_merge_act_%d' % (prefix, level) |
|
|
347 |
last_tensor = KL.Activation(activation, name=name)(last_tensor) |
|
|
348 |
|
|
|
349 |
if batch_norm is not None: |
|
|
350 |
name = '%s_bn_down_%d' % (prefix, level) |
|
|
351 |
last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) |
|
|
352 |
|
|
|
353 |
# max pool if we're not at the last level |
|
|
354 |
if level < (nb_levels - 1): |
|
|
355 |
name = '%s_maxpool_%d' % (prefix, level) |
|
|
356 |
last_tensor = maxpool(pool_size=pool_size, name=name, padding=padding)(last_tensor) |
|
|
357 |
|
|
|
358 |
# create the model and return |
|
|
359 |
model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name) |
|
|
360 |
return model |
|
|
361 |
|
|
|
362 |
|
|
|
363 |
def conv_dec(nb_features, |
|
|
364 |
input_shape, |
|
|
365 |
nb_levels, |
|
|
366 |
conv_size, |
|
|
367 |
nb_labels, |
|
|
368 |
name=None, |
|
|
369 |
prefix=None, |
|
|
370 |
feat_mult=1, |
|
|
371 |
pool_size=2, |
|
|
372 |
use_skip_connections=False, |
|
|
373 |
skip_n_concatenations=0, |
|
|
374 |
padding='same', |
|
|
375 |
dilation_rate_mult=1, |
|
|
376 |
activation='elu', |
|
|
377 |
use_residuals=False, |
|
|
378 |
final_pred_activation='softmax', |
|
|
379 |
nb_conv_per_level=2, |
|
|
380 |
layer_nb_feats=None, |
|
|
381 |
batch_norm=None, |
|
|
382 |
conv_dropout=0, |
|
|
383 |
input_model=None): |
|
|
384 |
"""Fully Convolutional Decoder""" |
|
|
385 |
|
|
|
386 |
# naming |
|
|
387 |
model_name = name |
|
|
388 |
if prefix is None: |
|
|
389 |
prefix = model_name |
|
|
390 |
|
|
|
391 |
# if using skip connections, make sure need to use them. |
|
|
392 |
if use_skip_connections: |
|
|
393 |
assert input_model is not None, "is using skip connections, tensors dictionary is required" |
|
|
394 |
|
|
|
395 |
# first layer: input |
|
|
396 |
input_name = '%s_input' % prefix |
|
|
397 |
if input_model is None: |
|
|
398 |
input_tensor = KL.Input(shape=input_shape, name=input_name) |
|
|
399 |
last_tensor = input_tensor |
|
|
400 |
else: |
|
|
401 |
input_tensor = input_model.input |
|
|
402 |
last_tensor = input_model.output |
|
|
403 |
input_shape = last_tensor.shape.as_list()[1:] |
|
|
404 |
|
|
|
405 |
# vol size info |
|
|
406 |
ndims = len(input_shape) - 1 |
|
|
407 |
if isinstance(pool_size, int): |
|
|
408 |
if ndims > 1: |
|
|
409 |
pool_size = (pool_size,) * ndims |
|
|
410 |
|
|
|
411 |
# prepare layers |
|
|
412 |
convL = getattr(KL, 'Conv%dD' % ndims) |
|
|
413 |
conv_kwargs = {'padding': padding, 'activation': activation} |
|
|
414 |
upsample = getattr(KL, 'UpSampling%dD' % ndims) |
|
|
415 |
|
|
|
416 |
# up arm: |
|
|
417 |
# nb_levels - 1 layers of Deconvolution3D |
|
|
418 |
# (approx via up + conv + ReLu) + merge + conv + ReLu + conv + ReLu |
|
|
419 |
lfidx = 0 |
|
|
420 |
for level in range(nb_levels - 1): |
|
|
421 |
nb_lvl_feats = np.round(nb_features * feat_mult ** (nb_levels - 2 - level)).astype(int) |
|
|
422 |
conv_kwargs['dilation_rate'] = dilation_rate_mult ** (nb_levels - 2 - level) |
|
|
423 |
|
|
|
424 |
# upsample matching the max pooling layers size |
|
|
425 |
name = '%s_up_%d' % (prefix, nb_levels + level) |
|
|
426 |
last_tensor = upsample(size=pool_size, name=name)(last_tensor) |
|
|
427 |
up_tensor = last_tensor |
|
|
428 |
|
|
|
429 |
# merge layers combining previous layer |
|
|
430 |
if use_skip_connections & (level < (nb_levels - skip_n_concatenations - 1)): |
|
|
431 |
conv_name = '%s_conv_downarm_%d_%d' % (prefix, nb_levels - 2 - level, nb_conv_per_level - 1) |
|
|
432 |
cat_tensor = input_model.get_layer(conv_name).output |
|
|
433 |
name = '%s_merge_%d' % (prefix, nb_levels + level) |
|
|
434 |
last_tensor = KL.concatenate([cat_tensor, last_tensor], axis=ndims + 1, name=name) |
|
|
435 |
|
|
|
436 |
# convolution layers |
|
|
437 |
for conv in range(nb_conv_per_level): |
|
|
438 |
if layer_nb_feats is not None: |
|
|
439 |
nb_lvl_feats = layer_nb_feats[lfidx] |
|
|
440 |
lfidx += 1 |
|
|
441 |
|
|
|
442 |
name = '%s_conv_uparm_%d_%d' % (prefix, nb_levels + level, conv) |
|
|
443 |
if conv < (nb_conv_per_level - 1) or (not use_residuals): |
|
|
444 |
last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor) |
|
|
445 |
else: |
|
|
446 |
last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor) |
|
|
447 |
|
|
|
448 |
if conv_dropout > 0: |
|
|
449 |
name = '%s_dropout_uparm_%d_%d' % (prefix, level, conv) |
|
|
450 |
noise_shape = [None, *[1] * ndims, nb_lvl_feats] |
|
|
451 |
last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor) |
|
|
452 |
|
|
|
453 |
# residual block |
|
|
454 |
if use_residuals: |
|
|
455 |
|
|
|
456 |
# the "add" layer is the original input |
|
|
457 |
# However, it may not have the right number of features to be added |
|
|
458 |
add_layer = up_tensor |
|
|
459 |
nb_feats_in = add_layer.get_shape()[-1] |
|
|
460 |
nb_feats_out = last_tensor.get_shape()[-1] |
|
|
461 |
if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out): |
|
|
462 |
name = '%s_expand_up_merge_%d' % (prefix, level) |
|
|
463 |
add_layer = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(add_layer) |
|
|
464 |
|
|
|
465 |
if conv_dropout > 0: |
|
|
466 |
noise_shape = [None, *[1] * ndims, nb_lvl_feats] |
|
|
467 |
last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor) |
|
|
468 |
|
|
|
469 |
name = '%s_res_up_merge_%d' % (prefix, level) |
|
|
470 |
last_tensor = KL.add([last_tensor, add_layer], name=name) |
|
|
471 |
|
|
|
472 |
name = '%s_res_up_merge_act_%d' % (prefix, level) |
|
|
473 |
last_tensor = KL.Activation(activation, name=name)(last_tensor) |
|
|
474 |
|
|
|
475 |
if batch_norm is not None: |
|
|
476 |
name = '%s_bn_up_%d' % (prefix, level) |
|
|
477 |
last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) |
|
|
478 |
|
|
|
479 |
# Compute likelihood prediction (no activation yet) |
|
|
480 |
name = '%s_likelihood' % prefix |
|
|
481 |
last_tensor = convL(nb_labels, 1, activation=None, name=name)(last_tensor) |
|
|
482 |
like_tensor = last_tensor |
|
|
483 |
|
|
|
484 |
# output prediction layer |
|
|
485 |
# we use a softmax to compute P(L_x|I) where x is each location |
|
|
486 |
if final_pred_activation == 'softmax': |
|
|
487 |
name = '%s_prediction' % prefix |
|
|
488 |
softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=ndims + 1) |
|
|
489 |
pred_tensor = KL.Lambda(softmax_lambda_fcn, name=name)(last_tensor) |
|
|
490 |
|
|
|
491 |
# otherwise create a layer that does nothing. |
|
|
492 |
else: |
|
|
493 |
name = '%s_prediction' % prefix |
|
|
494 |
pred_tensor = KL.Activation('linear', name=name)(like_tensor) |
|
|
495 |
|
|
|
496 |
# create the model and return |
|
|
497 |
model = Model(inputs=input_tensor, outputs=pred_tensor, name=model_name) |
|
|
498 |
return model |
|
|
499 |
|
|
|
500 |
|
|
|
501 |
def add_prior(input_model, |
|
|
502 |
prior_shape, |
|
|
503 |
name='prior_model', |
|
|
504 |
prefix=None, |
|
|
505 |
use_logp=True, |
|
|
506 |
final_pred_activation='softmax'): |
|
|
507 |
""" |
|
|
508 |
Append post-prior layer to a given model |
|
|
509 |
""" |
|
|
510 |
|
|
|
511 |
# naming |
|
|
512 |
model_name = name |
|
|
513 |
if prefix is None: |
|
|
514 |
prefix = model_name |
|
|
515 |
|
|
|
516 |
# prior input layer |
|
|
517 |
prior_input_name = '%s-input' % prefix |
|
|
518 |
prior_tensor = KL.Input(shape=prior_shape, name=prior_input_name) |
|
|
519 |
prior_tensor_input = prior_tensor |
|
|
520 |
like_tensor = input_model.output |
|
|
521 |
|
|
|
522 |
# operation varies depending on whether we log() prior or not. |
|
|
523 |
if use_logp: |
|
|
524 |
print("Breaking change: use_logp option now requires log input!", file=sys.stderr) |
|
|
525 |
merge_op = KL.add |
|
|
526 |
|
|
|
527 |
else: |
|
|
528 |
# using sigmoid to get the likelihood values between 0 and 1 |
|
|
529 |
# note: they won't add up to 1. |
|
|
530 |
name = '%s_likelihood_sigmoid' % prefix |
|
|
531 |
like_tensor = KL.Activation('sigmoid', name=name)(like_tensor) |
|
|
532 |
merge_op = KL.multiply |
|
|
533 |
|
|
|
534 |
# merge the likelihood and prior layers into posterior layer |
|
|
535 |
name = '%s_posterior' % prefix |
|
|
536 |
post_tensor = merge_op([prior_tensor, like_tensor], name=name) |
|
|
537 |
|
|
|
538 |
# output prediction layer |
|
|
539 |
# we use a softmax to compute P(L_x|I) where x is each location |
|
|
540 |
pred_name = '%s_prediction' % prefix |
|
|
541 |
if final_pred_activation == 'softmax': |
|
|
542 |
assert use_logp, 'cannot do softmax when adding prior via P()' |
|
|
543 |
print("using final_pred_activation %s for %s" % (final_pred_activation, model_name)) |
|
|
544 |
softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=-1) |
|
|
545 |
pred_tensor = KL.Lambda(softmax_lambda_fcn, name=pred_name)(post_tensor) |
|
|
546 |
|
|
|
547 |
else: |
|
|
548 |
pred_tensor = KL.Activation('linear', name=pred_name)(post_tensor) |
|
|
549 |
|
|
|
550 |
# create the model |
|
|
551 |
model_inputs = [*input_model.inputs, prior_tensor_input] |
|
|
552 |
model = Model(inputs=model_inputs, outputs=[pred_tensor], name=model_name) |
|
|
553 |
|
|
|
554 |
# compile |
|
|
555 |
return model |
|
|
556 |
|
|
|
557 |
|
|
|
558 |
def single_ae(enc_size, |
|
|
559 |
input_shape, |
|
|
560 |
name='single_ae', |
|
|
561 |
prefix=None, |
|
|
562 |
ae_type='dense', # 'dense', or 'conv' |
|
|
563 |
conv_size=None, |
|
|
564 |
input_model=None, |
|
|
565 |
enc_lambda_layers=None, |
|
|
566 |
batch_norm=True, |
|
|
567 |
padding='same', |
|
|
568 |
activation=None, |
|
|
569 |
include_mu_shift_layer=False, |
|
|
570 |
do_vae=False): |
|
|
571 |
"""single-layer Autoencoder (i.e. input - encoding - output""" |
|
|
572 |
|
|
|
573 |
# naming |
|
|
574 |
model_name = name |
|
|
575 |
if prefix is None: |
|
|
576 |
prefix = model_name |
|
|
577 |
|
|
|
578 |
if enc_lambda_layers is None: |
|
|
579 |
enc_lambda_layers = [] |
|
|
580 |
|
|
|
581 |
# prepare input |
|
|
582 |
input_name = '%s_input' % prefix |
|
|
583 |
if input_model is None: |
|
|
584 |
assert input_shape is not None, 'input_shape of input_model is necessary' |
|
|
585 |
input_tensor = KL.Input(shape=input_shape, name=input_name) |
|
|
586 |
last_tensor = input_tensor |
|
|
587 |
else: |
|
|
588 |
input_tensor = input_model.input |
|
|
589 |
last_tensor = input_model.output |
|
|
590 |
input_shape = last_tensor.shape.as_list()[1:] |
|
|
591 |
input_nb_feats = last_tensor.shape.as_list()[-1] |
|
|
592 |
|
|
|
593 |
# prepare conv type based on input |
|
|
594 |
ndims = len(input_shape) - 1 |
|
|
595 |
if ae_type == 'conv': |
|
|
596 |
convL = getattr(KL, 'Conv%dD' % ndims) |
|
|
597 |
assert conv_size is not None, 'with conv ae, need conv_size' |
|
|
598 |
conv_kwargs = {'padding': padding, 'activation': activation} |
|
|
599 |
enc_size_str = None |
|
|
600 |
|
|
|
601 |
# if want to go through a dense layer in the middle of the U, need to: |
|
|
602 |
# - flatten last layer if not flat |
|
|
603 |
# - do dense encoding and decoding |
|
|
604 |
# - unflatten (reshape spatially) at end |
|
|
605 |
else: # ae_type == 'dense' |
|
|
606 |
if len(input_shape) > 1: |
|
|
607 |
name = '%s_ae_%s_down_flat' % (prefix, ae_type) |
|
|
608 |
last_tensor = KL.Flatten(name=name)(last_tensor) |
|
|
609 |
convL = conv_kwargs = None |
|
|
610 |
assert len(enc_size) == 1, "enc_size should be of length 1 for dense layer" |
|
|
611 |
enc_size_str = ''.join(['%d_' % d for d in enc_size])[:-1] |
|
|
612 |
|
|
|
613 |
# recall this layer |
|
|
614 |
pre_enc_layer = last_tensor |
|
|
615 |
|
|
|
616 |
# encoding layer |
|
|
617 |
if ae_type == 'dense': |
|
|
618 |
name = '%s_ae_mu_enc_dense_%s' % (prefix, enc_size_str) |
|
|
619 |
last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer) |
|
|
620 |
|
|
|
621 |
else: # convolution |
|
|
622 |
|
|
|
623 |
# convolve then resize. enc_size should be [nb_dim1, nb_dim2, ..., nb_feats] |
|
|
624 |
assert len(enc_size) == len(input_shape), \ |
|
|
625 |
"encoding size does not match input shape %d %d" % (len(enc_size), len(input_shape)) |
|
|
626 |
|
|
|
627 |
if list(enc_size)[:-1] != list(input_shape)[:-1] and \ |
|
|
628 |
all([f is not None for f in input_shape[:-1]]) and \ |
|
|
629 |
all([f is not None for f in enc_size[:-1]]): |
|
|
630 |
|
|
|
631 |
name = '%s_ae_mu_enc_conv' % prefix |
|
|
632 |
last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) |
|
|
633 |
|
|
|
634 |
name = '%s_ae_mu_enc' % prefix |
|
|
635 |
zf = [enc_size[:-1][f] / last_tensor.shape.as_list()[1:-1][f] for f in range(len(enc_size) - 1)] |
|
|
636 |
last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor) |
|
|
637 |
|
|
|
638 |
elif enc_size[-1] is None: # convolutional, but won't tell us bottleneck |
|
|
639 |
name = '%s_ae_mu_enc' % prefix |
|
|
640 |
last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer) |
|
|
641 |
|
|
|
642 |
else: |
|
|
643 |
name = '%s_ae_mu_enc' % prefix |
|
|
644 |
last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) |
|
|
645 |
|
|
|
646 |
if include_mu_shift_layer: |
|
|
647 |
# shift |
|
|
648 |
name = '%s_ae_mu_shift' % prefix |
|
|
649 |
last_tensor = layers.LocalBias(name=name)(last_tensor) |
|
|
650 |
|
|
|
651 |
# encoding clean-up layers |
|
|
652 |
for layer_fcn in enc_lambda_layers: |
|
|
653 |
lambda_name = layer_fcn.__name__ |
|
|
654 |
name = '%s_ae_mu_%s' % (prefix, lambda_name) |
|
|
655 |
last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor) |
|
|
656 |
|
|
|
657 |
if batch_norm is not None: |
|
|
658 |
name = '%s_ae_mu_bn' % prefix |
|
|
659 |
last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) |
|
|
660 |
|
|
|
661 |
# have a simple layer that does nothing to have a clear name before sampling |
|
|
662 |
name = '%s_ae_mu' % prefix |
|
|
663 |
last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor) |
|
|
664 |
|
|
|
665 |
# if doing variational AE, will need the sigma layer as well. |
|
|
666 |
if do_vae: |
|
|
667 |
mu_tensor = last_tensor |
|
|
668 |
|
|
|
669 |
# encoding layer |
|
|
670 |
if ae_type == 'dense': |
|
|
671 |
name = '%s_ae_sigma_enc_dense_%s' % (prefix, enc_size_str) |
|
|
672 |
last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer) |
|
|
673 |
|
|
|
674 |
else: |
|
|
675 |
if list(enc_size)[:-1] != list(input_shape)[:-1] and \ |
|
|
676 |
all([f is not None for f in input_shape[:-1]]) and \ |
|
|
677 |
all([f is not None for f in enc_size[:-1]]): |
|
|
678 |
|
|
|
679 |
assert len(enc_size) - 1 == 2, "Sorry, I have not yet implemented non-2D resizing..." |
|
|
680 |
name = '%s_ae_sigma_enc_conv' % prefix |
|
|
681 |
last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) |
|
|
682 |
|
|
|
683 |
name = '%s_ae_sigma_enc' % prefix |
|
|
684 |
resize_fn = lambda x: tf.image.resize_bilinear(x, enc_size[:-1]) |
|
|
685 |
last_tensor = KL.Lambda(resize_fn, name=name)(last_tensor) |
|
|
686 |
|
|
|
687 |
elif enc_size[-1] is None: # convolutional, but won't tell us bottleneck |
|
|
688 |
name = '%s_ae_sigma_enc' % prefix |
|
|
689 |
last_tensor = convL(pre_enc_layer.shape.as_list()[-1], conv_size, name=name, **conv_kwargs)( |
|
|
690 |
pre_enc_layer) |
|
|
691 |
# cannot use lambda, then mu and sigma will be same layer. |
|
|
692 |
# last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer) |
|
|
693 |
|
|
|
694 |
else: |
|
|
695 |
name = '%s_ae_sigma_enc' % prefix |
|
|
696 |
last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) |
|
|
697 |
|
|
|
698 |
# encoding clean-up layers |
|
|
699 |
for layer_fcn in enc_lambda_layers: |
|
|
700 |
lambda_name = layer_fcn.__name__ |
|
|
701 |
name = '%s_ae_sigma_%s' % (prefix, lambda_name) |
|
|
702 |
last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor) |
|
|
703 |
|
|
|
704 |
if batch_norm is not None: |
|
|
705 |
name = '%s_ae_sigma_bn' % prefix |
|
|
706 |
last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) |
|
|
707 |
|
|
|
708 |
# have a simple layer that does nothing to have a clear name before sampling |
|
|
709 |
name = '%s_ae_sigma' % prefix |
|
|
710 |
last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor) |
|
|
711 |
|
|
|
712 |
logvar_tensor = last_tensor |
|
|
713 |
|
|
|
714 |
# VAE sampling |
|
|
715 |
sampler = _VAESample().sample_z |
|
|
716 |
|
|
|
717 |
name = '%s_ae_sample' % prefix |
|
|
718 |
last_tensor = KL.Lambda(sampler, name=name)([mu_tensor, logvar_tensor]) |
|
|
719 |
|
|
|
720 |
if include_mu_shift_layer: |
|
|
721 |
# shift |
|
|
722 |
name = '%s_ae_sample_shift' % prefix |
|
|
723 |
last_tensor = layers.LocalBias(name=name)(last_tensor) |
|
|
724 |
|
|
|
725 |
# decoding layer |
|
|
726 |
if ae_type == 'dense': |
|
|
727 |
name = '%s_ae_%s_dec_flat_%s' % (prefix, ae_type, enc_size_str) |
|
|
728 |
last_tensor = KL.Dense(np.prod(input_shape), name=name)(last_tensor) |
|
|
729 |
|
|
|
730 |
# unflatten if dense method |
|
|
731 |
if len(input_shape) > 1: |
|
|
732 |
name = '%s_ae_%s_dec' % (prefix, ae_type) |
|
|
733 |
last_tensor = KL.Reshape(input_shape, name=name)(last_tensor) |
|
|
734 |
|
|
|
735 |
else: |
|
|
736 |
|
|
|
737 |
if list(enc_size)[:-1] != list(input_shape)[:-1] and \ |
|
|
738 |
all([f is not None for f in input_shape[:-1]]) and \ |
|
|
739 |
all([f is not None for f in enc_size[:-1]]): |
|
|
740 |
name = '%s_ae_mu_dec' % prefix |
|
|
741 |
zf = [last_tensor.shape.as_list()[1:-1][f] / enc_size[:-1][f] for f in range(len(enc_size) - 1)] |
|
|
742 |
last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor) |
|
|
743 |
|
|
|
744 |
name = '%s_ae_%s_dec' % (prefix, ae_type) |
|
|
745 |
last_tensor = convL(input_nb_feats, conv_size, name=name, **conv_kwargs)(last_tensor) |
|
|
746 |
|
|
|
747 |
if batch_norm is not None: |
|
|
748 |
name = '%s_bn_ae_%s_dec' % (prefix, ae_type) |
|
|
749 |
last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) |
|
|
750 |
|
|
|
751 |
# create the model and return |
|
|
752 |
model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name) |
|
|
753 |
return model |
|
|
754 |
|
|
|
755 |
|
|
|
756 |
############################################################################### |
|
|
757 |
# Helper function |
|
|
758 |
############################################################################### |
|
|
759 |
|
|
|
760 |
class _VAESample: |
|
|
761 |
def __init__(self): |
|
|
762 |
pass |
|
|
763 |
|
|
|
764 |
def sample_z(self, args): |
|
|
765 |
mu, log_var = args |
|
|
766 |
shape = K.shape(mu) |
|
|
767 |
eps = K.random_normal(shape=shape, mean=0., stddev=1.) |
|
|
768 |
return mu + K.exp(log_var / 2) * eps |