[c128d9]: / dataloaders / dataset2d.py

Download this file

35 lines (24 with data), 999 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import json
import cv2
from albumentations import Compose, Normalize
from albumentations.pytorch.transforms import ToTensorV2
from torch.utils.data import DataLoader, Dataset
augment = Compose([Normalize(), ToTensorV2()])
class EcgDataset2D(Dataset):
def __init__(self, ann_path, mapping_path):
super().__init__()
self.data = json.load(open(ann_path))
self.mapper = json.load(open(mapping_path))
def __getitem__(self, index):
img = cv2.imread(self.data[index]["path"])
img = augment(**{"image": img})["image"]
return {"image": img, "class": self.mapper[self.data[index]["label"]]}
def get_dataloader(self, num_workers=4, batch_size=16, shuffle=True):
data_loader = DataLoader(
self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
)
return data_loader
def __len__(self):
return len(self.data)
def callback_get_label(dataset, idx):
return dataset[idx]["class"]