a b/scripts/tutorials/3-training.py
1
"""
2
3
This script shows how we trained SynthSeg.
4
Importantly, it reuses numerous parameters seen in the previous tutorial about image generation
5
(i.e., 2-generation_explained.py), which we strongly recommend reading before this one.
6
7
8
9
If you use this code, please cite one of the SynthSeg papers:
10
https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
11
12
Copyright 2020 Benjamin Billot
13
14
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
15
compliance with the License. You may obtain a copy of the License at
16
https://www.apache.org/licenses/LICENSE-2.0
17
Unless required by applicable law or agreed to in writing, software distributed under the License is
18
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
19
implied. See the License for the specific language governing permissions and limitations under the
20
License.
21
"""
22
23
24
# project imports
25
from SynthSeg.training import training
26
27
28
# path training label maps
29
path_training_label_maps = '../../data/training_label_maps'
30
path_model_dir = './outputs_tutorial_3/'
31
batchsize = 1
32
33
# architecture parameters
34
n_levels = 5           # number of resolution levels
35
nb_conv_per_level = 2  # number of convolution per level
36
conv_size = 3          # size of the convolution kernel (e.g. 3x3x3)
37
unet_feat_count = 24   # number of feature maps after the first convolution
38
activation = 'elu'     # activation for all convolution layers except the last, which will use softmax regardless
39
feat_multiplier = 2    # if feat_multiplier is set to 1, we will keep the number of feature maps constant throughout the
40
#                        network; 2 will double them(resp. half) after each max-pooling (resp. upsampling);
41
#                        3 will triple them, etc.
42
43
# training parameters
44
lr = 1e-4               # learning rate
45
wl2_epochs = 1          # number of pre-training epochs with wl2 metric w.r.t. the layer before the softmax
46
dice_epochs = 100       # number of training epochs
47
steps_per_epoch = 5000  # number of iteration per epoch
48
49
50
# ---------- Generation parameters ----------
51
# these parameters are from the previous tutorial, and thus we do not explain them again here
52
53
# generation and segmentation labels
54
path_generation_labels = '../../data/labels_classes_priors/generation_labels.npy'
55
n_neutral_labels = 18
56
path_segmentation_labels = '../../data/labels_classes_priors/synthseg_segmentation_labels.npy'
57
58
# shape and resolution of the outputs
59
target_res = None
60
output_shape = 160
61
n_channels = 1
62
63
# GMM sampling
64
prior_distributions = 'uniform'
65
path_generation_classes = '../../data/labels_classes_priors/generation_classes.npy'
66
67
# spatial deformation parameters
68
flipping = True
69
scaling_bounds = .2
70
rotation_bounds = 15
71
shearing_bounds = .012
72
translation_bounds = False
73
nonlin_std = 4.
74
bias_field_std = .7
75
76
# acquisition resolution parameters
77
randomise_res = True
78
79
# ------------------------------------------------------ Training ------------------------------------------------------
80
81
training(path_training_label_maps,
82
         path_model_dir,
83
         generation_labels=path_generation_labels,
84
         segmentation_labels=path_segmentation_labels,
85
         n_neutral_labels=n_neutral_labels,
86
         batchsize=batchsize,
87
         n_channels=n_channels,
88
         target_res=target_res,
89
         output_shape=output_shape,
90
         prior_distributions=prior_distributions,
91
         generation_classes=path_generation_classes,
92
         flipping=flipping,
93
         scaling_bounds=scaling_bounds,
94
         rotation_bounds=rotation_bounds,
95
         shearing_bounds=shearing_bounds,
96
         translation_bounds=translation_bounds,
97
         nonlin_std=nonlin_std,
98
         randomise_res=randomise_res,
99
         bias_field_std=bias_field_std,
100
         n_levels=n_levels,
101
         nb_conv_per_level=nb_conv_per_level,
102
         conv_size=conv_size,
103
         unet_feat_count=unet_feat_count,
104
         feat_multiplier=feat_multiplier,
105
         activation=activation,
106
         lr=lr,
107
         wl2_epochs=wl2_epochs,
108
         dice_epochs=dice_epochs,
109
         steps_per_epoch=steps_per_epoch)