Diff of /src/utils.py [000000] .. [cbdc43]

Switch to unified view

a b/src/utils.py
1
import numpy as np
2
import torch
3
import os
4
import sys
5
6
7
def set_seed(seed=0xD153A53):
8
    np.random.seed(seed)
9
    torch.manual_seed(seed)
10
    torch.cuda.manual_seed(seed)
11
    os.environ['PYTHONHASHSEED'] = str(seed)
12
13
14
def get_metric(cfg):
15
    return getattr(sys.modules['custom.metrics'], cfg.metric)
16
17
18
def get_criterion(cfg):
19
    return getattr(sys.modules['custom.losses'], cfg.criterion)
20
21
22
def get_model(cfg):
23
    return getattr(sys.modules['custom.models'], cfg.model)
24
25
26
def get_optimizer(cfg):
27
    return getattr(sys.modules['torch.optim'], cfg.optimizer)
28
29
30
def get_scheduler(cfg):
31
    if cfg.scheduler:
32
        return getattr(sys.modules['torch.optim.lr_scheduler'], cfg.scheduler)
33
    return FakeScheduler
34
35
36
def get_paths_1_dataset(data_folder, dataset_name):
37
    paths_folder = os.path.join(data_folder, dataset_name)
38
    last_number = 0
39
    paths, _paths = [], []
40
41
    # MedSeg has only one NIftI for many patients
42
    if 'MedSeg' in dataset_name:
43
        for i, name in enumerate(sorted(os.listdir(paths_folder))):
44
            path = os.path.join(paths_folder, name)
45
            if i % 5 == 0:
46
                paths.append(_paths)
47
                _paths = []
48
            _paths.append(path)
49
        paths.append(_paths)
50
        return paths
51
52
    # adding paths by patients
53
    for name in sorted(os.listdir(paths_folder)):
54
        path = os.path.join(paths_folder, name)
55
        number_of_patient = int(name.split('_')[0])
56
        if last_number != number_of_patient:
57
            paths.append(_paths)
58
            _paths = []
59
            last_number = number_of_patient
60
        _paths.append(path)
61
    paths.append(_paths)
62
    return paths
63
64
65
def get_paths(cfg):
66
    # if list for adding paths from every dataset
67
    if isinstance(cfg.dataset_name, list):
68
        paths = []
69
        for name in cfg.dataset_name:
70
            paths.extend(get_paths_1_dataset(cfg.data_folder, name))
71
        return paths
72
73
    # if solo dataset
74
    paths = get_paths_1_dataset(cfg.data_folder, cfg.dataset_name)
75
    return paths
76
77
78
class OneHotEncoder:
79
    def __init__(self, cfg):
80
        self.zeros = [0] * cfg.num_classes
81
82
    def encode_num(self, x):
83
        zeros = self.zeros.copy()
84
        zeros[int(x)] = 1
85
        return zeros
86
87
    def __call__(self, y):
88
        y = np.array(y)
89
        y = np.expand_dims(y, -1)
90
        y = np.apply_along_axis(self.encode_num, -1, y)
91
        y = np.swapaxes(y, -1, 1)
92
        y = np.ascontiguousarray(y)
93
        return torch.Tensor(y)
94
95
96
class FakeScheduler:
97
    def __init__(self, *args, **kwargs):
98
        self.lr = kwargs['lr']
99
100
    def step(self, *args, **kwargs):
101
        pass
102
103
    def get_last_lr(self, *args, **kwargs):
104
        return [self.lr]