Diff of /ext/neuron/models.py [000000] .. [e571d1]

Switch to side-by-side view

--- a
+++ b/ext/neuron/models.py
@@ -0,0 +1,768 @@
+"""
+tensorflow/keras utilities for the neuron project
+
+If you use this code, please cite 
+Dalca AV, Guttag J, Sabuncu MR
+Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, 
+CVPR 2018
+
+Contact: adalca [at] csail [dot] mit [dot] edu
+License: GPLv3
+"""
+
+import sys
+
+from ext.neuron import layers
+
+# third party
+import numpy as np
+import tensorflow as tf
+import keras
+import keras.layers as KL
+from keras.models import Model
+import keras.backend as K
+
+
+def unet(nb_features,
+         input_shape,
+         nb_levels,
+         conv_size,
+         nb_labels,
+         name='unet',
+         prefix=None,
+         feat_mult=1,
+         pool_size=2,
+         use_logp=True,
+         padding='same',
+         dilation_rate_mult=1,
+         activation='elu',
+         skip_n_concatenations=0,
+         use_residuals=False,
+         final_pred_activation='softmax',
+         nb_conv_per_level=1,
+         add_prior_layer=False,
+         layer_nb_feats=None,
+         conv_dropout=0,
+         batch_norm=None,
+         input_model=None):
+    """
+    unet-style keras model with an overdose of parametrization.
+
+    Parameters:
+        nb_features: the number of features at each convolutional level
+            see below for `feat_mult` and `layer_nb_feats` for modifiers to this number
+        input_shape: input layer shape, vector of size ndims + 1 (nb_channels)
+        conv_size: the convolution kernel size
+        nb_levels: the number of Unet levels (number of downsamples) in the "encoder" 
+            (e.g. 4 would give you 4 levels in encoder, 4 in decoder)
+        nb_labels: number of output channels
+        name (default: 'unet'): the name of the network
+        prefix (default: `name` value): prefix to be added to layer names
+        feat_mult (default: 1) multiple for `nb_features` as we go down the encoder levels.
+            e.g. feat_mult of 2 and nb_features of 16 would yield 32 features in the 
+            second layer, 64 features in the third layer, etc.
+        pool_size (default: 2): max pooling size (integer or list if specifying per dimension)
+        skip_n_concatenations=0: enabled to skip concatenation links between contracting and expanding paths for the n
+            top levels.
+        use_logp:
+        padding:
+        dilation_rate_mult:
+        activation:
+        use_residuals:
+        final_pred_activation:
+        nb_conv_per_level:
+        add_prior_layer:
+        skip_n_concatenations:
+        layer_nb_feats: list of the number of features for each layer. Automatically used if specified
+        conv_dropout: dropout probability
+        batch_norm:
+        input_model: concatenate the provided input_model to this current model.
+            Only the first output of input_model is used.
+    """
+
+    # naming
+    model_name = name
+    if prefix is None:
+        prefix = model_name
+
+    # volume size data
+    ndims = len(input_shape) - 1
+    if isinstance(pool_size, int):
+        pool_size = (pool_size,) * ndims
+
+    # get encoding model
+    enc_model = conv_enc(nb_features,
+                         input_shape,
+                         nb_levels,
+                         conv_size,
+                         name=model_name,
+                         prefix=prefix,
+                         feat_mult=feat_mult,
+                         pool_size=pool_size,
+                         padding=padding,
+                         dilation_rate_mult=dilation_rate_mult,
+                         activation=activation,
+                         use_residuals=use_residuals,
+                         nb_conv_per_level=nb_conv_per_level,
+                         layer_nb_feats=layer_nb_feats,
+                         conv_dropout=conv_dropout,
+                         batch_norm=batch_norm,
+                         input_model=input_model)
+
+    # get decoder
+    # use_skip_connections=True makes it a u-net
+    lnf = layer_nb_feats[(nb_levels * nb_conv_per_level):] if layer_nb_feats is not None else None
+    dec_model = conv_dec(nb_features,
+                         [],
+                         nb_levels,
+                         conv_size,
+                         nb_labels,
+                         name=model_name,
+                         prefix=prefix,
+                         feat_mult=feat_mult,
+                         pool_size=pool_size,
+                         use_skip_connections=True,
+                         skip_n_concatenations=skip_n_concatenations,
+                         padding=padding,
+                         dilation_rate_mult=dilation_rate_mult,
+                         activation=activation,
+                         use_residuals=use_residuals,
+                         final_pred_activation='linear' if add_prior_layer else final_pred_activation,
+                         nb_conv_per_level=nb_conv_per_level,
+                         batch_norm=batch_norm,
+                         layer_nb_feats=lnf,
+                         conv_dropout=conv_dropout,
+                         input_model=enc_model)
+    final_model = dec_model
+
+    if add_prior_layer:
+        final_model = add_prior(dec_model,
+                                [*input_shape[:-1], nb_labels],
+                                name=model_name + '_prior',
+                                use_logp=use_logp,
+                                final_pred_activation=final_pred_activation)
+
+    return final_model
+
+
+def ae(nb_features,
+       input_shape,
+       nb_levels,
+       conv_size,
+       nb_labels,
+       enc_size,
+       name='ae',
+       feat_mult=1,
+       pool_size=2,
+       padding='same',
+       activation='elu',
+       use_residuals=False,
+       nb_conv_per_level=1,
+       batch_norm=None,
+       enc_batch_norm=None,
+       ae_type='conv',  # 'dense', or 'conv'
+       enc_lambda_layers=None,
+       add_prior_layer=False,
+       use_logp=True,
+       conv_dropout=0,
+       include_mu_shift_layer=False,
+       single_model=False,  # whether to return a single model, or a tuple of models that can be stacked.
+       final_pred_activation='softmax',
+       do_vae=False,
+       input_model=None):
+    """Convolutional Auto-Encoder. Optionally Variational (if do_vae is set to True)."""
+
+    # naming
+    model_name = name
+
+    # volume size data
+    ndims = len(input_shape) - 1
+    if isinstance(pool_size, int):
+        pool_size = (pool_size,) * ndims
+
+    # get encoding model
+    enc_model = conv_enc(nb_features,
+                         input_shape,
+                         nb_levels,
+                         conv_size,
+                         name=model_name,
+                         feat_mult=feat_mult,
+                         pool_size=pool_size,
+                         padding=padding,
+                         activation=activation,
+                         use_residuals=use_residuals,
+                         nb_conv_per_level=nb_conv_per_level,
+                         conv_dropout=conv_dropout,
+                         batch_norm=batch_norm,
+                         input_model=input_model)
+
+    # middle AE structure
+    if single_model:
+        in_input_shape = None
+        in_model = enc_model
+    else:
+        in_input_shape = enc_model.output.shape.as_list()[1:]
+        in_model = None
+    mid_ae_model = single_ae(enc_size,
+                             in_input_shape,
+                             conv_size=conv_size,
+                             name=model_name,
+                             ae_type=ae_type,
+                             input_model=in_model,
+                             batch_norm=enc_batch_norm,
+                             enc_lambda_layers=enc_lambda_layers,
+                             include_mu_shift_layer=include_mu_shift_layer,
+                             do_vae=do_vae)
+
+    # decoder
+    if single_model:
+        in_input_shape = None
+        in_model = mid_ae_model
+    else:
+        in_input_shape = mid_ae_model.output.shape.as_list()[1:]
+        in_model = None
+    dec_model = conv_dec(nb_features,
+                         in_input_shape,
+                         nb_levels,
+                         conv_size,
+                         nb_labels,
+                         name=model_name,
+                         feat_mult=feat_mult,
+                         pool_size=pool_size,
+                         use_skip_connections=False,
+                         padding=padding,
+                         activation=activation,
+                         use_residuals=use_residuals,
+                         final_pred_activation='linear',
+                         nb_conv_per_level=nb_conv_per_level,
+                         batch_norm=batch_norm,
+                         conv_dropout=conv_dropout,
+                         input_model=in_model)
+
+    if add_prior_layer:
+        dec_model = add_prior(dec_model,
+                              [*input_shape[:-1], nb_labels],
+                              name=model_name,
+                              prefix=model_name + '_prior',
+                              use_logp=use_logp,
+                              final_pred_activation=final_pred_activation)
+
+    if single_model:
+        return dec_model
+    else:
+        return dec_model, mid_ae_model, enc_model
+
+
+def conv_enc(nb_features,
+             input_shape,
+             nb_levels,
+             conv_size,
+             name=None,
+             prefix=None,
+             feat_mult=1,
+             pool_size=2,
+             dilation_rate_mult=1,
+             padding='same',
+             activation='elu',
+             layer_nb_feats=None,
+             use_residuals=False,
+             nb_conv_per_level=2,
+             conv_dropout=0,
+             batch_norm=None,
+             input_model=None):
+    """Fully Convolutional Encoder"""
+
+    # naming
+    model_name = name
+    if prefix is None:
+        prefix = model_name
+
+    # first layer: input
+    name = '%s_input' % prefix
+    if input_model is None:
+        input_tensor = KL.Input(shape=input_shape, name=name)
+        last_tensor = input_tensor
+    else:
+        input_tensor = input_model.inputs
+        last_tensor = input_model.outputs
+        if isinstance(last_tensor, list):
+            last_tensor = last_tensor[0]
+
+    # volume size data
+    ndims = len(input_shape) - 1
+    if isinstance(pool_size, int):
+        pool_size = (pool_size,) * ndims
+
+    # prepare layers
+    convL = getattr(KL, 'Conv%dD' % ndims)
+    conv_kwargs = {'padding': padding, 'activation': activation, 'data_format': 'channels_last'}
+    maxpool = getattr(KL, 'MaxPooling%dD' % ndims)
+
+    # down arm:
+    # add nb_levels of conv + ReLu + conv + ReLu. Pool after each of first nb_levels - 1 layers
+    lfidx = 0  # level feature index
+    for level in range(nb_levels):
+        lvl_first_tensor = last_tensor
+        nb_lvl_feats = np.round(nb_features * feat_mult ** level).astype(int)
+        conv_kwargs['dilation_rate'] = dilation_rate_mult ** level
+
+        for conv in range(nb_conv_per_level):  # does several conv per level, max pooling applied at the end
+            if layer_nb_feats is not None:  # None or List of all the feature numbers
+                nb_lvl_feats = layer_nb_feats[lfidx]
+                lfidx += 1
+
+            name = '%s_conv_downarm_%d_%d' % (prefix, level, conv)
+            if conv < (nb_conv_per_level - 1) or (not use_residuals):
+                last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor)
+            else:  # no activation
+                last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor)
+
+            if conv_dropout > 0:
+                # conv dropout along feature space only
+                name = '%s_dropout_downarm_%d_%d' % (prefix, level, conv)
+                noise_shape = [None, *[1] * ndims, nb_lvl_feats]
+                last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)
+
+        if use_residuals:
+            convarm_layer = last_tensor
+
+            # the "add" layer is the original input
+            # However, it may not have the right number of features to be added
+            nb_feats_in = lvl_first_tensor.get_shape()[-1]
+            nb_feats_out = convarm_layer.get_shape()[-1]
+            add_layer = lvl_first_tensor
+            if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out):
+                name = '%s_expand_down_merge_%d' % (prefix, level)
+                last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(lvl_first_tensor)
+                add_layer = last_tensor
+
+                if conv_dropout > 0:
+                    noise_shape = [None, *[1] * ndims, nb_lvl_feats]
+                    convarm_layer = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor)
+
+            name = '%s_res_down_merge_%d' % (prefix, level)
+            last_tensor = KL.add([add_layer, convarm_layer], name=name)
+
+            name = '%s_res_down_merge_act_%d' % (prefix, level)
+            last_tensor = KL.Activation(activation, name=name)(last_tensor)
+
+        if batch_norm is not None:
+            name = '%s_bn_down_%d' % (prefix, level)
+            last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
+
+        # max pool if we're not at the last level
+        if level < (nb_levels - 1):
+            name = '%s_maxpool_%d' % (prefix, level)
+            last_tensor = maxpool(pool_size=pool_size, name=name, padding=padding)(last_tensor)
+
+    # create the model and return
+    model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name)
+    return model
+
+
+def conv_dec(nb_features,
+             input_shape,
+             nb_levels,
+             conv_size,
+             nb_labels,
+             name=None,
+             prefix=None,
+             feat_mult=1,
+             pool_size=2,
+             use_skip_connections=False,
+             skip_n_concatenations=0,
+             padding='same',
+             dilation_rate_mult=1,
+             activation='elu',
+             use_residuals=False,
+             final_pred_activation='softmax',
+             nb_conv_per_level=2,
+             layer_nb_feats=None,
+             batch_norm=None,
+             conv_dropout=0,
+             input_model=None):
+    """Fully Convolutional Decoder"""
+
+    # naming
+    model_name = name
+    if prefix is None:
+        prefix = model_name
+
+    # if using skip connections, make sure need to use them.
+    if use_skip_connections:
+        assert input_model is not None, "is using skip connections, tensors dictionary is required"
+
+    # first layer: input
+    input_name = '%s_input' % prefix
+    if input_model is None:
+        input_tensor = KL.Input(shape=input_shape, name=input_name)
+        last_tensor = input_tensor
+    else:
+        input_tensor = input_model.input
+        last_tensor = input_model.output
+        input_shape = last_tensor.shape.as_list()[1:]
+
+    # vol size info
+    ndims = len(input_shape) - 1
+    if isinstance(pool_size, int):
+        if ndims > 1:
+            pool_size = (pool_size,) * ndims
+
+    # prepare layers
+    convL = getattr(KL, 'Conv%dD' % ndims)
+    conv_kwargs = {'padding': padding, 'activation': activation}
+    upsample = getattr(KL, 'UpSampling%dD' % ndims)
+
+    # up arm:
+    # nb_levels - 1 layers of Deconvolution3D
+    #    (approx via up + conv + ReLu) + merge + conv + ReLu + conv + ReLu
+    lfidx = 0
+    for level in range(nb_levels - 1):
+        nb_lvl_feats = np.round(nb_features * feat_mult ** (nb_levels - 2 - level)).astype(int)
+        conv_kwargs['dilation_rate'] = dilation_rate_mult ** (nb_levels - 2 - level)
+
+        # upsample matching the max pooling layers size
+        name = '%s_up_%d' % (prefix, nb_levels + level)
+        last_tensor = upsample(size=pool_size, name=name)(last_tensor)
+        up_tensor = last_tensor
+
+        # merge layers combining previous layer
+        if use_skip_connections & (level < (nb_levels - skip_n_concatenations - 1)):
+            conv_name = '%s_conv_downarm_%d_%d' % (prefix, nb_levels - 2 - level, nb_conv_per_level - 1)
+            cat_tensor = input_model.get_layer(conv_name).output
+            name = '%s_merge_%d' % (prefix, nb_levels + level)
+            last_tensor = KL.concatenate([cat_tensor, last_tensor], axis=ndims + 1, name=name)
+
+        # convolution layers
+        for conv in range(nb_conv_per_level):
+            if layer_nb_feats is not None:
+                nb_lvl_feats = layer_nb_feats[lfidx]
+                lfidx += 1
+
+            name = '%s_conv_uparm_%d_%d' % (prefix, nb_levels + level, conv)
+            if conv < (nb_conv_per_level - 1) or (not use_residuals):
+                last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor)
+            else:
+                last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor)
+
+            if conv_dropout > 0:
+                name = '%s_dropout_uparm_%d_%d' % (prefix, level, conv)
+                noise_shape = [None, *[1] * ndims, nb_lvl_feats]
+                last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)
+
+        # residual block
+        if use_residuals:
+
+            # the "add" layer is the original input
+            # However, it may not have the right number of features to be added
+            add_layer = up_tensor
+            nb_feats_in = add_layer.get_shape()[-1]
+            nb_feats_out = last_tensor.get_shape()[-1]
+            if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out):
+                name = '%s_expand_up_merge_%d' % (prefix, level)
+                add_layer = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(add_layer)
+
+                if conv_dropout > 0:
+                    noise_shape = [None, *[1] * ndims, nb_lvl_feats]
+                    last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor)
+
+            name = '%s_res_up_merge_%d' % (prefix, level)
+            last_tensor = KL.add([last_tensor, add_layer], name=name)
+
+            name = '%s_res_up_merge_act_%d' % (prefix, level)
+            last_tensor = KL.Activation(activation, name=name)(last_tensor)
+
+        if batch_norm is not None:
+            name = '%s_bn_up_%d' % (prefix, level)
+            last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
+
+    # Compute likelihood prediction (no activation yet)
+    name = '%s_likelihood' % prefix
+    last_tensor = convL(nb_labels, 1, activation=None, name=name)(last_tensor)
+    like_tensor = last_tensor
+
+    # output prediction layer
+    # we use a softmax to compute P(L_x|I) where x is each location
+    if final_pred_activation == 'softmax':
+        name = '%s_prediction' % prefix
+        softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=ndims + 1)
+        pred_tensor = KL.Lambda(softmax_lambda_fcn, name=name)(last_tensor)
+
+    # otherwise create a layer that does nothing.
+    else:
+        name = '%s_prediction' % prefix
+        pred_tensor = KL.Activation('linear', name=name)(like_tensor)
+
+    # create the model and return
+    model = Model(inputs=input_tensor, outputs=pred_tensor, name=model_name)
+    return model
+
+
+def add_prior(input_model,
+              prior_shape,
+              name='prior_model',
+              prefix=None,
+              use_logp=True,
+              final_pred_activation='softmax'):
+    """
+    Append post-prior layer to a given model
+    """
+
+    # naming
+    model_name = name
+    if prefix is None:
+        prefix = model_name
+
+    # prior input layer
+    prior_input_name = '%s-input' % prefix
+    prior_tensor = KL.Input(shape=prior_shape, name=prior_input_name)
+    prior_tensor_input = prior_tensor
+    like_tensor = input_model.output
+
+    # operation varies depending on whether we log() prior or not.
+    if use_logp:
+        print("Breaking change: use_logp option now requires log input!", file=sys.stderr)
+        merge_op = KL.add
+
+    else:
+        # using sigmoid to get the likelihood values between 0 and 1
+        # note: they won't add up to 1.
+        name = '%s_likelihood_sigmoid' % prefix
+        like_tensor = KL.Activation('sigmoid', name=name)(like_tensor)
+        merge_op = KL.multiply
+
+    # merge the likelihood and prior layers into posterior layer
+    name = '%s_posterior' % prefix
+    post_tensor = merge_op([prior_tensor, like_tensor], name=name)
+
+    # output prediction layer
+    # we use a softmax to compute P(L_x|I) where x is each location
+    pred_name = '%s_prediction' % prefix
+    if final_pred_activation == 'softmax':
+        assert use_logp, 'cannot do softmax when adding prior via P()'
+        print("using final_pred_activation %s for %s" % (final_pred_activation, model_name))
+        softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=-1)
+        pred_tensor = KL.Lambda(softmax_lambda_fcn, name=pred_name)(post_tensor)
+
+    else:
+        pred_tensor = KL.Activation('linear', name=pred_name)(post_tensor)
+
+    # create the model
+    model_inputs = [*input_model.inputs, prior_tensor_input]
+    model = Model(inputs=model_inputs, outputs=[pred_tensor], name=model_name)
+
+    # compile
+    return model
+
+
+def single_ae(enc_size,
+              input_shape,
+              name='single_ae',
+              prefix=None,
+              ae_type='dense',  # 'dense', or 'conv'
+              conv_size=None,
+              input_model=None,
+              enc_lambda_layers=None,
+              batch_norm=True,
+              padding='same',
+              activation=None,
+              include_mu_shift_layer=False,
+              do_vae=False):
+    """single-layer Autoencoder (i.e. input - encoding - output"""
+
+    # naming
+    model_name = name
+    if prefix is None:
+        prefix = model_name
+
+    if enc_lambda_layers is None:
+        enc_lambda_layers = []
+
+    # prepare input
+    input_name = '%s_input' % prefix
+    if input_model is None:
+        assert input_shape is not None, 'input_shape of input_model is necessary'
+        input_tensor = KL.Input(shape=input_shape, name=input_name)
+        last_tensor = input_tensor
+    else:
+        input_tensor = input_model.input
+        last_tensor = input_model.output
+        input_shape = last_tensor.shape.as_list()[1:]
+    input_nb_feats = last_tensor.shape.as_list()[-1]
+
+    # prepare conv type based on input
+    ndims = len(input_shape) - 1
+    if ae_type == 'conv':
+        convL = getattr(KL, 'Conv%dD' % ndims)
+        assert conv_size is not None, 'with conv ae, need conv_size'
+        conv_kwargs = {'padding': padding, 'activation': activation}
+        enc_size_str = None
+
+    # if want to go through a dense layer in the middle of the U, need to:
+    # - flatten last layer if not flat
+    # - do dense encoding and decoding
+    # - unflatten (reshape spatially) at end
+    else:  # ae_type == 'dense'
+        if len(input_shape) > 1:
+            name = '%s_ae_%s_down_flat' % (prefix, ae_type)
+            last_tensor = KL.Flatten(name=name)(last_tensor)
+        convL = conv_kwargs = None
+        assert len(enc_size) == 1, "enc_size should be of length 1 for dense layer"
+        enc_size_str = ''.join(['%d_' % d for d in enc_size])[:-1]
+
+    # recall this layer
+    pre_enc_layer = last_tensor
+
+    # encoding layer
+    if ae_type == 'dense':
+        name = '%s_ae_mu_enc_dense_%s' % (prefix, enc_size_str)
+        last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer)
+
+    else:  # convolution
+
+        # convolve then resize. enc_size should be [nb_dim1, nb_dim2, ..., nb_feats]
+        assert len(enc_size) == len(input_shape), \
+            "encoding size does not match input shape %d %d" % (len(enc_size), len(input_shape))
+
+        if list(enc_size)[:-1] != list(input_shape)[:-1] and \
+                all([f is not None for f in input_shape[:-1]]) and \
+                all([f is not None for f in enc_size[:-1]]):
+
+            name = '%s_ae_mu_enc_conv' % prefix
+            last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer)
+
+            name = '%s_ae_mu_enc' % prefix
+            zf = [enc_size[:-1][f] / last_tensor.shape.as_list()[1:-1][f] for f in range(len(enc_size) - 1)]
+            last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor)
+
+        elif enc_size[-1] is None:  # convolutional, but won't tell us bottleneck
+            name = '%s_ae_mu_enc' % prefix
+            last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer)
+
+        else:
+            name = '%s_ae_mu_enc' % prefix
+            last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer)
+
+    if include_mu_shift_layer:
+        # shift
+        name = '%s_ae_mu_shift' % prefix
+        last_tensor = layers.LocalBias(name=name)(last_tensor)
+
+    # encoding clean-up layers
+    for layer_fcn in enc_lambda_layers:
+        lambda_name = layer_fcn.__name__
+        name = '%s_ae_mu_%s' % (prefix, lambda_name)
+        last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor)
+
+    if batch_norm is not None:
+        name = '%s_ae_mu_bn' % prefix
+        last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
+
+    # have a simple layer that does nothing to have a clear name before sampling
+    name = '%s_ae_mu' % prefix
+    last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor)
+
+    # if doing variational AE, will need the sigma layer as well.
+    if do_vae:
+        mu_tensor = last_tensor
+
+        # encoding layer
+        if ae_type == 'dense':
+            name = '%s_ae_sigma_enc_dense_%s' % (prefix, enc_size_str)
+            last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer)
+
+        else:
+            if list(enc_size)[:-1] != list(input_shape)[:-1] and \
+                    all([f is not None for f in input_shape[:-1]]) and \
+                    all([f is not None for f in enc_size[:-1]]):
+
+                assert len(enc_size) - 1 == 2, "Sorry, I have not yet implemented non-2D resizing..."
+                name = '%s_ae_sigma_enc_conv' % prefix
+                last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer)
+
+                name = '%s_ae_sigma_enc' % prefix
+                resize_fn = lambda x: tf.image.resize_bilinear(x, enc_size[:-1])
+                last_tensor = KL.Lambda(resize_fn, name=name)(last_tensor)
+
+            elif enc_size[-1] is None:  # convolutional, but won't tell us bottleneck
+                name = '%s_ae_sigma_enc' % prefix
+                last_tensor = convL(pre_enc_layer.shape.as_list()[-1], conv_size, name=name, **conv_kwargs)(
+                    pre_enc_layer)
+                # cannot use lambda, then mu and sigma will be same layer.
+                # last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer)
+
+            else:
+                name = '%s_ae_sigma_enc' % prefix
+                last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer)
+
+        # encoding clean-up layers
+        for layer_fcn in enc_lambda_layers:
+            lambda_name = layer_fcn.__name__
+            name = '%s_ae_sigma_%s' % (prefix, lambda_name)
+            last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor)
+
+        if batch_norm is not None:
+            name = '%s_ae_sigma_bn' % prefix
+            last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
+
+        # have a simple layer that does nothing to have a clear name before sampling
+        name = '%s_ae_sigma' % prefix
+        last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor)
+
+        logvar_tensor = last_tensor
+
+        # VAE sampling
+        sampler = _VAESample().sample_z
+
+        name = '%s_ae_sample' % prefix
+        last_tensor = KL.Lambda(sampler, name=name)([mu_tensor, logvar_tensor])
+
+    if include_mu_shift_layer:
+        # shift
+        name = '%s_ae_sample_shift' % prefix
+        last_tensor = layers.LocalBias(name=name)(last_tensor)
+
+    # decoding layer
+    if ae_type == 'dense':
+        name = '%s_ae_%s_dec_flat_%s' % (prefix, ae_type, enc_size_str)
+        last_tensor = KL.Dense(np.prod(input_shape), name=name)(last_tensor)
+
+        # unflatten if dense method
+        if len(input_shape) > 1:
+            name = '%s_ae_%s_dec' % (prefix, ae_type)
+            last_tensor = KL.Reshape(input_shape, name=name)(last_tensor)
+
+    else:
+
+        if list(enc_size)[:-1] != list(input_shape)[:-1] and \
+                all([f is not None for f in input_shape[:-1]]) and \
+                all([f is not None for f in enc_size[:-1]]):
+            name = '%s_ae_mu_dec' % prefix
+            zf = [last_tensor.shape.as_list()[1:-1][f] / enc_size[:-1][f] for f in range(len(enc_size) - 1)]
+            last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor)
+
+        name = '%s_ae_%s_dec' % (prefix, ae_type)
+        last_tensor = convL(input_nb_feats, conv_size, name=name, **conv_kwargs)(last_tensor)
+
+    if batch_norm is not None:
+        name = '%s_bn_ae_%s_dec' % (prefix, ae_type)
+        last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor)
+
+    # create the model and return
+    model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name)
+    return model
+
+
+###############################################################################
+# Helper function
+###############################################################################
+
+class _VAESample:
+    def __init__(self):
+        pass
+
+    def sample_z(self, args):
+        mu, log_var = args
+        shape = K.shape(mu)
+        eps = K.random_normal(shape=shape, mean=0., stddev=1.)
+        return mu + K.exp(log_var / 2) * eps