a b/scripts/tutorials/4-prediction.py
1
"""
2
3
This script shows how to perform inference after having trained your own model.
4
Importantly, it reuses some of the parameters used in tutorial 3-training.
5
Moreover, we emphasise that this tutorial explains how to perform inference on your own trained models.
6
To predict segmentations based on the distributed mode for SynthSeg, please refer to the README.md file.
7
8
9
10
If you use this code, please cite one of the SynthSeg papers:
11
https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
12
13
Copyright 2020 Benjamin Billot
14
15
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
16
compliance with the License. You may obtain a copy of the License at
17
https://www.apache.org/licenses/LICENSE-2.0
18
Unless required by applicable law or agreed to in writing, software distributed under the License is
19
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
20
implied. See the License for the specific language governing permissions and limitations under the
21
License.
22
"""
23
24
# project imports
25
from SynthSeg.predict import predict
26
27
# paths to input/output files
28
# Here we assume the availability of an image that we wish to segment with a model we have just trained.
29
# We emphasise that we do not provide such an image (this is just an example after all :))
30
# Input images must have a .nii, .nii.gz, or .mgz extension.
31
# Note that path_images can also be the path to an entire folder, in which case all the images within this folder will
32
# be segmented. In this case, please provide path_segm (and possibly path_posteriors, and path_resampled) as folder.
33
path_images = '/a/path/to/an/image/im.nii.gz'
34
# path to the output segmentation
35
path_segm = './outputs_tutorial_4/predicted_segmentations/im_seg.nii.gz'
36
# we can also provide paths for optional files containing the probability map for all predicted labels
37
path_posteriors = './outputs_tutorial_4/predicted_information/im_post.nii.gz'
38
# and for a csv file that will contain the volumes of each segmented structure
39
path_vol = './outputs_tutorial_4/predicted_information/volumes.csv'
40
41
# of course we need to provide the path to the trained model (here we use the main synthseg model).
42
path_model = '../../models/synthseg_1.0.h5'
43
# but we also need to provide the path to the segmentation labels used during training
44
path_segmentation_labels = '../../data/labels_classes_priors/synthseg_segmentation_labels.npy'
45
# optionally we can give a numpy array with the names corresponding to the structures in path_segmentation_labels
46
path_segmentation_names = '../../data/labels_classes_priors/synthseg_segmentation_names.npy'
47
48
# We can now provide various parameters to control the preprocessing of the input.
49
# First we can play with the size of the input. Remember that the size of input must be divisible by 2**n_levels, so the
50
# input image will be automatically padded to the nearest shape divisible by 2**n_levels (this is just for processing,
51
# the output will then be cropped to the original image size).
52
# Alternatively, you can crop the input to a smaller shape for faster processing, or to make it fit on your GPU.
53
cropping = 192
54
# Finally, we finish preprocessing the input by resampling it to the resolution at which the network has been trained to
55
# produce predictions. If the input image has a resolution outside the range [target_res-0.05, target_res+0.05], it will
56
# automatically be resampled to target_res.
57
target_res = 1.
58
# Note that if the image is indeed resampled, you have the option to save the resampled image.
59
path_resampled = './outputs_tutorial_4/predicted_information/im_resampled_target_res.nii.gz'
60
61
# After the image has been processed by the network, there are again various options to postprocess it.
62
# First, we can apply some test-time augmentation by flipping the input along the right-left axis and segmenting
63
# the resulting image. In this case, and if the network has right/left specific labels, it is also very important to
64
# provide the number of neutral labels. This must be the exact same as the one used during training.
65
flip = True
66
n_neutral_labels = 18
67
# Second, we can smooth the probability maps produced by the network. This doesn't change much the results, but helps to
68
# reduce high frequency noise in the obtained segmentations.
69
sigma_smoothing = 0.5
70
# Then we can operate some fancier version of biggest connected component, by regrouping structures within so-called
71
# "topological classes". For each class we successively: 1) sum all the posteriors corresponding to the labels of this
72
# class, 2) obtain a mask for this class by thresholding the summed posteriors by a low value (arbitrarily set to 0.1),
73
# 3) keep the biggest connected component, and 4) individually apply the obtained mask to the posteriors of all the
74
# labels for this class.
75
# Example: (continuing the previous one)  generation_labels = [0, 24, 507, 2, 3, 4, 17, 25, 41, 42, 43, 53, 57]
76
#                                             output_labels = [0,  0,  0,  2, 3, 4, 17,  2, 41, 42, 43, 53, 41]
77
#                                       topological_classes = [0,  0,  0,  1, 1, 2,  3,  1,  4,  4,  5,  6,  7]
78
# Here we regroup labels 2 and 3 in the same topological class, same for labels 41 and 42. The topological class of
79
# unsegmented structures must be set to 0 (like for 24 and 507).
80
topology_classes = '../../data/labels_classes_priors/synthseg_topological_classes.npy'
81
# Finally, we can also operate a strict version of biggest connected component, to get rid of unwanted noisy label
82
# patch that can sometimes occur in the background. If so, we do recommend to use the smoothing option described above.
83
keep_biggest_component = True
84
85
# Regarding the architecture of the network, we must provide the predict function with the same parameters as during
86
# training.
87
n_levels = 5
88
nb_conv_per_level = 2
89
conv_size = 3
90
unet_feat_count = 24
91
activation = 'elu'
92
feat_multiplier = 2
93
94
# Finally, we can set up an evaluation step after all images have been segmented.
95
# In this purpose, we need to provide the path to the ground truth corresponding to the input image(s).
96
# This is done by using the "gt_folder" parameter, which must have the same type as path_images (i.e., the path to a
97
# single image or to a folder). If provided as a folder, ground truths must be sorted in the same order as images in
98
# path_images.
99
# Just set this to None if you do not want to run evaluation.
100
gt_folder = '/the/path/to/the/ground_truth/gt.nii.gz'
101
# Dice scores will be computed and saved as a numpy array in the folder containing the segmentation(s).
102
# This numpy array will be organised as follows: rows correspond to structures, and columns to subjects. Importantly,
103
# rows are given in a sorted order.
104
# Example: we segment 2 subjects, where output_labels = [0,  0,  0,  2, 3, 4, 17,  2, 41, 42, 43, 53, 41]
105
#                             so sorted output_labels = [0, 2, 3, 4, 17, 41, 42, 43, 53]
106
# dice = [[xxx, xxx],  # scores for label 0
107
#         [xxx, xxx],  # scores for label 2
108
#         [xxx, xxx],  # scores for label 3
109
#         [xxx, xxx],  # scores for label 4
110
#         [xxx, xxx],  # scores for label 17
111
#         [xxx, xxx],  # scores for label 41
112
#         [xxx, xxx],  # scores for label 42
113
#         [xxx, xxx],  # scores for label 43
114
#         [xxx, xxx]]  # scores for label 53
115
#         /       \
116
#   subject 1    subject 2
117
#
118
# Also we can compute different surface distances (Hausdorff, Hausdorff99, Hausdorff95 and mean surface distance). The
119
# results will be saved in arrays similar to the Dice scores.
120
compute_distances = True
121
122
# All right, we're ready to make predictions !!
123
predict(path_images,
124
        path_segm,
125
        path_model,
126
        path_segmentation_labels,
127
        n_neutral_labels=n_neutral_labels,
128
        path_posteriors=path_posteriors,
129
        path_resampled=path_resampled,
130
        path_volumes=path_vol,
131
        names_segmentation=path_segmentation_names,
132
        cropping=cropping,
133
        target_res=target_res,
134
        flip=flip,
135
        topology_classes=topology_classes,
136
        sigma_smoothing=sigma_smoothing,
137
        keep_biggest_component=keep_biggest_component,
138
        n_levels=n_levels,
139
        nb_conv_per_level=nb_conv_per_level,
140
        conv_size=conv_size,
141
        unet_feat_count=unet_feat_count,
142
        feat_multiplier=feat_multiplier,
143
        activation=activation,
144
        gt_folder=gt_folder,
145
        compute_distances=compute_distances)