[8eeb5a]: / data / endocv.py

Download this file

36 lines (29 with data), 1.2 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from os import listdir
from os.path import join
import matplotlib.pyplot as plt
from PIL.Image import open
from torch.utils.data import Dataset
import data.augmentation as aug
class EndoCV2020(Dataset):
def __init__(self, root_directory):
super(EndoCV2020, self).__init__()
self.root = root_directory
self.mask_fnames = listdir(join(self.root, "masksPerClass", "polyp"))
self.mask_locs = [join(self.root, "masksPerClass", "polyp", i) for i in self.mask_fnames]
self.img_locs = [join(self.root, "originalImages", i.replace("_polyp", "").replace(".tif", ".jpg")) for i in
self.mask_fnames]
self.common_transforms = aug.pipeline_tranforms()
def __getitem__(self, idx):
mask = self.common_transforms(open(self.mask_locs[idx]))
image = self.common_transforms(open(self.img_locs[idx]))
return image, mask, self.mask_fnames[idx]
def __len__(self):
return len(self.mask_fnames)
if __name__ == '__main__':
dataset = EndoCV2020("Datasets/EndoCV2020")
for img, mask, fname in dataset:
plt.imshow(img.T)
plt.imshow(mask.T, alpha=0.5)
plt.show()
input()
print("done")