Switch to unified view

a b/scripts/commands/training.py
1
"""
2
If you use this code, please cite one of the SynthSeg papers:
3
https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
4
5
Copyright 2020 Benjamin Billot
6
7
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
8
compliance with the License. You may obtain a copy of the License at
9
https://www.apache.org/licenses/LICENSE-2.0
10
Unless required by applicable law or agreed to in writing, software distributed under the License is
11
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12
implied. See the License for the specific language governing permissions and limitations under the
13
License.
14
"""
15
16
17
from argparse import ArgumentParser
18
from SynthSeg.training import training
19
from ext.lab2im.utils import infer
20
21
parser = ArgumentParser()
22
23
# ------------------------------------------------- General parameters -------------------------------------------------
24
# Positional arguments
25
parser.add_argument("labels_dir", type=str)
26
parser.add_argument("model_dir", type=str)
27
28
# ---------------------------------------------- Generation parameters ----------------------------------------------
29
# label maps parameters
30
parser.add_argument("--generation_labels", type=str, dest="generation_labels", default=None)
31
parser.add_argument("--neutral_labels", type=int, dest="n_neutral_labels", default=None)
32
parser.add_argument("--segmentation_labels", type=str, dest="segmentation_labels", default=None)
33
parser.add_argument("--subjects_prob", type=str, dest="subjects_prob", default=None)
34
35
# output-related parameters
36
parser.add_argument("--batch_size", type=int, dest="batchsize", default=1)
37
parser.add_argument("--channels", type=int, dest="n_channels", default=1)
38
parser.add_argument("--target_res", type=float, dest="target_res", default=None)
39
parser.add_argument("--output_shape", type=int, dest="output_shape", default=None)
40
41
# GMM-sampling parameters
42
parser.add_argument("--generation_classes", type=str, dest="generation_classes", default=None)
43
parser.add_argument("--prior_type", type=str, dest="prior_distributions", default='uniform')
44
parser.add_argument("--prior_means", type=str, dest="prior_means", default=None)
45
parser.add_argument("--prior_stds", type=str, dest="prior_stds", default=None)
46
parser.add_argument("--specific_stats", action='store_true', dest="use_specific_stats_for_channel")
47
parser.add_argument("--mix_prior_and_random", action='store_true', dest="mix_prior_and_random")
48
49
# spatial deformation parameters
50
parser.add_argument("--no_flipping", action='store_false', dest="flipping")
51
parser.add_argument("--scaling", dest="scaling_bounds", type=infer, default=0.2)
52
parser.add_argument("--rotation", dest="rotation_bounds", type=infer, default=15)
53
parser.add_argument("--shearing", dest="shearing_bounds", type=infer, default=.012)
54
parser.add_argument("--translation", dest="translation_bounds", type=infer, default=False)
55
parser.add_argument("--nonlin_std", type=float, dest="nonlin_std", default=4.)
56
parser.add_argument("--nonlin_scale", type=float, dest="nonlin_scale", default=.04)
57
58
# blurring/resampling parameters
59
parser.add_argument("--randomise_res", action='store_true', dest="randomise_res")
60
parser.add_argument("--max_res_iso", type=float, dest="max_res_iso", default=4.)
61
parser.add_argument("--max_res_aniso", type=float, dest="max_res_aniso", default=8.)
62
parser.add_argument("--data_res", dest="data_res", type=infer, default=None)
63
parser.add_argument("--thickness", dest="thickness", type=infer, default=None)
64
65
# bias field parameters
66
parser.add_argument("--bias_std", type=float, dest="bias_field_std", default=.7)
67
parser.add_argument("--bias_scale", type=float, dest="bias_scale", default=.025)
68
69
parser.add_argument("--gradients", action='store_true', dest="return_gradients")
70
71
# -------------------------------------------- UNet architecture parameters --------------------------------------------
72
parser.add_argument("--n_levels", type=int, dest="n_levels", default=5)
73
parser.add_argument("--conv_per_level", type=int, dest="nb_conv_per_level", default=2)
74
parser.add_argument("--conv_size", type=int, dest="conv_size", default=3)
75
parser.add_argument("--unet_feat", type=int, dest="unet_feat_count", default=24)
76
parser.add_argument("--feat_mult", type=int, dest="feat_multiplier", default=2)
77
parser.add_argument("--activation", type=str, dest="activation", default='elu')
78
79
# ------------------------------------------------- Training parameters ------------------------------------------------
80
parser.add_argument("--lr", type=float, dest="lr", default=1e-4)
81
parser.add_argument("--wl2_epochs", type=int, dest="wl2_epochs", default=1)
82
parser.add_argument("--dice_epochs", type=int, dest="dice_epochs", default=50)
83
parser.add_argument("--steps_per_epoch", type=int, dest="steps_per_epoch", default=10000)
84
parser.add_argument("--checkpoint", type=str, dest="checkpoint", default=None)
85
86
args = parser.parse_args()
87
training(**vars(args))