|
a |
|
b/dataset.py |
|
|
1 |
import os |
|
|
2 |
from PIL import Image |
|
|
3 |
from torch.utils.data import Dataset |
|
|
4 |
import numpy as np |
|
|
5 |
|
|
|
6 |
class ChestDataset(Dataset): |
|
|
7 |
def __init__(self, image_dir, mask_dir, transform=None): |
|
|
8 |
super(ChestDataset, self).__init__() |
|
|
9 |
self.image_dir = image_dir |
|
|
10 |
self.mask_dir = mask_dir |
|
|
11 |
self.transform = transform |
|
|
12 |
self.images = os.listdir(image_dir) |
|
|
13 |
self.masked_images = os.listdir(mask_dir) |
|
|
14 |
|
|
|
15 |
def __len__(self): |
|
|
16 |
return len(self.images) |
|
|
17 |
|
|
|
18 |
def __getitem__(self, index): |
|
|
19 |
img_path = os.path.join(self.image_dir, self.images[index]) |
|
|
20 |
mask_path = os.path.join(self.mask_dir, self.masked_images[index]) |
|
|
21 |
image = np.array(Image.open(img_path).convert('RGB'), dtype=np.float32) |
|
|
22 |
mask = np.array(Image.open(mask_path).convert('L'), dtype=np.float32) |
|
|
23 |
#mask[mask == 255.0] = 1.0 |
|
|
24 |
|
|
|
25 |
if self.transform is not None: |
|
|
26 |
augmentation = self.transform(image=image, mask=mask) |
|
|
27 |
image = augmentation['image'] |
|
|
28 |
mask = augmentation['mask'] |
|
|
29 |
|
|
|
30 |
return image, mask |