Diff of /common/dataset.py [000000] .. [f804b3]

Switch to unified view

a b/common/dataset.py
1
from __future__ import print_function, division
2
3
import os
4
import torch
5
import pandas as pd
6
from skimage import io, transform
7
import numpy as np
8
from torch.utils.data import Dataset, DataLoader
9
from torchvision import transforms, utils
10
from PIL import Image, ImageOps
11
from random import random, randint
12
13
# Ignore warnings
14
import warnings
15
import pdb
16
17
warnings.filterwarnings("ignore")
18
19
20
def make_dataset(root,mode):
21
22
  """   Takes in the root directory and mode(train or val or test) as inputs
23
  then joins the path with the folder of the specified mode.Applies normalize
24
  function to each img of the folder and returns a list of tuples containing
25
  image and its corresponding mask.
26
27
  Returns
28
  -------
29
  tuple : Normalized image and its annotation.
30
31
  """
32
  assert mode in ['train', 'val', 'test']
33
  items = []
34
35
  if mode == 'train':
36
      train_img_path = os.path.join(root, 'Train/train_image')
37
      train_mask_path = os.path.join(root, 'Train/train_mask')
38
39
      images = os.listdir(train_img_path)
40
      labels = os.listdir(train_mask_path)
41
42
      images.sort()
43
      labels.sort()
44
45
      for it_im, it_gt in zip(images, labels):
46
          item = (os.path.join(train_img_path, it_im), os.path.join(train_mask_path, it_gt))
47
          items.append(item)
48
  elif mode == 'val':
49
      val_img_path = os.path.join(root, 'Val/val_img')
50
      val_mask_path = os.path.join(root, 'Val/val_mask')
51
52
      images = os.listdir(val_img_path)
53
      labels = os.listdir(val_mask_path)
54
55
      images.sort()
56
      labels.sort()
57
58
59
      for it_im, it_gt in zip(images, labels):
60
          item = (os.path.join(val_img_path, it_im), os.path.join(val_mask_path, it_gt))
61
          items.append(item)
62
  else:
63
      test_img_path = os.path.join(root, 'Test/test_img')
64
     # test_mask_path = os.path.join(root, 'Test/test_mask')
65
66
      images = os.listdir(test_img_path)
67
      #labels = os.listdir(test_mask_path)
68
69
      images.sort()
70
      #labels.sort()
71
72
      for it_im in images:
73
          item = os.path.join(test_img_path, it_im)
74
          items.append(item)
75
76
  return items
77
78
79
80
class MedicalImageDataset(Dataset):
81
    """ GI dataset."""
82
83
    def __init__(self, mode, root_dir, transform=None):
84
        """
85
        Args:
86
            csv_file (string): Path to the csv file with annotations.
87
            root_dir (string): Directory with all the images.
88
            transform (callable, optional): Optional transform to be applied
89
                on a sample.
90
        """
91
        self.root_dir = root_dir
92
        self.mode=mode
93
        self.transform = transform
94
        self.imgs = make_dataset(root_dir, mode)
95
96
97
    def __len__(self):
98
        return len(self.imgs)
99
100
101
102
    def __getitem__(self,index):
103
      if self.mode== 'test':
104
         img_path=self.imgs[index]
105
         img =np.array( Image.open(img_path))
106
         img_shape=np.array(img).shape
107
108
         if self.transform:
109
            augmented = self.transform(image=img)
110
            image=augmented["image"]
111
112
113
         return image
114
115
      else:
116
        img_path, mask_path = self.imgs[index]
117
        # print("{} and {}".format(img_path,mask_path))
118
        img = np.array(Image.open(img_path))  # .convert('RGB')
119
        # mask = Image.open(mask_path)  # .convert('RGB')
120
        # img = Image.open(img_path).convert('L')
121
        mask = np.array(Image.open(mask_path).convert('L'))
122
123
        # print('{} and {}'.format(img_path,mask_path))
124
125
        if self.transform:
126
            augmented = self.transform(image=img,mask=mask)
127
            image=augmented["image"]
128
            mask=augmented["mask"]
129
130
        return [image, mask]