[e571d1]: / ext / neuron / models.py

Download this file

769 lines (648 with data), 29.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
"""
tensorflow/keras utilities for the neuron project
If you use this code, please cite
Dalca AV, Guttag J, Sabuncu MR
Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation,
CVPR 2018
Contact: adalca [at] csail [dot] mit [dot] edu
License: GPLv3
"""
import sys
from ext.neuron import layers
# third party
import numpy as np
import tensorflow as tf
import keras
import keras.layers as KL
from keras.models import Model
import keras.backend as K
def unet(nb_features,
input_shape,
nb_levels,
conv_size,
nb_labels,
name='unet',
prefix=None,
feat_mult=1,
pool_size=2,
use_logp=True,
padding='same',
dilation_rate_mult=1,
activation='elu',
skip_n_concatenations=0,
use_residuals=False,
final_pred_activation='softmax',
nb_conv_per_level=1,
add_prior_layer=False,
layer_nb_feats=None,
conv_dropout=0,
batch_norm=None,
input_model=None):
"""
unet-style keras model with an overdose of parametrization.
Parameters:
nb_features: the number of features at each convolutional level
see below for `feat_mult` and `layer_nb_feats` for modifiers to this number
input_shape: input layer shape, vector of size ndims + 1 (nb_channels)
conv_size: the convolution kernel size
nb_levels: the number of Unet levels (number of downsamples) in the "encoder"
(e.g. 4 would give you 4 levels in encoder, 4 in decoder)
nb_labels: number of output channels
name (default: 'unet'): the name of the network
prefix (default: `name` value): prefix to be added to layer names
feat_mult (default: 1) multiple for `nb_features` as we go down the encoder levels.
e.g. feat_mult of 2 and nb_features of 16 would yield 32 features in the
second layer, 64 features in the third layer, etc.
pool_size (default: 2): max pooling size (integer or list if specifying per dimension)
skip_n_concatenations=0: enabled to skip concatenation links between contracting and expanding paths for the n
top levels.
use_logp:
padding:
dilation_rate_mult:
activation:
use_residuals:
final_pred_activation:
nb_conv_per_level:
add_prior_layer:
skip_n_concatenations:
layer_nb_feats: list of the number of features for each layer. Automatically used if specified
conv_dropout: dropout probability
batch_norm:
input_model: concatenate the provided input_model to this current model.
Only the first output of input_model is used.
"""
# naming
model_name = name
if prefix is None:
prefix = model_name
# volume size data
ndims = len(input_shape) - 1
if isinstance(pool_size, int):
pool_size = (pool_size,) * ndims
# get encoding model
enc_model = conv_enc(nb_features,
input_shape,
nb_levels,
conv_size,
name=model_name,
prefix=prefix,
feat_mult=feat_mult,
pool_size=pool_size,
padding=padding,
dilation_rate_mult=dilation_rate_mult,
activation=activation,
use_residuals=use_residuals,
nb_conv_per_level=nb_conv_per_level,
layer_nb_feats=layer_nb_feats,
conv_dropout=conv_dropout,
batch_norm=batch_norm,
input_model=input_model)
# get decoder
# use_skip_connections=True makes it a u-net
lnf = layer_nb_feats[(nb_levels * nb_conv_per_level):] if layer_nb_feats is not None else None
dec_model = conv_dec(nb_features,
[],
nb_levels,
conv_size,
nb_labels,
name=model_name,
prefix=prefix,
feat_mult=feat_mult,
pool_size=pool_size,
use_skip_connections=True,
skip_n_concatenations=skip_n_concatenations,
padding=padding,
dilation_rate_mult=dilation_rate_mult,
activation=activation,
use_residuals=use_residuals,
final_pred_activation='linear' if add_prior_layer else final_pred_activation,
nb_conv_per_level=nb_conv_per_level,
batch_norm=batch_norm,
layer_nb_feats=lnf,
conv_dropout=conv_dropout,
input_model=enc_model)
final_model = dec_model
if add_prior_layer:
final_model = add_prior(dec_model,
[*input_shape[:-1], nb_labels],
name=model_name + '_prior',
use_logp=use_logp,
final_pred_activation=final_pred_activation)
return final_model
def ae(nb_features,
input_shape,
nb_levels,
conv_size,
nb_labels,
enc_size,
name='ae',
feat_mult=1,
pool_size=2,
padding='same',
activation='elu',
use_residuals=False,
nb_conv_per_level=1,
batch_norm=None,
enc_batch_norm=None,
ae_type='conv', # 'dense', or 'conv'
enc_lambda_layers=None,
add_prior_layer=False,
use_logp=True,
conv_dropout=0,
include_mu_shift_layer=False,
single_model=False, # whether to return a single model, or a tuple of models that can be stacked.
final_pred_activation='softmax',
do_vae=False,
input_model=None):
"""Convolutional Auto-Encoder. Optionally Variational (if do_vae is set to True)."""
# naming
model_name = name
# volume size data
ndims = len(input_shape) - 1
if isinstance(pool_size, int):
pool_size = (pool_size,) * ndims
# get encoding model
enc_model = conv_enc(nb_features,
input_shape,
nb_levels,
conv_size,
name=model_name,
feat_mult=feat_mult,
pool_size=pool_size,
padding=padding,
activation=activation,
use_residuals=use_residuals,
nb_conv_per_level=nb_conv_per_level,
conv_dropout=conv_dropout,
batch_norm=batch_norm,
input_model=input_model)
# middle AE structure
if single_model:
in_input_shape = None
in_model = enc_model
else:
in_input_shape = enc_model.output.shape.as_list()[1:]
in_model = None
mid_ae_model = single_ae(enc_size,
in_input_shape,
conv_size=conv_size,
name=model_name,
ae_type=ae_type,
input_model=in_model,
batch_norm=enc_batch_norm,
enc_lambda_layers=enc_lambda_layers,
include_mu_shift_layer=include_mu_shift_layer,
do_vae=do_vae)
# decoder
if single_model:
in_input_shape = None
in_model = mid_ae_model
else:
in_input_shape = mid_ae_model.output.shape.as_list()[1:]
in_model = None
dec_model = conv_dec(nb_features,
in_input_shape,
nb_levels,
conv_size,
nb_labels,
name=model_name,
feat_mult=feat_mult,
pool_size=pool_size,
use_skip_connections=False,
padding=padding,
activation=activation,
use_residuals=use_residuals,
final_pred_activation='linear',
nb_conv_per_level=nb_conv_per_level,
batch_norm=batch_norm,
conv_dropout=conv_dropout,
input_model=in_model)
if add_prior_layer:
dec_model = add_prior(dec_model,
[*input_shape[:-1], nb_labels],
name=model_name,
prefix=model_name + '_prior',
use_logp=use_logp,
final_pred_activation=final_pred_activation)
if single_model:
return dec_model
else:
return dec_model, mid_ae_model, enc_model
def conv_enc(nb_features,
input_shape,
nb_levels,
conv_size,
name=None,
prefix=None,
feat_mult=1,
pool_size=2,
dilation_rate_mult=1,
padding='same',
activation='elu',
layer_nb_feats=None,
use_residuals=False,
nb_conv_per_level=2,
conv_dropout=0,
batch_norm=None,
input_model=None):
"""Fully Convolutional Encoder"""
# naming
model_name = name
if prefix is None:
prefix = model_name
# first layer: input
name = '%s_input' % prefix
if input_model is None:
input_tensor = KL.Input(shape=input_shape, name=name)
last_tensor = input_tensor
else:
input_tensor = input_model.inputs
last_tensor = input_model.outputs
if isinstance(last_tensor, list):
last_tensor = last_tensor[0]
# volume size data
ndims = len(input_shape) - 1
if isinstance(pool_size, int):
pool_size = (pool_size,) * ndims
# prepare layers
convL = getattr(KL, 'Conv%dD' % ndims)
conv_kwargs = {'padding': padding, 'activation': activation, 'data_format': 'channels_last'}
maxpool = getattr(KL, 'MaxPooling%dD' % ndims)
# down arm:
# add nb_levels of conv + ReLu + conv + ReLu. Pool after each of first nb_levels - 1 layers
lfidx = 0 # level feature index
for level in range(nb_levels):
lvl_first_tensor = last_tensor
nb_lvl_feats = np.round(nb_features * feat_mult ** level).astype(int)
conv_kwargs['dilation_rate'] = dilation_rate_mult ** level
for conv in range(nb_conv_per_level): # does several conv per level, max pooling applied at the end
if layer_nb_feats is not None: # None or List of all the feature numbers
nb_lvl_feats = layer_nb_feats[lfidx]
lfidx += 1
name = '%s_conv_downarm_%d_%d' % (prefix, level, conv)
if conv < (nb_conv_per_level - 1) or (not use_residuals):
last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor)
else: # no activation
last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor)
if conv_dropout > 0:
# conv dropout along feature space only
name = '%s_dropout_downarm_%d_%d' % (prefix, level, conv)
noise_shape = [None, *[1] * ndims, nb_lvl_feats]
last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)
if use_residuals:
convarm_layer = last_tensor
# the "add" layer is the original input
# However, it may not have the right number of features to be added
nb_feats_in = lvl_first_tensor.get_shape()[-1]
nb_feats_out = convarm_layer.get_shape()[-1]
add_layer = lvl_first_tensor
if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out):
name = '%s_expand_down_merge_%d' % (prefix, level)
last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(lvl_first_tensor)
add_layer = last_tensor
if conv_dropout > 0:
noise_shape = [None, *[1] * ndims, nb_lvl_feats]
convarm_layer = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor)
name = '%s_res_down_merge_%d' % (prefix, level)
last_tensor = KL.add([add_layer, convarm_layer], name=name)
name = '%s_res_down_merge_act_%d' % (prefix, level)
last_tensor = KL.Activation(activation, name=name)(last_tensor)
if batch_norm is not None:
name = '%s_bn_down_%d' % (prefix, level)
last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
# max pool if we're not at the last level
if level < (nb_levels - 1):
name = '%s_maxpool_%d' % (prefix, level)
last_tensor = maxpool(pool_size=pool_size, name=name, padding=padding)(last_tensor)
# create the model and return
model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name)
return model
def conv_dec(nb_features,
input_shape,
nb_levels,
conv_size,
nb_labels,
name=None,
prefix=None,
feat_mult=1,
pool_size=2,
use_skip_connections=False,
skip_n_concatenations=0,
padding='same',
dilation_rate_mult=1,
activation='elu',
use_residuals=False,
final_pred_activation='softmax',
nb_conv_per_level=2,
layer_nb_feats=None,
batch_norm=None,
conv_dropout=0,
input_model=None):
"""Fully Convolutional Decoder"""
# naming
model_name = name
if prefix is None:
prefix = model_name
# if using skip connections, make sure need to use them.
if use_skip_connections:
assert input_model is not None, "is using skip connections, tensors dictionary is required"
# first layer: input
input_name = '%s_input' % prefix
if input_model is None:
input_tensor = KL.Input(shape=input_shape, name=input_name)
last_tensor = input_tensor
else:
input_tensor = input_model.input
last_tensor = input_model.output
input_shape = last_tensor.shape.as_list()[1:]
# vol size info
ndims = len(input_shape) - 1
if isinstance(pool_size, int):
if ndims > 1:
pool_size = (pool_size,) * ndims
# prepare layers
convL = getattr(KL, 'Conv%dD' % ndims)
conv_kwargs = {'padding': padding, 'activation': activation}
upsample = getattr(KL, 'UpSampling%dD' % ndims)
# up arm:
# nb_levels - 1 layers of Deconvolution3D
# (approx via up + conv + ReLu) + merge + conv + ReLu + conv + ReLu
lfidx = 0
for level in range(nb_levels - 1):
nb_lvl_feats = np.round(nb_features * feat_mult ** (nb_levels - 2 - level)).astype(int)
conv_kwargs['dilation_rate'] = dilation_rate_mult ** (nb_levels - 2 - level)
# upsample matching the max pooling layers size
name = '%s_up_%d' % (prefix, nb_levels + level)
last_tensor = upsample(size=pool_size, name=name)(last_tensor)
up_tensor = last_tensor
# merge layers combining previous layer
if use_skip_connections & (level < (nb_levels - skip_n_concatenations - 1)):
conv_name = '%s_conv_downarm_%d_%d' % (prefix, nb_levels - 2 - level, nb_conv_per_level - 1)
cat_tensor = input_model.get_layer(conv_name).output
name = '%s_merge_%d' % (prefix, nb_levels + level)
last_tensor = KL.concatenate([cat_tensor, last_tensor], axis=ndims + 1, name=name)
# convolution layers
for conv in range(nb_conv_per_level):
if layer_nb_feats is not None:
nb_lvl_feats = layer_nb_feats[lfidx]
lfidx += 1
name = '%s_conv_uparm_%d_%d' % (prefix, nb_levels + level, conv)
if conv < (nb_conv_per_level - 1) or (not use_residuals):
last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor)
else:
last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor)
if conv_dropout > 0:
name = '%s_dropout_uparm_%d_%d' % (prefix, level, conv)
noise_shape = [None, *[1] * ndims, nb_lvl_feats]
last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)
# residual block
if use_residuals:
# the "add" layer is the original input
# However, it may not have the right number of features to be added
add_layer = up_tensor
nb_feats_in = add_layer.get_shape()[-1]
nb_feats_out = last_tensor.get_shape()[-1]
if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out):
name = '%s_expand_up_merge_%d' % (prefix, level)
add_layer = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(add_layer)
if conv_dropout > 0:
noise_shape = [None, *[1] * ndims, nb_lvl_feats]
last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor)
name = '%s_res_up_merge_%d' % (prefix, level)
last_tensor = KL.add([last_tensor, add_layer], name=name)
name = '%s_res_up_merge_act_%d' % (prefix, level)
last_tensor = KL.Activation(activation, name=name)(last_tensor)
if batch_norm is not None:
name = '%s_bn_up_%d' % (prefix, level)
last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
# Compute likelihood prediction (no activation yet)
name = '%s_likelihood' % prefix
last_tensor = convL(nb_labels, 1, activation=None, name=name)(last_tensor)
like_tensor = last_tensor
# output prediction layer
# we use a softmax to compute P(L_x|I) where x is each location
if final_pred_activation == 'softmax':
name = '%s_prediction' % prefix
softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=ndims + 1)
pred_tensor = KL.Lambda(softmax_lambda_fcn, name=name)(last_tensor)
# otherwise create a layer that does nothing.
else:
name = '%s_prediction' % prefix
pred_tensor = KL.Activation('linear', name=name)(like_tensor)
# create the model and return
model = Model(inputs=input_tensor, outputs=pred_tensor, name=model_name)
return model
def add_prior(input_model,
prior_shape,
name='prior_model',
prefix=None,
use_logp=True,
final_pred_activation='softmax'):
"""
Append post-prior layer to a given model
"""
# naming
model_name = name
if prefix is None:
prefix = model_name
# prior input layer
prior_input_name = '%s-input' % prefix
prior_tensor = KL.Input(shape=prior_shape, name=prior_input_name)
prior_tensor_input = prior_tensor
like_tensor = input_model.output
# operation varies depending on whether we log() prior or not.
if use_logp:
print("Breaking change: use_logp option now requires log input!", file=sys.stderr)
merge_op = KL.add
else:
# using sigmoid to get the likelihood values between 0 and 1
# note: they won't add up to 1.
name = '%s_likelihood_sigmoid' % prefix
like_tensor = KL.Activation('sigmoid', name=name)(like_tensor)
merge_op = KL.multiply
# merge the likelihood and prior layers into posterior layer
name = '%s_posterior' % prefix
post_tensor = merge_op([prior_tensor, like_tensor], name=name)
# output prediction layer
# we use a softmax to compute P(L_x|I) where x is each location
pred_name = '%s_prediction' % prefix
if final_pred_activation == 'softmax':
assert use_logp, 'cannot do softmax when adding prior via P()'
print("using final_pred_activation %s for %s" % (final_pred_activation, model_name))
softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=-1)
pred_tensor = KL.Lambda(softmax_lambda_fcn, name=pred_name)(post_tensor)
else:
pred_tensor = KL.Activation('linear', name=pred_name)(post_tensor)
# create the model
model_inputs = [*input_model.inputs, prior_tensor_input]
model = Model(inputs=model_inputs, outputs=[pred_tensor], name=model_name)
# compile
return model
def single_ae(enc_size,
input_shape,
name='single_ae',
prefix=None,
ae_type='dense', # 'dense', or 'conv'
conv_size=None,
input_model=None,
enc_lambda_layers=None,
batch_norm=True,
padding='same',
activation=None,
include_mu_shift_layer=False,
do_vae=False):
"""single-layer Autoencoder (i.e. input - encoding - output"""
# naming
model_name = name
if prefix is None:
prefix = model_name
if enc_lambda_layers is None:
enc_lambda_layers = []
# prepare input
input_name = '%s_input' % prefix
if input_model is None:
assert input_shape is not None, 'input_shape of input_model is necessary'
input_tensor = KL.Input(shape=input_shape, name=input_name)
last_tensor = input_tensor
else:
input_tensor = input_model.input
last_tensor = input_model.output
input_shape = last_tensor.shape.as_list()[1:]
input_nb_feats = last_tensor.shape.as_list()[-1]
# prepare conv type based on input
ndims = len(input_shape) - 1
if ae_type == 'conv':
convL = getattr(KL, 'Conv%dD' % ndims)
assert conv_size is not None, 'with conv ae, need conv_size'
conv_kwargs = {'padding': padding, 'activation': activation}
enc_size_str = None
# if want to go through a dense layer in the middle of the U, need to:
# - flatten last layer if not flat
# - do dense encoding and decoding
# - unflatten (reshape spatially) at end
else: # ae_type == 'dense'
if len(input_shape) > 1:
name = '%s_ae_%s_down_flat' % (prefix, ae_type)
last_tensor = KL.Flatten(name=name)(last_tensor)
convL = conv_kwargs = None
assert len(enc_size) == 1, "enc_size should be of length 1 for dense layer"
enc_size_str = ''.join(['%d_' % d for d in enc_size])[:-1]
# recall this layer
pre_enc_layer = last_tensor
# encoding layer
if ae_type == 'dense':
name = '%s_ae_mu_enc_dense_%s' % (prefix, enc_size_str)
last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer)
else: # convolution
# convolve then resize. enc_size should be [nb_dim1, nb_dim2, ..., nb_feats]
assert len(enc_size) == len(input_shape), \
"encoding size does not match input shape %d %d" % (len(enc_size), len(input_shape))
if list(enc_size)[:-1] != list(input_shape)[:-1] and \
all([f is not None for f in input_shape[:-1]]) and \
all([f is not None for f in enc_size[:-1]]):
name = '%s_ae_mu_enc_conv' % prefix
last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer)
name = '%s_ae_mu_enc' % prefix
zf = [enc_size[:-1][f] / last_tensor.shape.as_list()[1:-1][f] for f in range(len(enc_size) - 1)]
last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor)
elif enc_size[-1] is None: # convolutional, but won't tell us bottleneck
name = '%s_ae_mu_enc' % prefix
last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer)
else:
name = '%s_ae_mu_enc' % prefix
last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer)
if include_mu_shift_layer:
# shift
name = '%s_ae_mu_shift' % prefix
last_tensor = layers.LocalBias(name=name)(last_tensor)
# encoding clean-up layers
for layer_fcn in enc_lambda_layers:
lambda_name = layer_fcn.__name__
name = '%s_ae_mu_%s' % (prefix, lambda_name)
last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor)
if batch_norm is not None:
name = '%s_ae_mu_bn' % prefix
last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
# have a simple layer that does nothing to have a clear name before sampling
name = '%s_ae_mu' % prefix
last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor)
# if doing variational AE, will need the sigma layer as well.
if do_vae:
mu_tensor = last_tensor
# encoding layer
if ae_type == 'dense':
name = '%s_ae_sigma_enc_dense_%s' % (prefix, enc_size_str)
last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer)
else:
if list(enc_size)[:-1] != list(input_shape)[:-1] and \
all([f is not None for f in input_shape[:-1]]) and \
all([f is not None for f in enc_size[:-1]]):
assert len(enc_size) - 1 == 2, "Sorry, I have not yet implemented non-2D resizing..."
name = '%s_ae_sigma_enc_conv' % prefix
last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer)
name = '%s_ae_sigma_enc' % prefix
resize_fn = lambda x: tf.image.resize_bilinear(x, enc_size[:-1])
last_tensor = KL.Lambda(resize_fn, name=name)(last_tensor)
elif enc_size[-1] is None: # convolutional, but won't tell us bottleneck
name = '%s_ae_sigma_enc' % prefix
last_tensor = convL(pre_enc_layer.shape.as_list()[-1], conv_size, name=name, **conv_kwargs)(
pre_enc_layer)
# cannot use lambda, then mu and sigma will be same layer.
# last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer)
else:
name = '%s_ae_sigma_enc' % prefix
last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer)
# encoding clean-up layers
for layer_fcn in enc_lambda_layers:
lambda_name = layer_fcn.__name__
name = '%s_ae_sigma_%s' % (prefix, lambda_name)
last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor)
if batch_norm is not None:
name = '%s_ae_sigma_bn' % prefix
last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
# have a simple layer that does nothing to have a clear name before sampling
name = '%s_ae_sigma' % prefix
last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor)
logvar_tensor = last_tensor
# VAE sampling
sampler = _VAESample().sample_z
name = '%s_ae_sample' % prefix
last_tensor = KL.Lambda(sampler, name=name)([mu_tensor, logvar_tensor])
if include_mu_shift_layer:
# shift
name = '%s_ae_sample_shift' % prefix
last_tensor = layers.LocalBias(name=name)(last_tensor)
# decoding layer
if ae_type == 'dense':
name = '%s_ae_%s_dec_flat_%s' % (prefix, ae_type, enc_size_str)
last_tensor = KL.Dense(np.prod(input_shape), name=name)(last_tensor)
# unflatten if dense method
if len(input_shape) > 1:
name = '%s_ae_%s_dec' % (prefix, ae_type)
last_tensor = KL.Reshape(input_shape, name=name)(last_tensor)
else:
if list(enc_size)[:-1] != list(input_shape)[:-1] and \
all([f is not None for f in input_shape[:-1]]) and \
all([f is not None for f in enc_size[:-1]]):
name = '%s_ae_mu_dec' % prefix
zf = [last_tensor.shape.as_list()[1:-1][f] / enc_size[:-1][f] for f in range(len(enc_size) - 1)]
last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor)
name = '%s_ae_%s_dec' % (prefix, ae_type)
last_tensor = convL(input_nb_feats, conv_size, name=name, **conv_kwargs)(last_tensor)
if batch_norm is not None:
name = '%s_bn_ae_%s_dec' % (prefix, ae_type)
last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
# create the model and return
model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name)
return model
###############################################################################
# Helper function
###############################################################################
class _VAESample:
def __init__(self):
pass
def sample_z(self, args):
mu, log_var = args
shape = K.shape(mu)
eps = K.random_normal(shape=shape, mean=0., stddev=1.)
return mu + K.exp(log_var / 2) * eps