Diff of /data/cvc.py [000000] .. [92cc18]

Switch to side-by-side view

--- a
+++ b/data/cvc.py
@@ -0,0 +1,37 @@
+from os import listdir
+from os.path import join
+import matplotlib.pyplot as plt
+from PIL.Image import open
+from torch.utils.data import Dataset
+import data.augmentation as aug
+
+from PIL import Image
+
+
+class CVC_ClinicDB(Dataset):
+    def __init__(self, root_directory):
+        super(CVC_ClinicDB, self).__init__()
+        self.root = root_directory
+        self.mask_fnames = listdir(join(self.root, "Ground Truth"))
+        self.mask_locs = [join(self.root, "Ground Truth", i) for i in self.mask_fnames]
+        self.img_locs = [join(self.root, "Original", i) for i in
+                         self.mask_fnames]
+        self.common_transforms = aug.pipeline_tranforms()
+
+    def __getitem__(self, idx):
+        mask = self.common_transforms(open(self.mask_locs[idx]))
+        image = self.common_transforms(open(self.img_locs[idx]))
+        return image, mask, self.mask_fnames[idx]
+
+    def __len__(self):
+        return len(self.mask_fnames)
+
+
+if __name__ == '__main__':
+    dataset = CVC_ClinicDB("Datasets/CVC-ClinicDB")
+    for img, mask, fname in dataset:
+        plt.imshow(img.T)
+        plt.imshow(mask.T, alpha=0.5)
+        plt.show()
+        input()
+    print("done")