|
a |
|
b/src/utils.py |
|
|
1 |
import numpy as np |
|
|
2 |
import torch |
|
|
3 |
import os |
|
|
4 |
import sys |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
def set_seed(seed=0xD153A53): |
|
|
8 |
np.random.seed(seed) |
|
|
9 |
torch.manual_seed(seed) |
|
|
10 |
torch.cuda.manual_seed(seed) |
|
|
11 |
os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
def get_metric(cfg): |
|
|
15 |
return getattr(sys.modules['custom.metrics'], cfg.metric) |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
def get_criterion(cfg): |
|
|
19 |
return getattr(sys.modules['custom.losses'], cfg.criterion) |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
def get_model(cfg): |
|
|
23 |
return getattr(sys.modules['custom.models'], cfg.model) |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
def get_optimizer(cfg): |
|
|
27 |
return getattr(sys.modules['torch.optim'], cfg.optimizer) |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
def get_scheduler(cfg): |
|
|
31 |
if cfg.scheduler: |
|
|
32 |
return getattr(sys.modules['torch.optim.lr_scheduler'], cfg.scheduler) |
|
|
33 |
return FakeScheduler |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
def get_paths_1_dataset(data_folder, dataset_name): |
|
|
37 |
paths_folder = os.path.join(data_folder, dataset_name) |
|
|
38 |
last_number = 0 |
|
|
39 |
paths, _paths = [], [] |
|
|
40 |
|
|
|
41 |
# MedSeg has only one NIftI for many patients |
|
|
42 |
if 'MedSeg' in dataset_name: |
|
|
43 |
for i, name in enumerate(sorted(os.listdir(paths_folder))): |
|
|
44 |
path = os.path.join(paths_folder, name) |
|
|
45 |
if i % 5 == 0: |
|
|
46 |
paths.append(_paths) |
|
|
47 |
_paths = [] |
|
|
48 |
_paths.append(path) |
|
|
49 |
paths.append(_paths) |
|
|
50 |
return paths |
|
|
51 |
|
|
|
52 |
# adding paths by patients |
|
|
53 |
for name in sorted(os.listdir(paths_folder)): |
|
|
54 |
path = os.path.join(paths_folder, name) |
|
|
55 |
number_of_patient = int(name.split('_')[0]) |
|
|
56 |
if last_number != number_of_patient: |
|
|
57 |
paths.append(_paths) |
|
|
58 |
_paths = [] |
|
|
59 |
last_number = number_of_patient |
|
|
60 |
_paths.append(path) |
|
|
61 |
paths.append(_paths) |
|
|
62 |
return paths |
|
|
63 |
|
|
|
64 |
|
|
|
65 |
def get_paths(cfg): |
|
|
66 |
# if list for adding paths from every dataset |
|
|
67 |
if isinstance(cfg.dataset_name, list): |
|
|
68 |
paths = [] |
|
|
69 |
for name in cfg.dataset_name: |
|
|
70 |
paths.extend(get_paths_1_dataset(cfg.data_folder, name)) |
|
|
71 |
return paths |
|
|
72 |
|
|
|
73 |
# if solo dataset |
|
|
74 |
paths = get_paths_1_dataset(cfg.data_folder, cfg.dataset_name) |
|
|
75 |
return paths |
|
|
76 |
|
|
|
77 |
|
|
|
78 |
class OneHotEncoder: |
|
|
79 |
def __init__(self, cfg): |
|
|
80 |
self.zeros = [0] * cfg.num_classes |
|
|
81 |
|
|
|
82 |
def encode_num(self, x): |
|
|
83 |
zeros = self.zeros.copy() |
|
|
84 |
zeros[int(x)] = 1 |
|
|
85 |
return zeros |
|
|
86 |
|
|
|
87 |
def __call__(self, y): |
|
|
88 |
y = np.array(y) |
|
|
89 |
y = np.expand_dims(y, -1) |
|
|
90 |
y = np.apply_along_axis(self.encode_num, -1, y) |
|
|
91 |
y = np.swapaxes(y, -1, 1) |
|
|
92 |
y = np.ascontiguousarray(y) |
|
|
93 |
return torch.Tensor(y) |
|
|
94 |
|
|
|
95 |
|
|
|
96 |
class FakeScheduler: |
|
|
97 |
def __init__(self, *args, **kwargs): |
|
|
98 |
self.lr = kwargs['lr'] |
|
|
99 |
|
|
|
100 |
def step(self, *args, **kwargs): |
|
|
101 |
pass |
|
|
102 |
|
|
|
103 |
def get_last_lr(self, *args, **kwargs): |
|
|
104 |
return [self.lr] |