Switch to unified view

a b/CaraNet/utils/dataloader.py
1
import os
2
from PIL import Image
3
import torch.utils.data as data
4
import torchvision.transforms as transforms
5
import numpy as np
6
import random
7
import torch
8
9
10
class PolypDataset(data.Dataset):
11
    """
12
    dataloader for polyp segmentation tasks
13
    """
14
    def __init__(self, image_root, gt_root, trainsize, augmentations):
15
        self.trainsize = trainsize
16
        self.augmentations = augmentations
17
        print(self.augmentations)
18
        self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
19
        self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.png')]
20
        self.images = sorted(self.images)
21
        self.gts = sorted(self.gts)
22
        self.filter_files()
23
        self.size = len(self.images)
24
        if self.augmentations == True:
25
            print('Using RandomRotation, RandomFlip')
26
            self.img_transform = transforms.Compose([
27
                transforms.RandomRotation(90, resample=False, expand=False, center=None),
28
                transforms.RandomVerticalFlip(p=0.5),
29
                transforms.RandomHorizontalFlip(p=0.5),
30
                transforms.Resize((self.trainsize, self.trainsize)),
31
                transforms.ToTensor(),
32
                transforms.Normalize([0.485, 0.456, 0.406],
33
                                     [0.229, 0.224, 0.225])])
34
            self.gt_transform = transforms.Compose([
35
                transforms.RandomRotation(90, resample=False, expand=False, center=None),
36
                transforms.RandomVerticalFlip(p=0.5),
37
                transforms.RandomHorizontalFlip(p=0.5),
38
                transforms.Resize((self.trainsize, self.trainsize)),
39
                transforms.ToTensor()])
40
            
41
        else:
42
            print('no augmentation')
43
            self.img_transform = transforms.Compose([
44
                transforms.Resize((self.trainsize, self.trainsize)),
45
                transforms.ToTensor(),
46
                transforms.Normalize([0.485, 0.456, 0.406],
47
                                     [0.229, 0.224, 0.225])])
48
            
49
            self.gt_transform = transforms.Compose([
50
                transforms.Resize((self.trainsize, self.trainsize)),
51
                transforms.ToTensor()])
52
            
53
54
    def __getitem__(self, index):
55
        
56
        image = self.rgb_loader(self.images[index])
57
        gt = self.binary_loader(self.gts[index])
58
        
59
        seed = np.random.randint(2147483647) # make a seed with numpy generator 
60
        random.seed(seed) # apply this seed to img tranfsorms
61
        torch.manual_seed(seed) # needed for torchvision 0.7
62
        if self.img_transform is not None:
63
            image = self.img_transform(image)
64
            
65
        random.seed(seed) # apply this seed to img tranfsorms
66
        torch.manual_seed(seed) # needed for torchvision 0.7
67
        if self.gt_transform is not None:
68
            gt = self.gt_transform(gt)
69
        return image, gt
70
71
    def filter_files(self):
72
        assert len(self.images) == len(self.gts)
73
        images = []
74
        gts = []
75
        for img_path, gt_path in zip(self.images, self.gts):
76
            img = Image.open(img_path)
77
            gt = Image.open(gt_path)
78
            if img.size == gt.size:
79
                images.append(img_path)
80
                gts.append(gt_path)
81
        self.images = images
82
        self.gts = gts
83
84
    def rgb_loader(self, path):
85
        with open(path, 'rb') as f:
86
            img = Image.open(f)
87
            return img.convert('RGB')
88
89
    def binary_loader(self, path):
90
        with open(path, 'rb') as f:
91
            img = Image.open(f)
92
            # return img.convert('1')
93
            return img.convert('L')
94
95
    def resize(self, img, gt):
96
        assert img.size == gt.size
97
        w, h = img.size
98
        if h < self.trainsize or w < self.trainsize:
99
            h = max(h, self.trainsize)
100
            w = max(w, self.trainsize)
101
            return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST)
102
        else:
103
            return img, gt
104
105
    def __len__(self):
106
        return self.size
107
108
109
def get_loader(image_root, gt_root, batchsize, trainsize, shuffle=True, num_workers=4, pin_memory=True, augmentation=False):
110
111
    dataset = PolypDataset(image_root, gt_root, trainsize, augmentation)
112
    data_loader = data.DataLoader(dataset=dataset,
113
                                  batch_size=batchsize,
114
                                  shuffle=shuffle,
115
                                  num_workers=num_workers,
116
                                  pin_memory=pin_memory)
117
    return data_loader
118
119
120
class test_dataset:
121
    def __init__(self, image_root, gt_root, testsize):
122
        self.testsize = testsize
123
        self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
124
        self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')]
125
        self.images = sorted(self.images)
126
        self.gts = sorted(self.gts)
127
        self.transform = transforms.Compose([
128
            transforms.Resize((self.testsize, self.testsize)),
129
            transforms.ToTensor(),
130
            transforms.Normalize([0.485, 0.456, 0.406],
131
                                 [0.229, 0.224, 0.225])])
132
        self.gt_transform = transforms.ToTensor()
133
        self.size = len(self.images)
134
        self.index = 0
135
136
    def load_data(self):
137
        image = self.rgb_loader(self.images[self.index])
138
        image = self.transform(image).unsqueeze(0)
139
        gt = self.binary_loader(self.gts[self.index])
140
        name = self.images[self.index].split('/')[-1]
141
        if name.endswith('.jpg'):
142
            name = name.split('.jpg')[0] + '.png'
143
        self.index += 1
144
        return image, gt, name
145
146
    def rgb_loader(self, path):
147
        with open(path, 'rb') as f:
148
            img = Image.open(f)
149
            return img.convert('RGB')
150
151
    def binary_loader(self, path):
152
        with open(path, 'rb') as f:
153
            img = Image.open(f)
154
            return img.convert('L')