[8eeb5a]: / data / cvc.py

Download this file

38 lines (30 with data), 1.2 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from os import listdir
from os.path import join
import matplotlib.pyplot as plt
from PIL.Image import open
from torch.utils.data import Dataset
import data.augmentation as aug
from PIL import Image
class CVC_ClinicDB(Dataset):
def __init__(self, root_directory):
super(CVC_ClinicDB, self).__init__()
self.root = root_directory
self.mask_fnames = listdir(join(self.root, "Ground Truth"))
self.mask_locs = [join(self.root, "Ground Truth", i) for i in self.mask_fnames]
self.img_locs = [join(self.root, "Original", i) for i in
self.mask_fnames]
self.common_transforms = aug.pipeline_tranforms()
def __getitem__(self, idx):
mask = self.common_transforms(open(self.mask_locs[idx]))
image = self.common_transforms(open(self.img_locs[idx]))
return image, mask, self.mask_fnames[idx]
def __len__(self):
return len(self.mask_fnames)
if __name__ == '__main__':
dataset = CVC_ClinicDB("Datasets/CVC-ClinicDB")
for img, mask, fname in dataset:
plt.imshow(img.T)
plt.imshow(mask.T, alpha=0.5)
plt.show()
input()
print("done")