Diff of /common/dataset.py [000000] .. [f804b3]

Switch to side-by-side view

--- a
+++ b/common/dataset.py
@@ -0,0 +1,130 @@
+from __future__ import print_function, division
+
+import os
+import torch
+import pandas as pd
+from skimage import io, transform
+import numpy as np
+from torch.utils.data import Dataset, DataLoader
+from torchvision import transforms, utils
+from PIL import Image, ImageOps
+from random import random, randint
+
+# Ignore warnings
+import warnings
+import pdb
+
+warnings.filterwarnings("ignore")
+
+
+def make_dataset(root,mode):
+
+  """   Takes in the root directory and mode(train or val or test) as inputs
+  then joins the path with the folder of the specified mode.Applies normalize
+  function to each img of the folder and returns a list of tuples containing
+  image and its corresponding mask.
+
+  Returns
+  -------
+  tuple : Normalized image and its annotation.
+
+  """
+  assert mode in ['train', 'val', 'test']
+  items = []
+
+  if mode == 'train':
+      train_img_path = os.path.join(root, 'Train/train_image')
+      train_mask_path = os.path.join(root, 'Train/train_mask')
+
+      images = os.listdir(train_img_path)
+      labels = os.listdir(train_mask_path)
+
+      images.sort()
+      labels.sort()
+
+      for it_im, it_gt in zip(images, labels):
+          item = (os.path.join(train_img_path, it_im), os.path.join(train_mask_path, it_gt))
+          items.append(item)
+  elif mode == 'val':
+      val_img_path = os.path.join(root, 'Val/val_img')
+      val_mask_path = os.path.join(root, 'Val/val_mask')
+
+      images = os.listdir(val_img_path)
+      labels = os.listdir(val_mask_path)
+
+      images.sort()
+      labels.sort()
+
+
+      for it_im, it_gt in zip(images, labels):
+          item = (os.path.join(val_img_path, it_im), os.path.join(val_mask_path, it_gt))
+          items.append(item)
+  else:
+      test_img_path = os.path.join(root, 'Test/test_img')
+     # test_mask_path = os.path.join(root, 'Test/test_mask')
+
+      images = os.listdir(test_img_path)
+      #labels = os.listdir(test_mask_path)
+
+      images.sort()
+      #labels.sort()
+
+      for it_im in images:
+          item = os.path.join(test_img_path, it_im)
+          items.append(item)
+
+  return items
+
+
+
+class MedicalImageDataset(Dataset):
+    """ GI dataset."""
+
+    def __init__(self, mode, root_dir, transform=None):
+        """
+        Args:
+            csv_file (string): Path to the csv file with annotations.
+            root_dir (string): Directory with all the images.
+            transform (callable, optional): Optional transform to be applied
+                on a sample.
+        """
+        self.root_dir = root_dir
+        self.mode=mode
+        self.transform = transform
+        self.imgs = make_dataset(root_dir, mode)
+
+
+    def __len__(self):
+        return len(self.imgs)
+
+
+
+    def __getitem__(self,index):
+      if self.mode== 'test':
+         img_path=self.imgs[index]
+         img =np.array( Image.open(img_path))
+         img_shape=np.array(img).shape
+
+         if self.transform:
+            augmented = self.transform(image=img)
+            image=augmented["image"]
+
+
+         return image
+
+      else:
+        img_path, mask_path = self.imgs[index]
+        # print("{} and {}".format(img_path,mask_path))
+        img = np.array(Image.open(img_path))  # .convert('RGB')
+        # mask = Image.open(mask_path)  # .convert('RGB')
+        # img = Image.open(img_path).convert('L')
+        mask = np.array(Image.open(mask_path).convert('L'))
+
+        # print('{} and {}'.format(img_path,mask_path))
+
+        if self.transform:
+            augmented = self.transform(image=img,mask=mask)
+            image=augmented["image"]
+            mask=augmented["mask"]
+
+        return [image, mask]