Diff of /src/config.py [000000] .. [f45789]

Switch to side-by-side view

--- a
+++ b/src/config.py
@@ -0,0 +1,191 @@
+import sys
+sys.path.append('.')
+import yaml
+import os
+from shutil import rmtree
+
+import torch
+import torchvision.transforms as transforms
+
+from src.data import get_datasets, get_dataloaders
+from src.model import initialize_model
+from src.optimizer import get_optimizer
+from src.criterion import get_criterion
+from src.scheduler import get_scheduler
+import yaml
+
+
+
+def get_transforms(conf):
+    model_name = conf['model']['name']
+    if 'efficientdet' in model_name:
+        train_transform = transforms.Compose([transforms.Resize((512,512)),
+                                    transforms.RandomVerticalFlip(p=0.5),
+                                    transforms.RandomHorizontalFlip(p=0.5),
+                                    transforms.ToTensor(),
+                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+                                    ])
+        valid_transform = transforms.Compose([transforms.Resize((512,512)),
+                                    transforms.ToTensor(),
+                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+                                    ])
+        test_transform = transforms.Compose([transforms.Resize((512,512)),
+                                    transforms.ToTensor(),
+                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+                                    ])
+
+    elif 'resnet' in model_name:
+        train_transform = transforms.Compose([transforms.Resize((224,224)),
+                                    transforms.RandomVerticalFlip(p=0.5),
+                                    transforms.RandomHorizontalFlip(p=0.5),
+                                    transforms.ToTensor(),
+                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+                                    ])
+        valid_transform = transforms.Compose([transforms.Resize((224,224)),
+                                    transforms.ToTensor(),
+                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+                                    ])
+        test_transform = transforms.Compose([transforms.Resize((224,224)),
+                                    transforms.ToTensor(),
+                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+                                    ])
+
+    return train_transform, valid_transform, test_transform
+
+
+def get_experiment_dir(conf):
+    experiments_dir = conf['experiments_dir']
+    if not os.path.exists(experiments_dir):
+        os.makedirs(experiments_dir)
+
+    experiment_dir = os.path.join(conf['experiments_dir'],
+                                  conf['experiment_code'])
+    if conf['task'] == 'training':
+        if not os.path.exists(experiment_dir):
+            os.makedirs(experiment_dir)
+            save_conf = os.path.join(experiment_dir, conf['task'] + '.yaml')
+            with open(save_conf, 'w') as fp:
+                yaml.dump(conf, fp)
+        else:
+            print(f'Experiment dir {experiment_dir} exists.')
+            exit()
+    return experiment_dir
+
+def get_test_config(conf):
+
+    # Check if results dir exists and create it
+    results_dir = conf['results_dir']
+    if os.path.exists(results_dir): rmtree(results_dir)
+    os.makedirs(results_dir)
+
+    # Get transforms
+    transforms = get_transforms(conf)
+    conf['train_transform'] = transforms[0]
+    conf['valid_transform'] = transforms[1]
+    conf['test_transform'] = transforms[2]
+
+    # Get datasets
+    dataset_name = conf['data']['name']
+    print(f'Dataset: {dataset_name}')
+    _, _, conf['test_dataset'] = get_datasets(conf)
+
+    # Check if GPU is available
+    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+    conf['device'] = device
+    print(f'Device: {device}')
+
+    # Initialize model
+    model_name = conf['model']['name']
+    print(f'Model: {model_name}')
+    conf['model'] = initialize_model(conf)
+
+    return conf
+
+
+def get_config(yaml_path):
+    # Load YAML conf
+    conf = yaml.safe_load(open(yaml_path, 'r'))
+
+    # Task
+    task = conf['task']
+    if task in ['training', 'evaluation']:
+        print(f'Task: {task}')
+
+    elif task == 'testing':
+        return get_test_config(conf)
+
+    else:
+        print(f'Task {task} not supported.')
+        exit()
+
+    # Get experiment directory
+    experiment_dir = get_experiment_dir(conf)
+    conf['experiment_dir'] = experiment_dir
+    print(f'Experiment directory: {experiment_dir}')
+
+    # Get transforms
+    transforms = get_transforms(conf)
+    conf['train_transform'] = transforms[0]
+    conf['valid_transform'] = transforms[1]
+    conf['test_transform'] = transforms[2]
+
+    # Get datasets
+    dataset_name = conf['data']['name']
+    print(f'Dataset: {dataset_name}')
+    datasets = get_datasets(conf)
+    conf['train_dataset'] = datasets[0]
+    conf['valid_dataset'] = datasets[1]
+    conf['test_dataset'] = datasets[2]
+    conf['patients_dataset'] = datasets[3]
+
+    # Get dataloaders
+    dataloaders = get_dataloaders(conf)
+    train_dataloader = dataloaders[0]
+    valid_dataloader = dataloaders[1]
+    test_dataloader = dataloaders[2]
+    conf['dataloaders'] = {'train':train_dataloader,
+                           'valid':valid_dataloader,
+                           'test': test_dataloader}
+
+    # Check if GPU is available
+    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+    conf['device'] = device
+    print(f'Device: {device}')
+
+    # Initialize model
+    model_name = conf['model']['name']
+    print(f'Model: {model_name}')
+    conf['model'] = initialize_model(conf)
+
+    # Only in traning task
+    if task == 'training':
+        # Get optimizer
+        optimizer_name = conf['optimizer']['name']
+        print(f'Optimizer: {optimizer_name}')
+        conf['optimizer'] = get_optimizer(conf)
+
+        # Get criterion
+        criterion_name = conf['criterion']['name']
+        print(f'Criterion: {criterion_name}')
+        conf['criterion'] = get_criterion(conf)
+
+        # Get scheduler
+        scheduler_name = conf['scheduler']['name']
+        print(f'Scheduler: {scheduler_name}')
+        conf['scheduler'] = get_scheduler(conf)
+
+    # Only in evaluation task
+    elif task == 'evaluation':
+        path = os.path.join(conf['experiment_dir'], 'best_weights.pt')
+        if os.path.exists(path):
+            print(f'Loading weights from {path}')
+            conf['best_weights'] = torch.load(path)
+        else:
+            print(f'Experiment weights {path} not found.')
+            exit()
+
+    return conf
+
+
+
+