|
a |
|
b/data.py |
|
|
1 |
import torch |
|
|
2 |
import torchvision |
|
|
3 |
|
|
|
4 |
import pandas as pd |
|
|
5 |
import numpy as np |
|
|
6 |
|
|
|
7 |
from PIL import Image |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
class LungDataset(torch.utils.data.Dataset): |
|
|
11 |
def __init__(self, origin_mask_list, origins_folder, masks_folder, transforms=None,dataset_type="png"): |
|
|
12 |
self.origin_mask_list = origin_mask_list |
|
|
13 |
self.origins_folder = origins_folder |
|
|
14 |
self.masks_folder = masks_folder |
|
|
15 |
self.transforms = transforms |
|
|
16 |
self.dataset_type = dataset_type |
|
|
17 |
def __getitem__(self, idx): |
|
|
18 |
origin_name, mask_name = self.origin_mask_list[idx] |
|
|
19 |
origin = Image.open(self.origins_folder / (origin_name + "."+self.dataset_type)).convert("P") |
|
|
20 |
mask = Image.open(self.masks_folder / (mask_name + "."+self.dataset_type)) |
|
|
21 |
if self.transforms is not None: |
|
|
22 |
origin, mask = self.transforms((origin, mask)) |
|
|
23 |
|
|
|
24 |
origin = torchvision.transforms.functional.to_tensor(origin) - 0.5 |
|
|
25 |
|
|
|
26 |
mask = np.array(mask) |
|
|
27 |
mask = (torch.tensor(mask) > 128).long() |
|
|
28 |
return origin, mask |
|
|
29 |
|
|
|
30 |
|
|
|
31 |
def __len__(self): |
|
|
32 |
return len(self.origin_mask_list) |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
class Pad(): |
|
|
36 |
def __init__(self, max_padding): |
|
|
37 |
self.max_padding = max_padding |
|
|
38 |
|
|
|
39 |
def __call__(self, sample): |
|
|
40 |
origin, mask = sample |
|
|
41 |
padding = np.random.randint(0, self.max_padding) |
|
|
42 |
# origin = torchvision.transforms.functional.pad(origin, padding=padding, padding_mode="symmetric") |
|
|
43 |
origin = torchvision.transforms.functional.pad(origin, padding=padding, fill=0) |
|
|
44 |
mask = torchvision.transforms.functional.pad(mask, padding=padding, fill=0) |
|
|
45 |
return origin, mask |
|
|
46 |
|
|
|
47 |
|
|
|
48 |
class Crop(): |
|
|
49 |
def __init__(self, max_shift): |
|
|
50 |
self.max_shift = max_shift |
|
|
51 |
|
|
|
52 |
def __call__(self, sample): |
|
|
53 |
origin, mask = sample |
|
|
54 |
tl_shift = np.random.randint(0, self.max_shift) |
|
|
55 |
br_shift = np.random.randint(0, self.max_shift) |
|
|
56 |
origin_w, origin_h = origin.size |
|
|
57 |
crop_w = origin_w - tl_shift - br_shift |
|
|
58 |
crop_h = origin_h - tl_shift - br_shift |
|
|
59 |
|
|
|
60 |
origin = torchvision.transforms.functional.crop(origin, tl_shift, tl_shift, |
|
|
61 |
crop_h, crop_w) |
|
|
62 |
mask = torchvision.transforms.functional.crop(mask, tl_shift, tl_shift, |
|
|
63 |
crop_h, crop_w) |
|
|
64 |
return origin, mask |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
class Resize(): |
|
|
68 |
def __init__(self, output_size): |
|
|
69 |
self.output_size = output_size |
|
|
70 |
|
|
|
71 |
def __call__(self, sample): |
|
|
72 |
origin, mask = sample |
|
|
73 |
origin = torchvision.transforms.functional.resize(origin, self.output_size) |
|
|
74 |
mask = torchvision.transforms.functional.resize(mask, self.output_size) |
|
|
75 |
|
|
|
76 |
return origin, mask |
|
|
77 |
|
|
|
78 |
|
|
|
79 |
def blend(origin, mask1=None, mask2=None,mask3=None): |
|
|
80 |
img = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert("RGB") |
|
|
81 |
if mask1 is not None: |
|
|
82 |
mask1 = torchvision.transforms.functional.to_pil_image(torch.cat([ |
|
|
83 |
torch.zeros_like(origin), |
|
|
84 |
torch.stack([mask1.float()]), |
|
|
85 |
torch.zeros_like(origin) |
|
|
86 |
])) |
|
|
87 |
img = Image.blend(img, mask1, 0.2) |
|
|
88 |
|
|
|
89 |
if mask2 is not None: |
|
|
90 |
mask2 = torchvision.transforms.functional.to_pil_image(torch.cat([ |
|
|
91 |
torch.stack([mask2.float()]), |
|
|
92 |
torch.zeros_like(origin), |
|
|
93 |
torch.zeros_like(origin) |
|
|
94 |
])) |
|
|
95 |
img = Image.blend(img, mask2, 0.2) |
|
|
96 |
if mask3 is not None: |
|
|
97 |
mask3 = torchvision.transforms.functional.to_pil_image(torch.cat([ |
|
|
98 |
torch.stack([mask3.float()]), |
|
|
99 |
torch.zeros_like(origin), |
|
|
100 |
torch.zeros_like(origin) |
|
|
101 |
])) |
|
|
102 |
img = Image.blend(img, mask3, 0.2) |
|
|
103 |
|
|
|
104 |
return img |