|
a |
|
b/data/etis.py |
|
|
1 |
from os import listdir |
|
|
2 |
from os.path import join |
|
|
3 |
|
|
|
4 |
import matplotlib.pyplot as plt |
|
|
5 |
from PIL.Image import open |
|
|
6 |
from torch.utils.data import Dataset, DataLoader |
|
|
7 |
from torchvision import transforms |
|
|
8 |
import data.augmentation as aug |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
class EtisDataset(Dataset): |
|
|
12 |
""" |
|
|
13 |
Dataset class that fetches Etis-LaribPolypDB images with the associated segmentation mask. |
|
|
14 |
Used for testing. |
|
|
15 |
""" |
|
|
16 |
|
|
|
17 |
def __init__(self, path): |
|
|
18 |
super(EtisDataset, self).__init__() |
|
|
19 |
self.path = path |
|
|
20 |
self.len = len(listdir(join(self.path, "ETIS-LaribPolypDB"))) |
|
|
21 |
self.common_transforms = aug.pipeline_tranforms() |
|
|
22 |
|
|
|
23 |
def __len__(self): |
|
|
24 |
return self.len |
|
|
25 |
|
|
|
26 |
def __getitem__(self, index): |
|
|
27 |
image = self.common_transforms( |
|
|
28 |
open(join(self.path, "ETIS-LaribPolypDB/{}.jpg".format(index + 1))).convert("RGB")) |
|
|
29 |
mask = self.common_transforms( |
|
|
30 |
open(join(self.path, "GroundTruth/p{}.jpg".format(index + 1))).convert("RGB")) |
|
|
31 |
mask = (mask > 0.5).float() |
|
|
32 |
return image, mask, index + 1 |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
def test_etis(): |
|
|
36 |
for x, y, in DataLoader(EtisDataset("Datasets/ETIS-LaribPolypDB")): |
|
|
37 |
plt.imshow(x) |
|
|
38 |
plt.show() |