Diff of /data/cvc.py [000000] .. [92cc18]

Switch to unified view

a b/data/cvc.py
1
from os import listdir
2
from os.path import join
3
import matplotlib.pyplot as plt
4
from PIL.Image import open
5
from torch.utils.data import Dataset
6
import data.augmentation as aug
7
8
from PIL import Image
9
10
11
class CVC_ClinicDB(Dataset):
12
    def __init__(self, root_directory):
13
        super(CVC_ClinicDB, self).__init__()
14
        self.root = root_directory
15
        self.mask_fnames = listdir(join(self.root, "Ground Truth"))
16
        self.mask_locs = [join(self.root, "Ground Truth", i) for i in self.mask_fnames]
17
        self.img_locs = [join(self.root, "Original", i) for i in
18
                         self.mask_fnames]
19
        self.common_transforms = aug.pipeline_tranforms()
20
21
    def __getitem__(self, idx):
22
        mask = self.common_transforms(open(self.mask_locs[idx]))
23
        image = self.common_transforms(open(self.img_locs[idx]))
24
        return image, mask, self.mask_fnames[idx]
25
26
    def __len__(self):
27
        return len(self.mask_fnames)
28
29
30
if __name__ == '__main__':
31
    dataset = CVC_ClinicDB("Datasets/CVC-ClinicDB")
32
    for img, mask, fname in dataset:
33
        plt.imshow(img.T)
34
        plt.imshow(mask.T, alpha=0.5)
35
        plt.show()
36
        input()
37
    print("done")