Diff of /ext/neuron/models.py [000000] .. [e571d1]

Switch to unified view

a b/ext/neuron/models.py
1
"""
2
tensorflow/keras utilities for the neuron project
3
4
If you use this code, please cite 
5
Dalca AV, Guttag J, Sabuncu MR
6
Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, 
7
CVPR 2018
8
9
Contact: adalca [at] csail [dot] mit [dot] edu
10
License: GPLv3
11
"""
12
13
import sys
14
15
from ext.neuron import layers
16
17
# third party
18
import numpy as np
19
import tensorflow as tf
20
import keras
21
import keras.layers as KL
22
from keras.models import Model
23
import keras.backend as K
24
25
26
def unet(nb_features,
27
         input_shape,
28
         nb_levels,
29
         conv_size,
30
         nb_labels,
31
         name='unet',
32
         prefix=None,
33
         feat_mult=1,
34
         pool_size=2,
35
         use_logp=True,
36
         padding='same',
37
         dilation_rate_mult=1,
38
         activation='elu',
39
         skip_n_concatenations=0,
40
         use_residuals=False,
41
         final_pred_activation='softmax',
42
         nb_conv_per_level=1,
43
         add_prior_layer=False,
44
         layer_nb_feats=None,
45
         conv_dropout=0,
46
         batch_norm=None,
47
         input_model=None):
48
    """
49
    unet-style keras model with an overdose of parametrization.
50
51
    Parameters:
52
        nb_features: the number of features at each convolutional level
53
            see below for `feat_mult` and `layer_nb_feats` for modifiers to this number
54
        input_shape: input layer shape, vector of size ndims + 1 (nb_channels)
55
        conv_size: the convolution kernel size
56
        nb_levels: the number of Unet levels (number of downsamples) in the "encoder" 
57
            (e.g. 4 would give you 4 levels in encoder, 4 in decoder)
58
        nb_labels: number of output channels
59
        name (default: 'unet'): the name of the network
60
        prefix (default: `name` value): prefix to be added to layer names
61
        feat_mult (default: 1) multiple for `nb_features` as we go down the encoder levels.
62
            e.g. feat_mult of 2 and nb_features of 16 would yield 32 features in the 
63
            second layer, 64 features in the third layer, etc.
64
        pool_size (default: 2): max pooling size (integer or list if specifying per dimension)
65
        skip_n_concatenations=0: enabled to skip concatenation links between contracting and expanding paths for the n
66
            top levels.
67
        use_logp:
68
        padding:
69
        dilation_rate_mult:
70
        activation:
71
        use_residuals:
72
        final_pred_activation:
73
        nb_conv_per_level:
74
        add_prior_layer:
75
        skip_n_concatenations:
76
        layer_nb_feats: list of the number of features for each layer. Automatically used if specified
77
        conv_dropout: dropout probability
78
        batch_norm:
79
        input_model: concatenate the provided input_model to this current model.
80
            Only the first output of input_model is used.
81
    """
82
83
    # naming
84
    model_name = name
85
    if prefix is None:
86
        prefix = model_name
87
88
    # volume size data
89
    ndims = len(input_shape) - 1
90
    if isinstance(pool_size, int):
91
        pool_size = (pool_size,) * ndims
92
93
    # get encoding model
94
    enc_model = conv_enc(nb_features,
95
                         input_shape,
96
                         nb_levels,
97
                         conv_size,
98
                         name=model_name,
99
                         prefix=prefix,
100
                         feat_mult=feat_mult,
101
                         pool_size=pool_size,
102
                         padding=padding,
103
                         dilation_rate_mult=dilation_rate_mult,
104
                         activation=activation,
105
                         use_residuals=use_residuals,
106
                         nb_conv_per_level=nb_conv_per_level,
107
                         layer_nb_feats=layer_nb_feats,
108
                         conv_dropout=conv_dropout,
109
                         batch_norm=batch_norm,
110
                         input_model=input_model)
111
112
    # get decoder
113
    # use_skip_connections=True makes it a u-net
114
    lnf = layer_nb_feats[(nb_levels * nb_conv_per_level):] if layer_nb_feats is not None else None
115
    dec_model = conv_dec(nb_features,
116
                         [],
117
                         nb_levels,
118
                         conv_size,
119
                         nb_labels,
120
                         name=model_name,
121
                         prefix=prefix,
122
                         feat_mult=feat_mult,
123
                         pool_size=pool_size,
124
                         use_skip_connections=True,
125
                         skip_n_concatenations=skip_n_concatenations,
126
                         padding=padding,
127
                         dilation_rate_mult=dilation_rate_mult,
128
                         activation=activation,
129
                         use_residuals=use_residuals,
130
                         final_pred_activation='linear' if add_prior_layer else final_pred_activation,
131
                         nb_conv_per_level=nb_conv_per_level,
132
                         batch_norm=batch_norm,
133
                         layer_nb_feats=lnf,
134
                         conv_dropout=conv_dropout,
135
                         input_model=enc_model)
136
    final_model = dec_model
137
138
    if add_prior_layer:
139
        final_model = add_prior(dec_model,
140
                                [*input_shape[:-1], nb_labels],
141
                                name=model_name + '_prior',
142
                                use_logp=use_logp,
143
                                final_pred_activation=final_pred_activation)
144
145
    return final_model
146
147
148
def ae(nb_features,
149
       input_shape,
150
       nb_levels,
151
       conv_size,
152
       nb_labels,
153
       enc_size,
154
       name='ae',
155
       feat_mult=1,
156
       pool_size=2,
157
       padding='same',
158
       activation='elu',
159
       use_residuals=False,
160
       nb_conv_per_level=1,
161
       batch_norm=None,
162
       enc_batch_norm=None,
163
       ae_type='conv',  # 'dense', or 'conv'
164
       enc_lambda_layers=None,
165
       add_prior_layer=False,
166
       use_logp=True,
167
       conv_dropout=0,
168
       include_mu_shift_layer=False,
169
       single_model=False,  # whether to return a single model, or a tuple of models that can be stacked.
170
       final_pred_activation='softmax',
171
       do_vae=False,
172
       input_model=None):
173
    """Convolutional Auto-Encoder. Optionally Variational (if do_vae is set to True)."""
174
175
    # naming
176
    model_name = name
177
178
    # volume size data
179
    ndims = len(input_shape) - 1
180
    if isinstance(pool_size, int):
181
        pool_size = (pool_size,) * ndims
182
183
    # get encoding model
184
    enc_model = conv_enc(nb_features,
185
                         input_shape,
186
                         nb_levels,
187
                         conv_size,
188
                         name=model_name,
189
                         feat_mult=feat_mult,
190
                         pool_size=pool_size,
191
                         padding=padding,
192
                         activation=activation,
193
                         use_residuals=use_residuals,
194
                         nb_conv_per_level=nb_conv_per_level,
195
                         conv_dropout=conv_dropout,
196
                         batch_norm=batch_norm,
197
                         input_model=input_model)
198
199
    # middle AE structure
200
    if single_model:
201
        in_input_shape = None
202
        in_model = enc_model
203
    else:
204
        in_input_shape = enc_model.output.shape.as_list()[1:]
205
        in_model = None
206
    mid_ae_model = single_ae(enc_size,
207
                             in_input_shape,
208
                             conv_size=conv_size,
209
                             name=model_name,
210
                             ae_type=ae_type,
211
                             input_model=in_model,
212
                             batch_norm=enc_batch_norm,
213
                             enc_lambda_layers=enc_lambda_layers,
214
                             include_mu_shift_layer=include_mu_shift_layer,
215
                             do_vae=do_vae)
216
217
    # decoder
218
    if single_model:
219
        in_input_shape = None
220
        in_model = mid_ae_model
221
    else:
222
        in_input_shape = mid_ae_model.output.shape.as_list()[1:]
223
        in_model = None
224
    dec_model = conv_dec(nb_features,
225
                         in_input_shape,
226
                         nb_levels,
227
                         conv_size,
228
                         nb_labels,
229
                         name=model_name,
230
                         feat_mult=feat_mult,
231
                         pool_size=pool_size,
232
                         use_skip_connections=False,
233
                         padding=padding,
234
                         activation=activation,
235
                         use_residuals=use_residuals,
236
                         final_pred_activation='linear',
237
                         nb_conv_per_level=nb_conv_per_level,
238
                         batch_norm=batch_norm,
239
                         conv_dropout=conv_dropout,
240
                         input_model=in_model)
241
242
    if add_prior_layer:
243
        dec_model = add_prior(dec_model,
244
                              [*input_shape[:-1], nb_labels],
245
                              name=model_name,
246
                              prefix=model_name + '_prior',
247
                              use_logp=use_logp,
248
                              final_pred_activation=final_pred_activation)
249
250
    if single_model:
251
        return dec_model
252
    else:
253
        return dec_model, mid_ae_model, enc_model
254
255
256
def conv_enc(nb_features,
257
             input_shape,
258
             nb_levels,
259
             conv_size,
260
             name=None,
261
             prefix=None,
262
             feat_mult=1,
263
             pool_size=2,
264
             dilation_rate_mult=1,
265
             padding='same',
266
             activation='elu',
267
             layer_nb_feats=None,
268
             use_residuals=False,
269
             nb_conv_per_level=2,
270
             conv_dropout=0,
271
             batch_norm=None,
272
             input_model=None):
273
    """Fully Convolutional Encoder"""
274
275
    # naming
276
    model_name = name
277
    if prefix is None:
278
        prefix = model_name
279
280
    # first layer: input
281
    name = '%s_input' % prefix
282
    if input_model is None:
283
        input_tensor = KL.Input(shape=input_shape, name=name)
284
        last_tensor = input_tensor
285
    else:
286
        input_tensor = input_model.inputs
287
        last_tensor = input_model.outputs
288
        if isinstance(last_tensor, list):
289
            last_tensor = last_tensor[0]
290
291
    # volume size data
292
    ndims = len(input_shape) - 1
293
    if isinstance(pool_size, int):
294
        pool_size = (pool_size,) * ndims
295
296
    # prepare layers
297
    convL = getattr(KL, 'Conv%dD' % ndims)
298
    conv_kwargs = {'padding': padding, 'activation': activation, 'data_format': 'channels_last'}
299
    maxpool = getattr(KL, 'MaxPooling%dD' % ndims)
300
301
    # down arm:
302
    # add nb_levels of conv + ReLu + conv + ReLu. Pool after each of first nb_levels - 1 layers
303
    lfidx = 0  # level feature index
304
    for level in range(nb_levels):
305
        lvl_first_tensor = last_tensor
306
        nb_lvl_feats = np.round(nb_features * feat_mult ** level).astype(int)
307
        conv_kwargs['dilation_rate'] = dilation_rate_mult ** level
308
309
        for conv in range(nb_conv_per_level):  # does several conv per level, max pooling applied at the end
310
            if layer_nb_feats is not None:  # None or List of all the feature numbers
311
                nb_lvl_feats = layer_nb_feats[lfidx]
312
                lfidx += 1
313
314
            name = '%s_conv_downarm_%d_%d' % (prefix, level, conv)
315
            if conv < (nb_conv_per_level - 1) or (not use_residuals):
316
                last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor)
317
            else:  # no activation
318
                last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor)
319
320
            if conv_dropout > 0:
321
                # conv dropout along feature space only
322
                name = '%s_dropout_downarm_%d_%d' % (prefix, level, conv)
323
                noise_shape = [None, *[1] * ndims, nb_lvl_feats]
324
                last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)
325
326
        if use_residuals:
327
            convarm_layer = last_tensor
328
329
            # the "add" layer is the original input
330
            # However, it may not have the right number of features to be added
331
            nb_feats_in = lvl_first_tensor.get_shape()[-1]
332
            nb_feats_out = convarm_layer.get_shape()[-1]
333
            add_layer = lvl_first_tensor
334
            if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out):
335
                name = '%s_expand_down_merge_%d' % (prefix, level)
336
                last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(lvl_first_tensor)
337
                add_layer = last_tensor
338
339
                if conv_dropout > 0:
340
                    noise_shape = [None, *[1] * ndims, nb_lvl_feats]
341
                    convarm_layer = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor)
342
343
            name = '%s_res_down_merge_%d' % (prefix, level)
344
            last_tensor = KL.add([add_layer, convarm_layer], name=name)
345
346
            name = '%s_res_down_merge_act_%d' % (prefix, level)
347
            last_tensor = KL.Activation(activation, name=name)(last_tensor)
348
349
        if batch_norm is not None:
350
            name = '%s_bn_down_%d' % (prefix, level)
351
            last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
352
353
        # max pool if we're not at the last level
354
        if level < (nb_levels - 1):
355
            name = '%s_maxpool_%d' % (prefix, level)
356
            last_tensor = maxpool(pool_size=pool_size, name=name, padding=padding)(last_tensor)
357
358
    # create the model and return
359
    model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name)
360
    return model
361
362
363
def conv_dec(nb_features,
364
             input_shape,
365
             nb_levels,
366
             conv_size,
367
             nb_labels,
368
             name=None,
369
             prefix=None,
370
             feat_mult=1,
371
             pool_size=2,
372
             use_skip_connections=False,
373
             skip_n_concatenations=0,
374
             padding='same',
375
             dilation_rate_mult=1,
376
             activation='elu',
377
             use_residuals=False,
378
             final_pred_activation='softmax',
379
             nb_conv_per_level=2,
380
             layer_nb_feats=None,
381
             batch_norm=None,
382
             conv_dropout=0,
383
             input_model=None):
384
    """Fully Convolutional Decoder"""
385
386
    # naming
387
    model_name = name
388
    if prefix is None:
389
        prefix = model_name
390
391
    # if using skip connections, make sure need to use them.
392
    if use_skip_connections:
393
        assert input_model is not None, "is using skip connections, tensors dictionary is required"
394
395
    # first layer: input
396
    input_name = '%s_input' % prefix
397
    if input_model is None:
398
        input_tensor = KL.Input(shape=input_shape, name=input_name)
399
        last_tensor = input_tensor
400
    else:
401
        input_tensor = input_model.input
402
        last_tensor = input_model.output
403
        input_shape = last_tensor.shape.as_list()[1:]
404
405
    # vol size info
406
    ndims = len(input_shape) - 1
407
    if isinstance(pool_size, int):
408
        if ndims > 1:
409
            pool_size = (pool_size,) * ndims
410
411
    # prepare layers
412
    convL = getattr(KL, 'Conv%dD' % ndims)
413
    conv_kwargs = {'padding': padding, 'activation': activation}
414
    upsample = getattr(KL, 'UpSampling%dD' % ndims)
415
416
    # up arm:
417
    # nb_levels - 1 layers of Deconvolution3D
418
    #    (approx via up + conv + ReLu) + merge + conv + ReLu + conv + ReLu
419
    lfidx = 0
420
    for level in range(nb_levels - 1):
421
        nb_lvl_feats = np.round(nb_features * feat_mult ** (nb_levels - 2 - level)).astype(int)
422
        conv_kwargs['dilation_rate'] = dilation_rate_mult ** (nb_levels - 2 - level)
423
424
        # upsample matching the max pooling layers size
425
        name = '%s_up_%d' % (prefix, nb_levels + level)
426
        last_tensor = upsample(size=pool_size, name=name)(last_tensor)
427
        up_tensor = last_tensor
428
429
        # merge layers combining previous layer
430
        if use_skip_connections & (level < (nb_levels - skip_n_concatenations - 1)):
431
            conv_name = '%s_conv_downarm_%d_%d' % (prefix, nb_levels - 2 - level, nb_conv_per_level - 1)
432
            cat_tensor = input_model.get_layer(conv_name).output
433
            name = '%s_merge_%d' % (prefix, nb_levels + level)
434
            last_tensor = KL.concatenate([cat_tensor, last_tensor], axis=ndims + 1, name=name)
435
436
        # convolution layers
437
        for conv in range(nb_conv_per_level):
438
            if layer_nb_feats is not None:
439
                nb_lvl_feats = layer_nb_feats[lfidx]
440
                lfidx += 1
441
442
            name = '%s_conv_uparm_%d_%d' % (prefix, nb_levels + level, conv)
443
            if conv < (nb_conv_per_level - 1) or (not use_residuals):
444
                last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor)
445
            else:
446
                last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor)
447
448
            if conv_dropout > 0:
449
                name = '%s_dropout_uparm_%d_%d' % (prefix, level, conv)
450
                noise_shape = [None, *[1] * ndims, nb_lvl_feats]
451
                last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)
452
453
        # residual block
454
        if use_residuals:
455
456
            # the "add" layer is the original input
457
            # However, it may not have the right number of features to be added
458
            add_layer = up_tensor
459
            nb_feats_in = add_layer.get_shape()[-1]
460
            nb_feats_out = last_tensor.get_shape()[-1]
461
            if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out):
462
                name = '%s_expand_up_merge_%d' % (prefix, level)
463
                add_layer = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(add_layer)
464
465
                if conv_dropout > 0:
466
                    noise_shape = [None, *[1] * ndims, nb_lvl_feats]
467
                    last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor)
468
469
            name = '%s_res_up_merge_%d' % (prefix, level)
470
            last_tensor = KL.add([last_tensor, add_layer], name=name)
471
472
            name = '%s_res_up_merge_act_%d' % (prefix, level)
473
            last_tensor = KL.Activation(activation, name=name)(last_tensor)
474
475
        if batch_norm is not None:
476
            name = '%s_bn_up_%d' % (prefix, level)
477
            last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
478
479
    # Compute likelihood prediction (no activation yet)
480
    name = '%s_likelihood' % prefix
481
    last_tensor = convL(nb_labels, 1, activation=None, name=name)(last_tensor)
482
    like_tensor = last_tensor
483
484
    # output prediction layer
485
    # we use a softmax to compute P(L_x|I) where x is each location
486
    if final_pred_activation == 'softmax':
487
        name = '%s_prediction' % prefix
488
        softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=ndims + 1)
489
        pred_tensor = KL.Lambda(softmax_lambda_fcn, name=name)(last_tensor)
490
491
    # otherwise create a layer that does nothing.
492
    else:
493
        name = '%s_prediction' % prefix
494
        pred_tensor = KL.Activation('linear', name=name)(like_tensor)
495
496
    # create the model and return
497
    model = Model(inputs=input_tensor, outputs=pred_tensor, name=model_name)
498
    return model
499
500
501
def add_prior(input_model,
502
              prior_shape,
503
              name='prior_model',
504
              prefix=None,
505
              use_logp=True,
506
              final_pred_activation='softmax'):
507
    """
508
    Append post-prior layer to a given model
509
    """
510
511
    # naming
512
    model_name = name
513
    if prefix is None:
514
        prefix = model_name
515
516
    # prior input layer
517
    prior_input_name = '%s-input' % prefix
518
    prior_tensor = KL.Input(shape=prior_shape, name=prior_input_name)
519
    prior_tensor_input = prior_tensor
520
    like_tensor = input_model.output
521
522
    # operation varies depending on whether we log() prior or not.
523
    if use_logp:
524
        print("Breaking change: use_logp option now requires log input!", file=sys.stderr)
525
        merge_op = KL.add
526
527
    else:
528
        # using sigmoid to get the likelihood values between 0 and 1
529
        # note: they won't add up to 1.
530
        name = '%s_likelihood_sigmoid' % prefix
531
        like_tensor = KL.Activation('sigmoid', name=name)(like_tensor)
532
        merge_op = KL.multiply
533
534
    # merge the likelihood and prior layers into posterior layer
535
    name = '%s_posterior' % prefix
536
    post_tensor = merge_op([prior_tensor, like_tensor], name=name)
537
538
    # output prediction layer
539
    # we use a softmax to compute P(L_x|I) where x is each location
540
    pred_name = '%s_prediction' % prefix
541
    if final_pred_activation == 'softmax':
542
        assert use_logp, 'cannot do softmax when adding prior via P()'
543
        print("using final_pred_activation %s for %s" % (final_pred_activation, model_name))
544
        softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=-1)
545
        pred_tensor = KL.Lambda(softmax_lambda_fcn, name=pred_name)(post_tensor)
546
547
    else:
548
        pred_tensor = KL.Activation('linear', name=pred_name)(post_tensor)
549
550
    # create the model
551
    model_inputs = [*input_model.inputs, prior_tensor_input]
552
    model = Model(inputs=model_inputs, outputs=[pred_tensor], name=model_name)
553
554
    # compile
555
    return model
556
557
558
def single_ae(enc_size,
559
              input_shape,
560
              name='single_ae',
561
              prefix=None,
562
              ae_type='dense',  # 'dense', or 'conv'
563
              conv_size=None,
564
              input_model=None,
565
              enc_lambda_layers=None,
566
              batch_norm=True,
567
              padding='same',
568
              activation=None,
569
              include_mu_shift_layer=False,
570
              do_vae=False):
571
    """single-layer Autoencoder (i.e. input - encoding - output"""
572
573
    # naming
574
    model_name = name
575
    if prefix is None:
576
        prefix = model_name
577
578
    if enc_lambda_layers is None:
579
        enc_lambda_layers = []
580
581
    # prepare input
582
    input_name = '%s_input' % prefix
583
    if input_model is None:
584
        assert input_shape is not None, 'input_shape of input_model is necessary'
585
        input_tensor = KL.Input(shape=input_shape, name=input_name)
586
        last_tensor = input_tensor
587
    else:
588
        input_tensor = input_model.input
589
        last_tensor = input_model.output
590
        input_shape = last_tensor.shape.as_list()[1:]
591
    input_nb_feats = last_tensor.shape.as_list()[-1]
592
593
    # prepare conv type based on input
594
    ndims = len(input_shape) - 1
595
    if ae_type == 'conv':
596
        convL = getattr(KL, 'Conv%dD' % ndims)
597
        assert conv_size is not None, 'with conv ae, need conv_size'
598
        conv_kwargs = {'padding': padding, 'activation': activation}
599
        enc_size_str = None
600
601
    # if want to go through a dense layer in the middle of the U, need to:
602
    # - flatten last layer if not flat
603
    # - do dense encoding and decoding
604
    # - unflatten (reshape spatially) at end
605
    else:  # ae_type == 'dense'
606
        if len(input_shape) > 1:
607
            name = '%s_ae_%s_down_flat' % (prefix, ae_type)
608
            last_tensor = KL.Flatten(name=name)(last_tensor)
609
        convL = conv_kwargs = None
610
        assert len(enc_size) == 1, "enc_size should be of length 1 for dense layer"
611
        enc_size_str = ''.join(['%d_' % d for d in enc_size])[:-1]
612
613
    # recall this layer
614
    pre_enc_layer = last_tensor
615
616
    # encoding layer
617
    if ae_type == 'dense':
618
        name = '%s_ae_mu_enc_dense_%s' % (prefix, enc_size_str)
619
        last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer)
620
621
    else:  # convolution
622
623
        # convolve then resize. enc_size should be [nb_dim1, nb_dim2, ..., nb_feats]
624
        assert len(enc_size) == len(input_shape), \
625
            "encoding size does not match input shape %d %d" % (len(enc_size), len(input_shape))
626
627
        if list(enc_size)[:-1] != list(input_shape)[:-1] and \
628
                all([f is not None for f in input_shape[:-1]]) and \
629
                all([f is not None for f in enc_size[:-1]]):
630
631
            name = '%s_ae_mu_enc_conv' % prefix
632
            last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer)
633
634
            name = '%s_ae_mu_enc' % prefix
635
            zf = [enc_size[:-1][f] / last_tensor.shape.as_list()[1:-1][f] for f in range(len(enc_size) - 1)]
636
            last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor)
637
638
        elif enc_size[-1] is None:  # convolutional, but won't tell us bottleneck
639
            name = '%s_ae_mu_enc' % prefix
640
            last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer)
641
642
        else:
643
            name = '%s_ae_mu_enc' % prefix
644
            last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer)
645
646
    if include_mu_shift_layer:
647
        # shift
648
        name = '%s_ae_mu_shift' % prefix
649
        last_tensor = layers.LocalBias(name=name)(last_tensor)
650
651
    # encoding clean-up layers
652
    for layer_fcn in enc_lambda_layers:
653
        lambda_name = layer_fcn.__name__
654
        name = '%s_ae_mu_%s' % (prefix, lambda_name)
655
        last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor)
656
657
    if batch_norm is not None:
658
        name = '%s_ae_mu_bn' % prefix
659
        last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
660
661
    # have a simple layer that does nothing to have a clear name before sampling
662
    name = '%s_ae_mu' % prefix
663
    last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor)
664
665
    # if doing variational AE, will need the sigma layer as well.
666
    if do_vae:
667
        mu_tensor = last_tensor
668
669
        # encoding layer
670
        if ae_type == 'dense':
671
            name = '%s_ae_sigma_enc_dense_%s' % (prefix, enc_size_str)
672
            last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer)
673
674
        else:
675
            if list(enc_size)[:-1] != list(input_shape)[:-1] and \
676
                    all([f is not None for f in input_shape[:-1]]) and \
677
                    all([f is not None for f in enc_size[:-1]]):
678
679
                assert len(enc_size) - 1 == 2, "Sorry, I have not yet implemented non-2D resizing..."
680
                name = '%s_ae_sigma_enc_conv' % prefix
681
                last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer)
682
683
                name = '%s_ae_sigma_enc' % prefix
684
                resize_fn = lambda x: tf.image.resize_bilinear(x, enc_size[:-1])
685
                last_tensor = KL.Lambda(resize_fn, name=name)(last_tensor)
686
687
            elif enc_size[-1] is None:  # convolutional, but won't tell us bottleneck
688
                name = '%s_ae_sigma_enc' % prefix
689
                last_tensor = convL(pre_enc_layer.shape.as_list()[-1], conv_size, name=name, **conv_kwargs)(
690
                    pre_enc_layer)
691
                # cannot use lambda, then mu and sigma will be same layer.
692
                # last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer)
693
694
            else:
695
                name = '%s_ae_sigma_enc' % prefix
696
                last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer)
697
698
        # encoding clean-up layers
699
        for layer_fcn in enc_lambda_layers:
700
            lambda_name = layer_fcn.__name__
701
            name = '%s_ae_sigma_%s' % (prefix, lambda_name)
702
            last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor)
703
704
        if batch_norm is not None:
705
            name = '%s_ae_sigma_bn' % prefix
706
            last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
707
708
        # have a simple layer that does nothing to have a clear name before sampling
709
        name = '%s_ae_sigma' % prefix
710
        last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor)
711
712
        logvar_tensor = last_tensor
713
714
        # VAE sampling
715
        sampler = _VAESample().sample_z
716
717
        name = '%s_ae_sample' % prefix
718
        last_tensor = KL.Lambda(sampler, name=name)([mu_tensor, logvar_tensor])
719
720
    if include_mu_shift_layer:
721
        # shift
722
        name = '%s_ae_sample_shift' % prefix
723
        last_tensor = layers.LocalBias(name=name)(last_tensor)
724
725
    # decoding layer
726
    if ae_type == 'dense':
727
        name = '%s_ae_%s_dec_flat_%s' % (prefix, ae_type, enc_size_str)
728
        last_tensor = KL.Dense(np.prod(input_shape), name=name)(last_tensor)
729
730
        # unflatten if dense method
731
        if len(input_shape) > 1:
732
            name = '%s_ae_%s_dec' % (prefix, ae_type)
733
            last_tensor = KL.Reshape(input_shape, name=name)(last_tensor)
734
735
    else:
736
737
        if list(enc_size)[:-1] != list(input_shape)[:-1] and \
738
                all([f is not None for f in input_shape[:-1]]) and \
739
                all([f is not None for f in enc_size[:-1]]):
740
            name = '%s_ae_mu_dec' % prefix
741
            zf = [last_tensor.shape.as_list()[1:-1][f] / enc_size[:-1][f] for f in range(len(enc_size) - 1)]
742
            last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor)
743
744
        name = '%s_ae_%s_dec' % (prefix, ae_type)
745
        last_tensor = convL(input_nb_feats, conv_size, name=name, **conv_kwargs)(last_tensor)
746
747
    if batch_norm is not None:
748
        name = '%s_bn_ae_%s_dec' % (prefix, ae_type)
749
        last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
750
751
    # create the model and return
752
    model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name)
753
    return model
754
755
756
###############################################################################
757
# Helper function
758
###############################################################################
759
760
class _VAESample:
761
    def __init__(self):
762
        pass
763
764
    def sample_z(self, args):
765
        mu, log_var = args
766
        shape = K.shape(mu)
767
        eps = K.random_normal(shape=shape, mean=0., stddev=1.)
768
        return mu + K.exp(log_var / 2) * eps