Diff of /run.py [000000] .. [978658]

Switch to side-by-side view

--- a
+++ b/run.py
@@ -0,0 +1,152 @@
+#!/usr/bin/env python
+
+import argparse
+import json
+import os
+import sys
+from importlib.machinery import SourceFileLoader
+from typing import Tuple
+
+import tensorflow as tf
+
+from utils.Evaluation import evaluate, determine_threshold_on_labeled_patients
+from utils.default_config_setup import get_config, get_options, get_datasets, Dataset
+
+base_path = os.path.dirname(os.path.abspath(__file__))
+
+
+def main(args):
+    # reset default graph
+    tf.reset_default_graph()
+    base_path_trainer = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trainers', f'{args.trainer}.py')
+    base_path_network = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', f'{args.model}.py')
+    trainer = getattr(SourceFileLoader(args.trainer, base_path_trainer).load_module(), args.trainer)
+    network = getattr(SourceFileLoader(args.model, base_path_network).load_module(), args.model)
+
+    with open(os.path.join(base_path, args.config), 'r') as f:
+        json_config = json.load(f)
+
+    dataset = Dataset.BRAINWEB
+    options = get_options(batchsize=args.batchsize, learningrate=args.lr, numEpochs=args.numEpochs, zDim=args.zDim, outputWidth=args.outputWidth,
+                          outputHeight=args.outputHeight, slices_start=args.slices_start, slices_end=args.slices_end,
+                          numMonteCarloSamples=args.numMonteCarloSamples, config=json_config)
+    options['data']['dir'] = options["globals"][dataset.value]
+    dataset_hc, dataset_pc = get_datasets(options, dataset=dataset)
+    config = get_config(
+        trainer=trainer,
+        options=options,
+        optimizer=args.optimizer,
+        intermediateResolutions=args.intermediateResolutions,
+        dropout_rate=0.2,
+        dataset=dataset_hc
+    )
+
+    # handle additional Config parameters e.g. for GMVAE
+    for arg in vars(args):
+        if hasattr(config, arg):
+            setattr(config, arg, getattr(args, arg))
+
+    # Create an instance of the model and train it
+    model = trainer(tf.Session(), config, network=network)
+
+    # Train it
+    model.train(dataset_hc)
+
+    ########################
+    #  Evaluate best dice  #
+    #########################
+    if not args.threshold:
+        # if no threshold is given but a dataset => Best dice evaluation on specific dataset
+        if args.ds:
+            # evaluate specific dataset
+            evaluate_optimal(model, options, args.ds)
+            return
+        else:
+            # evaluate all datasets for best dice without hyper intensity prior
+            options['applyHyperIntensityPrior'] = False
+            evaluate_optimal(model, options, Dataset.Brainweb)
+            evaluate_optimal(model, options, Dataset.MSLUB)
+            evaluate_optimal(model, options, Dataset.MSISBI2015)
+
+            # evaluate all datasets for best dice without hyper intensity prior
+            options['applyHyperIntensityPrior'] = True
+            evaluate_optimal(model, options, Dataset.Brainweb)
+            evaluate_optimal(model, options, Dataset.MSLUB)
+            evaluate_optimal(model, options, Dataset.MSISBI2015)
+
+    ###############################################
+    #  Evaluate generalization to other datasets  #
+    ###############################################
+    if args.threshold and args.ds:  # only threshold is invalid
+        evaluate_with_threshold(model, options, args.threshold, args.ds)
+    else:
+        options['applyHyperIntensityPrior'] = False
+        datasetBrainweb = get_evaluation_dataset(options, Dataset.Brainweb)
+        _bestDiceVAL, _threshVAL = determine_threshold_on_labeled_patients([datasetBrainweb], model, options, description='VAL')
+
+        print(f"Optimal threshold on MS Lesion Validation Set without optimal postprocessing: {_threshVAL} (Dice-Score {_bestDiceVAL})")
+
+        # Re-evaluate with the previously determined threshold
+        evaluate_with_threshold(model, options, _threshVAL, Dataset.Brainweb)
+        evaluate_with_threshold(model, options, _threshVAL, Dataset.MSLUB)
+        evaluate_with_threshold(model, options, _threshVAL, Dataset.MSISBI2015)
+
+
+def evaluate_with_threshold(model, options, threshold, dataset):
+    options['applyHyperIntensityPrior'] = False
+    options['threshold'] = threshold
+    description = lambda ds: f'{type(ds).__name__}-VALthresh_{options["threshold"]}'
+    evaluation_dataset = get_evaluation_dataset(options, dataset)
+    evaluate(evaluation_dataset, model, options, description=description(evaluation_dataset), epoch=str(options['train']['numEpochs']))
+
+
+def evaluate_optimal(model, options, dataset):
+    hyper_intensity_prior_str = ''
+    if options['applyHyperIntensityPrior']:
+        hyper_intensity_prior_str = "_wPrior"
+    evaluation_dataset = get_evaluation_dataset(options, dataset)
+    epochs = str(options['train']['numEpochs'])
+    description = f'{type(evaluation_dataset).__name__}_upperbound_{options["threshold"]}{hyper_intensity_prior_str}'
+    # Evaluate
+    evaluate(evaluation_dataset, model, options, description=description, epoch=epochs)
+
+
+def get_evaluation_dataset(options, dataset=Dataset.BRAINWEB):
+    options['data']['dir'] = options["globals"][dataset.value]
+    return get_datasets(options, dataset=dataset)[1]
+
+
+if __name__ == '__main__':
+    args = argparse.ArgumentParser(description='Framework')
+    args.print_help(sys.stderr)
+    args.add_argument('-c', '--config', default='config.default.json', type=str, help='config-path')
+    args.add_argument('-b', '--batchsize', default=8, type=int, help='batchsize')
+    args.add_argument('-l', '--lr', default=0.0001, type=float, help='learning rate')
+    args.add_argument('-E', '--numEpochs', default=1000, type=int, help='how many epochs to train')
+    args.add_argument('-z', '--zDim', default=128, type=int, help='Latent dimension')
+    args.add_argument('-w', '--outputWidth', default=128, type=int, help='Output width')
+    args.add_argument('-g', '--outputHeight', default=128, type=int, help='Output height')
+    args.add_argument('-o', '--optimizer', default='ADAM', type=str, help='Can be either ADAM, SGD or RMSProp')
+    args.add_argument('-i', '--intermediateResolutions', default=(8, 8), type=Tuple[int], help='Spatial Bottleneck resolution')
+    args.add_argument('-s', '--slices_start', default=20, type=int, help='slices start')
+    args.add_argument('-e', '--slices_end', default=130, type=int, help='slices end')
+    args.add_argument('-t', '--trainer', default='AE', type=str, help='Can be every class from trainers directory')
+    args.add_argument('-m', '--model', default='autoencoder', type=str, help='Can be every class from models directory')
+    args.add_argument('-O', '--threshold', default=None, type=float, help='Use predefined ThreshOld')
+    args.add_argument('-d', '--ds', default=None, type=Dataset, help='Only evaluate on given dataset')
+
+    # following arguments are only relevant for specific architectures
+    args.add_argument('-n', '--numMonteCarloSamples', default=0, type=int, help='Amount of Monte Carlos Samples during restoration')
+    args.add_argument('-G', '--use_gradient_based_restoration', default=False, type=bool, help='only for ceVAE')
+    args.add_argument('-K', '--kappa', default=1.0, type=float, help='only for GANs')
+    args.add_argument('-M', '--scale', default=10.0, type=float, help='only for GANs')
+    args.add_argument('-R', '--rho', default=1.0, type=float, help='only for ConstrainedAAE')
+    args.add_argument('-C', '--dim_c', default=9, type=int, help='only for GMVAE')
+    args.add_argument('-Z', '--dim_z', default=128, type=int, help='only for GMVAE')
+    args.add_argument('-W', '--dim_w', default=1, type=int, help='only for GMVAE')
+    args.add_argument('-A', '--c_lambda', default=1, type=int, help='only for GMVAE')
+    args.add_argument('-L', '--restore_lr', default=1e-3, type=float, help='only for GMVAE')
+    args.add_argument('-S', '--restore_steps', default=150, type=int, help='only for GMVAE')
+    args.add_argument('-T', '--tv_lambda', default=-1.0, type=float, help='only for GMVAE')
+
+    main(args.parse_args())