a b/data/hyperkvasir.py
1
from os import listdir
2
from os.path import join
3
4
import PIL.Image
5
import matplotlib.pyplot as plt
6
import numpy as np
7
import torch.utils.data
8
from PIL.Image import open
9
from torch.nn.functional import one_hot
10
from torch.utils.data import Dataset
11
from torchvision import transforms
12
from perturbation.model import ModelOfNaturalVariation
13
import data.augmentation as aug
14
from utils.mask_generator import generate_a_mask
15
16
17
class KvasirClassificationDataset(Dataset):
18
    """
19
    Dataset class that fetches images with the associated pathological class labels for use in Pretraining
20
    """
21
22
    def __init__(self, path):
23
        super(KvasirClassificationDataset, self).__init__()
24
        self.path = join(path, "labeled-images/lower-gi-tract/pathological-findings")
25
        self.label_names = listdir(self.path)
26
        self.num_classes = len(self.label_names)
27
        self.fname_class_dict = {}
28
        i = 0
29
        self.class_weights = np.zeros(self.num_classes)
30
        for i, label in enumerate(self.label_names):
31
            class_path = join(self.path, label)
32
            for fname in listdir(class_path):
33
                self.class_weights[i] += 1
34
                self.fname_class_dict[fname] = label
35
        self.index_dict = dict(zip(self.label_names, range(self.num_classes)))
36
37
        self.common_transforms = transforms.Compose([transforms.Resize((400, 400)),
38
                                                     transforms.ToTensor()
39
                                                     ])
40
41
    def __len__(self):
42
        # return 256  # for debugging
43
        return len(self.fname_class_dict)
44
45
    def __getitem__(self, item):
46
        fname, label = list(self.fname_class_dict.items())[item]
47
        onehot = one_hot(torch.tensor(self.index_dict[label]), num_classes=self.num_classes)
48
        image = open(join(join(self.path, label), fname)).convert("RGB")
49
        # print(list(image.getdata()))
50
        # input()
51
        image = self.common_transforms(open(join(join(self.path, label), fname)).convert("RGB"))
52
        return image, onehot.float(), fname
53
54
55
class KvasirSegmentationDataset(Dataset):
56
    """
57
        Dataset class that fetches images with the associated segmentation mask.
58
        Employs "vanilla" augmentations
59
    """
60
61
    def __init__(self, path, split="train", augment=False):
62
        super(KvasirSegmentationDataset, self).__init__()
63
        self.path = join(path, "segmented-images/")
64
        self.fnames = listdir(join(self.path, "images"))
65
        self.common_transforms = aug.pipeline_tranforms()
66
        self.pixeltrans = aug.albumentation_pixelwise_transforms()
67
        self.segtrans = aug.albumentation_pixelwise_transforms()
68
        # deterministic partition
69
        self.split = split
70
        train_size = int(len(self.fnames) * 0.8)
71
        val_size = (len(self.fnames) - train_size) // 2
72
        test_size = len(self.fnames) - train_size - val_size
73
        self.augment = augment
74
        self.fnames_train = self.fnames[:train_size]
75
        self.fnames_val = self.fnames[train_size:train_size + val_size]
76
        self.fnames_test = self.fnames[train_size + val_size:]
77
        self.split_fnames = None  # iterable for selected split
78
        if self.split == "train":
79
            self.size = train_size
80
            self.split_fnames = self.fnames_train
81
        elif self.split == "val":
82
            self.size = val_size
83
            self.split_fnames = self.fnames_val
84
        elif self.split == "test":
85
            self.size = test_size
86
            self.split_fnames = self.fnames_test
87
        else:
88
            raise ValueError("Choices are train/val/test")
89
90
    def __len__(self):
91
        return self.size
92
93
    def __getitem__(self, index):
94
95
        image = np.array(open(join(self.path, "images/", self.split_fnames[index])).convert("RGB"))
96
        mask = np.array(open(join(self.path, "masks/", self.split_fnames[index])).convert("L"))
97
        if self.split == "train" and self.augment == True:
98
            transformed = self.pixeltrans(image=image)
99
            image = transformed["image"]
100
            segtransformed = self.segtrans(image=image, mask=mask)
101
            image, mask = segtransformed["image"], segtransformed["mask"]
102
        image = self.common_transforms(PIL.Image.fromarray(image))
103
        mask = self.common_transforms(PIL.Image.fromarray(mask))
104
        mask = (mask > 0.5).float()
105
        return image, mask, self.split_fnames[index]
106
107
108
class KvasirMNVset(KvasirSegmentationDataset):
109
    def __init__(self, path, split, inpaint=False):
110
        super(KvasirMNVset, self).__init__(path, split, augment=False)
111
        self.mnv = ModelOfNaturalVariation(1, use_inpainter=True)
112
        self.p = 0.5
113
114
    def __getitem__(self, index):
115
        image = np.array(open(join(self.path, "images/", self.split_fnames[index])).convert("RGB"))
116
        mask = np.array(open(join(self.path, "masks/", self.split_fnames[index])).convert("L"))
117
        image = self.common_transforms(PIL.Image.fromarray(image))
118
        mask = self.common_transforms(PIL.Image.fromarray(mask))
119
        mask = (mask > 0.5).float()
120
        flag = False
121
        if self.split == "train" and np.random.rand() < self.p:
122
            flag = True
123
            image, mask = self.mnv(image.unsqueeze(0), mask.unsqueeze(0))
124
            image = image.squeeze()
125
            mask = mask.squeeze(0)  # todo make this less ugly
126
            # plt.imshow(image.T.cpu().numpy())
127
            # plt.show()
128
        return image, mask, self.split_fnames[index], flag
129
130
    def set_prob(self, prob):
131
        self.p = prob
132
133
134
class KvasirInpaintingDataset(Dataset):
135
    def __init__(self, path, split="train"):
136
        super(KvasirInpaintingDataset, self).__init__()
137
        self.path = join(path, "segmented-images/")
138
        self.fnames = listdir(join(self.path, "images"))
139
        self.common_transforms = transforms.Compose([transforms.Resize((400, 400)),
140
                                                     transforms.ToTensor()
141
                                                     ])
142
        self.split = split
143
        train_size = int(len(self.fnames) * 0.8)
144
        val_size = (len(self.fnames) - train_size) // 2
145
        test_size = len(self.fnames) - train_size - val_size
146
        self.fnames_train = self.fnames[:train_size]
147
        self.fnames_val = self.fnames[train_size:train_size + val_size]
148
        self.fnames_test = self.fnames[train_size + val_size:]
149
        self.split_fnames = None  # iterable for selected split
150
        if self.split == "train":
151
            self.size = train_size
152
            self.split_fnames = self.fnames_train
153
        elif self.split == "val":
154
            self.size = val_size
155
            self.split_fnames = self.fnames_val
156
        elif self.split == "test":
157
            self.size = test_size
158
            self.split_fnames = self.fnames_test
159
        else:
160
            raise ValueError("Choices are train/val/test")
161
162
    def __len__(self):
163
        return len(self.split_fnames)
164
165
    def __getitem__(self, index):
166
        image = self.common_transforms(
167
            open(join(join(self.path, "images/"), self.split_fnames[index])).convert("RGB"))
168
        mask = self.common_transforms(
169
            open(join(join(self.path, "masks/"), self.split_fnames[index])).convert("L"))
170
        mask = (mask > 0.5).float()
171
172
        part = mask * image
173
        masked_image = image - part
174
175
        return image, mask, masked_image, part, self.split_fnames[index]
176
177
178
class KvasirSyntheticDataset(Dataset):
179
    def __init__(self, path, split="train"):
180
        super(KvasirSyntheticDataset, self).__init__()
181
        self.path = join(path, "unlabeled-images")
182
        self.fnames = listdir(join(self.path, "images"))
183
        self.common_transforms = aug.pipeline_tranforms()
184
        self.split = split
185
        train_size = int(len(self.fnames) * 0.8)
186
        val_size = (len(self.fnames) - train_size) // 2
187
        test_size = len(self.fnames) - train_size - val_size
188
        self.fnames_train = self.fnames[:train_size]
189
        self.fnames_val = self.fnames[train_size:train_size + val_size]
190
        self.fnames_test = self.fnames[train_size + val_size:]
191
        self.split_fnames = None  # iterable for selected split
192
        if self.split == "train":
193
            self.size = train_size
194
            self.split_fnames = self.fnames_train
195
        elif self.split == "val":
196
            self.size = val_size
197
            self.split_fnames = self.fnames_val
198
        elif self.split == "test":
199
            self.size = test_size
200
            self.split_fnames = self.fnames_test
201
        else:
202
            raise ValueError("Choices are train/val/test")
203
        print("loading mnv")
204
        self.mnv = ModelOfNaturalVariation(0, use_inpainter=True).to("cuda")
205
        print("mnv loaded")
206
207
    def __len__(self):
208
        return len(self.split_fnames)
209
        # return 10  # debug
210
211
    def __getitem__(self, index):
212
        print(f"getting {index}")
213
        image = np.array(open(join(self.path, "images/", self.split_fnames[index])).convert("RGB"))
214
        mask = np.zeros_like(image)
215
        image = self.common_transforms(PIL.Image.fromarray(image))
216
        mask = self.common_transforms(PIL.Image.fromarray(mask))
217
        image, mask = self.mnv(image.unsqueeze(0), mask.unsqueeze(0))
218
        image = image.squeeze()
219
        mask = mask.squeeze(0)
220
221
        return image, mask, self.split_fnames[index]
222
223
224
def test_KvasirSegmentationDataset():
225
    dataset = KvasirSegmentationDataset("Datasets/HyperKvasir", split="test")
226
    for x, y, fname in torch.utils.data.DataLoader(dataset):
227
        plt.imshow(x.squeeze().T)
228
        # plt.imshow(y.squeeze().T, alpha=0.5)
229
        plt.show()
230
231
        assert isinstance(x, torch.Tensor)
232
        assert isinstance(y, torch.Tensor)
233
    print("Segmentation evaluation passed")
234
235
236
def test_KvasirClassificationDataset():
237
    dataset = KvasirClassificationDataset("Datasets/HyperKvasir")
238
    for x, y, fname in torch.utils.data.DataLoader(dataset):
239
        assert isinstance(x, torch.Tensor)
240
        assert isinstance(y, torch.Tensor)
241
    print("Classification evaluation passed")
242
243
244
if __name__ == '__main__':
245
    test_KvasirSegmentationDataset()
246
    # test_KvasirClassificationDataset()