|
a |
|
b/scripts/commands/training.py |
|
|
1 |
""" |
|
|
2 |
If you use this code, please cite one of the SynthSeg papers: |
|
|
3 |
https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib |
|
|
4 |
|
|
|
5 |
Copyright 2020 Benjamin Billot |
|
|
6 |
|
|
|
7 |
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in |
|
|
8 |
compliance with the License. You may obtain a copy of the License at |
|
|
9 |
https://www.apache.org/licenses/LICENSE-2.0 |
|
|
10 |
Unless required by applicable law or agreed to in writing, software distributed under the License is |
|
|
11 |
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or |
|
|
12 |
implied. See the License for the specific language governing permissions and limitations under the |
|
|
13 |
License. |
|
|
14 |
""" |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
from argparse import ArgumentParser |
|
|
18 |
from SynthSeg.training import training |
|
|
19 |
from ext.lab2im.utils import infer |
|
|
20 |
|
|
|
21 |
parser = ArgumentParser() |
|
|
22 |
|
|
|
23 |
# ------------------------------------------------- General parameters ------------------------------------------------- |
|
|
24 |
# Positional arguments |
|
|
25 |
parser.add_argument("labels_dir", type=str) |
|
|
26 |
parser.add_argument("model_dir", type=str) |
|
|
27 |
|
|
|
28 |
# ---------------------------------------------- Generation parameters ---------------------------------------------- |
|
|
29 |
# label maps parameters |
|
|
30 |
parser.add_argument("--generation_labels", type=str, dest="generation_labels", default=None) |
|
|
31 |
parser.add_argument("--neutral_labels", type=int, dest="n_neutral_labels", default=None) |
|
|
32 |
parser.add_argument("--segmentation_labels", type=str, dest="segmentation_labels", default=None) |
|
|
33 |
parser.add_argument("--subjects_prob", type=str, dest="subjects_prob", default=None) |
|
|
34 |
|
|
|
35 |
# output-related parameters |
|
|
36 |
parser.add_argument("--batch_size", type=int, dest="batchsize", default=1) |
|
|
37 |
parser.add_argument("--channels", type=int, dest="n_channels", default=1) |
|
|
38 |
parser.add_argument("--target_res", type=float, dest="target_res", default=None) |
|
|
39 |
parser.add_argument("--output_shape", type=int, dest="output_shape", default=None) |
|
|
40 |
|
|
|
41 |
# GMM-sampling parameters |
|
|
42 |
parser.add_argument("--generation_classes", type=str, dest="generation_classes", default=None) |
|
|
43 |
parser.add_argument("--prior_type", type=str, dest="prior_distributions", default='uniform') |
|
|
44 |
parser.add_argument("--prior_means", type=str, dest="prior_means", default=None) |
|
|
45 |
parser.add_argument("--prior_stds", type=str, dest="prior_stds", default=None) |
|
|
46 |
parser.add_argument("--specific_stats", action='store_true', dest="use_specific_stats_for_channel") |
|
|
47 |
parser.add_argument("--mix_prior_and_random", action='store_true', dest="mix_prior_and_random") |
|
|
48 |
|
|
|
49 |
# spatial deformation parameters |
|
|
50 |
parser.add_argument("--no_flipping", action='store_false', dest="flipping") |
|
|
51 |
parser.add_argument("--scaling", dest="scaling_bounds", type=infer, default=0.2) |
|
|
52 |
parser.add_argument("--rotation", dest="rotation_bounds", type=infer, default=15) |
|
|
53 |
parser.add_argument("--shearing", dest="shearing_bounds", type=infer, default=.012) |
|
|
54 |
parser.add_argument("--translation", dest="translation_bounds", type=infer, default=False) |
|
|
55 |
parser.add_argument("--nonlin_std", type=float, dest="nonlin_std", default=4.) |
|
|
56 |
parser.add_argument("--nonlin_scale", type=float, dest="nonlin_scale", default=.04) |
|
|
57 |
|
|
|
58 |
# blurring/resampling parameters |
|
|
59 |
parser.add_argument("--randomise_res", action='store_true', dest="randomise_res") |
|
|
60 |
parser.add_argument("--max_res_iso", type=float, dest="max_res_iso", default=4.) |
|
|
61 |
parser.add_argument("--max_res_aniso", type=float, dest="max_res_aniso", default=8.) |
|
|
62 |
parser.add_argument("--data_res", dest="data_res", type=infer, default=None) |
|
|
63 |
parser.add_argument("--thickness", dest="thickness", type=infer, default=None) |
|
|
64 |
|
|
|
65 |
# bias field parameters |
|
|
66 |
parser.add_argument("--bias_std", type=float, dest="bias_field_std", default=.7) |
|
|
67 |
parser.add_argument("--bias_scale", type=float, dest="bias_scale", default=.025) |
|
|
68 |
|
|
|
69 |
parser.add_argument("--gradients", action='store_true', dest="return_gradients") |
|
|
70 |
|
|
|
71 |
# -------------------------------------------- UNet architecture parameters -------------------------------------------- |
|
|
72 |
parser.add_argument("--n_levels", type=int, dest="n_levels", default=5) |
|
|
73 |
parser.add_argument("--conv_per_level", type=int, dest="nb_conv_per_level", default=2) |
|
|
74 |
parser.add_argument("--conv_size", type=int, dest="conv_size", default=3) |
|
|
75 |
parser.add_argument("--unet_feat", type=int, dest="unet_feat_count", default=24) |
|
|
76 |
parser.add_argument("--feat_mult", type=int, dest="feat_multiplier", default=2) |
|
|
77 |
parser.add_argument("--activation", type=str, dest="activation", default='elu') |
|
|
78 |
|
|
|
79 |
# ------------------------------------------------- Training parameters ------------------------------------------------ |
|
|
80 |
parser.add_argument("--lr", type=float, dest="lr", default=1e-4) |
|
|
81 |
parser.add_argument("--wl2_epochs", type=int, dest="wl2_epochs", default=1) |
|
|
82 |
parser.add_argument("--dice_epochs", type=int, dest="dice_epochs", default=50) |
|
|
83 |
parser.add_argument("--steps_per_epoch", type=int, dest="steps_per_epoch", default=10000) |
|
|
84 |
parser.add_argument("--checkpoint", type=str, dest="checkpoint", default=None) |
|
|
85 |
|
|
|
86 |
args = parser.parse_args() |
|
|
87 |
training(**vars(args)) |