--- a +++ b/run.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python + +import argparse +import json +import os +import sys +from importlib.machinery import SourceFileLoader +from typing import Tuple + +import tensorflow as tf + +from utils.Evaluation import evaluate, determine_threshold_on_labeled_patients +from utils.default_config_setup import get_config, get_options, get_datasets, Dataset + +base_path = os.path.dirname(os.path.abspath(__file__)) + + +def main(args): + # reset default graph + tf.reset_default_graph() + base_path_trainer = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trainers', f'{args.trainer}.py') + base_path_network = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models', f'{args.model}.py') + trainer = getattr(SourceFileLoader(args.trainer, base_path_trainer).load_module(), args.trainer) + network = getattr(SourceFileLoader(args.model, base_path_network).load_module(), args.model) + + with open(os.path.join(base_path, args.config), 'r') as f: + json_config = json.load(f) + + dataset = Dataset.BRAINWEB + options = get_options(batchsize=args.batchsize, learningrate=args.lr, numEpochs=args.numEpochs, zDim=args.zDim, outputWidth=args.outputWidth, + outputHeight=args.outputHeight, slices_start=args.slices_start, slices_end=args.slices_end, + numMonteCarloSamples=args.numMonteCarloSamples, config=json_config) + options['data']['dir'] = options["globals"][dataset.value] + dataset_hc, dataset_pc = get_datasets(options, dataset=dataset) + config = get_config( + trainer=trainer, + options=options, + optimizer=args.optimizer, + intermediateResolutions=args.intermediateResolutions, + dropout_rate=0.2, + dataset=dataset_hc + ) + + # handle additional Config parameters e.g. for GMVAE + for arg in vars(args): + if hasattr(config, arg): + setattr(config, arg, getattr(args, arg)) + + # Create an instance of the model and train it + model = trainer(tf.Session(), config, network=network) + + # Train it + model.train(dataset_hc) + + ######################## + # Evaluate best dice # + ######################### + if not args.threshold: + # if no threshold is given but a dataset => Best dice evaluation on specific dataset + if args.ds: + # evaluate specific dataset + evaluate_optimal(model, options, args.ds) + return + else: + # evaluate all datasets for best dice without hyper intensity prior + options['applyHyperIntensityPrior'] = False + evaluate_optimal(model, options, Dataset.Brainweb) + evaluate_optimal(model, options, Dataset.MSLUB) + evaluate_optimal(model, options, Dataset.MSISBI2015) + + # evaluate all datasets for best dice without hyper intensity prior + options['applyHyperIntensityPrior'] = True + evaluate_optimal(model, options, Dataset.Brainweb) + evaluate_optimal(model, options, Dataset.MSLUB) + evaluate_optimal(model, options, Dataset.MSISBI2015) + + ############################################### + # Evaluate generalization to other datasets # + ############################################### + if args.threshold and args.ds: # only threshold is invalid + evaluate_with_threshold(model, options, args.threshold, args.ds) + else: + options['applyHyperIntensityPrior'] = False + datasetBrainweb = get_evaluation_dataset(options, Dataset.Brainweb) + _bestDiceVAL, _threshVAL = determine_threshold_on_labeled_patients([datasetBrainweb], model, options, description='VAL') + + print(f"Optimal threshold on MS Lesion Validation Set without optimal postprocessing: {_threshVAL} (Dice-Score {_bestDiceVAL})") + + # Re-evaluate with the previously determined threshold + evaluate_with_threshold(model, options, _threshVAL, Dataset.Brainweb) + evaluate_with_threshold(model, options, _threshVAL, Dataset.MSLUB) + evaluate_with_threshold(model, options, _threshVAL, Dataset.MSISBI2015) + + +def evaluate_with_threshold(model, options, threshold, dataset): + options['applyHyperIntensityPrior'] = False + options['threshold'] = threshold + description = lambda ds: f'{type(ds).__name__}-VALthresh_{options["threshold"]}' + evaluation_dataset = get_evaluation_dataset(options, dataset) + evaluate(evaluation_dataset, model, options, description=description(evaluation_dataset), epoch=str(options['train']['numEpochs'])) + + +def evaluate_optimal(model, options, dataset): + hyper_intensity_prior_str = '' + if options['applyHyperIntensityPrior']: + hyper_intensity_prior_str = "_wPrior" + evaluation_dataset = get_evaluation_dataset(options, dataset) + epochs = str(options['train']['numEpochs']) + description = f'{type(evaluation_dataset).__name__}_upperbound_{options["threshold"]}{hyper_intensity_prior_str}' + # Evaluate + evaluate(evaluation_dataset, model, options, description=description, epoch=epochs) + + +def get_evaluation_dataset(options, dataset=Dataset.BRAINWEB): + options['data']['dir'] = options["globals"][dataset.value] + return get_datasets(options, dataset=dataset)[1] + + +if __name__ == '__main__': + args = argparse.ArgumentParser(description='Framework') + args.print_help(sys.stderr) + args.add_argument('-c', '--config', default='config.default.json', type=str, help='config-path') + args.add_argument('-b', '--batchsize', default=8, type=int, help='batchsize') + args.add_argument('-l', '--lr', default=0.0001, type=float, help='learning rate') + args.add_argument('-E', '--numEpochs', default=1000, type=int, help='how many epochs to train') + args.add_argument('-z', '--zDim', default=128, type=int, help='Latent dimension') + args.add_argument('-w', '--outputWidth', default=128, type=int, help='Output width') + args.add_argument('-g', '--outputHeight', default=128, type=int, help='Output height') + args.add_argument('-o', '--optimizer', default='ADAM', type=str, help='Can be either ADAM, SGD or RMSProp') + args.add_argument('-i', '--intermediateResolutions', default=(8, 8), type=Tuple[int], help='Spatial Bottleneck resolution') + args.add_argument('-s', '--slices_start', default=20, type=int, help='slices start') + args.add_argument('-e', '--slices_end', default=130, type=int, help='slices end') + args.add_argument('-t', '--trainer', default='AE', type=str, help='Can be every class from trainers directory') + args.add_argument('-m', '--model', default='autoencoder', type=str, help='Can be every class from models directory') + args.add_argument('-O', '--threshold', default=None, type=float, help='Use predefined ThreshOld') + args.add_argument('-d', '--ds', default=None, type=Dataset, help='Only evaluate on given dataset') + + # following arguments are only relevant for specific architectures + args.add_argument('-n', '--numMonteCarloSamples', default=0, type=int, help='Amount of Monte Carlos Samples during restoration') + args.add_argument('-G', '--use_gradient_based_restoration', default=False, type=bool, help='only for ceVAE') + args.add_argument('-K', '--kappa', default=1.0, type=float, help='only for GANs') + args.add_argument('-M', '--scale', default=10.0, type=float, help='only for GANs') + args.add_argument('-R', '--rho', default=1.0, type=float, help='only for ConstrainedAAE') + args.add_argument('-C', '--dim_c', default=9, type=int, help='only for GMVAE') + args.add_argument('-Z', '--dim_z', default=128, type=int, help='only for GMVAE') + args.add_argument('-W', '--dim_w', default=1, type=int, help='only for GMVAE') + args.add_argument('-A', '--c_lambda', default=1, type=int, help='only for GMVAE') + args.add_argument('-L', '--restore_lr', default=1e-3, type=float, help='only for GMVAE') + args.add_argument('-S', '--restore_steps', default=150, type=int, help='only for GMVAE') + args.add_argument('-T', '--tv_lambda', default=-1.0, type=float, help='only for GMVAE') + + main(args.parse_args())