--- a +++ b/Dataset/datasetloader.py @@ -0,0 +1,70 @@ +import numpy as np +import cv2 +from torch.utils.data import Dataset +import torch +import random +IGNORED = ['.DS_Store'] + + +class MRIDataset(Dataset): + def __init__(self, imgpath, labelpath, preprocessors=None, verbose=-1): + super(MRIDataset, self).__init__() + # store the image preprocessor + self.preprocessors = preprocessors + self.imgpath = imgpath + self.labelpath = labelpath + + # if the preprocessors are None, initialize them as an + # empty list + if self.preprocessors is None: + self.preprocessors = [] + + self.images = [] + self.masks = [] + + for (i, path) in enumerate(self.imgpath): + image = cv2.imread(path,0) + + if self.preprocessors is not None: + for p in self.preprocessors: + image = p.preprocess(image) + + image = torch.from_numpy(image) + image = image.unsqueeze(0) + + self.images.append(image) + + if verbose > 0 and i > 0 and (i + 1) % verbose == 0: + print("[INFO] processed {}/{}".format(i + 1, len(path))) + + for (i, path) in enumerate(self.labelpath): + label = cv2.imread(path) + + if self.preprocessors is not None: + for p in self.preprocessors: + label = p.preprocess(label) + label = np.sum(label, axis=2) + label = label > 0.5 + label = torch.from_numpy(label) + label = label.unsqueeze(0) + + self.masks.append(label) + + if verbose > 0 and i > 0 and (i + 1) % verbose == 0: + print("[INFO] processed {}/{}".format(i + 1, len(path))) + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + + image = self.images[idx] + mask = self.masks[idx] + + # Flip image for data augmentation + if random.random() > 0.5: + image = torch.flip(image, [0]) + mask = torch.flip(mask, [0]) + + return image.float(), mask.float() +