--- a +++ b/SynthSeg/training_denoiser.py @@ -0,0 +1,256 @@ +""" +If you use this code, please cite one of the SynthSeg papers: +https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + + +# python imports +import os +import numpy as np +import tensorflow as tf +from keras import models +from keras import layers as KL + +# project imports +from SynthSeg import metrics_model as metrics +from SynthSeg.training import train_model +from SynthSeg.labels_to_image_model import get_shapes +from SynthSeg.training_supervised import build_model_inputs + +# third-party imports +from ext.lab2im import utils, layers +from ext.neuron import models as nrn_models + + +def training(list_paths_input_labels, + list_paths_target_labels, + model_dir, + input_segmentation_labels, + target_segmentation_labels=None, + subjects_prob=None, + batchsize=1, + output_shape=None, + scaling_bounds=.2, + rotation_bounds=15, + shearing_bounds=.012, + nonlin_std=3., + nonlin_scale=.04, + prob_erosion_dilation=0.3, + min_erosion_dilation=4, + max_erosion_dilation=5, + n_levels=5, + nb_conv_per_level=2, + conv_size=5, + unet_feat_count=16, + feat_multiplier=2, + activation='elu', + skip_n_concatenations=2, + lr=1e-4, + wl2_epochs=1, + dice_epochs=50, + steps_per_epoch=10000, + checkpoint=None): + """ + + This function trains a UNet to segment MRI images with synthetic scans generated by sampling a GMM conditioned on + label maps. We regroup the parameters in four categories: General, Augmentation, Architecture, Training. + + # IMPORTANT !!! + # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence), + # these values refer to the RAS axes. + + :param list_paths_input_labels: list of all the paths of the input label maps. These correspond to "noisy" + segmentations that the denoiser will be trained to correct. + :param list_paths_target_labels: list of all the paths of the output label maps. Must have the same order as + list_paths_input_labels. These are the target label maps that the network will learn to produce given the "noisy" + input label maps. + :param model_dir: path of a directory where the models will be saved during training. + :param input_segmentation_labels: list of all the label values present in the input label maps. + :param target_segmentation_labels: list of all the label values present in the output label maps. By default (None) + this will be taken to be the same as input_segmentation_labels. + + # ----------------------------------------------- General parameters ----------------------------------------------- + # label maps parameters + :param subjects_prob: (optional) relative order of importance (doesn't have to be probabilistic), with which to pick + the provided label maps at each minibatch. Can be a sequence, a 1D numpy array, or the path to such an array, and it + must be as long as path_label_maps. By default, all label maps are chosen with the same importance. + + # output-related parameters + :param batchsize: (optional) number of images to generate per mini-batch. Default is 1. + :param output_shape: (optional) desired shape of the output image, obtained by randomly cropping the generated image + Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array. + Default is None, where no cropping is performed. + + # --------------------------------------------- Augmentation parameters -------------------------------------------- + # spatial deformation parameters + :param scaling_bounds: (optional) if apply_linear_trans is True, the scaling factor for each dimension is + sampled from a uniform distribution of predefined bounds. Can either be: + 1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds + (1-scaling_bounds, 1+scaling_bounds) for each dimension. + 2) the path to a numpy array of shape (2, n_dims), in which case the scaling factor in dimension i is sampled from + the uniform distribution of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension. + 3) False, in which case scaling is completely turned off. + Default is scaling_bounds = 0.2 (case 1) + :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for case 1 the + bounds are centred on 0 rather than 1, i.e. (0+rotation_bounds[i], 0-rotation_bounds[i]). + Default is rotation_bounds = 15. + :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012. + :param nonlin_std: (optional) Standard deviation of the normal distribution from which we sample the first + tensor for synthesising the deformation field. Set to 0 to completely deactivate elastic deformation. + :param nonlin_scale: (optional) Ratio between the size of the input label maps and the size of the sampled + tensor for synthesising the elastic deformation field. + + # degradation of the input labels + :param prob_erosion_dilation: (optional) probability with which to degrade the input label maps with erosion or + dilation. If 0, then no erosion/dilation is applied to the label maps given as inputs to the network. + :param min_erosion_dilation: (optional) when prob_erosion_dilation is not zero, erosion and dilation of random + coefficients are applied. Set the minimum erosion/dilation coefficient here. + :param max_erosion_dilation: (optional) Set the maximum erosion/dilation coefficient here. + + # ------------------------------------------ UNet architecture parameters ------------------------------------------ + :param n_levels: (optional) number of level for the Unet. Default is 5. + :param nb_conv_per_level: (optional) number of convolutional layers per level. Default is 2. + :param conv_size: (optional) size of the convolution kernels. Default is 2. + :param unet_feat_count: (optional) number of feature for the first layer of the UNet. Default is 24. + :param feat_multiplier: (optional) multiply the number of feature by this number at each new level. Default is 2. + :param activation: (optional) activation function. Can be 'elu', 'relu'. + :param skip_n_concatenations: (optional) number of levels for which to remove the traditional skip connections of + the UNet architecture. default is zero, which corresponds to the classic UNet architecture. Example: + If skip_n_concatenations = 2, then we will remove the concatenation link between the two top levels of the UNet. + + # ----------------------------------------------- Training parameters ---------------------------------------------- + :param lr: (optional) learning rate for the training. Default is 1e-4 + :param wl2_epochs: (optional) number of epochs for which the network (except the soft-max layer) is trained with L2 + norm loss function. Default is 1. + :param dice_epochs: (optional) number of epochs with the soft Dice loss function. Default is 50. + :param steps_per_epoch: (optional) number of steps per epoch. Default is 10000. Since no online validation is + possible, this is equivalent to the frequency at which the models are saved. + :param checkpoint: (optional) path of an already saved model to load before starting the training. + """ + + # check epochs + assert (wl2_epochs > 0) | (dice_epochs > 0), \ + 'either wl2_epochs or dice_epochs must be positive, had {0} and {1}'.format(wl2_epochs, dice_epochs) + + # prepare data files + input_label_list, _ = utils.get_list_labels(label_list=input_segmentation_labels) + if target_segmentation_labels is None: + target_label_list = input_label_list + else: + target_label_list, _ = utils.get_list_labels(label_list=target_segmentation_labels) + n_labels = np.size(target_label_list) + + # create augmentation model + labels_shape, _, _, _, _, _ = utils.get_volume_info(list_paths_input_labels[0], aff_ref=np.eye(4)) + augmentation_model = build_augmentation_model(labels_shape, + input_label_list, + crop_shape=output_shape, + output_div_by_n=2 ** n_levels, + scaling_bounds=scaling_bounds, + rotation_bounds=rotation_bounds, + shearing_bounds=shearing_bounds, + nonlin_std=nonlin_std, + nonlin_scale=nonlin_scale, + prob_erosion_dilation=prob_erosion_dilation, + min_erosion_dilation=min_erosion_dilation, + max_erosion_dilation=max_erosion_dilation) + unet_input_shape = augmentation_model.output[0].get_shape().as_list()[1:] + + # prepare the segmentation model + l2l_model = nrn_models.unet(input_model=augmentation_model, + input_shape=unet_input_shape, + nb_labels=n_labels, + nb_levels=n_levels, + nb_conv_per_level=nb_conv_per_level, + conv_size=conv_size, + nb_features=unet_feat_count, + feat_mult=feat_multiplier, + activation=activation, + batch_norm=-1, + skip_n_concatenations=skip_n_concatenations, + name='l2l') + + # input generator + model_inputs = build_model_inputs(path_inputs=list_paths_input_labels, + path_outputs=list_paths_target_labels, + batchsize=batchsize, + subjects_prob=subjects_prob, + dtype_input='int32') + input_generator = utils.build_training_generator(model_inputs, batchsize) + + # pre-training with weighted L2, input is fit to the softmax rather than the probabilities + if wl2_epochs > 0: + wl2_model = models.Model(l2l_model.inputs, [l2l_model.get_layer('l2l_likelihood').output]) + wl2_model = metrics.metrics_model(wl2_model, target_label_list, 'wl2') + train_model(wl2_model, input_generator, lr, wl2_epochs, steps_per_epoch, model_dir, 'wl2', checkpoint) + checkpoint = os.path.join(model_dir, 'wl2_%03d.h5' % wl2_epochs) + + # fine-tuning with dice metric + dice_model = metrics.metrics_model(l2l_model, target_label_list, 'dice') + train_model(dice_model, input_generator, lr, dice_epochs, steps_per_epoch, model_dir, 'dice', checkpoint) + + +def build_augmentation_model(labels_shape, + segmentation_labels, + crop_shape=None, + output_div_by_n=None, + scaling_bounds=0.15, + rotation_bounds=15, + shearing_bounds=0.012, + translation_bounds=False, + nonlin_std=3., + nonlin_scale=.0625, + prob_erosion_dilation=0.3, + min_erosion_dilation=4, + max_erosion_dilation=7): + + # reformat resolutions and get shapes + labels_shape = utils.reformat_to_list(labels_shape) + n_dims, _ = utils.get_dims(labels_shape) + n_labels = len(segmentation_labels) + + # get shapes + crop_shape, _ = get_shapes(labels_shape, crop_shape, np.array([1]*n_dims), np.array([1]*n_dims), output_div_by_n) + + # define model inputs + net_input = KL.Input(shape=labels_shape + [1], name='l2l_noisy_labels_input', dtype='int32') + target_input = KL.Input(shape=labels_shape + [1], name='l2l_target_input', dtype='int32') + + # deform labels + noisy_labels, target = layers.RandomSpatialDeformation(scaling_bounds=scaling_bounds, + rotation_bounds=rotation_bounds, + shearing_bounds=shearing_bounds, + translation_bounds=translation_bounds, + nonlin_std=nonlin_std, + nonlin_scale=nonlin_scale, + inter_method='nearest')([net_input, target_input]) + + # cropping + if crop_shape != labels_shape: + noisy_labels, target = layers.RandomCrop(crop_shape)([noisy_labels, target]) + + # random erosion + if prob_erosion_dilation > 0: + noisy_labels = layers.RandomDilationErosion(min_erosion_dilation, + max_erosion_dilation, + prob=prob_erosion_dilation)(noisy_labels) + + # convert input labels (i.e. noisy_labels) to [0, ... N-1] and make them one-hot + noisy_labels = layers.ConvertLabels(np.unique(segmentation_labels))(noisy_labels) + target = KL.Lambda(lambda x: tf.cast(x[..., 0], 'int32'), name='labels_out')(target) + noisy_labels = KL.Lambda(lambda x: tf.one_hot(x[0][..., 0], depth=n_labels), + name='noisy_labels_out')([noisy_labels, target]) + + # build model and return + brain_model = models.Model(inputs=[net_input, target_input], outputs=[noisy_labels, target]) + return brain_model