Diff of /run.py [000000] .. [978658]

Switch to unified view

a b/run.py
1
#!/usr/bin/env python
2
3
import argparse
4
import json
5
import os
6
import sys
7
from importlib.machinery import SourceFileLoader
8
from typing import Tuple
9
10
import tensorflow as tf
11
12
from utils.Evaluation import evaluate, determine_threshold_on_labeled_patients
13
from utils.default_config_setup import get_config, get_options, get_datasets, Dataset
14
15
base_path = os.path.dirname(os.path.abspath(__file__))
16
17
18
def main(args):
19
    # reset default graph
20
    tf.reset_default_graph()
21
    base_path_trainer = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trainers', f'{args.trainer}.py')
22
    base_path_network = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', f'{args.model}.py')
23
    trainer = getattr(SourceFileLoader(args.trainer, base_path_trainer).load_module(), args.trainer)
24
    network = getattr(SourceFileLoader(args.model, base_path_network).load_module(), args.model)
25
26
    with open(os.path.join(base_path, args.config), 'r') as f:
27
        json_config = json.load(f)
28
29
    dataset = Dataset.BRAINWEB
30
    options = get_options(batchsize=args.batchsize, learningrate=args.lr, numEpochs=args.numEpochs, zDim=args.zDim, outputWidth=args.outputWidth,
31
                          outputHeight=args.outputHeight, slices_start=args.slices_start, slices_end=args.slices_end,
32
                          numMonteCarloSamples=args.numMonteCarloSamples, config=json_config)
33
    options['data']['dir'] = options["globals"][dataset.value]
34
    dataset_hc, dataset_pc = get_datasets(options, dataset=dataset)
35
    config = get_config(
36
        trainer=trainer,
37
        options=options,
38
        optimizer=args.optimizer,
39
        intermediateResolutions=args.intermediateResolutions,
40
        dropout_rate=0.2,
41
        dataset=dataset_hc
42
    )
43
44
    # handle additional Config parameters e.g. for GMVAE
45
    for arg in vars(args):
46
        if hasattr(config, arg):
47
            setattr(config, arg, getattr(args, arg))
48
49
    # Create an instance of the model and train it
50
    model = trainer(tf.Session(), config, network=network)
51
52
    # Train it
53
    model.train(dataset_hc)
54
55
    ########################
56
    #  Evaluate best dice  #
57
    #########################
58
    if not args.threshold:
59
        # if no threshold is given but a dataset => Best dice evaluation on specific dataset
60
        if args.ds:
61
            # evaluate specific dataset
62
            evaluate_optimal(model, options, args.ds)
63
            return
64
        else:
65
            # evaluate all datasets for best dice without hyper intensity prior
66
            options['applyHyperIntensityPrior'] = False
67
            evaluate_optimal(model, options, Dataset.Brainweb)
68
            evaluate_optimal(model, options, Dataset.MSLUB)
69
            evaluate_optimal(model, options, Dataset.MSISBI2015)
70
71
            # evaluate all datasets for best dice without hyper intensity prior
72
            options['applyHyperIntensityPrior'] = True
73
            evaluate_optimal(model, options, Dataset.Brainweb)
74
            evaluate_optimal(model, options, Dataset.MSLUB)
75
            evaluate_optimal(model, options, Dataset.MSISBI2015)
76
77
    ###############################################
78
    #  Evaluate generalization to other datasets  #
79
    ###############################################
80
    if args.threshold and args.ds:  # only threshold is invalid
81
        evaluate_with_threshold(model, options, args.threshold, args.ds)
82
    else:
83
        options['applyHyperIntensityPrior'] = False
84
        datasetBrainweb = get_evaluation_dataset(options, Dataset.Brainweb)
85
        _bestDiceVAL, _threshVAL = determine_threshold_on_labeled_patients([datasetBrainweb], model, options, description='VAL')
86
87
        print(f"Optimal threshold on MS Lesion Validation Set without optimal postprocessing: {_threshVAL} (Dice-Score {_bestDiceVAL})")
88
89
        # Re-evaluate with the previously determined threshold
90
        evaluate_with_threshold(model, options, _threshVAL, Dataset.Brainweb)
91
        evaluate_with_threshold(model, options, _threshVAL, Dataset.MSLUB)
92
        evaluate_with_threshold(model, options, _threshVAL, Dataset.MSISBI2015)
93
94
95
def evaluate_with_threshold(model, options, threshold, dataset):
96
    options['applyHyperIntensityPrior'] = False
97
    options['threshold'] = threshold
98
    description = lambda ds: f'{type(ds).__name__}-VALthresh_{options["threshold"]}'
99
    evaluation_dataset = get_evaluation_dataset(options, dataset)
100
    evaluate(evaluation_dataset, model, options, description=description(evaluation_dataset), epoch=str(options['train']['numEpochs']))
101
102
103
def evaluate_optimal(model, options, dataset):
104
    hyper_intensity_prior_str = ''
105
    if options['applyHyperIntensityPrior']:
106
        hyper_intensity_prior_str = "_wPrior"
107
    evaluation_dataset = get_evaluation_dataset(options, dataset)
108
    epochs = str(options['train']['numEpochs'])
109
    description = f'{type(evaluation_dataset).__name__}_upperbound_{options["threshold"]}{hyper_intensity_prior_str}'
110
    # Evaluate
111
    evaluate(evaluation_dataset, model, options, description=description, epoch=epochs)
112
113
114
def get_evaluation_dataset(options, dataset=Dataset.BRAINWEB):
115
    options['data']['dir'] = options["globals"][dataset.value]
116
    return get_datasets(options, dataset=dataset)[1]
117
118
119
if __name__ == '__main__':
120
    args = argparse.ArgumentParser(description='Framework')
121
    args.print_help(sys.stderr)
122
    args.add_argument('-c', '--config', default='config.default.json', type=str, help='config-path')
123
    args.add_argument('-b', '--batchsize', default=8, type=int, help='batchsize')
124
    args.add_argument('-l', '--lr', default=0.0001, type=float, help='learning rate')
125
    args.add_argument('-E', '--numEpochs', default=1000, type=int, help='how many epochs to train')
126
    args.add_argument('-z', '--zDim', default=128, type=int, help='Latent dimension')
127
    args.add_argument('-w', '--outputWidth', default=128, type=int, help='Output width')
128
    args.add_argument('-g', '--outputHeight', default=128, type=int, help='Output height')
129
    args.add_argument('-o', '--optimizer', default='ADAM', type=str, help='Can be either ADAM, SGD or RMSProp')
130
    args.add_argument('-i', '--intermediateResolutions', default=(8, 8), type=Tuple[int], help='Spatial Bottleneck resolution')
131
    args.add_argument('-s', '--slices_start', default=20, type=int, help='slices start')
132
    args.add_argument('-e', '--slices_end', default=130, type=int, help='slices end')
133
    args.add_argument('-t', '--trainer', default='AE', type=str, help='Can be every class from trainers directory')
134
    args.add_argument('-m', '--model', default='autoencoder', type=str, help='Can be every class from models directory')
135
    args.add_argument('-O', '--threshold', default=None, type=float, help='Use predefined ThreshOld')
136
    args.add_argument('-d', '--ds', default=None, type=Dataset, help='Only evaluate on given dataset')
137
138
    # following arguments are only relevant for specific architectures
139
    args.add_argument('-n', '--numMonteCarloSamples', default=0, type=int, help='Amount of Monte Carlos Samples during restoration')
140
    args.add_argument('-G', '--use_gradient_based_restoration', default=False, type=bool, help='only for ceVAE')
141
    args.add_argument('-K', '--kappa', default=1.0, type=float, help='only for GANs')
142
    args.add_argument('-M', '--scale', default=10.0, type=float, help='only for GANs')
143
    args.add_argument('-R', '--rho', default=1.0, type=float, help='only for ConstrainedAAE')
144
    args.add_argument('-C', '--dim_c', default=9, type=int, help='only for GMVAE')
145
    args.add_argument('-Z', '--dim_z', default=128, type=int, help='only for GMVAE')
146
    args.add_argument('-W', '--dim_w', default=1, type=int, help='only for GMVAE')
147
    args.add_argument('-A', '--c_lambda', default=1, type=int, help='only for GMVAE')
148
    args.add_argument('-L', '--restore_lr', default=1e-3, type=float, help='only for GMVAE')
149
    args.add_argument('-S', '--restore_steps', default=150, type=int, help='only for GMVAE')
150
    args.add_argument('-T', '--tv_lambda', default=-1.0, type=float, help='only for GMVAE')
151
152
    main(args.parse_args())