|
a |
|
b/rvseg/opts.py |
|
|
1 |
from __future__ import division, print_function |
|
|
2 |
|
|
|
3 |
import os |
|
|
4 |
import argparse |
|
|
5 |
import configparser |
|
|
6 |
import logging |
|
|
7 |
|
|
|
8 |
definitions = [ |
|
|
9 |
# model type default help |
|
|
10 |
('model', (str, 'unet', "Model: unet, dilated-unet, dilated-densenet")), |
|
|
11 |
('features', (int, 64, "Number of features maps after first convolutional layer.")), |
|
|
12 |
('depth', (int, 4, "Number of downsampled convolutional blocks.")), |
|
|
13 |
('temperature', (float, 1.0, "Temperature of final softmax layer in model.")), |
|
|
14 |
('padding', (str, 'same', "Padding in convolutional layers. Either `same' or `valid'.")), |
|
|
15 |
('dropout', (float, 0.0, "Rate for dropout of activation units.")), |
|
|
16 |
('classes', (str, 'inner', "One of `inner' (endocardium), `outer' (epicardium), or `both'.")), |
|
|
17 |
('batchnorm', {'default': False, 'action': 'store_true', |
|
|
18 |
'help': "Apply batch normalization before nonlinearities."}), |
|
|
19 |
|
|
|
20 |
# loss |
|
|
21 |
('loss', (str, 'pixel', "Loss function: `pixel' for pixel-wise cross entropy, `dice' for dice coefficient.")), |
|
|
22 |
('loss-weights', {'type': float, 'nargs': '+', 'default': [0.1, 0.9], |
|
|
23 |
'help': "When using dice or jaccard loss, how much to weight each output class."}), |
|
|
24 |
|
|
|
25 |
# training |
|
|
26 |
('epochs', (int, 20, "Number of epochs to train.")), |
|
|
27 |
('batch-size', (int, 32, "Mini-batch size for training.")), |
|
|
28 |
('validation-split', (float, 0.2, "Percentage of training data to hold out for validation.")), |
|
|
29 |
('optimizer', (str, 'adam', "Optimizer: sgd, rmsprop, adagrad, adadelta, adam, adamax, or nadam.")), |
|
|
30 |
('learning-rate', (float, None, "Optimizer learning rate.")), |
|
|
31 |
('momentum', (float, None, "Momentum for SGD optimizer.")), |
|
|
32 |
('decay', (float, None, "Learning rate decay (not applicable for nadam).")), |
|
|
33 |
('shuffle_train_val', {'default': False, 'action': 'store_true', |
|
|
34 |
'help': "Shuffle images before splitting into train vs. val."}), |
|
|
35 |
('shuffle', {'default': False, 'action': 'store_true', |
|
|
36 |
'help': "Shuffle images before each training epoch."}), |
|
|
37 |
('seed', (int, None, "Seed for numpy RandomState")), |
|
|
38 |
|
|
|
39 |
# files |
|
|
40 |
('datadir', (str, '.', "Directory containing patientXX/ directories.")), |
|
|
41 |
('outdir', (str, '.', "Directory to write output data.")), |
|
|
42 |
('outfile', (str, 'weights-final.hdf5', "File to write final model weights.")), |
|
|
43 |
('load-weights', (str, '', "Load model weights from specified file to initialize training.")), |
|
|
44 |
('checkpoint', {'default': False, 'action': 'store_true', |
|
|
45 |
'help': "Write model weights after each epoch if validation accuracy improves."}), |
|
|
46 |
|
|
|
47 |
# augmentation |
|
|
48 |
('augment-training', {'default': False, 'action': 'store_true', |
|
|
49 |
'help': "Whether to apply image augmentation to training set."}), |
|
|
50 |
('augment-validation', {'default': False, 'action': 'store_true', |
|
|
51 |
'help': "Whether to apply image augmentation to validation set."}), |
|
|
52 |
('rotation-range', (float, 180, "Rotation range (0-180 degrees)")), |
|
|
53 |
('width-shift-range', (float, 0.1, "Width shift range, as a float fraction of the width")), |
|
|
54 |
('height-shift-range', (float, 0.1, "Height shift range, as a float fraction of the height")), |
|
|
55 |
('shear-range', (float, 0.1, "Shear intensity (in radians)")), |
|
|
56 |
('zoom-range', (float, 0.05, "Amount of zoom. If a scalar z, zoom in [1-z, 1+z]. Can also pass a pair of floats as the zoom range.")), |
|
|
57 |
('fill-mode', (str, 'nearest', "Points outside boundaries are filled according to mode: constant, nearest, reflect, or wrap")), |
|
|
58 |
('alpha', (float, 500, "Random elastic distortion: magnitude of distortion")), |
|
|
59 |
('sigma', (float, 20, "Random elastic distortion: length scale")), |
|
|
60 |
('normalize', {'default': False, 'action': 'store_true', |
|
|
61 |
'help': "Subtract mean and divide by std dev from each image."}), |
|
|
62 |
] |
|
|
63 |
|
|
|
64 |
noninitialized = { |
|
|
65 |
'learning_rate': 'getfloat', |
|
|
66 |
'momentum': 'getfloat', |
|
|
67 |
'decay': 'getfloat', |
|
|
68 |
'seed': 'getint', |
|
|
69 |
} |
|
|
70 |
|
|
|
71 |
def update_from_configfile(args, default, config, section, key): |
|
|
72 |
# Point of this function is to update the args Namespace. |
|
|
73 |
value = config.get(section, key) |
|
|
74 |
if value == '' or value is None: |
|
|
75 |
return |
|
|
76 |
|
|
|
77 |
# Command-line arguments override config file values |
|
|
78 |
if getattr(args, key) != default: |
|
|
79 |
return |
|
|
80 |
|
|
|
81 |
# Config files always store values as strings -- get correct type |
|
|
82 |
if isinstance(default, bool): |
|
|
83 |
value = config.getboolean(section, key) |
|
|
84 |
elif isinstance(default, int): |
|
|
85 |
value = config.getint(section, key) |
|
|
86 |
elif isinstance(default, float): |
|
|
87 |
value = config.getfloat(section, key) |
|
|
88 |
elif isinstance(default, str): |
|
|
89 |
value = config.get(section, key) |
|
|
90 |
elif isinstance(default, list): |
|
|
91 |
# special case (HACK): loss-weights is list of floats |
|
|
92 |
string = config.get(section, key) |
|
|
93 |
value = [float(x) for x in string.split()] |
|
|
94 |
elif default is None: |
|
|
95 |
# values which aren't initialized |
|
|
96 |
getter = getattr(config, noninitialized[key]) |
|
|
97 |
value = getter(section, key) |
|
|
98 |
setattr(args, key, value) |
|
|
99 |
|
|
|
100 |
def parse_arguments(): |
|
|
101 |
parser = argparse.ArgumentParser( |
|
|
102 |
description="Train U-Net to segment right ventricles from cardiac " |
|
|
103 |
"MRI images.") |
|
|
104 |
|
|
|
105 |
for argname, kwargs in definitions: |
|
|
106 |
d = kwargs |
|
|
107 |
if isinstance(kwargs, tuple): |
|
|
108 |
d = dict(zip(['type', 'default', 'help'], kwargs)) |
|
|
109 |
parser.add_argument('--' + argname, **d) |
|
|
110 |
|
|
|
111 |
# allow user to input configuration file |
|
|
112 |
parser.add_argument( |
|
|
113 |
'configfile', nargs='?', type=str, help="Load options from config " |
|
|
114 |
"file (command line arguments take precedence).") |
|
|
115 |
|
|
|
116 |
args = parser.parse_args() |
|
|
117 |
|
|
|
118 |
if args.configfile: |
|
|
119 |
logging.info("Loading options from config file: {}".format(args.configfile)) |
|
|
120 |
config = configparser.ConfigParser( |
|
|
121 |
inline_comment_prefixes=['#', ';'], allow_no_value=True) |
|
|
122 |
config.read(args.configfile) |
|
|
123 |
for section in config: |
|
|
124 |
for key in config[section]: |
|
|
125 |
if key not in args: |
|
|
126 |
raise Exception("Unknown option {} in config file.".format(key)) |
|
|
127 |
update_from_configfile(args, parser.get_default(key), |
|
|
128 |
config, section, key) |
|
|
129 |
|
|
|
130 |
for k,v in vars(args).items(): |
|
|
131 |
logging.info("{:20s} = {}".format(k, v)) |
|
|
132 |
|
|
|
133 |
return args |