Diff of /dataloaders/dataset2d.py [000000] .. [fbbdf8]

Switch to unified view

a b/dataloaders/dataset2d.py
1
import json
2
3
import cv2
4
from albumentations import Compose, Normalize
5
from albumentations.pytorch.transforms import ToTensorV2
6
from torch.utils.data import DataLoader, Dataset
7
8
augment = Compose([Normalize(), ToTensorV2()])
9
10
11
class EcgDataset2D(Dataset):
12
    def __init__(self, ann_path, mapping_path):
13
        super().__init__()
14
        self.data = json.load(open(ann_path))
15
        self.mapper = json.load(open(mapping_path))
16
17
    def __getitem__(self, index):
18
        img = cv2.imread(self.data[index]["path"])
19
        img = augment(**{"image": img})["image"]
20
21
        return {"image": img, "class": self.mapper[self.data[index]["label"]]}
22
23
    def get_dataloader(self, num_workers=4, batch_size=16, shuffle=True):
24
        data_loader = DataLoader(
25
            self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
26
        )
27
        return data_loader
28
29
    def __len__(self):
30
        return len(self.data)
31
32
33
def callback_get_label(dataset, idx):
34
    return dataset[idx]["class"]