|
a |
|
b/src/data_functions.py |
|
|
1 |
import random |
|
|
2 |
|
|
|
3 |
import torch |
|
|
4 |
from torch.utils.data import Dataset, DataLoader |
|
|
5 |
import numpy as np |
|
|
6 |
from sklearn.model_selection import train_test_split, KFold |
|
|
7 |
from utils import get_paths |
|
|
8 |
import albumentations as A |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
class Covid19Dataset(Dataset): |
|
|
12 |
def __init__(self, paths, transform=None): |
|
|
13 |
self.paths = paths |
|
|
14 |
self.transform = transform |
|
|
15 |
self._len = len(self.paths) |
|
|
16 |
|
|
|
17 |
def __len__(self): |
|
|
18 |
return self._len |
|
|
19 |
|
|
|
20 |
def __getitem__(self, index): |
|
|
21 |
path = self.paths[index] |
|
|
22 |
loaded = np.load(path) |
|
|
23 |
image = loaded['image'] |
|
|
24 |
mask = loaded['mask'] |
|
|
25 |
if self.transform: |
|
|
26 |
transformed = self.transform(image=image, mask=mask) |
|
|
27 |
image = transformed['image'] |
|
|
28 |
mask = transformed['mask'] |
|
|
29 |
image = torch.from_numpy(np.array([image], dtype=np.float)) |
|
|
30 |
image = image.type(torch.FloatTensor) |
|
|
31 |
mask = torch.from_numpy(np.array([mask], dtype=np.uint8)) |
|
|
32 |
return image, mask |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
def data_generator(cfg): |
|
|
36 |
image_paths = get_paths(cfg) |
|
|
37 |
image_paths = np.asarray(image_paths) |
|
|
38 |
train_paths, val_paths = [], [] |
|
|
39 |
|
|
|
40 |
if not cfg.kfold: |
|
|
41 |
_train_paths, _val_paths = train_test_split(image_paths, test_size=cfg.val_size, random_state=cfg.seed) |
|
|
42 |
else: |
|
|
43 |
kf = KFold(n_splits=cfg.n_splits) |
|
|
44 |
for i, (train_index, val_index) in enumerate(kf.split(image_paths)): |
|
|
45 |
if i + 1 == cfg.fold_number: |
|
|
46 |
_train_paths = image_paths[train_index] |
|
|
47 |
_val_paths = image_paths[val_index] |
|
|
48 |
|
|
|
49 |
for paths in _train_paths: |
|
|
50 |
train_paths.extend(paths) |
|
|
51 |
for paths in _val_paths: |
|
|
52 |
val_paths.extend(paths) |
|
|
53 |
random.shuffle(train_paths) |
|
|
54 |
random.shuffle(val_paths) |
|
|
55 |
return train_paths, val_paths |
|
|
56 |
|
|
|
57 |
|
|
|
58 |
def get_transforms(cfg): |
|
|
59 |
# getting transforms from albumentations |
|
|
60 |
pre_transforms = [getattr(A, item["name"])(**item["params"]) for item in cfg.pre_transforms] |
|
|
61 |
augmentations = [getattr(A, item["name"])(**item["params"]) for item in cfg.augmentations] |
|
|
62 |
post_transforms = [getattr(A, item["name"])(**item["params"]) for item in cfg.post_transforms] |
|
|
63 |
|
|
|
64 |
# concatenate transforms |
|
|
65 |
train = A.Compose(pre_transforms + augmentations + post_transforms) |
|
|
66 |
test = A.Compose(pre_transforms + post_transforms) |
|
|
67 |
return train, test |
|
|
68 |
|
|
|
69 |
|
|
|
70 |
def get_loaders(cfg): |
|
|
71 |
train_transforms, test_transforms = get_transforms(cfg) |
|
|
72 |
train_paths, val_paths = data_generator(cfg) |
|
|
73 |
train_ds = Covid19Dataset(train_paths, transform=train_transforms) |
|
|
74 |
val_ds = Covid19Dataset(val_paths, transform=train_transforms) |
|
|
75 |
train_dl = DataLoader(train_ds, batch_size=cfg.batch_size, drop_last=True) |
|
|
76 |
val_dl = DataLoader(val_ds, batch_size=cfg.batch_size, drop_last=True) |
|
|
77 |
return train_dl, val_dl |