|
a |
|
b/SynthSeg/training_denoiser.py |
|
|
1 |
""" |
|
|
2 |
If you use this code, please cite one of the SynthSeg papers: |
|
|
3 |
https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib |
|
|
4 |
|
|
|
5 |
Copyright 2020 Benjamin Billot |
|
|
6 |
|
|
|
7 |
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in |
|
|
8 |
compliance with the License. You may obtain a copy of the License at |
|
|
9 |
https://www.apache.org/licenses/LICENSE-2.0 |
|
|
10 |
Unless required by applicable law or agreed to in writing, software distributed under the License is |
|
|
11 |
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or |
|
|
12 |
implied. See the License for the specific language governing permissions and limitations under the |
|
|
13 |
License. |
|
|
14 |
""" |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
# python imports |
|
|
18 |
import os |
|
|
19 |
import numpy as np |
|
|
20 |
import tensorflow as tf |
|
|
21 |
from keras import models |
|
|
22 |
from keras import layers as KL |
|
|
23 |
|
|
|
24 |
# project imports |
|
|
25 |
from SynthSeg import metrics_model as metrics |
|
|
26 |
from SynthSeg.training import train_model |
|
|
27 |
from SynthSeg.labels_to_image_model import get_shapes |
|
|
28 |
from SynthSeg.training_supervised import build_model_inputs |
|
|
29 |
|
|
|
30 |
# third-party imports |
|
|
31 |
from ext.lab2im import utils, layers |
|
|
32 |
from ext.neuron import models as nrn_models |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
def training(list_paths_input_labels, |
|
|
36 |
list_paths_target_labels, |
|
|
37 |
model_dir, |
|
|
38 |
input_segmentation_labels, |
|
|
39 |
target_segmentation_labels=None, |
|
|
40 |
subjects_prob=None, |
|
|
41 |
batchsize=1, |
|
|
42 |
output_shape=None, |
|
|
43 |
scaling_bounds=.2, |
|
|
44 |
rotation_bounds=15, |
|
|
45 |
shearing_bounds=.012, |
|
|
46 |
nonlin_std=3., |
|
|
47 |
nonlin_scale=.04, |
|
|
48 |
prob_erosion_dilation=0.3, |
|
|
49 |
min_erosion_dilation=4, |
|
|
50 |
max_erosion_dilation=5, |
|
|
51 |
n_levels=5, |
|
|
52 |
nb_conv_per_level=2, |
|
|
53 |
conv_size=5, |
|
|
54 |
unet_feat_count=16, |
|
|
55 |
feat_multiplier=2, |
|
|
56 |
activation='elu', |
|
|
57 |
skip_n_concatenations=2, |
|
|
58 |
lr=1e-4, |
|
|
59 |
wl2_epochs=1, |
|
|
60 |
dice_epochs=50, |
|
|
61 |
steps_per_epoch=10000, |
|
|
62 |
checkpoint=None): |
|
|
63 |
""" |
|
|
64 |
|
|
|
65 |
This function trains a UNet to segment MRI images with synthetic scans generated by sampling a GMM conditioned on |
|
|
66 |
label maps. We regroup the parameters in four categories: General, Augmentation, Architecture, Training. |
|
|
67 |
|
|
|
68 |
# IMPORTANT !!! |
|
|
69 |
# Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence), |
|
|
70 |
# these values refer to the RAS axes. |
|
|
71 |
|
|
|
72 |
:param list_paths_input_labels: list of all the paths of the input label maps. These correspond to "noisy" |
|
|
73 |
segmentations that the denoiser will be trained to correct. |
|
|
74 |
:param list_paths_target_labels: list of all the paths of the output label maps. Must have the same order as |
|
|
75 |
list_paths_input_labels. These are the target label maps that the network will learn to produce given the "noisy" |
|
|
76 |
input label maps. |
|
|
77 |
:param model_dir: path of a directory where the models will be saved during training. |
|
|
78 |
:param input_segmentation_labels: list of all the label values present in the input label maps. |
|
|
79 |
:param target_segmentation_labels: list of all the label values present in the output label maps. By default (None) |
|
|
80 |
this will be taken to be the same as input_segmentation_labels. |
|
|
81 |
|
|
|
82 |
# ----------------------------------------------- General parameters ----------------------------------------------- |
|
|
83 |
# label maps parameters |
|
|
84 |
:param subjects_prob: (optional) relative order of importance (doesn't have to be probabilistic), with which to pick |
|
|
85 |
the provided label maps at each minibatch. Can be a sequence, a 1D numpy array, or the path to such an array, and it |
|
|
86 |
must be as long as path_label_maps. By default, all label maps are chosen with the same importance. |
|
|
87 |
|
|
|
88 |
# output-related parameters |
|
|
89 |
:param batchsize: (optional) number of images to generate per mini-batch. Default is 1. |
|
|
90 |
:param output_shape: (optional) desired shape of the output image, obtained by randomly cropping the generated image |
|
|
91 |
Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array. |
|
|
92 |
Default is None, where no cropping is performed. |
|
|
93 |
|
|
|
94 |
# --------------------------------------------- Augmentation parameters -------------------------------------------- |
|
|
95 |
# spatial deformation parameters |
|
|
96 |
:param scaling_bounds: (optional) if apply_linear_trans is True, the scaling factor for each dimension is |
|
|
97 |
sampled from a uniform distribution of predefined bounds. Can either be: |
|
|
98 |
1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds |
|
|
99 |
(1-scaling_bounds, 1+scaling_bounds) for each dimension. |
|
|
100 |
2) the path to a numpy array of shape (2, n_dims), in which case the scaling factor in dimension i is sampled from |
|
|
101 |
the uniform distribution of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension. |
|
|
102 |
3) False, in which case scaling is completely turned off. |
|
|
103 |
Default is scaling_bounds = 0.2 (case 1) |
|
|
104 |
:param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for case 1 the |
|
|
105 |
bounds are centred on 0 rather than 1, i.e. (0+rotation_bounds[i], 0-rotation_bounds[i]). |
|
|
106 |
Default is rotation_bounds = 15. |
|
|
107 |
:param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012. |
|
|
108 |
:param nonlin_std: (optional) Standard deviation of the normal distribution from which we sample the first |
|
|
109 |
tensor for synthesising the deformation field. Set to 0 to completely deactivate elastic deformation. |
|
|
110 |
:param nonlin_scale: (optional) Ratio between the size of the input label maps and the size of the sampled |
|
|
111 |
tensor for synthesising the elastic deformation field. |
|
|
112 |
|
|
|
113 |
# degradation of the input labels |
|
|
114 |
:param prob_erosion_dilation: (optional) probability with which to degrade the input label maps with erosion or |
|
|
115 |
dilation. If 0, then no erosion/dilation is applied to the label maps given as inputs to the network. |
|
|
116 |
:param min_erosion_dilation: (optional) when prob_erosion_dilation is not zero, erosion and dilation of random |
|
|
117 |
coefficients are applied. Set the minimum erosion/dilation coefficient here. |
|
|
118 |
:param max_erosion_dilation: (optional) Set the maximum erosion/dilation coefficient here. |
|
|
119 |
|
|
|
120 |
# ------------------------------------------ UNet architecture parameters ------------------------------------------ |
|
|
121 |
:param n_levels: (optional) number of level for the Unet. Default is 5. |
|
|
122 |
:param nb_conv_per_level: (optional) number of convolutional layers per level. Default is 2. |
|
|
123 |
:param conv_size: (optional) size of the convolution kernels. Default is 2. |
|
|
124 |
:param unet_feat_count: (optional) number of feature for the first layer of the UNet. Default is 24. |
|
|
125 |
:param feat_multiplier: (optional) multiply the number of feature by this number at each new level. Default is 2. |
|
|
126 |
:param activation: (optional) activation function. Can be 'elu', 'relu'. |
|
|
127 |
:param skip_n_concatenations: (optional) number of levels for which to remove the traditional skip connections of |
|
|
128 |
the UNet architecture. default is zero, which corresponds to the classic UNet architecture. Example: |
|
|
129 |
If skip_n_concatenations = 2, then we will remove the concatenation link between the two top levels of the UNet. |
|
|
130 |
|
|
|
131 |
# ----------------------------------------------- Training parameters ---------------------------------------------- |
|
|
132 |
:param lr: (optional) learning rate for the training. Default is 1e-4 |
|
|
133 |
:param wl2_epochs: (optional) number of epochs for which the network (except the soft-max layer) is trained with L2 |
|
|
134 |
norm loss function. Default is 1. |
|
|
135 |
:param dice_epochs: (optional) number of epochs with the soft Dice loss function. Default is 50. |
|
|
136 |
:param steps_per_epoch: (optional) number of steps per epoch. Default is 10000. Since no online validation is |
|
|
137 |
possible, this is equivalent to the frequency at which the models are saved. |
|
|
138 |
:param checkpoint: (optional) path of an already saved model to load before starting the training. |
|
|
139 |
""" |
|
|
140 |
|
|
|
141 |
# check epochs |
|
|
142 |
assert (wl2_epochs > 0) | (dice_epochs > 0), \ |
|
|
143 |
'either wl2_epochs or dice_epochs must be positive, had {0} and {1}'.format(wl2_epochs, dice_epochs) |
|
|
144 |
|
|
|
145 |
# prepare data files |
|
|
146 |
input_label_list, _ = utils.get_list_labels(label_list=input_segmentation_labels) |
|
|
147 |
if target_segmentation_labels is None: |
|
|
148 |
target_label_list = input_label_list |
|
|
149 |
else: |
|
|
150 |
target_label_list, _ = utils.get_list_labels(label_list=target_segmentation_labels) |
|
|
151 |
n_labels = np.size(target_label_list) |
|
|
152 |
|
|
|
153 |
# create augmentation model |
|
|
154 |
labels_shape, _, _, _, _, _ = utils.get_volume_info(list_paths_input_labels[0], aff_ref=np.eye(4)) |
|
|
155 |
augmentation_model = build_augmentation_model(labels_shape, |
|
|
156 |
input_label_list, |
|
|
157 |
crop_shape=output_shape, |
|
|
158 |
output_div_by_n=2 ** n_levels, |
|
|
159 |
scaling_bounds=scaling_bounds, |
|
|
160 |
rotation_bounds=rotation_bounds, |
|
|
161 |
shearing_bounds=shearing_bounds, |
|
|
162 |
nonlin_std=nonlin_std, |
|
|
163 |
nonlin_scale=nonlin_scale, |
|
|
164 |
prob_erosion_dilation=prob_erosion_dilation, |
|
|
165 |
min_erosion_dilation=min_erosion_dilation, |
|
|
166 |
max_erosion_dilation=max_erosion_dilation) |
|
|
167 |
unet_input_shape = augmentation_model.output[0].get_shape().as_list()[1:] |
|
|
168 |
|
|
|
169 |
# prepare the segmentation model |
|
|
170 |
l2l_model = nrn_models.unet(input_model=augmentation_model, |
|
|
171 |
input_shape=unet_input_shape, |
|
|
172 |
nb_labels=n_labels, |
|
|
173 |
nb_levels=n_levels, |
|
|
174 |
nb_conv_per_level=nb_conv_per_level, |
|
|
175 |
conv_size=conv_size, |
|
|
176 |
nb_features=unet_feat_count, |
|
|
177 |
feat_mult=feat_multiplier, |
|
|
178 |
activation=activation, |
|
|
179 |
batch_norm=-1, |
|
|
180 |
skip_n_concatenations=skip_n_concatenations, |
|
|
181 |
name='l2l') |
|
|
182 |
|
|
|
183 |
# input generator |
|
|
184 |
model_inputs = build_model_inputs(path_inputs=list_paths_input_labels, |
|
|
185 |
path_outputs=list_paths_target_labels, |
|
|
186 |
batchsize=batchsize, |
|
|
187 |
subjects_prob=subjects_prob, |
|
|
188 |
dtype_input='int32') |
|
|
189 |
input_generator = utils.build_training_generator(model_inputs, batchsize) |
|
|
190 |
|
|
|
191 |
# pre-training with weighted L2, input is fit to the softmax rather than the probabilities |
|
|
192 |
if wl2_epochs > 0: |
|
|
193 |
wl2_model = models.Model(l2l_model.inputs, [l2l_model.get_layer('l2l_likelihood').output]) |
|
|
194 |
wl2_model = metrics.metrics_model(wl2_model, target_label_list, 'wl2') |
|
|
195 |
train_model(wl2_model, input_generator, lr, wl2_epochs, steps_per_epoch, model_dir, 'wl2', checkpoint) |
|
|
196 |
checkpoint = os.path.join(model_dir, 'wl2_%03d.h5' % wl2_epochs) |
|
|
197 |
|
|
|
198 |
# fine-tuning with dice metric |
|
|
199 |
dice_model = metrics.metrics_model(l2l_model, target_label_list, 'dice') |
|
|
200 |
train_model(dice_model, input_generator, lr, dice_epochs, steps_per_epoch, model_dir, 'dice', checkpoint) |
|
|
201 |
|
|
|
202 |
|
|
|
203 |
def build_augmentation_model(labels_shape, |
|
|
204 |
segmentation_labels, |
|
|
205 |
crop_shape=None, |
|
|
206 |
output_div_by_n=None, |
|
|
207 |
scaling_bounds=0.15, |
|
|
208 |
rotation_bounds=15, |
|
|
209 |
shearing_bounds=0.012, |
|
|
210 |
translation_bounds=False, |
|
|
211 |
nonlin_std=3., |
|
|
212 |
nonlin_scale=.0625, |
|
|
213 |
prob_erosion_dilation=0.3, |
|
|
214 |
min_erosion_dilation=4, |
|
|
215 |
max_erosion_dilation=7): |
|
|
216 |
|
|
|
217 |
# reformat resolutions and get shapes |
|
|
218 |
labels_shape = utils.reformat_to_list(labels_shape) |
|
|
219 |
n_dims, _ = utils.get_dims(labels_shape) |
|
|
220 |
n_labels = len(segmentation_labels) |
|
|
221 |
|
|
|
222 |
# get shapes |
|
|
223 |
crop_shape, _ = get_shapes(labels_shape, crop_shape, np.array([1]*n_dims), np.array([1]*n_dims), output_div_by_n) |
|
|
224 |
|
|
|
225 |
# define model inputs |
|
|
226 |
net_input = KL.Input(shape=labels_shape + [1], name='l2l_noisy_labels_input', dtype='int32') |
|
|
227 |
target_input = KL.Input(shape=labels_shape + [1], name='l2l_target_input', dtype='int32') |
|
|
228 |
|
|
|
229 |
# deform labels |
|
|
230 |
noisy_labels, target = layers.RandomSpatialDeformation(scaling_bounds=scaling_bounds, |
|
|
231 |
rotation_bounds=rotation_bounds, |
|
|
232 |
shearing_bounds=shearing_bounds, |
|
|
233 |
translation_bounds=translation_bounds, |
|
|
234 |
nonlin_std=nonlin_std, |
|
|
235 |
nonlin_scale=nonlin_scale, |
|
|
236 |
inter_method='nearest')([net_input, target_input]) |
|
|
237 |
|
|
|
238 |
# cropping |
|
|
239 |
if crop_shape != labels_shape: |
|
|
240 |
noisy_labels, target = layers.RandomCrop(crop_shape)([noisy_labels, target]) |
|
|
241 |
|
|
|
242 |
# random erosion |
|
|
243 |
if prob_erosion_dilation > 0: |
|
|
244 |
noisy_labels = layers.RandomDilationErosion(min_erosion_dilation, |
|
|
245 |
max_erosion_dilation, |
|
|
246 |
prob=prob_erosion_dilation)(noisy_labels) |
|
|
247 |
|
|
|
248 |
# convert input labels (i.e. noisy_labels) to [0, ... N-1] and make them one-hot |
|
|
249 |
noisy_labels = layers.ConvertLabels(np.unique(segmentation_labels))(noisy_labels) |
|
|
250 |
target = KL.Lambda(lambda x: tf.cast(x[..., 0], 'int32'), name='labels_out')(target) |
|
|
251 |
noisy_labels = KL.Lambda(lambda x: tf.one_hot(x[0][..., 0], depth=n_labels), |
|
|
252 |
name='noisy_labels_out')([noisy_labels, target]) |
|
|
253 |
|
|
|
254 |
# build model and return |
|
|
255 |
brain_model = models.Model(inputs=[net_input, target_input], outputs=[noisy_labels, target]) |
|
|
256 |
return brain_model |