a b/data/endocv.py
1
from os import listdir
2
from os.path import join
3
import matplotlib.pyplot as plt
4
from PIL.Image import open
5
from torch.utils.data import Dataset
6
import data.augmentation as aug
7
8
9
class EndoCV2020(Dataset):
10
    def __init__(self, root_directory):
11
        super(EndoCV2020, self).__init__()
12
        self.root = root_directory
13
        self.mask_fnames = listdir(join(self.root, "masksPerClass", "polyp"))
14
        self.mask_locs = [join(self.root, "masksPerClass", "polyp", i) for i in self.mask_fnames]
15
        self.img_locs = [join(self.root, "originalImages", i.replace("_polyp", "").replace(".tif", ".jpg")) for i in
16
                         self.mask_fnames]
17
        self.common_transforms = aug.pipeline_tranforms()
18
19
    def __getitem__(self, idx):
20
        mask = self.common_transforms(open(self.mask_locs[idx]))
21
        image = self.common_transforms(open(self.img_locs[idx]))
22
        return image, mask, self.mask_fnames[idx]
23
24
    def __len__(self):
25
        return len(self.mask_fnames)
26
27
28
if __name__ == '__main__':
29
    dataset = EndoCV2020("Datasets/EndoCV2020")
30
    for img, mask, fname in dataset:
31
        plt.imshow(img.T)
32
        plt.imshow(mask.T, alpha=0.5)
33
        plt.show()
34
        input()
35
    print("done")