--- a
+++ b/ext/lab2im/lab2im_model.py
@@ -0,0 +1,174 @@
+"""
+If you use this code, please cite the first SynthSeg paper:
+https://github.com/BBillot/lab2im/blob/master/bibtex.bib
+
+Copyright 2020 Benjamin Billot
+
+Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
+compliance with the License. You may obtain a copy of the License at
+https://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software distributed under the License is
+distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+implied. See the License for the specific language governing permissions and limitations under the
+License.
+"""
+
+
+# python imports
+import numpy as np
+import keras.layers as KL
+from keras.models import Model
+
+# project imports
+from ext.lab2im import utils
+from ext.lab2im import layers
+from ext.lab2im.edit_tensors import resample_tensor, blurring_sigma_for_downsampling
+
+
+def lab2im_model(labels_shape,
+                 n_channels,
+                 generation_labels,
+                 output_labels,
+                 atlas_res,
+                 target_res,
+                 output_shape=None,
+                 output_div_by_n=None,
+                 blur_range=1.15):
+    """
+    This function builds a keras/tensorflow model to generate images from provided label maps.
+    The images are generated by sampling a Gaussian Mixture Model (of given parameters), conditioned on the label map.
+    The model will take as inputs:
+        -a label map
+        -a vector containing the means of the Gaussian Mixture Model for each label,
+        -a vector containing the standard deviations of the Gaussian Mixture Model for each label,
+        -an array of size batch*(n_dims+1)*(n_dims+1) representing a linear transformation
+    The model returns:
+        -the generated image normalised between 0 and 1.
+        -the corresponding label map, with only the labels present in output_labels (the other are reset to zero).
+    :param labels_shape: shape of the input label maps. Can be a sequence or a 1d numpy array.
+    :param n_channels: number of channels to be synthetised.
+    :param generation_labels: list of all possible label values in the input label maps.
+    Can be a sequence or a 1d numpy array.
+    :param output_labels: list of the same length as generation_labels to indicate which values to use in the label maps
+    returned by this model, i.e. all occurrences of generation_labels[i] in the input label maps will be converted to
+    output_labels[i] in the returned label maps. Examples:
+    Set output_labels[i] to zero if you wish to erase the value generation_labels[i] from the returned label maps.
+    Set output_labels[i]=generation_labels[i] if you wish to keep the value generation_labels[i] in the returned maps.
+    Can be a list or a 1d numpy array. By default output_labels is equal to generation_labels.
+    :param atlas_res: resolution of the input label maps.
+    Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
+    :param target_res: target resolution of the generated images and corresponding label maps.
+    Can be a number (isotropic resolution), a sequence, or a 1d numpy array.
+    :param output_shape: (optional) desired shape of the output images.
+    If the atlas and target resolutions are the same, the output will be cropped to output_shape, and if the two
+    resolutions are different, the output will be resized with trilinear interpolation to output_shape.
+    Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array.
+    :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites output_shape
+    if necessary. Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array.
+    :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is given
+    or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a coefficient sampled
+    from a uniform distribution with bounds [1/blur_range, blur_range]. If None, no randomisation. Default is 1.15.
+    """
+
+    # reformat resolutions
+    labels_shape = utils.reformat_to_list(labels_shape)
+    n_dims, _ = utils.get_dims(labels_shape)
+    atlas_res = utils.reformat_to_n_channels_array(atlas_res, n_dims=n_dims)[0]
+    target_res = atlas_res if (target_res is None) else utils.reformat_to_n_channels_array(target_res, n_dims)[0]
+
+    # get shapes
+    crop_shape, output_shape = get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n)
+
+    # define model inputs
+    labels_input = KL.Input(shape=labels_shape+[1], name='labels_input', dtype='int32')
+    means_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='means_input')
+    stds_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='stds_input')
+
+    # deform labels
+    labels = layers.RandomSpatialDeformation(inter_method='nearest')(labels_input)
+
+    # cropping
+    if crop_shape != labels_shape:
+        labels._keras_shape = tuple(labels.get_shape().as_list())
+        labels = layers.RandomCrop(crop_shape)(labels)
+
+    # build synthetic image
+    labels._keras_shape = tuple(labels.get_shape().as_list())
+    image = layers.SampleConditionalGMM(generation_labels)([labels, means_input, stds_input])
+
+    # apply bias field
+    image._keras_shape = tuple(image.get_shape().as_list())
+    image = layers.BiasFieldCorruption(.3, .025, same_bias_for_all_channels=False)(image)
+
+    # intensity augmentation
+    image._keras_shape = tuple(image.get_shape().as_list())
+    image = layers.IntensityAugmentation(clip=300, normalise=True, gamma_std=.2)(image)
+
+    # blur image
+    sigma = blurring_sigma_for_downsampling(atlas_res, target_res)
+    image._keras_shape = tuple(image.get_shape().as_list())
+    image = layers.GaussianBlur(sigma=sigma, random_blur_range=blur_range)(image)
+
+    # resample to target res
+    if crop_shape != output_shape:
+        image = resample_tensor(image, output_shape, interp_method='linear')
+        labels = resample_tensor(labels, output_shape, interp_method='nearest')
+
+    # reset unwanted labels to zero
+    labels = layers.ConvertLabels(generation_labels, dest_values=output_labels, name='labels_out')(labels)
+
+    # build model (dummy layer enables to keep the labels when plugging this model to other models)
+    image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels])
+    brain_model = Model(inputs=[labels_input, means_input, stds_input], outputs=[image, labels])
+
+    return brain_model
+
+
+def get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n):
+
+    n_dims = len(atlas_res)
+
+    # get resampling factor
+    if atlas_res.tolist() != target_res.tolist():
+        resample_factor = [atlas_res[i] / float(target_res[i]) for i in range(n_dims)]
+    else:
+        resample_factor = None
+
+    # output shape specified, need to get cropping shape, and resample shape if necessary
+    if output_shape is not None:
+        output_shape = utils.reformat_to_list(output_shape, length=n_dims, dtype='int')
+
+        # make sure that output shape is smaller or equal to label shape
+        if resample_factor is not None:
+            output_shape = [min(int(labels_shape[i] * resample_factor[i]), output_shape[i]) for i in range(n_dims)]
+        else:
+            output_shape = [min(labels_shape[i], output_shape[i]) for i in range(n_dims)]
+
+        # make sure output shape is divisible by output_div_by_n
+        if output_div_by_n is not None:
+            tmp_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n)
+                         for s in output_shape]
+            if output_shape != tmp_shape:
+                print('output shape {0} not divisible by {1}, changed to {2}'.format(output_shape, output_div_by_n,
+                                                                                     tmp_shape))
+                output_shape = tmp_shape
+
+        # get cropping and resample shape
+        if resample_factor is not None:
+            cropping_shape = [int(np.around(output_shape[i]/resample_factor[i], 0)) for i in range(n_dims)]
+        else:
+            cropping_shape = output_shape
+
+    # no output shape specified, so no cropping unless label_shape is not divisible by output_div_by_n
+    else:
+        cropping_shape = labels_shape
+        if resample_factor is not None:
+            output_shape = [int(np.around(cropping_shape[i]*resample_factor[i], 0)) for i in range(n_dims)]
+        else:
+            output_shape = cropping_shape
+        # make sure output shape is divisible by output_div_by_n
+        if output_div_by_n is not None:
+            output_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n, answer_type='closer')
+                            for s in output_shape]
+
+    return cropping_shape, output_shape