--- a +++ b/dataloaders/dataset2d.py @@ -0,0 +1,34 @@ +import json + +import cv2 +from albumentations import Compose, Normalize +from albumentations.pytorch.transforms import ToTensorV2 +from torch.utils.data import DataLoader, Dataset + +augment = Compose([Normalize(), ToTensorV2()]) + + +class EcgDataset2D(Dataset): + def __init__(self, ann_path, mapping_path): + super().__init__() + self.data = json.load(open(ann_path)) + self.mapper = json.load(open(mapping_path)) + + def __getitem__(self, index): + img = cv2.imread(self.data[index]["path"]) + img = augment(**{"image": img})["image"] + + return {"image": img, "class": self.mapper[self.data[index]["label"]]} + + def get_dataloader(self, num_workers=4, batch_size=16, shuffle=True): + data_loader = DataLoader( + self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, + ) + return data_loader + + def __len__(self): + return len(self.data) + + +def callback_get_label(dataset, idx): + return dataset[idx]["class"]