Diff of /src/config.py [000000] .. [f45789]

Switch to unified view

a b/src/config.py
1
import sys
2
sys.path.append('.')
3
import yaml
4
import os
5
from shutil import rmtree
6
7
import torch
8
import torchvision.transforms as transforms
9
10
from src.data import get_datasets, get_dataloaders
11
from src.model import initialize_model
12
from src.optimizer import get_optimizer
13
from src.criterion import get_criterion
14
from src.scheduler import get_scheduler
15
import yaml
16
17
18
19
def get_transforms(conf):
20
    model_name = conf['model']['name']
21
    if 'efficientdet' in model_name:
22
        train_transform = transforms.Compose([transforms.Resize((512,512)),
23
                                    transforms.RandomVerticalFlip(p=0.5),
24
                                    transforms.RandomHorizontalFlip(p=0.5),
25
                                    transforms.ToTensor(),
26
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
27
                                    ])
28
        valid_transform = transforms.Compose([transforms.Resize((512,512)),
29
                                    transforms.ToTensor(),
30
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
31
                                    ])
32
        test_transform = transforms.Compose([transforms.Resize((512,512)),
33
                                    transforms.ToTensor(),
34
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
35
                                    ])
36
37
    elif 'resnet' in model_name:
38
        train_transform = transforms.Compose([transforms.Resize((224,224)),
39
                                    transforms.RandomVerticalFlip(p=0.5),
40
                                    transforms.RandomHorizontalFlip(p=0.5),
41
                                    transforms.ToTensor(),
42
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
43
                                    ])
44
        valid_transform = transforms.Compose([transforms.Resize((224,224)),
45
                                    transforms.ToTensor(),
46
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
47
                                    ])
48
        test_transform = transforms.Compose([transforms.Resize((224,224)),
49
                                    transforms.ToTensor(),
50
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
51
                                    ])
52
53
    return train_transform, valid_transform, test_transform
54
55
56
def get_experiment_dir(conf):
57
    experiments_dir = conf['experiments_dir']
58
    if not os.path.exists(experiments_dir):
59
        os.makedirs(experiments_dir)
60
61
    experiment_dir = os.path.join(conf['experiments_dir'],
62
                                  conf['experiment_code'])
63
    if conf['task'] == 'training':
64
        if not os.path.exists(experiment_dir):
65
            os.makedirs(experiment_dir)
66
            save_conf = os.path.join(experiment_dir, conf['task'] + '.yaml')
67
            with open(save_conf, 'w') as fp:
68
                yaml.dump(conf, fp)
69
        else:
70
            print(f'Experiment dir {experiment_dir} exists.')
71
            exit()
72
    return experiment_dir
73
74
def get_test_config(conf):
75
76
    # Check if results dir exists and create it
77
    results_dir = conf['results_dir']
78
    if os.path.exists(results_dir): rmtree(results_dir)
79
    os.makedirs(results_dir)
80
81
    # Get transforms
82
    transforms = get_transforms(conf)
83
    conf['train_transform'] = transforms[0]
84
    conf['valid_transform'] = transforms[1]
85
    conf['test_transform'] = transforms[2]
86
87
    # Get datasets
88
    dataset_name = conf['data']['name']
89
    print(f'Dataset: {dataset_name}')
90
    _, _, conf['test_dataset'] = get_datasets(conf)
91
92
    # Check if GPU is available
93
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
94
    conf['device'] = device
95
    print(f'Device: {device}')
96
97
    # Initialize model
98
    model_name = conf['model']['name']
99
    print(f'Model: {model_name}')
100
    conf['model'] = initialize_model(conf)
101
102
    return conf
103
104
105
def get_config(yaml_path):
106
    # Load YAML conf
107
    conf = yaml.safe_load(open(yaml_path, 'r'))
108
109
    # Task
110
    task = conf['task']
111
    if task in ['training', 'evaluation']:
112
        print(f'Task: {task}')
113
114
    elif task == 'testing':
115
        return get_test_config(conf)
116
117
    else:
118
        print(f'Task {task} not supported.')
119
        exit()
120
121
    # Get experiment directory
122
    experiment_dir = get_experiment_dir(conf)
123
    conf['experiment_dir'] = experiment_dir
124
    print(f'Experiment directory: {experiment_dir}')
125
126
    # Get transforms
127
    transforms = get_transforms(conf)
128
    conf['train_transform'] = transforms[0]
129
    conf['valid_transform'] = transforms[1]
130
    conf['test_transform'] = transforms[2]
131
132
    # Get datasets
133
    dataset_name = conf['data']['name']
134
    print(f'Dataset: {dataset_name}')
135
    datasets = get_datasets(conf)
136
    conf['train_dataset'] = datasets[0]
137
    conf['valid_dataset'] = datasets[1]
138
    conf['test_dataset'] = datasets[2]
139
    conf['patients_dataset'] = datasets[3]
140
141
    # Get dataloaders
142
    dataloaders = get_dataloaders(conf)
143
    train_dataloader = dataloaders[0]
144
    valid_dataloader = dataloaders[1]
145
    test_dataloader = dataloaders[2]
146
    conf['dataloaders'] = {'train':train_dataloader,
147
                           'valid':valid_dataloader,
148
                           'test': test_dataloader}
149
150
    # Check if GPU is available
151
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
152
    conf['device'] = device
153
    print(f'Device: {device}')
154
155
    # Initialize model
156
    model_name = conf['model']['name']
157
    print(f'Model: {model_name}')
158
    conf['model'] = initialize_model(conf)
159
160
    # Only in traning task
161
    if task == 'training':
162
        # Get optimizer
163
        optimizer_name = conf['optimizer']['name']
164
        print(f'Optimizer: {optimizer_name}')
165
        conf['optimizer'] = get_optimizer(conf)
166
167
        # Get criterion
168
        criterion_name = conf['criterion']['name']
169
        print(f'Criterion: {criterion_name}')
170
        conf['criterion'] = get_criterion(conf)
171
172
        # Get scheduler
173
        scheduler_name = conf['scheduler']['name']
174
        print(f'Scheduler: {scheduler_name}')
175
        conf['scheduler'] = get_scheduler(conf)
176
177
    # Only in evaluation task
178
    elif task == 'evaluation':
179
        path = os.path.join(conf['experiment_dir'], 'best_weights.pt')
180
        if os.path.exists(path):
181
            print(f'Loading weights from {path}')
182
            conf['best_weights'] = torch.load(path)
183
        else:
184
            print(f'Experiment weights {path} not found.')
185
            exit()
186
187
    return conf
188
189
190
191