--- a +++ b/data.py @@ -0,0 +1,104 @@ +import torch +import torchvision + +import pandas as pd +import numpy as np + +from PIL import Image + + +class LungDataset(torch.utils.data.Dataset): + def __init__(self, origin_mask_list, origins_folder, masks_folder, transforms=None,dataset_type="png"): + self.origin_mask_list = origin_mask_list + self.origins_folder = origins_folder + self.masks_folder = masks_folder + self.transforms = transforms + self.dataset_type = dataset_type + def __getitem__(self, idx): + origin_name, mask_name = self.origin_mask_list[idx] + origin = Image.open(self.origins_folder / (origin_name + "."+self.dataset_type)).convert("P") + mask = Image.open(self.masks_folder / (mask_name + "."+self.dataset_type)) + if self.transforms is not None: + origin, mask = self.transforms((origin, mask)) + + origin = torchvision.transforms.functional.to_tensor(origin) - 0.5 + + mask = np.array(mask) + mask = (torch.tensor(mask) > 128).long() + return origin, mask + + + def __len__(self): + return len(self.origin_mask_list) + + +class Pad(): + def __init__(self, max_padding): + self.max_padding = max_padding + + def __call__(self, sample): + origin, mask = sample + padding = np.random.randint(0, self.max_padding) +# origin = torchvision.transforms.functional.pad(origin, padding=padding, padding_mode="symmetric") + origin = torchvision.transforms.functional.pad(origin, padding=padding, fill=0) + mask = torchvision.transforms.functional.pad(mask, padding=padding, fill=0) + return origin, mask + + +class Crop(): + def __init__(self, max_shift): + self.max_shift = max_shift + + def __call__(self, sample): + origin, mask = sample + tl_shift = np.random.randint(0, self.max_shift) + br_shift = np.random.randint(0, self.max_shift) + origin_w, origin_h = origin.size + crop_w = origin_w - tl_shift - br_shift + crop_h = origin_h - tl_shift - br_shift + + origin = torchvision.transforms.functional.crop(origin, tl_shift, tl_shift, + crop_h, crop_w) + mask = torchvision.transforms.functional.crop(mask, tl_shift, tl_shift, + crop_h, crop_w) + return origin, mask + + +class Resize(): + def __init__(self, output_size): + self.output_size = output_size + + def __call__(self, sample): + origin, mask = sample + origin = torchvision.transforms.functional.resize(origin, self.output_size) + mask = torchvision.transforms.functional.resize(mask, self.output_size) + + return origin, mask + + +def blend(origin, mask1=None, mask2=None,mask3=None): + img = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert("RGB") + if mask1 is not None: + mask1 = torchvision.transforms.functional.to_pil_image(torch.cat([ + torch.zeros_like(origin), + torch.stack([mask1.float()]), + torch.zeros_like(origin) + ])) + img = Image.blend(img, mask1, 0.2) + + if mask2 is not None: + mask2 = torchvision.transforms.functional.to_pil_image(torch.cat([ + torch.stack([mask2.float()]), + torch.zeros_like(origin), + torch.zeros_like(origin) + ])) + img = Image.blend(img, mask2, 0.2) + if mask3 is not None: + mask3 = torchvision.transforms.functional.to_pil_image(torch.cat([ + torch.stack([mask3.float()]), + torch.zeros_like(origin), + torch.zeros_like(origin) + ])) + img = Image.blend(img, mask3, 0.2) + + return img