|
a |
|
b/src/config.py |
|
|
1 |
import sys |
|
|
2 |
sys.path.append('.') |
|
|
3 |
import yaml |
|
|
4 |
import os |
|
|
5 |
from shutil import rmtree |
|
|
6 |
|
|
|
7 |
import torch |
|
|
8 |
import torchvision.transforms as transforms |
|
|
9 |
|
|
|
10 |
from src.data import get_datasets, get_dataloaders |
|
|
11 |
from src.model import initialize_model |
|
|
12 |
from src.optimizer import get_optimizer |
|
|
13 |
from src.criterion import get_criterion |
|
|
14 |
from src.scheduler import get_scheduler |
|
|
15 |
import yaml |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
|
|
|
19 |
def get_transforms(conf): |
|
|
20 |
model_name = conf['model']['name'] |
|
|
21 |
if 'efficientdet' in model_name: |
|
|
22 |
train_transform = transforms.Compose([transforms.Resize((512,512)), |
|
|
23 |
transforms.RandomVerticalFlip(p=0.5), |
|
|
24 |
transforms.RandomHorizontalFlip(p=0.5), |
|
|
25 |
transforms.ToTensor(), |
|
|
26 |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
|
27 |
]) |
|
|
28 |
valid_transform = transforms.Compose([transforms.Resize((512,512)), |
|
|
29 |
transforms.ToTensor(), |
|
|
30 |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
|
31 |
]) |
|
|
32 |
test_transform = transforms.Compose([transforms.Resize((512,512)), |
|
|
33 |
transforms.ToTensor(), |
|
|
34 |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
|
35 |
]) |
|
|
36 |
|
|
|
37 |
elif 'resnet' in model_name: |
|
|
38 |
train_transform = transforms.Compose([transforms.Resize((224,224)), |
|
|
39 |
transforms.RandomVerticalFlip(p=0.5), |
|
|
40 |
transforms.RandomHorizontalFlip(p=0.5), |
|
|
41 |
transforms.ToTensor(), |
|
|
42 |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
|
43 |
]) |
|
|
44 |
valid_transform = transforms.Compose([transforms.Resize((224,224)), |
|
|
45 |
transforms.ToTensor(), |
|
|
46 |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
|
47 |
]) |
|
|
48 |
test_transform = transforms.Compose([transforms.Resize((224,224)), |
|
|
49 |
transforms.ToTensor(), |
|
|
50 |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
|
51 |
]) |
|
|
52 |
|
|
|
53 |
return train_transform, valid_transform, test_transform |
|
|
54 |
|
|
|
55 |
|
|
|
56 |
def get_experiment_dir(conf): |
|
|
57 |
experiments_dir = conf['experiments_dir'] |
|
|
58 |
if not os.path.exists(experiments_dir): |
|
|
59 |
os.makedirs(experiments_dir) |
|
|
60 |
|
|
|
61 |
experiment_dir = os.path.join(conf['experiments_dir'], |
|
|
62 |
conf['experiment_code']) |
|
|
63 |
if conf['task'] == 'training': |
|
|
64 |
if not os.path.exists(experiment_dir): |
|
|
65 |
os.makedirs(experiment_dir) |
|
|
66 |
save_conf = os.path.join(experiment_dir, conf['task'] + '.yaml') |
|
|
67 |
with open(save_conf, 'w') as fp: |
|
|
68 |
yaml.dump(conf, fp) |
|
|
69 |
else: |
|
|
70 |
print(f'Experiment dir {experiment_dir} exists.') |
|
|
71 |
exit() |
|
|
72 |
return experiment_dir |
|
|
73 |
|
|
|
74 |
def get_test_config(conf): |
|
|
75 |
|
|
|
76 |
# Check if results dir exists and create it |
|
|
77 |
results_dir = conf['results_dir'] |
|
|
78 |
if os.path.exists(results_dir): rmtree(results_dir) |
|
|
79 |
os.makedirs(results_dir) |
|
|
80 |
|
|
|
81 |
# Get transforms |
|
|
82 |
transforms = get_transforms(conf) |
|
|
83 |
conf['train_transform'] = transforms[0] |
|
|
84 |
conf['valid_transform'] = transforms[1] |
|
|
85 |
conf['test_transform'] = transforms[2] |
|
|
86 |
|
|
|
87 |
# Get datasets |
|
|
88 |
dataset_name = conf['data']['name'] |
|
|
89 |
print(f'Dataset: {dataset_name}') |
|
|
90 |
_, _, conf['test_dataset'] = get_datasets(conf) |
|
|
91 |
|
|
|
92 |
# Check if GPU is available |
|
|
93 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
|
94 |
conf['device'] = device |
|
|
95 |
print(f'Device: {device}') |
|
|
96 |
|
|
|
97 |
# Initialize model |
|
|
98 |
model_name = conf['model']['name'] |
|
|
99 |
print(f'Model: {model_name}') |
|
|
100 |
conf['model'] = initialize_model(conf) |
|
|
101 |
|
|
|
102 |
return conf |
|
|
103 |
|
|
|
104 |
|
|
|
105 |
def get_config(yaml_path): |
|
|
106 |
# Load YAML conf |
|
|
107 |
conf = yaml.safe_load(open(yaml_path, 'r')) |
|
|
108 |
|
|
|
109 |
# Task |
|
|
110 |
task = conf['task'] |
|
|
111 |
if task in ['training', 'evaluation']: |
|
|
112 |
print(f'Task: {task}') |
|
|
113 |
|
|
|
114 |
elif task == 'testing': |
|
|
115 |
return get_test_config(conf) |
|
|
116 |
|
|
|
117 |
else: |
|
|
118 |
print(f'Task {task} not supported.') |
|
|
119 |
exit() |
|
|
120 |
|
|
|
121 |
# Get experiment directory |
|
|
122 |
experiment_dir = get_experiment_dir(conf) |
|
|
123 |
conf['experiment_dir'] = experiment_dir |
|
|
124 |
print(f'Experiment directory: {experiment_dir}') |
|
|
125 |
|
|
|
126 |
# Get transforms |
|
|
127 |
transforms = get_transforms(conf) |
|
|
128 |
conf['train_transform'] = transforms[0] |
|
|
129 |
conf['valid_transform'] = transforms[1] |
|
|
130 |
conf['test_transform'] = transforms[2] |
|
|
131 |
|
|
|
132 |
# Get datasets |
|
|
133 |
dataset_name = conf['data']['name'] |
|
|
134 |
print(f'Dataset: {dataset_name}') |
|
|
135 |
datasets = get_datasets(conf) |
|
|
136 |
conf['train_dataset'] = datasets[0] |
|
|
137 |
conf['valid_dataset'] = datasets[1] |
|
|
138 |
conf['test_dataset'] = datasets[2] |
|
|
139 |
conf['patients_dataset'] = datasets[3] |
|
|
140 |
|
|
|
141 |
# Get dataloaders |
|
|
142 |
dataloaders = get_dataloaders(conf) |
|
|
143 |
train_dataloader = dataloaders[0] |
|
|
144 |
valid_dataloader = dataloaders[1] |
|
|
145 |
test_dataloader = dataloaders[2] |
|
|
146 |
conf['dataloaders'] = {'train':train_dataloader, |
|
|
147 |
'valid':valid_dataloader, |
|
|
148 |
'test': test_dataloader} |
|
|
149 |
|
|
|
150 |
# Check if GPU is available |
|
|
151 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
|
152 |
conf['device'] = device |
|
|
153 |
print(f'Device: {device}') |
|
|
154 |
|
|
|
155 |
# Initialize model |
|
|
156 |
model_name = conf['model']['name'] |
|
|
157 |
print(f'Model: {model_name}') |
|
|
158 |
conf['model'] = initialize_model(conf) |
|
|
159 |
|
|
|
160 |
# Only in traning task |
|
|
161 |
if task == 'training': |
|
|
162 |
# Get optimizer |
|
|
163 |
optimizer_name = conf['optimizer']['name'] |
|
|
164 |
print(f'Optimizer: {optimizer_name}') |
|
|
165 |
conf['optimizer'] = get_optimizer(conf) |
|
|
166 |
|
|
|
167 |
# Get criterion |
|
|
168 |
criterion_name = conf['criterion']['name'] |
|
|
169 |
print(f'Criterion: {criterion_name}') |
|
|
170 |
conf['criterion'] = get_criterion(conf) |
|
|
171 |
|
|
|
172 |
# Get scheduler |
|
|
173 |
scheduler_name = conf['scheduler']['name'] |
|
|
174 |
print(f'Scheduler: {scheduler_name}') |
|
|
175 |
conf['scheduler'] = get_scheduler(conf) |
|
|
176 |
|
|
|
177 |
# Only in evaluation task |
|
|
178 |
elif task == 'evaluation': |
|
|
179 |
path = os.path.join(conf['experiment_dir'], 'best_weights.pt') |
|
|
180 |
if os.path.exists(path): |
|
|
181 |
print(f'Loading weights from {path}') |
|
|
182 |
conf['best_weights'] = torch.load(path) |
|
|
183 |
else: |
|
|
184 |
print(f'Experiment weights {path} not found.') |
|
|
185 |
exit() |
|
|
186 |
|
|
|
187 |
return conf |
|
|
188 |
|
|
|
189 |
|
|
|
190 |
|
|
|
191 |
|