--- a +++ b/ext/neuron/models.py @@ -0,0 +1,768 @@ +""" +tensorflow/keras utilities for the neuron project + +If you use this code, please cite +Dalca AV, Guttag J, Sabuncu MR +Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, +CVPR 2018 + +Contact: adalca [at] csail [dot] mit [dot] edu +License: GPLv3 +""" + +import sys + +from ext.neuron import layers + +# third party +import numpy as np +import tensorflow as tf +import keras +import keras.layers as KL +from keras.models import Model +import keras.backend as K + + +def unet(nb_features, + input_shape, + nb_levels, + conv_size, + nb_labels, + name='unet', + prefix=None, + feat_mult=1, + pool_size=2, + use_logp=True, + padding='same', + dilation_rate_mult=1, + activation='elu', + skip_n_concatenations=0, + use_residuals=False, + final_pred_activation='softmax', + nb_conv_per_level=1, + add_prior_layer=False, + layer_nb_feats=None, + conv_dropout=0, + batch_norm=None, + input_model=None): + """ + unet-style keras model with an overdose of parametrization. + + Parameters: + nb_features: the number of features at each convolutional level + see below for `feat_mult` and `layer_nb_feats` for modifiers to this number + input_shape: input layer shape, vector of size ndims + 1 (nb_channels) + conv_size: the convolution kernel size + nb_levels: the number of Unet levels (number of downsamples) in the "encoder" + (e.g. 4 would give you 4 levels in encoder, 4 in decoder) + nb_labels: number of output channels + name (default: 'unet'): the name of the network + prefix (default: `name` value): prefix to be added to layer names + feat_mult (default: 1) multiple for `nb_features` as we go down the encoder levels. + e.g. feat_mult of 2 and nb_features of 16 would yield 32 features in the + second layer, 64 features in the third layer, etc. + pool_size (default: 2): max pooling size (integer or list if specifying per dimension) + skip_n_concatenations=0: enabled to skip concatenation links between contracting and expanding paths for the n + top levels. + use_logp: + padding: + dilation_rate_mult: + activation: + use_residuals: + final_pred_activation: + nb_conv_per_level: + add_prior_layer: + skip_n_concatenations: + layer_nb_feats: list of the number of features for each layer. Automatically used if specified + conv_dropout: dropout probability + batch_norm: + input_model: concatenate the provided input_model to this current model. + Only the first output of input_model is used. + """ + + # naming + model_name = name + if prefix is None: + prefix = model_name + + # volume size data + ndims = len(input_shape) - 1 + if isinstance(pool_size, int): + pool_size = (pool_size,) * ndims + + # get encoding model + enc_model = conv_enc(nb_features, + input_shape, + nb_levels, + conv_size, + name=model_name, + prefix=prefix, + feat_mult=feat_mult, + pool_size=pool_size, + padding=padding, + dilation_rate_mult=dilation_rate_mult, + activation=activation, + use_residuals=use_residuals, + nb_conv_per_level=nb_conv_per_level, + layer_nb_feats=layer_nb_feats, + conv_dropout=conv_dropout, + batch_norm=batch_norm, + input_model=input_model) + + # get decoder + # use_skip_connections=True makes it a u-net + lnf = layer_nb_feats[(nb_levels * nb_conv_per_level):] if layer_nb_feats is not None else None + dec_model = conv_dec(nb_features, + [], + nb_levels, + conv_size, + nb_labels, + name=model_name, + prefix=prefix, + feat_mult=feat_mult, + pool_size=pool_size, + use_skip_connections=True, + skip_n_concatenations=skip_n_concatenations, + padding=padding, + dilation_rate_mult=dilation_rate_mult, + activation=activation, + use_residuals=use_residuals, + final_pred_activation='linear' if add_prior_layer else final_pred_activation, + nb_conv_per_level=nb_conv_per_level, + batch_norm=batch_norm, + layer_nb_feats=lnf, + conv_dropout=conv_dropout, + input_model=enc_model) + final_model = dec_model + + if add_prior_layer: + final_model = add_prior(dec_model, + [*input_shape[:-1], nb_labels], + name=model_name + '_prior', + use_logp=use_logp, + final_pred_activation=final_pred_activation) + + return final_model + + +def ae(nb_features, + input_shape, + nb_levels, + conv_size, + nb_labels, + enc_size, + name='ae', + feat_mult=1, + pool_size=2, + padding='same', + activation='elu', + use_residuals=False, + nb_conv_per_level=1, + batch_norm=None, + enc_batch_norm=None, + ae_type='conv', # 'dense', or 'conv' + enc_lambda_layers=None, + add_prior_layer=False, + use_logp=True, + conv_dropout=0, + include_mu_shift_layer=False, + single_model=False, # whether to return a single model, or a tuple of models that can be stacked. + final_pred_activation='softmax', + do_vae=False, + input_model=None): + """Convolutional Auto-Encoder. Optionally Variational (if do_vae is set to True).""" + + # naming + model_name = name + + # volume size data + ndims = len(input_shape) - 1 + if isinstance(pool_size, int): + pool_size = (pool_size,) * ndims + + # get encoding model + enc_model = conv_enc(nb_features, + input_shape, + nb_levels, + conv_size, + name=model_name, + feat_mult=feat_mult, + pool_size=pool_size, + padding=padding, + activation=activation, + use_residuals=use_residuals, + nb_conv_per_level=nb_conv_per_level, + conv_dropout=conv_dropout, + batch_norm=batch_norm, + input_model=input_model) + + # middle AE structure + if single_model: + in_input_shape = None + in_model = enc_model + else: + in_input_shape = enc_model.output.shape.as_list()[1:] + in_model = None + mid_ae_model = single_ae(enc_size, + in_input_shape, + conv_size=conv_size, + name=model_name, + ae_type=ae_type, + input_model=in_model, + batch_norm=enc_batch_norm, + enc_lambda_layers=enc_lambda_layers, + include_mu_shift_layer=include_mu_shift_layer, + do_vae=do_vae) + + # decoder + if single_model: + in_input_shape = None + in_model = mid_ae_model + else: + in_input_shape = mid_ae_model.output.shape.as_list()[1:] + in_model = None + dec_model = conv_dec(nb_features, + in_input_shape, + nb_levels, + conv_size, + nb_labels, + name=model_name, + feat_mult=feat_mult, + pool_size=pool_size, + use_skip_connections=False, + padding=padding, + activation=activation, + use_residuals=use_residuals, + final_pred_activation='linear', + nb_conv_per_level=nb_conv_per_level, + batch_norm=batch_norm, + conv_dropout=conv_dropout, + input_model=in_model) + + if add_prior_layer: + dec_model = add_prior(dec_model, + [*input_shape[:-1], nb_labels], + name=model_name, + prefix=model_name + '_prior', + use_logp=use_logp, + final_pred_activation=final_pred_activation) + + if single_model: + return dec_model + else: + return dec_model, mid_ae_model, enc_model + + +def conv_enc(nb_features, + input_shape, + nb_levels, + conv_size, + name=None, + prefix=None, + feat_mult=1, + pool_size=2, + dilation_rate_mult=1, + padding='same', + activation='elu', + layer_nb_feats=None, + use_residuals=False, + nb_conv_per_level=2, + conv_dropout=0, + batch_norm=None, + input_model=None): + """Fully Convolutional Encoder""" + + # naming + model_name = name + if prefix is None: + prefix = model_name + + # first layer: input + name = '%s_input' % prefix + if input_model is None: + input_tensor = KL.Input(shape=input_shape, name=name) + last_tensor = input_tensor + else: + input_tensor = input_model.inputs + last_tensor = input_model.outputs + if isinstance(last_tensor, list): + last_tensor = last_tensor[0] + + # volume size data + ndims = len(input_shape) - 1 + if isinstance(pool_size, int): + pool_size = (pool_size,) * ndims + + # prepare layers + convL = getattr(KL, 'Conv%dD' % ndims) + conv_kwargs = {'padding': padding, 'activation': activation, 'data_format': 'channels_last'} + maxpool = getattr(KL, 'MaxPooling%dD' % ndims) + + # down arm: + # add nb_levels of conv + ReLu + conv + ReLu. Pool after each of first nb_levels - 1 layers + lfidx = 0 # level feature index + for level in range(nb_levels): + lvl_first_tensor = last_tensor + nb_lvl_feats = np.round(nb_features * feat_mult ** level).astype(int) + conv_kwargs['dilation_rate'] = dilation_rate_mult ** level + + for conv in range(nb_conv_per_level): # does several conv per level, max pooling applied at the end + if layer_nb_feats is not None: # None or List of all the feature numbers + nb_lvl_feats = layer_nb_feats[lfidx] + lfidx += 1 + + name = '%s_conv_downarm_%d_%d' % (prefix, level, conv) + if conv < (nb_conv_per_level - 1) or (not use_residuals): + last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor) + else: # no activation + last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor) + + if conv_dropout > 0: + # conv dropout along feature space only + name = '%s_dropout_downarm_%d_%d' % (prefix, level, conv) + noise_shape = [None, *[1] * ndims, nb_lvl_feats] + last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor) + + if use_residuals: + convarm_layer = last_tensor + + # the "add" layer is the original input + # However, it may not have the right number of features to be added + nb_feats_in = lvl_first_tensor.get_shape()[-1] + nb_feats_out = convarm_layer.get_shape()[-1] + add_layer = lvl_first_tensor + if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out): + name = '%s_expand_down_merge_%d' % (prefix, level) + last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(lvl_first_tensor) + add_layer = last_tensor + + if conv_dropout > 0: + noise_shape = [None, *[1] * ndims, nb_lvl_feats] + convarm_layer = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor) + + name = '%s_res_down_merge_%d' % (prefix, level) + last_tensor = KL.add([add_layer, convarm_layer], name=name) + + name = '%s_res_down_merge_act_%d' % (prefix, level) + last_tensor = KL.Activation(activation, name=name)(last_tensor) + + if batch_norm is not None: + name = '%s_bn_down_%d' % (prefix, level) + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # max pool if we're not at the last level + if level < (nb_levels - 1): + name = '%s_maxpool_%d' % (prefix, level) + last_tensor = maxpool(pool_size=pool_size, name=name, padding=padding)(last_tensor) + + # create the model and return + model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name) + return model + + +def conv_dec(nb_features, + input_shape, + nb_levels, + conv_size, + nb_labels, + name=None, + prefix=None, + feat_mult=1, + pool_size=2, + use_skip_connections=False, + skip_n_concatenations=0, + padding='same', + dilation_rate_mult=1, + activation='elu', + use_residuals=False, + final_pred_activation='softmax', + nb_conv_per_level=2, + layer_nb_feats=None, + batch_norm=None, + conv_dropout=0, + input_model=None): + """Fully Convolutional Decoder""" + + # naming + model_name = name + if prefix is None: + prefix = model_name + + # if using skip connections, make sure need to use them. + if use_skip_connections: + assert input_model is not None, "is using skip connections, tensors dictionary is required" + + # first layer: input + input_name = '%s_input' % prefix + if input_model is None: + input_tensor = KL.Input(shape=input_shape, name=input_name) + last_tensor = input_tensor + else: + input_tensor = input_model.input + last_tensor = input_model.output + input_shape = last_tensor.shape.as_list()[1:] + + # vol size info + ndims = len(input_shape) - 1 + if isinstance(pool_size, int): + if ndims > 1: + pool_size = (pool_size,) * ndims + + # prepare layers + convL = getattr(KL, 'Conv%dD' % ndims) + conv_kwargs = {'padding': padding, 'activation': activation} + upsample = getattr(KL, 'UpSampling%dD' % ndims) + + # up arm: + # nb_levels - 1 layers of Deconvolution3D + # (approx via up + conv + ReLu) + merge + conv + ReLu + conv + ReLu + lfidx = 0 + for level in range(nb_levels - 1): + nb_lvl_feats = np.round(nb_features * feat_mult ** (nb_levels - 2 - level)).astype(int) + conv_kwargs['dilation_rate'] = dilation_rate_mult ** (nb_levels - 2 - level) + + # upsample matching the max pooling layers size + name = '%s_up_%d' % (prefix, nb_levels + level) + last_tensor = upsample(size=pool_size, name=name)(last_tensor) + up_tensor = last_tensor + + # merge layers combining previous layer + if use_skip_connections & (level < (nb_levels - skip_n_concatenations - 1)): + conv_name = '%s_conv_downarm_%d_%d' % (prefix, nb_levels - 2 - level, nb_conv_per_level - 1) + cat_tensor = input_model.get_layer(conv_name).output + name = '%s_merge_%d' % (prefix, nb_levels + level) + last_tensor = KL.concatenate([cat_tensor, last_tensor], axis=ndims + 1, name=name) + + # convolution layers + for conv in range(nb_conv_per_level): + if layer_nb_feats is not None: + nb_lvl_feats = layer_nb_feats[lfidx] + lfidx += 1 + + name = '%s_conv_uparm_%d_%d' % (prefix, nb_levels + level, conv) + if conv < (nb_conv_per_level - 1) or (not use_residuals): + last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor) + else: + last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor) + + if conv_dropout > 0: + name = '%s_dropout_uparm_%d_%d' % (prefix, level, conv) + noise_shape = [None, *[1] * ndims, nb_lvl_feats] + last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor) + + # residual block + if use_residuals: + + # the "add" layer is the original input + # However, it may not have the right number of features to be added + add_layer = up_tensor + nb_feats_in = add_layer.get_shape()[-1] + nb_feats_out = last_tensor.get_shape()[-1] + if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out): + name = '%s_expand_up_merge_%d' % (prefix, level) + add_layer = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(add_layer) + + if conv_dropout > 0: + noise_shape = [None, *[1] * ndims, nb_lvl_feats] + last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor) + + name = '%s_res_up_merge_%d' % (prefix, level) + last_tensor = KL.add([last_tensor, add_layer], name=name) + + name = '%s_res_up_merge_act_%d' % (prefix, level) + last_tensor = KL.Activation(activation, name=name)(last_tensor) + + if batch_norm is not None: + name = '%s_bn_up_%d' % (prefix, level) + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # Compute likelihood prediction (no activation yet) + name = '%s_likelihood' % prefix + last_tensor = convL(nb_labels, 1, activation=None, name=name)(last_tensor) + like_tensor = last_tensor + + # output prediction layer + # we use a softmax to compute P(L_x|I) where x is each location + if final_pred_activation == 'softmax': + name = '%s_prediction' % prefix + softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=ndims + 1) + pred_tensor = KL.Lambda(softmax_lambda_fcn, name=name)(last_tensor) + + # otherwise create a layer that does nothing. + else: + name = '%s_prediction' % prefix + pred_tensor = KL.Activation('linear', name=name)(like_tensor) + + # create the model and return + model = Model(inputs=input_tensor, outputs=pred_tensor, name=model_name) + return model + + +def add_prior(input_model, + prior_shape, + name='prior_model', + prefix=None, + use_logp=True, + final_pred_activation='softmax'): + """ + Append post-prior layer to a given model + """ + + # naming + model_name = name + if prefix is None: + prefix = model_name + + # prior input layer + prior_input_name = '%s-input' % prefix + prior_tensor = KL.Input(shape=prior_shape, name=prior_input_name) + prior_tensor_input = prior_tensor + like_tensor = input_model.output + + # operation varies depending on whether we log() prior or not. + if use_logp: + print("Breaking change: use_logp option now requires log input!", file=sys.stderr) + merge_op = KL.add + + else: + # using sigmoid to get the likelihood values between 0 and 1 + # note: they won't add up to 1. + name = '%s_likelihood_sigmoid' % prefix + like_tensor = KL.Activation('sigmoid', name=name)(like_tensor) + merge_op = KL.multiply + + # merge the likelihood and prior layers into posterior layer + name = '%s_posterior' % prefix + post_tensor = merge_op([prior_tensor, like_tensor], name=name) + + # output prediction layer + # we use a softmax to compute P(L_x|I) where x is each location + pred_name = '%s_prediction' % prefix + if final_pred_activation == 'softmax': + assert use_logp, 'cannot do softmax when adding prior via P()' + print("using final_pred_activation %s for %s" % (final_pred_activation, model_name)) + softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=-1) + pred_tensor = KL.Lambda(softmax_lambda_fcn, name=pred_name)(post_tensor) + + else: + pred_tensor = KL.Activation('linear', name=pred_name)(post_tensor) + + # create the model + model_inputs = [*input_model.inputs, prior_tensor_input] + model = Model(inputs=model_inputs, outputs=[pred_tensor], name=model_name) + + # compile + return model + + +def single_ae(enc_size, + input_shape, + name='single_ae', + prefix=None, + ae_type='dense', # 'dense', or 'conv' + conv_size=None, + input_model=None, + enc_lambda_layers=None, + batch_norm=True, + padding='same', + activation=None, + include_mu_shift_layer=False, + do_vae=False): + """single-layer Autoencoder (i.e. input - encoding - output""" + + # naming + model_name = name + if prefix is None: + prefix = model_name + + if enc_lambda_layers is None: + enc_lambda_layers = [] + + # prepare input + input_name = '%s_input' % prefix + if input_model is None: + assert input_shape is not None, 'input_shape of input_model is necessary' + input_tensor = KL.Input(shape=input_shape, name=input_name) + last_tensor = input_tensor + else: + input_tensor = input_model.input + last_tensor = input_model.output + input_shape = last_tensor.shape.as_list()[1:] + input_nb_feats = last_tensor.shape.as_list()[-1] + + # prepare conv type based on input + ndims = len(input_shape) - 1 + if ae_type == 'conv': + convL = getattr(KL, 'Conv%dD' % ndims) + assert conv_size is not None, 'with conv ae, need conv_size' + conv_kwargs = {'padding': padding, 'activation': activation} + enc_size_str = None + + # if want to go through a dense layer in the middle of the U, need to: + # - flatten last layer if not flat + # - do dense encoding and decoding + # - unflatten (reshape spatially) at end + else: # ae_type == 'dense' + if len(input_shape) > 1: + name = '%s_ae_%s_down_flat' % (prefix, ae_type) + last_tensor = KL.Flatten(name=name)(last_tensor) + convL = conv_kwargs = None + assert len(enc_size) == 1, "enc_size should be of length 1 for dense layer" + enc_size_str = ''.join(['%d_' % d for d in enc_size])[:-1] + + # recall this layer + pre_enc_layer = last_tensor + + # encoding layer + if ae_type == 'dense': + name = '%s_ae_mu_enc_dense_%s' % (prefix, enc_size_str) + last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer) + + else: # convolution + + # convolve then resize. enc_size should be [nb_dim1, nb_dim2, ..., nb_feats] + assert len(enc_size) == len(input_shape), \ + "encoding size does not match input shape %d %d" % (len(enc_size), len(input_shape)) + + if list(enc_size)[:-1] != list(input_shape)[:-1] and \ + all([f is not None for f in input_shape[:-1]]) and \ + all([f is not None for f in enc_size[:-1]]): + + name = '%s_ae_mu_enc_conv' % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) + + name = '%s_ae_mu_enc' % prefix + zf = [enc_size[:-1][f] / last_tensor.shape.as_list()[1:-1][f] for f in range(len(enc_size) - 1)] + last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor) + + elif enc_size[-1] is None: # convolutional, but won't tell us bottleneck + name = '%s_ae_mu_enc' % prefix + last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer) + + else: + name = '%s_ae_mu_enc' % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) + + if include_mu_shift_layer: + # shift + name = '%s_ae_mu_shift' % prefix + last_tensor = layers.LocalBias(name=name)(last_tensor) + + # encoding clean-up layers + for layer_fcn in enc_lambda_layers: + lambda_name = layer_fcn.__name__ + name = '%s_ae_mu_%s' % (prefix, lambda_name) + last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor) + + if batch_norm is not None: + name = '%s_ae_mu_bn' % prefix + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # have a simple layer that does nothing to have a clear name before sampling + name = '%s_ae_mu' % prefix + last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor) + + # if doing variational AE, will need the sigma layer as well. + if do_vae: + mu_tensor = last_tensor + + # encoding layer + if ae_type == 'dense': + name = '%s_ae_sigma_enc_dense_%s' % (prefix, enc_size_str) + last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer) + + else: + if list(enc_size)[:-1] != list(input_shape)[:-1] and \ + all([f is not None for f in input_shape[:-1]]) and \ + all([f is not None for f in enc_size[:-1]]): + + assert len(enc_size) - 1 == 2, "Sorry, I have not yet implemented non-2D resizing..." + name = '%s_ae_sigma_enc_conv' % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) + + name = '%s_ae_sigma_enc' % prefix + resize_fn = lambda x: tf.image.resize_bilinear(x, enc_size[:-1]) + last_tensor = KL.Lambda(resize_fn, name=name)(last_tensor) + + elif enc_size[-1] is None: # convolutional, but won't tell us bottleneck + name = '%s_ae_sigma_enc' % prefix + last_tensor = convL(pre_enc_layer.shape.as_list()[-1], conv_size, name=name, **conv_kwargs)( + pre_enc_layer) + # cannot use lambda, then mu and sigma will be same layer. + # last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer) + + else: + name = '%s_ae_sigma_enc' % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) + + # encoding clean-up layers + for layer_fcn in enc_lambda_layers: + lambda_name = layer_fcn.__name__ + name = '%s_ae_sigma_%s' % (prefix, lambda_name) + last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor) + + if batch_norm is not None: + name = '%s_ae_sigma_bn' % prefix + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # have a simple layer that does nothing to have a clear name before sampling + name = '%s_ae_sigma' % prefix + last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor) + + logvar_tensor = last_tensor + + # VAE sampling + sampler = _VAESample().sample_z + + name = '%s_ae_sample' % prefix + last_tensor = KL.Lambda(sampler, name=name)([mu_tensor, logvar_tensor]) + + if include_mu_shift_layer: + # shift + name = '%s_ae_sample_shift' % prefix + last_tensor = layers.LocalBias(name=name)(last_tensor) + + # decoding layer + if ae_type == 'dense': + name = '%s_ae_%s_dec_flat_%s' % (prefix, ae_type, enc_size_str) + last_tensor = KL.Dense(np.prod(input_shape), name=name)(last_tensor) + + # unflatten if dense method + if len(input_shape) > 1: + name = '%s_ae_%s_dec' % (prefix, ae_type) + last_tensor = KL.Reshape(input_shape, name=name)(last_tensor) + + else: + + if list(enc_size)[:-1] != list(input_shape)[:-1] and \ + all([f is not None for f in input_shape[:-1]]) and \ + all([f is not None for f in enc_size[:-1]]): + name = '%s_ae_mu_dec' % prefix + zf = [last_tensor.shape.as_list()[1:-1][f] / enc_size[:-1][f] for f in range(len(enc_size) - 1)] + last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor) + + name = '%s_ae_%s_dec' % (prefix, ae_type) + last_tensor = convL(input_nb_feats, conv_size, name=name, **conv_kwargs)(last_tensor) + + if batch_norm is not None: + name = '%s_bn_ae_%s_dec' % (prefix, ae_type) + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # create the model and return + model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name) + return model + + +############################################################################### +# Helper function +############################################################################### + +class _VAESample: + def __init__(self): + pass + + def sample_z(self, args): + mu, log_var = args + shape = K.shape(mu) + eps = K.random_normal(shape=shape, mean=0., stddev=1.) + return mu + K.exp(log_var / 2) * eps