--- a +++ b/U-Net/utils/data_loading.py @@ -0,0 +1,132 @@ +import logging +import numpy as np +import torch +from PIL import Image +from functools import lru_cache +from functools import partial +from itertools import repeat +from multiprocessing import Pool +from os import listdir +from os.path import splitext, isfile, join +from pathlib import Path +from torch.utils.data import Dataset +from tqdm import tqdm + +def load_image(filename): + ext = splitext(filename)[1] + if ext == '.npy': + return Image.fromarray(np.load(filename)) + elif ext in ['.pt', '.pth']: + return Image.fromarray(torch.load(filename).numpy()) + else: + return Image.open(filename) + + +def unique_mask_values(idx, mask_dir, mask_suffix): + #print(str(idx), str(mask_suffix)) + mask_file = list(mask_dir.glob(idx + '_mask.png'))[0] + mask = np.asarray(load_image(mask_file)) + #print(mask.shape) + if mask.ndim == 2: + return np.unique(mask) + elif mask.ndim == 3: + #print(mask.shape[-1]) + mask = mask.reshape(-1, mask.shape[-1]) + #print(mask.shape) + return np.unique(mask, axis=0) + else: + raise ValueError(f'Loaded masks should have 2 or 3 dimensions, found {mask.ndim}') + + +class BasicDataset(Dataset): + def __init__(self, images_dir: str, mask_dir: str, scale: float = 1.0, mask_suffix: str = '', transform=None): + self.images_dir = Path(images_dir) + self.mask_dir = Path(mask_dir) + self.transform = transform + assert 0 < scale <= 1, 'Scale must be between 0 and 1' + self.scale = scale + self.mask_suffix = mask_suffix + + self.ids = [splitext(file)[0] for file in listdir(images_dir) if isfile(join(images_dir, file)) and not file.startswith('.')] + if not self.ids: + raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there') + + logging.info(f'Creating dataset with {len(self.ids)} examples') + logging.info('Scanning mask files to determine unique values') + with Pool() as p: + unique = list(tqdm( + p.imap(partial(unique_mask_values, mask_dir=self.mask_dir, mask_suffix=self.mask_suffix), self.ids), + total=len(self.ids) + )) + + self.mask_values = list(sorted(np.unique(np.concatenate(unique), axis=0).tolist())) + logging.info(f'Unique mask values: {self.mask_values}') + + def __len__(self): + return len(self.ids) + + @staticmethod + def preprocess(mask_values, pil_img, scale, is_mask): + w, h = pil_img.size + newW, newH = int(scale * w), int(scale * h) + output_stride = 32 + new_h = (h // output_stride + 1) * output_stride if h % output_stride != 0 else h + new_w = (w // output_stride + 1) * output_stride if w % output_stride != 0 else w + assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel' + pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC) + img = np.asarray(pil_img) + + if is_mask: + mask = np.zeros((newH, newW), dtype=np.int64) + for i, v in enumerate(mask_values): + if img.ndim == 2: + mask[img == v] = i + else: + mask[(img == v).all(-1)] = i + + return mask + + else: + if img.ndim == 2: + img = img[np.newaxis, ...] + else: + img = img.transpose((2, 0, 1)) + + # if (img > 1).any(): + # img = img / 255.0 + + return img + + def __getitem__(self, idx): + name = self.ids[idx] + mask_file = list(self.mask_dir.glob(name + '_mask.png')) + img_file = list(self.images_dir.glob(name + '.*')) + + assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}' + assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}' + mask = load_image(mask_file[0]) + img = load_image(img_file[0]) + #print(mask.size, img.size) #size NOT change + + assert img.size == mask.size, \ + f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}' + + if self.transform is not None: + augmentations = self.transform(image=img,mask=mask) + img = augmentations['image'] + mask = augmentations['mask'] + #print(mask.size, img.size) + + img = self.preprocess(self.mask_values, img, self.scale, is_mask=False) + mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True) + #print(mask.size, img.size) + + return { + 'image': torch.as_tensor(img.copy()).float().contiguous(), + 'mask': torch.as_tensor(mask.copy()).long().contiguous() + } + + +class CarvanaDataset(BasicDataset): + def __init__(self, images_dir, mask_dir, scale=1): + super().__init__(images_dir, mask_dir, scale, mask_suffix='_mask')