--- a +++ b/CaraNet/utils/dataloader.py @@ -0,0 +1,154 @@ +import os +from PIL import Image +import torch.utils.data as data +import torchvision.transforms as transforms +import numpy as np +import random +import torch + + +class PolypDataset(data.Dataset): + """ + dataloader for polyp segmentation tasks + """ + def __init__(self, image_root, gt_root, trainsize, augmentations): + self.trainsize = trainsize + self.augmentations = augmentations + print(self.augmentations) + self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] + self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.png')] + self.images = sorted(self.images) + self.gts = sorted(self.gts) + self.filter_files() + self.size = len(self.images) + if self.augmentations == True: + print('Using RandomRotation, RandomFlip') + self.img_transform = transforms.Compose([ + transforms.RandomRotation(90, resample=False, expand=False, center=None), + transforms.RandomVerticalFlip(p=0.5), + transforms.RandomHorizontalFlip(p=0.5), + transforms.Resize((self.trainsize, self.trainsize)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], + [0.229, 0.224, 0.225])]) + self.gt_transform = transforms.Compose([ + transforms.RandomRotation(90, resample=False, expand=False, center=None), + transforms.RandomVerticalFlip(p=0.5), + transforms.RandomHorizontalFlip(p=0.5), + transforms.Resize((self.trainsize, self.trainsize)), + transforms.ToTensor()]) + + else: + print('no augmentation') + self.img_transform = transforms.Compose([ + transforms.Resize((self.trainsize, self.trainsize)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], + [0.229, 0.224, 0.225])]) + + self.gt_transform = transforms.Compose([ + transforms.Resize((self.trainsize, self.trainsize)), + transforms.ToTensor()]) + + + def __getitem__(self, index): + + image = self.rgb_loader(self.images[index]) + gt = self.binary_loader(self.gts[index]) + + seed = np.random.randint(2147483647) # make a seed with numpy generator + random.seed(seed) # apply this seed to img tranfsorms + torch.manual_seed(seed) # needed for torchvision 0.7 + if self.img_transform is not None: + image = self.img_transform(image) + + random.seed(seed) # apply this seed to img tranfsorms + torch.manual_seed(seed) # needed for torchvision 0.7 + if self.gt_transform is not None: + gt = self.gt_transform(gt) + return image, gt + + def filter_files(self): + assert len(self.images) == len(self.gts) + images = [] + gts = [] + for img_path, gt_path in zip(self.images, self.gts): + img = Image.open(img_path) + gt = Image.open(gt_path) + if img.size == gt.size: + images.append(img_path) + gts.append(gt_path) + self.images = images + self.gts = gts + + def rgb_loader(self, path): + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + + def binary_loader(self, path): + with open(path, 'rb') as f: + img = Image.open(f) + # return img.convert('1') + return img.convert('L') + + def resize(self, img, gt): + assert img.size == gt.size + w, h = img.size + if h < self.trainsize or w < self.trainsize: + h = max(h, self.trainsize) + w = max(w, self.trainsize) + return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST) + else: + return img, gt + + def __len__(self): + return self.size + + +def get_loader(image_root, gt_root, batchsize, trainsize, shuffle=True, num_workers=4, pin_memory=True, augmentation=False): + + dataset = PolypDataset(image_root, gt_root, trainsize, augmentation) + data_loader = data.DataLoader(dataset=dataset, + batch_size=batchsize, + shuffle=shuffle, + num_workers=num_workers, + pin_memory=pin_memory) + return data_loader + + +class test_dataset: + def __init__(self, image_root, gt_root, testsize): + self.testsize = testsize + self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] + self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')] + self.images = sorted(self.images) + self.gts = sorted(self.gts) + self.transform = transforms.Compose([ + transforms.Resize((self.testsize, self.testsize)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], + [0.229, 0.224, 0.225])]) + self.gt_transform = transforms.ToTensor() + self.size = len(self.images) + self.index = 0 + + def load_data(self): + image = self.rgb_loader(self.images[self.index]) + image = self.transform(image).unsqueeze(0) + gt = self.binary_loader(self.gts[self.index]) + name = self.images[self.index].split('/')[-1] + if name.endswith('.jpg'): + name = name.split('.jpg')[0] + '.png' + self.index += 1 + return image, gt, name + + def rgb_loader(self, path): + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + + def binary_loader(self, path): + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('L')