a b/scripts/tutorials/7-synthseg+.py
1
"""
2
3
Very simple script to show how we trained SynthSeg+, which extends SynthSeg by building robustness to clinical
4
acquisitions. For more details, please look at our MICCAI 2022 paper:
5
6
Robust Segmentation of Brain MRI in the Wild with Hierarchical CNNs and no Retraining,
7
Billot, Magdamo, Das, Arnold, Iglesias
8
MICCAI 2022
9
10
If you use this code, please cite one of the SynthSeg papers:
11
https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
12
13
Copyright 2020 Benjamin Billot
14
15
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
16
compliance with the License. You may obtain a copy of the License at
17
https://www.apache.org/licenses/LICENSE-2.0
18
Unless required by applicable law or agreed to in writing, software distributed under the License is
19
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
20
implied. See the License for the specific language governing permissions and limitations under the
21
License.
22
"""
23
from SynthSeg.training import training as training_s1
24
from SynthSeg.training_denoiser import training as training_d
25
from SynthSeg.training_group import training as training_s2
26
27
import numpy as np
28
29
# ------------------ segmenter S1
30
# Here the purpose is to train a first network to produce preliminary segmentations of input scans with five general
31
# labels: 0-background, 1-white matter, 2-grey matter, 3-fluids, 4-cerebellum.
32
33
# As in tutorial 3, S1 is trained with synthetic images with randomised contrasts/resolution/artefacts such that it can
34
# readily segment a wide range of test scans without retraining. The synthetic scans are obtained from the same label
35
# maps and generative model as in the previous tutorials.
36
labels_dir_s1 = '../../data/training_label_maps'
37
path_generation_labels = '../../data/labels_classes_priors/generation_labels.npy'
38
path_generation_classes = '../../data/labels_classes_priors/generation_classes.npy'
39
# However, because we now wish to segment scans using only five labels, we use a different list of segmentation labels
40
# where all label values in generation_labels are assigned to a target value between [0, 4].
41
path_segmentation_labels_s1 = '../../data/tutorial_7/segmentation_labels_s1.npy'
42
43
model_dir_s1 = './outputs_tutorial_7/training_s1'  # folder where the models will be saved
44
45
46
training_s1(labels_dir=labels_dir_s1,
47
            model_dir=model_dir_s1,
48
            generation_labels=path_generation_labels,
49
            segmentation_labels=path_segmentation_labels_s1,
50
            n_neutral_labels=18,
51
            generation_classes=path_generation_classes,
52
            target_res=1,
53
            output_shape=160,
54
            prior_distributions='uniform',
55
            prior_means=[0, 255],
56
            prior_stds=[0, 50],
57
            randomise_res=True)
58
59
# ------------------ denoiser D
60
# The purpose of this network is to perform label-to-label correction in order to correct potential mistakes made by S1
61
# at test time. Therefore, D is trained with two sets of label maps: noisy segmentations from S1 (used as inputs to D),
62
# and their corresponding ground truth (used as target to train D). In order to obtain input segmentations
63
# representative of the mistakes of S1, these are obtained by degrading real images with extreme augmentation (spatial,
64
# intensity, resolution, etc.), and feeding them to S1.
65
66
# Obtaining the input/target segmentations is done offline by using the following function: sample_segmentation_pairs.py
67
# In practice we sample a lot of them (i.e. 10,000), but we give here 8 example pairs. Note that these segmentations
68
# have the same label values as the output of S1 (i.e. between [0, 4]).
69
list_input_labels = ['../../data/tutorial_7/noisy_segmentations_d/0001.nii.gz',
70
                     '../../data/tutorial_7/noisy_segmentations_d/0002.nii.gz',
71
                     '../../data/tutorial_7/noisy_segmentations_d/0003.nii.gz']
72
list_target_labels = ['../../data/tutorial_7/target_segmentations_d/0001.nii.gz',
73
                      '../../data/tutorial_7/target_segmentations_d/0002.nii.gz',
74
                      '../../data/tutorial_7/target_segmentations_d/0003.nii.gz']
75
76
# Moreover, we perform spatial augmentation on the sampled pairs, in order to further increase the morphological
77
# variability seen by the network. Furthermore, the input "noisy" segmentations are further augmented with random
78
# erosion/dilation:
79
prob_erosion_dilation = 0.3  # probability of performing random erosion/dilation
80
min_erosion_dilation = 4,    # minimum coefficient for erosion/dilation
81
max_erosion_dilation = 5     # maximum coefficient for erosion/dilation
82
83
# This is the list of label values included in the input/GT label maps. This list must contain unique values.
84
input_segmentation_labels = np.array([0, 1, 2, 3, 4])
85
86
model_dir_d = './outputs_tutorial_7/training_d'  # folder where the models will be saved
87
88
training_d(list_paths_input_labels=list_input_labels,
89
           list_paths_target_labels=list_target_labels,
90
           model_dir=model_dir_d,
91
           input_segmentation_labels=input_segmentation_labels,
92
           output_shape=160,
93
           prob_erosion_dilation=prob_erosion_dilation,
94
           min_erosion_dilation=min_erosion_dilation,
95
           max_erosion_dilation=max_erosion_dilation,
96
           conv_size=5,
97
           unet_feat_count=16,
98
           skip_n_concatenations=2)
99
100
# ------------------ segmenter S2
101
# Final segmentations are obtained with a last segmenter S2, which takes as inputs an image as well as the preliminary
102
# segmentations of S1 that are corrected by D.
103
104
# Here S2 is trained with synthetic images sampled from the usual training label maps with associated generation labels,
105
# classes. Also, we now use the same segmentation labels as in tutorials 2, 3, and 4, as we now segment all the usual
106
# regions.
107
labels_dir_s2 = '../../data/training_label_maps'  # these are the same as for S1
108
path_generation_labels = '../../data/labels_classes_priors/generation_labels.npy'
109
path_generation_classes = '../../data/labels_classes_priors/generation_classes.npy'
110
path_segmentation_labels_s2 = '../../data/labels_classes_priors/synthseg_segmentation_labels.npy'
111
112
# The preliminary segmentations are given as soft probability maps and are directly derived from the ground truth.
113
# Specifically, we take the structures that were segmented by S1, and regroup them into the same "groups" as before.
114
grouping_labels = '../../data/tutorial_7/segmentation_labels_s1.npy'
115
# However, in order to simulate test-time imperfections made by D, we these soft probability maps are slightly
116
# augmented with spatial transforms, and sometimes undergo a random dilation/erosion.
117
118
model_dir_s2 = './outputs_tutorial_7/training_s2'  # folder where the models will be saved
119
120
training_s2(labels_dir=labels_dir_s2,
121
            model_dir=model_dir_s2,
122
            generation_labels=path_generation_labels,
123
            n_neutral_labels=18,
124
            segmentation_labels=path_segmentation_labels_s2,
125
            generation_classes=path_generation_classes,
126
            grouping_labels=grouping_labels,
127
            target_res=1,
128
            output_shape=160,
129
            prior_distributions='uniform',
130
            prior_means=[0, 255],
131
            prior_stds=[0, 50],
132
            randomise_res=True)