--- a
+++ b/SynthSeg/training_denoiser.py
@@ -0,0 +1,256 @@
+"""
+If you use this code, please cite one of the SynthSeg papers:
+https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
+
+Copyright 2020 Benjamin Billot
+
+Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
+compliance with the License. You may obtain a copy of the License at
+https://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software distributed under the License is
+distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+implied. See the License for the specific language governing permissions and limitations under the
+License.
+"""
+
+
+# python imports
+import os
+import numpy as np
+import tensorflow as tf
+from keras import models
+from keras import layers as KL
+
+# project imports
+from SynthSeg import metrics_model as metrics
+from SynthSeg.training import train_model
+from SynthSeg.labels_to_image_model import get_shapes
+from SynthSeg.training_supervised import build_model_inputs
+
+# third-party imports
+from ext.lab2im import utils, layers
+from ext.neuron import models as nrn_models
+
+
+def training(list_paths_input_labels,
+             list_paths_target_labels,
+             model_dir,
+             input_segmentation_labels,
+             target_segmentation_labels=None,
+             subjects_prob=None,
+             batchsize=1,
+             output_shape=None,
+             scaling_bounds=.2,
+             rotation_bounds=15,
+             shearing_bounds=.012,
+             nonlin_std=3.,
+             nonlin_scale=.04,
+             prob_erosion_dilation=0.3,
+             min_erosion_dilation=4,
+             max_erosion_dilation=5,
+             n_levels=5,
+             nb_conv_per_level=2,
+             conv_size=5,
+             unet_feat_count=16,
+             feat_multiplier=2,
+             activation='elu',
+             skip_n_concatenations=2,
+             lr=1e-4,
+             wl2_epochs=1,
+             dice_epochs=50,
+             steps_per_epoch=10000,
+             checkpoint=None):
+    """
+
+    This function trains a UNet to segment MRI images with synthetic scans generated by sampling a GMM conditioned on
+    label maps. We regroup the parameters in four categories: General, Augmentation, Architecture, Training.
+
+    # IMPORTANT !!!
+    # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence),
+    # these values refer to the RAS axes.
+
+    :param list_paths_input_labels: list of all the paths of the input label maps. These correspond to "noisy"
+    segmentations that the denoiser will be trained to correct.
+    :param list_paths_target_labels: list of all the paths of the output label maps. Must have the same order as
+    list_paths_input_labels. These are the target label maps that the network will learn to produce given the "noisy"
+    input label maps.
+    :param model_dir: path of a directory where the models will be saved during training.
+    :param input_segmentation_labels: list of all the label values present in the input label maps.
+    :param target_segmentation_labels: list of all the label values present in the output label maps. By default (None)
+    this will be taken to be the same as input_segmentation_labels.
+
+    # ----------------------------------------------- General parameters -----------------------------------------------
+    # label maps parameters
+    :param subjects_prob: (optional) relative order of importance (doesn't have to be probabilistic), with which to pick
+    the provided label maps at each minibatch. Can be a sequence, a 1D numpy array, or the path to such an array, and it
+    must be as long as path_label_maps. By default, all label maps are chosen with the same importance.
+
+    # output-related parameters
+    :param batchsize: (optional) number of images to generate per mini-batch. Default is 1.
+    :param output_shape: (optional) desired shape of the output image, obtained by randomly cropping the generated image
+    Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array.
+    Default is None, where no cropping is performed.
+
+    # --------------------------------------------- Augmentation parameters --------------------------------------------
+    # spatial deformation parameters
+    :param scaling_bounds: (optional) if apply_linear_trans is True, the scaling factor for each dimension is
+    sampled from a uniform distribution of predefined bounds. Can either be:
+    1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds
+    (1-scaling_bounds, 1+scaling_bounds) for each dimension.
+    2) the path to a numpy array of shape (2, n_dims), in which case the scaling factor in dimension i is sampled from
+    the uniform distribution of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension.
+    3) False, in which case scaling is completely turned off.
+    Default is scaling_bounds = 0.2 (case 1)
+    :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for case 1 the
+    bounds are centred on 0 rather than 1, i.e. (0+rotation_bounds[i], 0-rotation_bounds[i]).
+    Default is rotation_bounds = 15.
+    :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012.
+    :param nonlin_std: (optional) Standard deviation of the normal distribution from which we sample the first
+    tensor for synthesising the deformation field. Set to 0 to completely deactivate elastic deformation.
+    :param nonlin_scale: (optional) Ratio between the size of the input label maps and the size of the sampled
+    tensor for synthesising the elastic deformation field.
+    
+    # degradation of the input labels
+    :param prob_erosion_dilation: (optional) probability with which to degrade the input label maps with erosion or 
+    dilation. If 0, then no erosion/dilation is applied to the label maps given as inputs to the network.
+    :param min_erosion_dilation: (optional) when prob_erosion_dilation is not zero, erosion and dilation of random
+    coefficients are applied. Set the minimum erosion/dilation coefficient here.
+    :param max_erosion_dilation: (optional) Set the maximum erosion/dilation coefficient here.
+
+    # ------------------------------------------ UNet architecture parameters ------------------------------------------
+    :param n_levels: (optional) number of level for the Unet. Default is 5.
+    :param nb_conv_per_level: (optional) number of convolutional layers per level. Default is 2.
+    :param conv_size: (optional) size of the convolution kernels. Default is 2.
+    :param unet_feat_count: (optional) number of feature for the first layer of the UNet. Default is 24.
+    :param feat_multiplier: (optional) multiply the number of feature by this number at each new level. Default is 2.
+    :param activation: (optional) activation function. Can be 'elu', 'relu'.
+    :param skip_n_concatenations: (optional) number of levels for which to remove the traditional skip connections of
+    the UNet architecture. default is zero, which corresponds to the classic UNet architecture. Example:
+    If skip_n_concatenations = 2, then we will remove the concatenation link between the two top levels of the UNet.
+
+    # ----------------------------------------------- Training parameters ----------------------------------------------
+    :param lr: (optional) learning rate for the training. Default is 1e-4
+    :param wl2_epochs: (optional) number of epochs for which the network (except the soft-max layer) is trained with L2
+    norm loss function. Default is 1.
+    :param dice_epochs: (optional) number of epochs with the soft Dice loss function. Default is 50.
+    :param steps_per_epoch: (optional) number of steps per epoch. Default is 10000. Since no online validation is
+    possible, this is equivalent to the frequency at which the models are saved.
+    :param checkpoint: (optional) path of an already saved model to load before starting the training.
+    """
+
+    # check epochs
+    assert (wl2_epochs > 0) | (dice_epochs > 0), \
+        'either wl2_epochs or dice_epochs must be positive, had {0} and {1}'.format(wl2_epochs, dice_epochs)
+
+    # prepare data files
+    input_label_list, _ = utils.get_list_labels(label_list=input_segmentation_labels)
+    if target_segmentation_labels is None:
+        target_label_list = input_label_list
+    else:
+        target_label_list, _ = utils.get_list_labels(label_list=target_segmentation_labels)
+    n_labels = np.size(target_label_list)
+
+    # create augmentation model
+    labels_shape, _, _, _, _, _ = utils.get_volume_info(list_paths_input_labels[0], aff_ref=np.eye(4))
+    augmentation_model = build_augmentation_model(labels_shape,
+                                                  input_label_list,
+                                                  crop_shape=output_shape,
+                                                  output_div_by_n=2 ** n_levels,
+                                                  scaling_bounds=scaling_bounds,
+                                                  rotation_bounds=rotation_bounds,
+                                                  shearing_bounds=shearing_bounds,
+                                                  nonlin_std=nonlin_std,
+                                                  nonlin_scale=nonlin_scale,
+                                                  prob_erosion_dilation=prob_erosion_dilation,
+                                                  min_erosion_dilation=min_erosion_dilation,
+                                                  max_erosion_dilation=max_erosion_dilation)
+    unet_input_shape = augmentation_model.output[0].get_shape().as_list()[1:]
+
+    # prepare the segmentation model
+    l2l_model = nrn_models.unet(input_model=augmentation_model,
+                                input_shape=unet_input_shape,
+                                nb_labels=n_labels,
+                                nb_levels=n_levels,
+                                nb_conv_per_level=nb_conv_per_level,
+                                conv_size=conv_size,
+                                nb_features=unet_feat_count,
+                                feat_mult=feat_multiplier,
+                                activation=activation,
+                                batch_norm=-1,
+                                skip_n_concatenations=skip_n_concatenations,
+                                name='l2l')
+
+    # input generator
+    model_inputs = build_model_inputs(path_inputs=list_paths_input_labels,
+                                      path_outputs=list_paths_target_labels,
+                                      batchsize=batchsize,
+                                      subjects_prob=subjects_prob,
+                                      dtype_input='int32')
+    input_generator = utils.build_training_generator(model_inputs, batchsize)
+
+    # pre-training with weighted L2, input is fit to the softmax rather than the probabilities
+    if wl2_epochs > 0:
+        wl2_model = models.Model(l2l_model.inputs, [l2l_model.get_layer('l2l_likelihood').output])
+        wl2_model = metrics.metrics_model(wl2_model, target_label_list, 'wl2')
+        train_model(wl2_model, input_generator, lr, wl2_epochs, steps_per_epoch, model_dir, 'wl2', checkpoint)
+        checkpoint = os.path.join(model_dir, 'wl2_%03d.h5' % wl2_epochs)
+
+    # fine-tuning with dice metric
+    dice_model = metrics.metrics_model(l2l_model, target_label_list, 'dice')
+    train_model(dice_model, input_generator, lr, dice_epochs, steps_per_epoch, model_dir, 'dice', checkpoint)
+
+
+def build_augmentation_model(labels_shape,
+                             segmentation_labels,
+                             crop_shape=None,
+                             output_div_by_n=None,
+                             scaling_bounds=0.15,
+                             rotation_bounds=15,
+                             shearing_bounds=0.012,
+                             translation_bounds=False,
+                             nonlin_std=3.,
+                             nonlin_scale=.0625,
+                             prob_erosion_dilation=0.3,
+                             min_erosion_dilation=4,
+                             max_erosion_dilation=7):
+
+    # reformat resolutions and get shapes
+    labels_shape = utils.reformat_to_list(labels_shape)
+    n_dims, _ = utils.get_dims(labels_shape)
+    n_labels = len(segmentation_labels)
+
+    # get shapes
+    crop_shape, _ = get_shapes(labels_shape, crop_shape, np.array([1]*n_dims), np.array([1]*n_dims), output_div_by_n)
+
+    # define model inputs
+    net_input = KL.Input(shape=labels_shape + [1], name='l2l_noisy_labels_input', dtype='int32')
+    target_input = KL.Input(shape=labels_shape + [1], name='l2l_target_input', dtype='int32')
+
+    # deform labels
+    noisy_labels, target = layers.RandomSpatialDeformation(scaling_bounds=scaling_bounds,
+                                                           rotation_bounds=rotation_bounds,
+                                                           shearing_bounds=shearing_bounds,
+                                                           translation_bounds=translation_bounds,
+                                                           nonlin_std=nonlin_std,
+                                                           nonlin_scale=nonlin_scale,
+                                                           inter_method='nearest')([net_input, target_input])
+
+    # cropping
+    if crop_shape != labels_shape:
+        noisy_labels, target = layers.RandomCrop(crop_shape)([noisy_labels, target])
+
+    # random erosion
+    if prob_erosion_dilation > 0:
+        noisy_labels = layers.RandomDilationErosion(min_erosion_dilation,
+                                                    max_erosion_dilation,
+                                                    prob=prob_erosion_dilation)(noisy_labels)
+
+    # convert input labels (i.e. noisy_labels) to [0, ... N-1] and make them one-hot
+    noisy_labels = layers.ConvertLabels(np.unique(segmentation_labels))(noisy_labels)
+    target = KL.Lambda(lambda x: tf.cast(x[..., 0], 'int32'), name='labels_out')(target)
+    noisy_labels = KL.Lambda(lambda x: tf.one_hot(x[0][..., 0], depth=n_labels),
+                             name='noisy_labels_out')([noisy_labels, target])
+
+    # build model and return
+    brain_model = models.Model(inputs=[net_input, target_input], outputs=[noisy_labels, target])
+    return brain_model