|
a |
|
b/data/hyperkvasir.py |
|
|
1 |
from os import listdir |
|
|
2 |
from os.path import join |
|
|
3 |
|
|
|
4 |
import PIL.Image |
|
|
5 |
import matplotlib.pyplot as plt |
|
|
6 |
import numpy as np |
|
|
7 |
import torch.utils.data |
|
|
8 |
from PIL.Image import open |
|
|
9 |
from torch.nn.functional import one_hot |
|
|
10 |
from torch.utils.data import Dataset |
|
|
11 |
from torchvision import transforms |
|
|
12 |
from perturbation.model import ModelOfNaturalVariation |
|
|
13 |
import data.augmentation as aug |
|
|
14 |
from utils.mask_generator import generate_a_mask |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
class KvasirClassificationDataset(Dataset): |
|
|
18 |
""" |
|
|
19 |
Dataset class that fetches images with the associated pathological class labels for use in Pretraining |
|
|
20 |
""" |
|
|
21 |
|
|
|
22 |
def __init__(self, path): |
|
|
23 |
super(KvasirClassificationDataset, self).__init__() |
|
|
24 |
self.path = join(path, "labeled-images/lower-gi-tract/pathological-findings") |
|
|
25 |
self.label_names = listdir(self.path) |
|
|
26 |
self.num_classes = len(self.label_names) |
|
|
27 |
self.fname_class_dict = {} |
|
|
28 |
i = 0 |
|
|
29 |
self.class_weights = np.zeros(self.num_classes) |
|
|
30 |
for i, label in enumerate(self.label_names): |
|
|
31 |
class_path = join(self.path, label) |
|
|
32 |
for fname in listdir(class_path): |
|
|
33 |
self.class_weights[i] += 1 |
|
|
34 |
self.fname_class_dict[fname] = label |
|
|
35 |
self.index_dict = dict(zip(self.label_names, range(self.num_classes))) |
|
|
36 |
|
|
|
37 |
self.common_transforms = transforms.Compose([transforms.Resize((400, 400)), |
|
|
38 |
transforms.ToTensor() |
|
|
39 |
]) |
|
|
40 |
|
|
|
41 |
def __len__(self): |
|
|
42 |
# return 256 # for debugging |
|
|
43 |
return len(self.fname_class_dict) |
|
|
44 |
|
|
|
45 |
def __getitem__(self, item): |
|
|
46 |
fname, label = list(self.fname_class_dict.items())[item] |
|
|
47 |
onehot = one_hot(torch.tensor(self.index_dict[label]), num_classes=self.num_classes) |
|
|
48 |
image = open(join(join(self.path, label), fname)).convert("RGB") |
|
|
49 |
# print(list(image.getdata())) |
|
|
50 |
# input() |
|
|
51 |
image = self.common_transforms(open(join(join(self.path, label), fname)).convert("RGB")) |
|
|
52 |
return image, onehot.float(), fname |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
class KvasirSegmentationDataset(Dataset): |
|
|
56 |
""" |
|
|
57 |
Dataset class that fetches images with the associated segmentation mask. |
|
|
58 |
Employs "vanilla" augmentations |
|
|
59 |
""" |
|
|
60 |
|
|
|
61 |
def __init__(self, path, split="train", augment=False): |
|
|
62 |
super(KvasirSegmentationDataset, self).__init__() |
|
|
63 |
self.path = join(path, "segmented-images/") |
|
|
64 |
self.fnames = listdir(join(self.path, "images")) |
|
|
65 |
self.common_transforms = aug.pipeline_tranforms() |
|
|
66 |
self.pixeltrans = aug.albumentation_pixelwise_transforms() |
|
|
67 |
self.segtrans = aug.albumentation_pixelwise_transforms() |
|
|
68 |
# deterministic partition |
|
|
69 |
self.split = split |
|
|
70 |
train_size = int(len(self.fnames) * 0.8) |
|
|
71 |
val_size = (len(self.fnames) - train_size) // 2 |
|
|
72 |
test_size = len(self.fnames) - train_size - val_size |
|
|
73 |
self.augment = augment |
|
|
74 |
self.fnames_train = self.fnames[:train_size] |
|
|
75 |
self.fnames_val = self.fnames[train_size:train_size + val_size] |
|
|
76 |
self.fnames_test = self.fnames[train_size + val_size:] |
|
|
77 |
self.split_fnames = None # iterable for selected split |
|
|
78 |
if self.split == "train": |
|
|
79 |
self.size = train_size |
|
|
80 |
self.split_fnames = self.fnames_train |
|
|
81 |
elif self.split == "val": |
|
|
82 |
self.size = val_size |
|
|
83 |
self.split_fnames = self.fnames_val |
|
|
84 |
elif self.split == "test": |
|
|
85 |
self.size = test_size |
|
|
86 |
self.split_fnames = self.fnames_test |
|
|
87 |
else: |
|
|
88 |
raise ValueError("Choices are train/val/test") |
|
|
89 |
|
|
|
90 |
def __len__(self): |
|
|
91 |
return self.size |
|
|
92 |
|
|
|
93 |
def __getitem__(self, index): |
|
|
94 |
|
|
|
95 |
image = np.array(open(join(self.path, "images/", self.split_fnames[index])).convert("RGB")) |
|
|
96 |
mask = np.array(open(join(self.path, "masks/", self.split_fnames[index])).convert("L")) |
|
|
97 |
if self.split == "train" and self.augment == True: |
|
|
98 |
transformed = self.pixeltrans(image=image) |
|
|
99 |
image = transformed["image"] |
|
|
100 |
segtransformed = self.segtrans(image=image, mask=mask) |
|
|
101 |
image, mask = segtransformed["image"], segtransformed["mask"] |
|
|
102 |
image = self.common_transforms(PIL.Image.fromarray(image)) |
|
|
103 |
mask = self.common_transforms(PIL.Image.fromarray(mask)) |
|
|
104 |
mask = (mask > 0.5).float() |
|
|
105 |
return image, mask, self.split_fnames[index] |
|
|
106 |
|
|
|
107 |
|
|
|
108 |
class KvasirMNVset(KvasirSegmentationDataset): |
|
|
109 |
def __init__(self, path, split, inpaint=False): |
|
|
110 |
super(KvasirMNVset, self).__init__(path, split, augment=False) |
|
|
111 |
self.mnv = ModelOfNaturalVariation(1, use_inpainter=True) |
|
|
112 |
self.p = 0.5 |
|
|
113 |
|
|
|
114 |
def __getitem__(self, index): |
|
|
115 |
image = np.array(open(join(self.path, "images/", self.split_fnames[index])).convert("RGB")) |
|
|
116 |
mask = np.array(open(join(self.path, "masks/", self.split_fnames[index])).convert("L")) |
|
|
117 |
image = self.common_transforms(PIL.Image.fromarray(image)) |
|
|
118 |
mask = self.common_transforms(PIL.Image.fromarray(mask)) |
|
|
119 |
mask = (mask > 0.5).float() |
|
|
120 |
flag = False |
|
|
121 |
if self.split == "train" and np.random.rand() < self.p: |
|
|
122 |
flag = True |
|
|
123 |
image, mask = self.mnv(image.unsqueeze(0), mask.unsqueeze(0)) |
|
|
124 |
image = image.squeeze() |
|
|
125 |
mask = mask.squeeze(0) # todo make this less ugly |
|
|
126 |
# plt.imshow(image.T.cpu().numpy()) |
|
|
127 |
# plt.show() |
|
|
128 |
return image, mask, self.split_fnames[index], flag |
|
|
129 |
|
|
|
130 |
def set_prob(self, prob): |
|
|
131 |
self.p = prob |
|
|
132 |
|
|
|
133 |
|
|
|
134 |
class KvasirInpaintingDataset(Dataset): |
|
|
135 |
def __init__(self, path, split="train"): |
|
|
136 |
super(KvasirInpaintingDataset, self).__init__() |
|
|
137 |
self.path = join(path, "segmented-images/") |
|
|
138 |
self.fnames = listdir(join(self.path, "images")) |
|
|
139 |
self.common_transforms = transforms.Compose([transforms.Resize((400, 400)), |
|
|
140 |
transforms.ToTensor() |
|
|
141 |
]) |
|
|
142 |
self.split = split |
|
|
143 |
train_size = int(len(self.fnames) * 0.8) |
|
|
144 |
val_size = (len(self.fnames) - train_size) // 2 |
|
|
145 |
test_size = len(self.fnames) - train_size - val_size |
|
|
146 |
self.fnames_train = self.fnames[:train_size] |
|
|
147 |
self.fnames_val = self.fnames[train_size:train_size + val_size] |
|
|
148 |
self.fnames_test = self.fnames[train_size + val_size:] |
|
|
149 |
self.split_fnames = None # iterable for selected split |
|
|
150 |
if self.split == "train": |
|
|
151 |
self.size = train_size |
|
|
152 |
self.split_fnames = self.fnames_train |
|
|
153 |
elif self.split == "val": |
|
|
154 |
self.size = val_size |
|
|
155 |
self.split_fnames = self.fnames_val |
|
|
156 |
elif self.split == "test": |
|
|
157 |
self.size = test_size |
|
|
158 |
self.split_fnames = self.fnames_test |
|
|
159 |
else: |
|
|
160 |
raise ValueError("Choices are train/val/test") |
|
|
161 |
|
|
|
162 |
def __len__(self): |
|
|
163 |
return len(self.split_fnames) |
|
|
164 |
|
|
|
165 |
def __getitem__(self, index): |
|
|
166 |
image = self.common_transforms( |
|
|
167 |
open(join(join(self.path, "images/"), self.split_fnames[index])).convert("RGB")) |
|
|
168 |
mask = self.common_transforms( |
|
|
169 |
open(join(join(self.path, "masks/"), self.split_fnames[index])).convert("L")) |
|
|
170 |
mask = (mask > 0.5).float() |
|
|
171 |
|
|
|
172 |
part = mask * image |
|
|
173 |
masked_image = image - part |
|
|
174 |
|
|
|
175 |
return image, mask, masked_image, part, self.split_fnames[index] |
|
|
176 |
|
|
|
177 |
|
|
|
178 |
class KvasirSyntheticDataset(Dataset): |
|
|
179 |
def __init__(self, path, split="train"): |
|
|
180 |
super(KvasirSyntheticDataset, self).__init__() |
|
|
181 |
self.path = join(path, "unlabeled-images") |
|
|
182 |
self.fnames = listdir(join(self.path, "images")) |
|
|
183 |
self.common_transforms = aug.pipeline_tranforms() |
|
|
184 |
self.split = split |
|
|
185 |
train_size = int(len(self.fnames) * 0.8) |
|
|
186 |
val_size = (len(self.fnames) - train_size) // 2 |
|
|
187 |
test_size = len(self.fnames) - train_size - val_size |
|
|
188 |
self.fnames_train = self.fnames[:train_size] |
|
|
189 |
self.fnames_val = self.fnames[train_size:train_size + val_size] |
|
|
190 |
self.fnames_test = self.fnames[train_size + val_size:] |
|
|
191 |
self.split_fnames = None # iterable for selected split |
|
|
192 |
if self.split == "train": |
|
|
193 |
self.size = train_size |
|
|
194 |
self.split_fnames = self.fnames_train |
|
|
195 |
elif self.split == "val": |
|
|
196 |
self.size = val_size |
|
|
197 |
self.split_fnames = self.fnames_val |
|
|
198 |
elif self.split == "test": |
|
|
199 |
self.size = test_size |
|
|
200 |
self.split_fnames = self.fnames_test |
|
|
201 |
else: |
|
|
202 |
raise ValueError("Choices are train/val/test") |
|
|
203 |
print("loading mnv") |
|
|
204 |
self.mnv = ModelOfNaturalVariation(0, use_inpainter=True).to("cuda") |
|
|
205 |
print("mnv loaded") |
|
|
206 |
|
|
|
207 |
def __len__(self): |
|
|
208 |
return len(self.split_fnames) |
|
|
209 |
# return 10 # debug |
|
|
210 |
|
|
|
211 |
def __getitem__(self, index): |
|
|
212 |
print(f"getting {index}") |
|
|
213 |
image = np.array(open(join(self.path, "images/", self.split_fnames[index])).convert("RGB")) |
|
|
214 |
mask = np.zeros_like(image) |
|
|
215 |
image = self.common_transforms(PIL.Image.fromarray(image)) |
|
|
216 |
mask = self.common_transforms(PIL.Image.fromarray(mask)) |
|
|
217 |
image, mask = self.mnv(image.unsqueeze(0), mask.unsqueeze(0)) |
|
|
218 |
image = image.squeeze() |
|
|
219 |
mask = mask.squeeze(0) |
|
|
220 |
|
|
|
221 |
return image, mask, self.split_fnames[index] |
|
|
222 |
|
|
|
223 |
|
|
|
224 |
def test_KvasirSegmentationDataset(): |
|
|
225 |
dataset = KvasirSegmentationDataset("Datasets/HyperKvasir", split="test") |
|
|
226 |
for x, y, fname in torch.utils.data.DataLoader(dataset): |
|
|
227 |
plt.imshow(x.squeeze().T) |
|
|
228 |
# plt.imshow(y.squeeze().T, alpha=0.5) |
|
|
229 |
plt.show() |
|
|
230 |
|
|
|
231 |
assert isinstance(x, torch.Tensor) |
|
|
232 |
assert isinstance(y, torch.Tensor) |
|
|
233 |
print("Segmentation evaluation passed") |
|
|
234 |
|
|
|
235 |
|
|
|
236 |
def test_KvasirClassificationDataset(): |
|
|
237 |
dataset = KvasirClassificationDataset("Datasets/HyperKvasir") |
|
|
238 |
for x, y, fname in torch.utils.data.DataLoader(dataset): |
|
|
239 |
assert isinstance(x, torch.Tensor) |
|
|
240 |
assert isinstance(y, torch.Tensor) |
|
|
241 |
print("Classification evaluation passed") |
|
|
242 |
|
|
|
243 |
|
|
|
244 |
if __name__ == '__main__': |
|
|
245 |
test_KvasirSegmentationDataset() |
|
|
246 |
# test_KvasirClassificationDataset() |