Diff of /run.py [000000] .. [6f9c00]

Switch to side-by-side view

--- a
+++ b/run.py
@@ -0,0 +1,218 @@
+import argparse
+import os
+import torch
+import utils.evaluator as eu
+from quicknat import QuickNat
+from settings import Settings
+from solver import Solver
+from utils.data_utils import get_imdb_dataset
+from utils.log_utils import LogWriter
+import logging
+import shutil
+
+torch.set_default_tensor_type('torch.FloatTensor')
+
+
+def load_data(data_params):
+    print("Loading dataset")
+    train_data, test_data = get_imdb_dataset(data_params)
+    print("Train size: %i" % len(train_data))
+    print("Test size: %i" % len(test_data))
+    return train_data, test_data
+
+
+def train(train_params, common_params, data_params, net_params):
+
+    train_data, test_data = load_data(data_params)
+
+    train_loader = torch.utils.data.DataLoader(train_data, batch_size=train_params['train_batch_size'], shuffle=True,
+                                               num_workers=4, pin_memory=True)
+    val_loader = torch.utils.data.DataLoader(test_data, batch_size=train_params['val_batch_size'], shuffle=False,
+                                             num_workers=4, pin_memory=True)
+
+    net_params_ = net_params.copy()
+    empty_model = QuickNat(net_params_)
+    if train_params['use_pre_trained']:
+        quicknat_model = torch.load(train_params['pre_trained_path'])
+    else:
+        quicknat_model = QuickNat(net_params)
+
+    solver = Solver(quicknat_model,
+                    device=common_params['device'],
+                    num_class=net_params['num_class'],
+                    optim_args={"lr": train_params['learning_rate'],
+                                "betas": train_params['optim_betas'],
+                                "eps": train_params['optim_eps'],
+                                "weight_decay": train_params['optim_weight_decay']},
+                    model_name=common_params['model_name'],
+                    exp_name=train_params['exp_name'],
+                    labels=data_params['labels'],
+                    log_nth=train_params['log_nth'],
+                    num_epochs=train_params['num_epochs'],
+                    lr_scheduler_step_size=train_params['lr_scheduler_step_size'],
+                    lr_scheduler_gamma=train_params['lr_scheduler_gamma'],
+                    use_last_checkpoint=train_params['use_last_checkpoint'],
+                    log_dir=common_params['log_dir'],
+                    exp_dir=common_params['exp_dir'])
+
+    solver.train(train_loader, val_loader)
+    final_model_path = os.path.join(common_params['save_model_dir'], train_params['final_model_file'])
+    # quicknat_model.save(final_model_path)
+    solver.model = empty_model
+    solver.save_best_model(final_model_path)
+    print("final model saved @ " + str(final_model_path))
+
+
+def evaluate(eval_params, net_params, data_params, common_params, train_params):
+    eval_model_path = eval_params['eval_model_path']
+    num_classes = net_params['num_class']
+    labels = data_params['labels']
+    data_dir = eval_params['data_dir']
+    label_dir = eval_params['label_dir']
+    volumes_txt_file = eval_params['volumes_txt_file']
+    remap_config = eval_params['remap_config']
+    device = common_params['device']
+    log_dir = common_params['log_dir']
+    exp_dir = common_params['exp_dir']
+    exp_name = train_params['exp_name']
+    save_predictions_dir = eval_params['save_predictions_dir']
+    prediction_path = os.path.join(exp_dir, exp_name, save_predictions_dir)
+    orientation = eval_params['orientation']
+    data_id = eval_params['data_id']
+
+    logWriter = LogWriter(num_classes, log_dir, exp_name, labels=labels)
+
+    avg_dice_score, class_dist = eu.evaluate_dice_score(eval_model_path,
+                                                        num_classes,
+                                                        data_dir,
+                                                        label_dir,
+                                                        volumes_txt_file,
+                                                        remap_config,
+                                                        orientation,
+                                                        prediction_path,
+                                                        data_id,
+                                                        device,
+                                                        logWriter)
+    logWriter.close()
+
+
+def evaluate_bulk(eval_bulk):
+    data_dir = eval_bulk['data_dir']
+    prediction_path = eval_bulk['save_predictions_dir']
+    volumes_txt_file = eval_bulk['volumes_txt_file']
+    device = eval_bulk['device']
+    label_names = ["vol_ID", "Background", "Left WM", "Left Cortex", "Left Lateral ventricle", "Left Inf LatVentricle",
+                   "Left Cerebellum WM", "Left Cerebellum Cortex", "Left Thalamus", "Left Caudate", "Left Putamen",
+                   "Left Pallidum", "3rd Ventricle", "4th Ventricle", "Brain Stem", "Left Hippocampus", "Left Amygdala",
+                   "CSF (Cranial)", "Left Accumbens", "Left Ventral DC", "Right WM", "Right Cortex",
+                   "Right Lateral Ventricle", "Right Inf LatVentricle", "Right Cerebellum WM",
+                   "Right Cerebellum Cortex", "Right Thalamus", "Right Caudate", "Right Putamen", "Right Pallidum",
+                   "Right Hippocampus", "Right Amygdala", "Right Accumbens", "Right Ventral DC"]
+    batch_size = eval_bulk['batch_size']
+    need_unc = eval_bulk['estimate_uncertainty']
+    mc_samples = eval_bulk['mc_samples']
+    dir_struct = eval_bulk['directory_struct']
+    if 'exit_on_error' in eval_bulk.keys():
+        exit_on_error = eval_bulk['exit_on_error']
+    else:
+        exit_on_error = False
+
+    if eval_bulk['view_agg'] == 'True':
+        coronal_model_path = eval_bulk['coronal_model_path']
+        axial_model_path = eval_bulk['axial_model_path']
+        eu.evaluate2view(coronal_model_path,
+                         axial_model_path,
+                         volumes_txt_file,
+                         data_dir, device,
+                         prediction_path,
+                         batch_size,
+                         label_names,
+                         dir_struct,
+                         need_unc,
+                         mc_samples,
+                         exit_on_error=exit_on_error)
+    else:
+        coronal_model_path = eval_bulk['coronal_model_path']
+        eu.evaluate(coronal_model_path,
+                    volumes_txt_file,
+                    data_dir,
+                    device,
+                    prediction_path,
+                    batch_size,
+                    "COR",
+                    label_names,
+                    dir_struct,
+                    need_unc,
+                    mc_samples,
+                    exit_on_error=exit_on_error)
+
+def compute_vol(eval_bulk):
+    prediction_path = eval_bulk['save_predictions_dir']
+    label_names = ["vol_ID", "Background", "Left WM", "Left Cortex", "Left Lateral ventricle", "Left Inf LatVentricle",
+                   "Left Cerebellum WM", "Left Cerebellum Cortex", "Left Thalamus", "Left Caudate", "Left Putamen",
+                   "Left Pallidum", "3rd Ventricle", "4th Ventricle", "Brain Stem", "Left Hippocampus", "Left Amygdala",
+                   "CSF (Cranial)", "Left Accumbens", "Left Ventral DC", "Right WM", "Right Cortex",
+                   "Right Lateral Ventricle", "Right Inf LatVentricle", "Right Cerebellum WM",
+                   "Right Cerebellum Cortex", "Right Thalamus", "Right Caudate", "Right Putamen", "Right Pallidum",
+                   "Right Hippocampus", "Right Amygdala", "Right Accumbens", "Right Ventral DC"]
+    volumes_txt_file = eval_bulk['volumes_txt_file']
+
+    eu.compute_vol_bulk(prediction_path, "Linear", label_names, volumes_txt_file)
+
+
+
+def delete_contents(folder):
+    for the_file in os.listdir(folder):
+        file_path = os.path.join(folder, the_file)
+        try:
+            if os.path.isfile(file_path):
+                os.unlink(file_path)
+            elif os.path.isdir(file_path):
+                shutil.rmtree(file_path)
+        except Exception as e:
+            print(e)
+
+
+if __name__ == '__main__':
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--mode', '-m', required=True, help='run mode, valid values are train and eval')
+    parser.add_argument('--setting_path', '-sp', required=False, help='optional path to settings_eval.ini')
+    args = parser.parse_args()
+
+    settings = Settings('settings.ini')
+    common_params, data_params, net_params, train_params, eval_params = settings['COMMON'], settings['DATA'], \
+                                                                        settings[
+                                                                            'NETWORK'], settings['TRAINING'], \
+                                                                        settings['EVAL']
+    if args.mode == 'train':
+        train(train_params, common_params, data_params, net_params)
+    elif args.mode == 'eval':
+        evaluate(eval_params, net_params, data_params, common_params, train_params)
+    elif args.mode == 'eval_bulk':
+        logging.basicConfig(filename='error.log')
+        if args.setting_path is not None:
+            settings_eval = Settings(args.setting_path)
+        else:
+            settings_eval = Settings('settings_eval.ini')
+        evaluate_bulk(settings_eval['EVAL_BULK'])
+    elif args.mode == 'clear':
+        shutil.rmtree(os.path.join(common_params['exp_dir'], train_params['exp_name']))
+        print("Cleared current experiment directory successfully!!")
+        shutil.rmtree(os.path.join(common_params['log_dir'], train_params['exp_name']))
+        print("Cleared current log directory successfully!!")
+
+    elif args.mode == 'clear-all':
+        delete_contents(common_params['exp_dir'])
+        print("Cleared experiments directory successfully!!")
+        delete_contents(common_params['log_dir'])
+        print("Cleared logs directory successfully!!")
+
+    elif args.mode == 'compute_vol':
+        if args.setting_path is not None:
+            settings_eval = Settings(args.setting_path)
+        else:
+            settings_eval = Settings('settings_eval.ini')
+        compute_vol(settings_eval['EVAL_BULK'])
+    else:
+        raise ValueError('Invalid value for mode. only support values are train, eval and clear')