--- a +++ b/dataset.py @@ -0,0 +1,30 @@ +import os +from PIL import Image +from torch.utils.data import Dataset +import numpy as np + +class ChestDataset(Dataset): + def __init__(self, image_dir, mask_dir, transform=None): + super(ChestDataset, self).__init__() + self.image_dir = image_dir + self.mask_dir = mask_dir + self.transform = transform + self.images = os.listdir(image_dir) + self.masked_images = os.listdir(mask_dir) + + def __len__(self): + return len(self.images) + + def __getitem__(self, index): + img_path = os.path.join(self.image_dir, self.images[index]) + mask_path = os.path.join(self.mask_dir, self.masked_images[index]) + image = np.array(Image.open(img_path).convert('RGB'), dtype=np.float32) + mask = np.array(Image.open(mask_path).convert('L'), dtype=np.float32) + #mask[mask == 255.0] = 1.0 + + if self.transform is not None: + augmentation = self.transform(image=image, mask=mask) + image = augmentation['image'] + mask = augmentation['mask'] + + return image, mask