Diff of /data.py [000000] .. [2507a0]

Switch to unified view

a b/data.py
1
import torch
2
import torchvision
3
4
import pandas as pd
5
import numpy as np
6
7
from PIL import Image
8
9
10
class LungDataset(torch.utils.data.Dataset):
11
    def __init__(self, origin_mask_list, origins_folder, masks_folder, transforms=None,dataset_type="png"):
12
        self.origin_mask_list = origin_mask_list
13
        self.origins_folder = origins_folder
14
        self.masks_folder = masks_folder
15
        self.transforms = transforms
16
        self.dataset_type = dataset_type
17
    def __getitem__(self, idx):
18
        origin_name, mask_name = self.origin_mask_list[idx]
19
        origin = Image.open(self.origins_folder / (origin_name + "."+self.dataset_type)).convert("P")
20
        mask = Image.open(self.masks_folder / (mask_name + "."+self.dataset_type))
21
        if self.transforms is not None:
22
            origin, mask = self.transforms((origin, mask))
23
            
24
        origin = torchvision.transforms.functional.to_tensor(origin) - 0.5
25
    
26
        mask = np.array(mask)
27
        mask = (torch.tensor(mask) > 128).long() 
28
        return origin, mask
29
        
30
    
31
    def __len__(self):
32
        return len(self.origin_mask_list)
33
34
    
35
class Pad():
36
    def __init__(self, max_padding):
37
        self.max_padding = max_padding
38
        
39
    def __call__(self, sample):
40
        origin, mask = sample
41
        padding = np.random.randint(0, self.max_padding)
42
#         origin = torchvision.transforms.functional.pad(origin, padding=padding, padding_mode="symmetric")
43
        origin = torchvision.transforms.functional.pad(origin, padding=padding, fill=0)
44
        mask = torchvision.transforms.functional.pad(mask, padding=padding, fill=0)
45
        return origin, mask
46
47
48
class Crop():
49
    def __init__(self, max_shift):
50
        self.max_shift = max_shift
51
        
52
    def __call__(self, sample):
53
        origin, mask = sample
54
        tl_shift = np.random.randint(0, self.max_shift)
55
        br_shift = np.random.randint(0, self.max_shift)
56
        origin_w, origin_h = origin.size
57
        crop_w = origin_w - tl_shift - br_shift
58
        crop_h = origin_h - tl_shift - br_shift
59
        
60
        origin = torchvision.transforms.functional.crop(origin, tl_shift, tl_shift,
61
                                                        crop_h, crop_w)
62
        mask = torchvision.transforms.functional.crop(mask, tl_shift, tl_shift,
63
                                                        crop_h, crop_w)
64
        return origin, mask
65
66
67
class Resize():
68
    def __init__(self, output_size):
69
        self.output_size = output_size
70
        
71
    def __call__(self, sample):
72
        origin, mask = sample
73
        origin = torchvision.transforms.functional.resize(origin, self.output_size)
74
        mask = torchvision.transforms.functional.resize(mask, self.output_size)
75
        
76
        return origin, mask
77
78
79
def blend(origin, mask1=None, mask2=None,mask3=None):
80
    img = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert("RGB")
81
    if mask1 is not None:
82
        mask1 =  torchvision.transforms.functional.to_pil_image(torch.cat([
83
            torch.zeros_like(origin),
84
            torch.stack([mask1.float()]),
85
            torch.zeros_like(origin)
86
        ]))
87
        img = Image.blend(img, mask1, 0.2)
88
        
89
    if mask2 is not None:
90
        mask2 =  torchvision.transforms.functional.to_pil_image(torch.cat([
91
            torch.stack([mask2.float()]),
92
            torch.zeros_like(origin),
93
            torch.zeros_like(origin)
94
        ]))
95
        img = Image.blend(img, mask2, 0.2)
96
    if mask3 is not None:
97
        mask3 =  torchvision.transforms.functional.to_pil_image(torch.cat([
98
            torch.stack([mask3.float()]),
99
            torch.zeros_like(origin),
100
            torch.zeros_like(origin)
101
        ]))
102
        img = Image.blend(img, mask3, 0.2)
103
    
104
    return img