|
a |
|
b/scripts/tutorials/7-synthseg+.py |
|
|
1 |
""" |
|
|
2 |
|
|
|
3 |
Very simple script to show how we trained SynthSeg+, which extends SynthSeg by building robustness to clinical |
|
|
4 |
acquisitions. For more details, please look at our MICCAI 2022 paper: |
|
|
5 |
|
|
|
6 |
Robust Segmentation of Brain MRI in the Wild with Hierarchical CNNs and no Retraining, |
|
|
7 |
Billot, Magdamo, Das, Arnold, Iglesias |
|
|
8 |
MICCAI 2022 |
|
|
9 |
|
|
|
10 |
If you use this code, please cite one of the SynthSeg papers: |
|
|
11 |
https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib |
|
|
12 |
|
|
|
13 |
Copyright 2020 Benjamin Billot |
|
|
14 |
|
|
|
15 |
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in |
|
|
16 |
compliance with the License. You may obtain a copy of the License at |
|
|
17 |
https://www.apache.org/licenses/LICENSE-2.0 |
|
|
18 |
Unless required by applicable law or agreed to in writing, software distributed under the License is |
|
|
19 |
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or |
|
|
20 |
implied. See the License for the specific language governing permissions and limitations under the |
|
|
21 |
License. |
|
|
22 |
""" |
|
|
23 |
from SynthSeg.training import training as training_s1 |
|
|
24 |
from SynthSeg.training_denoiser import training as training_d |
|
|
25 |
from SynthSeg.training_group import training as training_s2 |
|
|
26 |
|
|
|
27 |
import numpy as np |
|
|
28 |
|
|
|
29 |
# ------------------ segmenter S1 |
|
|
30 |
# Here the purpose is to train a first network to produce preliminary segmentations of input scans with five general |
|
|
31 |
# labels: 0-background, 1-white matter, 2-grey matter, 3-fluids, 4-cerebellum. |
|
|
32 |
|
|
|
33 |
# As in tutorial 3, S1 is trained with synthetic images with randomised contrasts/resolution/artefacts such that it can |
|
|
34 |
# readily segment a wide range of test scans without retraining. The synthetic scans are obtained from the same label |
|
|
35 |
# maps and generative model as in the previous tutorials. |
|
|
36 |
labels_dir_s1 = '../../data/training_label_maps' |
|
|
37 |
path_generation_labels = '../../data/labels_classes_priors/generation_labels.npy' |
|
|
38 |
path_generation_classes = '../../data/labels_classes_priors/generation_classes.npy' |
|
|
39 |
# However, because we now wish to segment scans using only five labels, we use a different list of segmentation labels |
|
|
40 |
# where all label values in generation_labels are assigned to a target value between [0, 4]. |
|
|
41 |
path_segmentation_labels_s1 = '../../data/tutorial_7/segmentation_labels_s1.npy' |
|
|
42 |
|
|
|
43 |
model_dir_s1 = './outputs_tutorial_7/training_s1' # folder where the models will be saved |
|
|
44 |
|
|
|
45 |
|
|
|
46 |
training_s1(labels_dir=labels_dir_s1, |
|
|
47 |
model_dir=model_dir_s1, |
|
|
48 |
generation_labels=path_generation_labels, |
|
|
49 |
segmentation_labels=path_segmentation_labels_s1, |
|
|
50 |
n_neutral_labels=18, |
|
|
51 |
generation_classes=path_generation_classes, |
|
|
52 |
target_res=1, |
|
|
53 |
output_shape=160, |
|
|
54 |
prior_distributions='uniform', |
|
|
55 |
prior_means=[0, 255], |
|
|
56 |
prior_stds=[0, 50], |
|
|
57 |
randomise_res=True) |
|
|
58 |
|
|
|
59 |
# ------------------ denoiser D |
|
|
60 |
# The purpose of this network is to perform label-to-label correction in order to correct potential mistakes made by S1 |
|
|
61 |
# at test time. Therefore, D is trained with two sets of label maps: noisy segmentations from S1 (used as inputs to D), |
|
|
62 |
# and their corresponding ground truth (used as target to train D). In order to obtain input segmentations |
|
|
63 |
# representative of the mistakes of S1, these are obtained by degrading real images with extreme augmentation (spatial, |
|
|
64 |
# intensity, resolution, etc.), and feeding them to S1. |
|
|
65 |
|
|
|
66 |
# Obtaining the input/target segmentations is done offline by using the following function: sample_segmentation_pairs.py |
|
|
67 |
# In practice we sample a lot of them (i.e. 10,000), but we give here 8 example pairs. Note that these segmentations |
|
|
68 |
# have the same label values as the output of S1 (i.e. between [0, 4]). |
|
|
69 |
list_input_labels = ['../../data/tutorial_7/noisy_segmentations_d/0001.nii.gz', |
|
|
70 |
'../../data/tutorial_7/noisy_segmentations_d/0002.nii.gz', |
|
|
71 |
'../../data/tutorial_7/noisy_segmentations_d/0003.nii.gz'] |
|
|
72 |
list_target_labels = ['../../data/tutorial_7/target_segmentations_d/0001.nii.gz', |
|
|
73 |
'../../data/tutorial_7/target_segmentations_d/0002.nii.gz', |
|
|
74 |
'../../data/tutorial_7/target_segmentations_d/0003.nii.gz'] |
|
|
75 |
|
|
|
76 |
# Moreover, we perform spatial augmentation on the sampled pairs, in order to further increase the morphological |
|
|
77 |
# variability seen by the network. Furthermore, the input "noisy" segmentations are further augmented with random |
|
|
78 |
# erosion/dilation: |
|
|
79 |
prob_erosion_dilation = 0.3 # probability of performing random erosion/dilation |
|
|
80 |
min_erosion_dilation = 4, # minimum coefficient for erosion/dilation |
|
|
81 |
max_erosion_dilation = 5 # maximum coefficient for erosion/dilation |
|
|
82 |
|
|
|
83 |
# This is the list of label values included in the input/GT label maps. This list must contain unique values. |
|
|
84 |
input_segmentation_labels = np.array([0, 1, 2, 3, 4]) |
|
|
85 |
|
|
|
86 |
model_dir_d = './outputs_tutorial_7/training_d' # folder where the models will be saved |
|
|
87 |
|
|
|
88 |
training_d(list_paths_input_labels=list_input_labels, |
|
|
89 |
list_paths_target_labels=list_target_labels, |
|
|
90 |
model_dir=model_dir_d, |
|
|
91 |
input_segmentation_labels=input_segmentation_labels, |
|
|
92 |
output_shape=160, |
|
|
93 |
prob_erosion_dilation=prob_erosion_dilation, |
|
|
94 |
min_erosion_dilation=min_erosion_dilation, |
|
|
95 |
max_erosion_dilation=max_erosion_dilation, |
|
|
96 |
conv_size=5, |
|
|
97 |
unet_feat_count=16, |
|
|
98 |
skip_n_concatenations=2) |
|
|
99 |
|
|
|
100 |
# ------------------ segmenter S2 |
|
|
101 |
# Final segmentations are obtained with a last segmenter S2, which takes as inputs an image as well as the preliminary |
|
|
102 |
# segmentations of S1 that are corrected by D. |
|
|
103 |
|
|
|
104 |
# Here S2 is trained with synthetic images sampled from the usual training label maps with associated generation labels, |
|
|
105 |
# classes. Also, we now use the same segmentation labels as in tutorials 2, 3, and 4, as we now segment all the usual |
|
|
106 |
# regions. |
|
|
107 |
labels_dir_s2 = '../../data/training_label_maps' # these are the same as for S1 |
|
|
108 |
path_generation_labels = '../../data/labels_classes_priors/generation_labels.npy' |
|
|
109 |
path_generation_classes = '../../data/labels_classes_priors/generation_classes.npy' |
|
|
110 |
path_segmentation_labels_s2 = '../../data/labels_classes_priors/synthseg_segmentation_labels.npy' |
|
|
111 |
|
|
|
112 |
# The preliminary segmentations are given as soft probability maps and are directly derived from the ground truth. |
|
|
113 |
# Specifically, we take the structures that were segmented by S1, and regroup them into the same "groups" as before. |
|
|
114 |
grouping_labels = '../../data/tutorial_7/segmentation_labels_s1.npy' |
|
|
115 |
# However, in order to simulate test-time imperfections made by D, we these soft probability maps are slightly |
|
|
116 |
# augmented with spatial transforms, and sometimes undergo a random dilation/erosion. |
|
|
117 |
|
|
|
118 |
model_dir_s2 = './outputs_tutorial_7/training_s2' # folder where the models will be saved |
|
|
119 |
|
|
|
120 |
training_s2(labels_dir=labels_dir_s2, |
|
|
121 |
model_dir=model_dir_s2, |
|
|
122 |
generation_labels=path_generation_labels, |
|
|
123 |
n_neutral_labels=18, |
|
|
124 |
segmentation_labels=path_segmentation_labels_s2, |
|
|
125 |
generation_classes=path_generation_classes, |
|
|
126 |
grouping_labels=grouping_labels, |
|
|
127 |
target_res=1, |
|
|
128 |
output_shape=160, |
|
|
129 |
prior_distributions='uniform', |
|
|
130 |
prior_means=[0, 255], |
|
|
131 |
prior_stds=[0, 50], |
|
|
132 |
randomise_res=True) |