Switch to unified view

a b/SynthSeg/training_denoiser.py
1
"""
2
If you use this code, please cite one of the SynthSeg papers:
3
https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
4
5
Copyright 2020 Benjamin Billot
6
7
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
8
compliance with the License. You may obtain a copy of the License at
9
https://www.apache.org/licenses/LICENSE-2.0
10
Unless required by applicable law or agreed to in writing, software distributed under the License is
11
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12
implied. See the License for the specific language governing permissions and limitations under the
13
License.
14
"""
15
16
17
# python imports
18
import os
19
import numpy as np
20
import tensorflow as tf
21
from keras import models
22
from keras import layers as KL
23
24
# project imports
25
from SynthSeg import metrics_model as metrics
26
from SynthSeg.training import train_model
27
from SynthSeg.labels_to_image_model import get_shapes
28
from SynthSeg.training_supervised import build_model_inputs
29
30
# third-party imports
31
from ext.lab2im import utils, layers
32
from ext.neuron import models as nrn_models
33
34
35
def training(list_paths_input_labels,
36
             list_paths_target_labels,
37
             model_dir,
38
             input_segmentation_labels,
39
             target_segmentation_labels=None,
40
             subjects_prob=None,
41
             batchsize=1,
42
             output_shape=None,
43
             scaling_bounds=.2,
44
             rotation_bounds=15,
45
             shearing_bounds=.012,
46
             nonlin_std=3.,
47
             nonlin_scale=.04,
48
             prob_erosion_dilation=0.3,
49
             min_erosion_dilation=4,
50
             max_erosion_dilation=5,
51
             n_levels=5,
52
             nb_conv_per_level=2,
53
             conv_size=5,
54
             unet_feat_count=16,
55
             feat_multiplier=2,
56
             activation='elu',
57
             skip_n_concatenations=2,
58
             lr=1e-4,
59
             wl2_epochs=1,
60
             dice_epochs=50,
61
             steps_per_epoch=10000,
62
             checkpoint=None):
63
    """
64
65
    This function trains a UNet to segment MRI images with synthetic scans generated by sampling a GMM conditioned on
66
    label maps. We regroup the parameters in four categories: General, Augmentation, Architecture, Training.
67
68
    # IMPORTANT !!!
69
    # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence),
70
    # these values refer to the RAS axes.
71
72
    :param list_paths_input_labels: list of all the paths of the input label maps. These correspond to "noisy"
73
    segmentations that the denoiser will be trained to correct.
74
    :param list_paths_target_labels: list of all the paths of the output label maps. Must have the same order as
75
    list_paths_input_labels. These are the target label maps that the network will learn to produce given the "noisy"
76
    input label maps.
77
    :param model_dir: path of a directory where the models will be saved during training.
78
    :param input_segmentation_labels: list of all the label values present in the input label maps.
79
    :param target_segmentation_labels: list of all the label values present in the output label maps. By default (None)
80
    this will be taken to be the same as input_segmentation_labels.
81
82
    # ----------------------------------------------- General parameters -----------------------------------------------
83
    # label maps parameters
84
    :param subjects_prob: (optional) relative order of importance (doesn't have to be probabilistic), with which to pick
85
    the provided label maps at each minibatch. Can be a sequence, a 1D numpy array, or the path to such an array, and it
86
    must be as long as path_label_maps. By default, all label maps are chosen with the same importance.
87
88
    # output-related parameters
89
    :param batchsize: (optional) number of images to generate per mini-batch. Default is 1.
90
    :param output_shape: (optional) desired shape of the output image, obtained by randomly cropping the generated image
91
    Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array.
92
    Default is None, where no cropping is performed.
93
94
    # --------------------------------------------- Augmentation parameters --------------------------------------------
95
    # spatial deformation parameters
96
    :param scaling_bounds: (optional) if apply_linear_trans is True, the scaling factor for each dimension is
97
    sampled from a uniform distribution of predefined bounds. Can either be:
98
    1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds
99
    (1-scaling_bounds, 1+scaling_bounds) for each dimension.
100
    2) the path to a numpy array of shape (2, n_dims), in which case the scaling factor in dimension i is sampled from
101
    the uniform distribution of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension.
102
    3) False, in which case scaling is completely turned off.
103
    Default is scaling_bounds = 0.2 (case 1)
104
    :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for case 1 the
105
    bounds are centred on 0 rather than 1, i.e. (0+rotation_bounds[i], 0-rotation_bounds[i]).
106
    Default is rotation_bounds = 15.
107
    :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012.
108
    :param nonlin_std: (optional) Standard deviation of the normal distribution from which we sample the first
109
    tensor for synthesising the deformation field. Set to 0 to completely deactivate elastic deformation.
110
    :param nonlin_scale: (optional) Ratio between the size of the input label maps and the size of the sampled
111
    tensor for synthesising the elastic deformation field.
112
    
113
    # degradation of the input labels
114
    :param prob_erosion_dilation: (optional) probability with which to degrade the input label maps with erosion or 
115
    dilation. If 0, then no erosion/dilation is applied to the label maps given as inputs to the network.
116
    :param min_erosion_dilation: (optional) when prob_erosion_dilation is not zero, erosion and dilation of random
117
    coefficients are applied. Set the minimum erosion/dilation coefficient here.
118
    :param max_erosion_dilation: (optional) Set the maximum erosion/dilation coefficient here.
119
120
    # ------------------------------------------ UNet architecture parameters ------------------------------------------
121
    :param n_levels: (optional) number of level for the Unet. Default is 5.
122
    :param nb_conv_per_level: (optional) number of convolutional layers per level. Default is 2.
123
    :param conv_size: (optional) size of the convolution kernels. Default is 2.
124
    :param unet_feat_count: (optional) number of feature for the first layer of the UNet. Default is 24.
125
    :param feat_multiplier: (optional) multiply the number of feature by this number at each new level. Default is 2.
126
    :param activation: (optional) activation function. Can be 'elu', 'relu'.
127
    :param skip_n_concatenations: (optional) number of levels for which to remove the traditional skip connections of
128
    the UNet architecture. default is zero, which corresponds to the classic UNet architecture. Example:
129
    If skip_n_concatenations = 2, then we will remove the concatenation link between the two top levels of the UNet.
130
131
    # ----------------------------------------------- Training parameters ----------------------------------------------
132
    :param lr: (optional) learning rate for the training. Default is 1e-4
133
    :param wl2_epochs: (optional) number of epochs for which the network (except the soft-max layer) is trained with L2
134
    norm loss function. Default is 1.
135
    :param dice_epochs: (optional) number of epochs with the soft Dice loss function. Default is 50.
136
    :param steps_per_epoch: (optional) number of steps per epoch. Default is 10000. Since no online validation is
137
    possible, this is equivalent to the frequency at which the models are saved.
138
    :param checkpoint: (optional) path of an already saved model to load before starting the training.
139
    """
140
141
    # check epochs
142
    assert (wl2_epochs > 0) | (dice_epochs > 0), \
143
        'either wl2_epochs or dice_epochs must be positive, had {0} and {1}'.format(wl2_epochs, dice_epochs)
144
145
    # prepare data files
146
    input_label_list, _ = utils.get_list_labels(label_list=input_segmentation_labels)
147
    if target_segmentation_labels is None:
148
        target_label_list = input_label_list
149
    else:
150
        target_label_list, _ = utils.get_list_labels(label_list=target_segmentation_labels)
151
    n_labels = np.size(target_label_list)
152
153
    # create augmentation model
154
    labels_shape, _, _, _, _, _ = utils.get_volume_info(list_paths_input_labels[0], aff_ref=np.eye(4))
155
    augmentation_model = build_augmentation_model(labels_shape,
156
                                                  input_label_list,
157
                                                  crop_shape=output_shape,
158
                                                  output_div_by_n=2 ** n_levels,
159
                                                  scaling_bounds=scaling_bounds,
160
                                                  rotation_bounds=rotation_bounds,
161
                                                  shearing_bounds=shearing_bounds,
162
                                                  nonlin_std=nonlin_std,
163
                                                  nonlin_scale=nonlin_scale,
164
                                                  prob_erosion_dilation=prob_erosion_dilation,
165
                                                  min_erosion_dilation=min_erosion_dilation,
166
                                                  max_erosion_dilation=max_erosion_dilation)
167
    unet_input_shape = augmentation_model.output[0].get_shape().as_list()[1:]
168
169
    # prepare the segmentation model
170
    l2l_model = nrn_models.unet(input_model=augmentation_model,
171
                                input_shape=unet_input_shape,
172
                                nb_labels=n_labels,
173
                                nb_levels=n_levels,
174
                                nb_conv_per_level=nb_conv_per_level,
175
                                conv_size=conv_size,
176
                                nb_features=unet_feat_count,
177
                                feat_mult=feat_multiplier,
178
                                activation=activation,
179
                                batch_norm=-1,
180
                                skip_n_concatenations=skip_n_concatenations,
181
                                name='l2l')
182
183
    # input generator
184
    model_inputs = build_model_inputs(path_inputs=list_paths_input_labels,
185
                                      path_outputs=list_paths_target_labels,
186
                                      batchsize=batchsize,
187
                                      subjects_prob=subjects_prob,
188
                                      dtype_input='int32')
189
    input_generator = utils.build_training_generator(model_inputs, batchsize)
190
191
    # pre-training with weighted L2, input is fit to the softmax rather than the probabilities
192
    if wl2_epochs > 0:
193
        wl2_model = models.Model(l2l_model.inputs, [l2l_model.get_layer('l2l_likelihood').output])
194
        wl2_model = metrics.metrics_model(wl2_model, target_label_list, 'wl2')
195
        train_model(wl2_model, input_generator, lr, wl2_epochs, steps_per_epoch, model_dir, 'wl2', checkpoint)
196
        checkpoint = os.path.join(model_dir, 'wl2_%03d.h5' % wl2_epochs)
197
198
    # fine-tuning with dice metric
199
    dice_model = metrics.metrics_model(l2l_model, target_label_list, 'dice')
200
    train_model(dice_model, input_generator, lr, dice_epochs, steps_per_epoch, model_dir, 'dice', checkpoint)
201
202
203
def build_augmentation_model(labels_shape,
204
                             segmentation_labels,
205
                             crop_shape=None,
206
                             output_div_by_n=None,
207
                             scaling_bounds=0.15,
208
                             rotation_bounds=15,
209
                             shearing_bounds=0.012,
210
                             translation_bounds=False,
211
                             nonlin_std=3.,
212
                             nonlin_scale=.0625,
213
                             prob_erosion_dilation=0.3,
214
                             min_erosion_dilation=4,
215
                             max_erosion_dilation=7):
216
217
    # reformat resolutions and get shapes
218
    labels_shape = utils.reformat_to_list(labels_shape)
219
    n_dims, _ = utils.get_dims(labels_shape)
220
    n_labels = len(segmentation_labels)
221
222
    # get shapes
223
    crop_shape, _ = get_shapes(labels_shape, crop_shape, np.array([1]*n_dims), np.array([1]*n_dims), output_div_by_n)
224
225
    # define model inputs
226
    net_input = KL.Input(shape=labels_shape + [1], name='l2l_noisy_labels_input', dtype='int32')
227
    target_input = KL.Input(shape=labels_shape + [1], name='l2l_target_input', dtype='int32')
228
229
    # deform labels
230
    noisy_labels, target = layers.RandomSpatialDeformation(scaling_bounds=scaling_bounds,
231
                                                           rotation_bounds=rotation_bounds,
232
                                                           shearing_bounds=shearing_bounds,
233
                                                           translation_bounds=translation_bounds,
234
                                                           nonlin_std=nonlin_std,
235
                                                           nonlin_scale=nonlin_scale,
236
                                                           inter_method='nearest')([net_input, target_input])
237
238
    # cropping
239
    if crop_shape != labels_shape:
240
        noisy_labels, target = layers.RandomCrop(crop_shape)([noisy_labels, target])
241
242
    # random erosion
243
    if prob_erosion_dilation > 0:
244
        noisy_labels = layers.RandomDilationErosion(min_erosion_dilation,
245
                                                    max_erosion_dilation,
246
                                                    prob=prob_erosion_dilation)(noisy_labels)
247
248
    # convert input labels (i.e. noisy_labels) to [0, ... N-1] and make them one-hot
249
    noisy_labels = layers.ConvertLabels(np.unique(segmentation_labels))(noisy_labels)
250
    target = KL.Lambda(lambda x: tf.cast(x[..., 0], 'int32'), name='labels_out')(target)
251
    noisy_labels = KL.Lambda(lambda x: tf.one_hot(x[0][..., 0], depth=n_labels),
252
                             name='noisy_labels_out')([noisy_labels, target])
253
254
    # build model and return
255
    brain_model = models.Model(inputs=[net_input, target_input], outputs=[noisy_labels, target])
256
    return brain_model