--- a +++ b/common/dataset.py @@ -0,0 +1,130 @@ +from __future__ import print_function, division + +import os +import torch +import pandas as pd +from skimage import io, transform +import numpy as np +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms, utils +from PIL import Image, ImageOps +from random import random, randint + +# Ignore warnings +import warnings +import pdb + +warnings.filterwarnings("ignore") + + +def make_dataset(root,mode): + + """ Takes in the root directory and mode(train or val or test) as inputs + then joins the path with the folder of the specified mode.Applies normalize + function to each img of the folder and returns a list of tuples containing + image and its corresponding mask. + + Returns + ------- + tuple : Normalized image and its annotation. + + """ + assert mode in ['train', 'val', 'test'] + items = [] + + if mode == 'train': + train_img_path = os.path.join(root, 'Train/train_image') + train_mask_path = os.path.join(root, 'Train/train_mask') + + images = os.listdir(train_img_path) + labels = os.listdir(train_mask_path) + + images.sort() + labels.sort() + + for it_im, it_gt in zip(images, labels): + item = (os.path.join(train_img_path, it_im), os.path.join(train_mask_path, it_gt)) + items.append(item) + elif mode == 'val': + val_img_path = os.path.join(root, 'Val/val_img') + val_mask_path = os.path.join(root, 'Val/val_mask') + + images = os.listdir(val_img_path) + labels = os.listdir(val_mask_path) + + images.sort() + labels.sort() + + + for it_im, it_gt in zip(images, labels): + item = (os.path.join(val_img_path, it_im), os.path.join(val_mask_path, it_gt)) + items.append(item) + else: + test_img_path = os.path.join(root, 'Test/test_img') + # test_mask_path = os.path.join(root, 'Test/test_mask') + + images = os.listdir(test_img_path) + #labels = os.listdir(test_mask_path) + + images.sort() + #labels.sort() + + for it_im in images: + item = os.path.join(test_img_path, it_im) + items.append(item) + + return items + + + +class MedicalImageDataset(Dataset): + """ GI dataset.""" + + def __init__(self, mode, root_dir, transform=None): + """ + Args: + csv_file (string): Path to the csv file with annotations. + root_dir (string): Directory with all the images. + transform (callable, optional): Optional transform to be applied + on a sample. + """ + self.root_dir = root_dir + self.mode=mode + self.transform = transform + self.imgs = make_dataset(root_dir, mode) + + + def __len__(self): + return len(self.imgs) + + + + def __getitem__(self,index): + if self.mode== 'test': + img_path=self.imgs[index] + img =np.array( Image.open(img_path)) + img_shape=np.array(img).shape + + if self.transform: + augmented = self.transform(image=img) + image=augmented["image"] + + + return image + + else: + img_path, mask_path = self.imgs[index] + # print("{} and {}".format(img_path,mask_path)) + img = np.array(Image.open(img_path)) # .convert('RGB') + # mask = Image.open(mask_path) # .convert('RGB') + # img = Image.open(img_path).convert('L') + mask = np.array(Image.open(mask_path).convert('L')) + + # print('{} and {}'.format(img_path,mask_path)) + + if self.transform: + augmented = self.transform(image=img,mask=mask) + image=augmented["image"] + mask=augmented["mask"] + + return [image, mask]