a b/rvseg/opts.py
1
from __future__ import division, print_function
2
3
import os
4
import argparse
5
import configparser
6
import logging
7
8
definitions = [
9
    # model               type   default help
10
    ('model',            (str,   'unet', "Model: unet, dilated-unet, dilated-densenet")),
11
    ('features',         (int,   64,     "Number of features maps after first convolutional layer.")),
12
    ('depth',            (int,   4,      "Number of downsampled convolutional blocks.")),
13
    ('temperature',      (float, 1.0,    "Temperature of final softmax layer in model.")),
14
    ('padding',          (str,   'same', "Padding in convolutional layers. Either `same' or `valid'.")),
15
    ('dropout',          (float, 0.0,    "Rate for dropout of activation units.")),
16
    ('classes',          (str,   'inner', "One of `inner' (endocardium), `outer' (epicardium), or `both'.")),
17
    ('batchnorm',        {'default': False, 'action': 'store_true', 
18
                          'help': "Apply batch normalization before nonlinearities."}),
19
20
    # loss
21
    ('loss',             (str,   'pixel', "Loss function: `pixel' for pixel-wise cross entropy, `dice' for dice coefficient.")),
22
    ('loss-weights',     {'type': float, 'nargs': '+', 'default': [0.1, 0.9],
23
                          'help': "When using dice or jaccard loss, how much to weight each output class."}),
24
25
    # training
26
    ('epochs',           (int,   20,     "Number of epochs to train.")),
27
    ('batch-size',       (int,   32,     "Mini-batch size for training.")),
28
    ('validation-split', (float, 0.2,    "Percentage of training data to hold out for validation.")),
29
    ('optimizer',        (str,   'adam', "Optimizer: sgd, rmsprop, adagrad, adadelta, adam, adamax, or nadam.")),
30
    ('learning-rate',    (float, None,   "Optimizer learning rate.")),
31
    ('momentum',         (float, None,   "Momentum for SGD optimizer.")),
32
    ('decay',            (float, None,   "Learning rate decay (not applicable for nadam).")),
33
    ('shuffle_train_val', {'default': False, 'action': 'store_true',
34
                           'help': "Shuffle images before splitting into train vs. val."}),
35
    ('shuffle',          {'default': False, 'action': 'store_true',
36
                          'help': "Shuffle images before each training epoch."}),
37
    ('seed',             (int,   None,   "Seed for numpy RandomState")),
38
39
    # files
40
    ('datadir',          (str,   '.',    "Directory containing patientXX/ directories.")),
41
    ('outdir',           (str,   '.',    "Directory to write output data.")),
42
    ('outfile',          (str,   'weights-final.hdf5', "File to write final model weights.")),
43
    ('load-weights',     (str,   '',     "Load model weights from specified file to initialize training.")),
44
    ('checkpoint',       {'default': False, 'action': 'store_true',
45
                          'help': "Write model weights after each epoch if validation accuracy improves."}),
46
47
    # augmentation
48
    ('augment-training', {'default': False, 'action': 'store_true',
49
                          'help': "Whether to apply image augmentation to training set."}),
50
    ('augment-validation', {'default': False, 'action': 'store_true',
51
                            'help': "Whether to apply image augmentation to validation set."}),
52
    ('rotation-range',     (float, 180,    "Rotation range (0-180 degrees)")),
53
    ('width-shift-range',  (float, 0.1,    "Width shift range, as a float fraction of the width")),
54
    ('height-shift-range', (float, 0.1,    "Height shift range, as a float fraction of the height")),
55
    ('shear-range',        (float, 0.1,    "Shear intensity (in radians)")),
56
    ('zoom-range',         (float, 0.05,   "Amount of zoom. If a scalar z, zoom in [1-z, 1+z]. Can also pass a pair of floats as the zoom range.")),
57
    ('fill-mode',          (str,   'nearest', "Points outside boundaries are filled according to mode: constant, nearest, reflect, or wrap")),
58
    ('alpha',              (float, 500,    "Random elastic distortion: magnitude of distortion")),
59
    ('sigma',              (float, 20,     "Random elastic distortion: length scale")),
60
    ('normalize', {'default': False, 'action': 'store_true',
61
                   'help': "Subtract mean and divide by std dev from each image."}),
62
]
63
64
noninitialized = {
65
    'learning_rate': 'getfloat',
66
    'momentum': 'getfloat',
67
    'decay': 'getfloat',
68
    'seed': 'getint',
69
}
70
71
def update_from_configfile(args, default, config, section, key):
72
    # Point of this function is to update the args Namespace.
73
    value = config.get(section, key)
74
    if value == '' or value is None:
75
        return
76
77
    # Command-line arguments override config file values
78
    if getattr(args, key) != default:
79
        return
80
81
    # Config files always store values as strings -- get correct type
82
    if isinstance(default, bool):
83
        value = config.getboolean(section, key)
84
    elif isinstance(default, int):
85
        value = config.getint(section, key)
86
    elif isinstance(default, float):
87
        value = config.getfloat(section, key)
88
    elif isinstance(default, str):
89
        value = config.get(section, key)
90
    elif isinstance(default, list):
91
        # special case (HACK): loss-weights is list of floats
92
        string = config.get(section, key)
93
        value = [float(x) for x in string.split()]
94
    elif default is None:
95
        # values which aren't initialized
96
        getter = getattr(config, noninitialized[key])
97
        value = getter(section, key)
98
    setattr(args, key, value)
99
100
def parse_arguments():
101
    parser = argparse.ArgumentParser(
102
        description="Train U-Net to segment right ventricles from cardiac "
103
        "MRI images.")
104
105
    for argname, kwargs in definitions:
106
        d = kwargs
107
        if isinstance(kwargs, tuple):
108
            d = dict(zip(['type', 'default', 'help'], kwargs))
109
        parser.add_argument('--' + argname, **d)
110
111
    # allow user to input configuration file
112
    parser.add_argument(
113
        'configfile', nargs='?', type=str, help="Load options from config "
114
        "file (command line arguments take precedence).")
115
116
    args = parser.parse_args()
117
118
    if args.configfile:
119
        logging.info("Loading options from config file: {}".format(args.configfile))
120
        config = configparser.ConfigParser(
121
            inline_comment_prefixes=['#', ';'], allow_no_value=True)
122
        config.read(args.configfile)
123
        for section in config:
124
            for key in config[section]:
125
                if key not in args:
126
                    raise Exception("Unknown option {} in config file.".format(key))
127
                update_from_configfile(args, parser.get_default(key),
128
                                       config, section, key)
129
130
    for k,v in vars(args).items():
131
        logging.info("{:20s} = {}".format(k, v))
132
133
    return args