--- a +++ b/data/etis.py @@ -0,0 +1,38 @@ +from os import listdir +from os.path import join + +import matplotlib.pyplot as plt +from PIL.Image import open +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +import data.augmentation as aug + + +class EtisDataset(Dataset): + """ + Dataset class that fetches Etis-LaribPolypDB images with the associated segmentation mask. + Used for testing. + """ + + def __init__(self, path): + super(EtisDataset, self).__init__() + self.path = path + self.len = len(listdir(join(self.path, "ETIS-LaribPolypDB"))) + self.common_transforms = aug.pipeline_tranforms() + + def __len__(self): + return self.len + + def __getitem__(self, index): + image = self.common_transforms( + open(join(self.path, "ETIS-LaribPolypDB/{}.jpg".format(index + 1))).convert("RGB")) + mask = self.common_transforms( + open(join(self.path, "GroundTruth/p{}.jpg".format(index + 1))).convert("RGB")) + mask = (mask > 0.5).float() + return image, mask, index + 1 + + +def test_etis(): + for x, y, in DataLoader(EtisDataset("Datasets/ETIS-LaribPolypDB")): + plt.imshow(x) + plt.show()