a b/dataprocess/segdataloader.py
1
import os
2
import numpy as np
3
from glob import glob
4
from PIL import Image
5
import torch
6
from torchvision.transforms import Compose, Normalize, ToTensor
7
from Transforms import Scale
8
import cv2
9
10
# width, height = 96, 96
11
12
13
def att_compare(a, b=3):
14
    if a > b:
15
        return np.array([1])
16
    else:
17
        return np.array([0])
18
19
20
def makedirs(path):
21
    if not os.path.exists(path):
22
        os.makedirs(path)
23
24
25
def normalazation(image_array):
26
    max = image_array.max()
27
    min = image_array.min()
28
    image_array = (image_array-min)/(max-min)  # float cannot apply the compute,or array error will occur
29
    avg = image_array.mean()
30
    image_array = image_array-avg
31
    return image_array   # a bug here, a array must be returned,directly appling function did't work
32
33
34
class Dataset(torch.utils.data.Dataset):
35
    def __init__(self, datas, lungs, medias, inters, unions, masks, label=None, width=64, height=64):
36
        self.size = (width, height)
37
        self.datas = datas
38
        self.lungs = lungs
39
        self.medias = medias
40
        self.inters = inters
41
        self.unions = unions
42
        self.masks = masks
43
        self.img_resize = Compose([
44
            Scale(self.size, Image.BILINEAR),
45
46
        ])
47
        self.label_resize = Compose([
48
            Scale(self.size, Image.NEAREST),
49
        ])
50
        self.img_transform_gray = Compose([
51
            ToTensor(),
52
            Normalize(mean=[0.448749], std=[0.399953])  # 这个归一化方式可以提升近1个点的dice
53
        ])
54
55
        self.img_transform_rgb = Compose([
56
            ToTensor(),
57
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
58
        ])
59
60
        self.input_paths = self.datas
61
        self.lung_paths = self.lungs
62
        self.media_paths = self.medias
63
64
        self.inter_path = self.inters
65
        self.union_path = self.unions
66
        self.mask_path = self.masks
67
68
        print('Training data:')
69
        print(len(self.datas))
70
71
    def __getitem__(self, index):
72
        input = np.load(self.input_paths[index])
73
        lung = np.load(self.lung_paths[index])
74
        media = np.load(self.media_paths[index])
75
        lung = normalazation(lung)
76
        media = normalazation(media)
77
78
        inter = Image.open(self.inter_path[index])
79
        union = Image.open(self.union_path[index])
80
        mask = Image.open(self.mask_path[index])
81
82
        input = cv2.resize(input, self.size)
83
        lung = cv2.resize(lung, self.size)
84
        media = cv2.resize(media, self.size)
85
        inter = self.img_resize(inter)
86
        union = self.img_resize(union)
87
        mask = self.img_resize(mask)
88
89
        torch.from_numpy(input)
90
        torch.from_numpy(lung)
91
        torch.from_numpy(media)
92
93
        # 归一化
94
        input = self.img_transform_gray(input)
95
        lung = self.img_transform_gray(lung)
96
        media = self.img_transform_gray(media)
97
        inter = self.img_transform_gray(inter)
98
        union = self.img_transform_gray(union)
99
        mask = self.img_transform_gray(mask)
100
101
        # 二值化
102
        inter[inter > 0.5] = 1
103
        inter[inter < 0.5] = 0
104
        union[union > 0.5] = 1
105
        union[union < 0.5] = 0
106
        mask[mask > 0.5] = 1
107
        mask[mask < 0.5] = 0
108
109
        return input, lung, media, union, inter, mask
110
111
    def __len__(self):
112
        return len(self.input_paths)
113
114
115
class Dataset_val(torch.utils.data.Dataset):
116
    def __init__(self, datas, medias, inters, unions, masks, label=None, width=64, height=64):
117
        self.size = (width, height)
118
        self.data = datas
119
        self.medias = medias
120
        self.inters = inters
121
        self.unions = unions
122
        self.mask = masks
123
        self.img_resize = Compose([
124
            Scale(self.size, Image.BILINEAR),
125
126
        ])
127
        self.label_resize = Compose([
128
            Scale(self.size, Image.NEAREST),
129
        ])
130
        self.img_transform_gray = Compose([
131
            ToTensor(),
132
            Normalize(mean=[0.448749], std=[0.399953])  # 这个归一化方式可以提升近1个点的dice
133
        ])
134
135
        self.img_transform_rgb = Compose([
136
            ToTensor(),
137
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
138
        ])
139
140
        self.input_paths = self.data
141
        self.media_paths = self.medias
142
        self.inter_path = self.inters
143
        self.union_path = self.unions
144
        self.mask_paths = self.mask
145
146
        print('Testing data:')
147
        print(len(self.input_paths))
148
149
    def __getitem__(self, index):
150
        input = np.load(self.input_paths[index])
151
        media = np.load(self.media_paths[index])
152
        inter = Image.open(self.inter_path[index])
153
        union = Image.open(self.union_path[index])
154
        mask = Image.open(self.mask_paths[index])
155
156
        input = cv2.resize(input, self.size)
157
        media = cv2.resize(media, self.size)
158
        inter = self.img_resize(inter)
159
        union = self.img_resize(union)
160
        mask = self.img_resize(mask)
161
162
        torch.from_numpy(input)
163
        torch.from_numpy(media)
164
165
        input = self.img_transform_gray(input)
166
        media = self.img_transform_gray(media)
167
        inter = self.img_transform_gray(inter)
168
        union = self.img_transform_gray(union)
169
        mask = self.img_transform_gray(mask)
170
171
        # 二值化
172
        inter[inter > 0.5] = 1
173
        inter[inter < 0.5] = 0
174
        union[union > 0.5] = 1
175
        union[union < 0.5] = 0
176
        mask[mask > 0.5] = 1
177
        mask[mask < 0.5] = 0
178
        
179
180
        return input, media, union, inter, mask
181
182
    def __len__(self):
183
        return len(self.input_paths)
184
185
186
class RowDataset(torch.utils.data.Dataset):
187
    def __init__(self, data, mask, label=None, width=64, height=64):
188
        self.size = (width, height)
189
        self.data = data
190
        self.mask = mask
191
        self.label = label
192
        self.img_resize = Compose([
193
            Scale(self.size, Image.BILINEAR),
194
195
        ])
196
        self.label_resize = Compose([
197
            Scale(self.size, Image.NEAREST),
198
        ])
199
        self.img_transform_gray = Compose([
200
            ToTensor(),
201
            Normalize(mean=[0.448749], std=[0.399953])  # 这个归一化方式可以提升近1个点的dice
202
        ])
203
204
        self.img_transform_rgb = Compose([
205
            ToTensor(),
206
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
207
        ])
208
209
        self.input_paths = self.data
210
        self.mask_paths = self.mask
211
        self.label_paths = self.label
212
213
    def __getitem__(self, index):
214
        image = Image.open(self.input_paths[index])
215
        mask = Image.open(self.mask_paths[index])
216
        # mask.save('image1.jpg')
217
218
        image = self.img_resize(image)
219
        mask = self.img_resize(mask)
220
        # mask.save('image2.jpg')
221
222
        image = self.img_transform_gray(image)
223
        mask = self.img_transform_gray(mask)
224
225
        # mask = mask.squeeze()
226
        # mask = np.transpose(mask.cpu().detach().numpy(), (0,1))
227
        # cv2.imwrite('image3.jpg', mask)
228
        return image, mask
229
230
    def __len__(self):
231
        return len(self.input_paths)
232
233
234
def loader(dataset, batch_size, num_workers=8, shuffle=False):
235
    input_images = dataset
236
237
    input_loader = torch.utils.data.DataLoader(dataset=input_images,
238
                                               batch_size=batch_size,
239
                                               shuffle=shuffle,
240
                                               num_workers=num_workers)
241
242
    return input_loader