a b/data_loader/data_loader_18.py
1
import os
2
import torch
3
import numpy as np
4
import math
5
import random
6
import cv2 as cv
7
import nibabel as nib
8
import torch
9
from torch.utils import data
10
import torchvision.transforms as transforms
11
import matplotlib.pyplot as plt
12
import pandas as pd
13
14
from data_loader.preprocess import readVol,to_uint8,IR_to_uint8,histeq,preprocessed,get_stacked,rotate,calc_crop_region,calc_max_region_list,crop,get_edge
15
16
class MR18loader_CV(data.Dataset):
17
    def __init__(self,root='../../data/',val_num=5,is_val=False,
18
                 is_transform=False,is_flip=False,is_rotate=False,is_crop=False,is_histeq=False,forest=5):
19
        self.root=root
20
        self.val_num=val_num
21
        self.is_val=is_val
22
        self.is_transform=is_transform
23
        self.is_flip=is_flip
24
        self.is_rotate=is_rotate
25
        self.is_crop=is_crop
26
        self.is_histeq=is_histeq
27
        self.forest=forest
28
        self.n_classes=11
29
        # Back: Background
30
        # GM:   Cortical GM(red), Basal ganglia(green)
31
        # WM:   WM(yellow), WM lesions(blue)
32
        # CSF:  CSF(pink), Ventricles(light blue)
33
        # Back: Cerebellum(white), Brainstem(dark red)
34
        self.color=np.asarray([[0,0,0],[0,0,255],[0,255,0],[0,255,255],[255,0,0],\
35
                [255,0,255],[255,255,0],[255,255,255],[0,0,128],[0,128,0],[128,0,0]]).astype(np.uint8)
36
        # Back , CSF , GM , WM
37
        self.label_test=[0,2,2,3,3,1,1,0,0]
38
        # nii paths
39
        self.T1path=[self.root+'training/'+name+'/pre/reg_T1.nii.gz' for name in ['1','4','5','7','14','070','148']]
40
        self.IRpath=[self.root+'training/'+name+'/pre/IR.nii.gz' for name in ['1','4','5','7','14','070','148']]
41
        self.T2path=[self.root+'training/'+name+'/pre/FLAIR.nii.gz' for name in ['1','4','5','7','14','070','148']]
42
        self.lblpath=[self.root+'training/'+name+'/segm.nii.gz' for name in ['1','4','5','7','14','070','148']]
43
44
        # val path
45
        self.val_T1path=self.T1path[self.val_num-1]
46
        self.val_IRpath=self.IRpath[self.val_num-1]
47
        self.val_T2path=self.T2path[self.val_num-1]
48
        self.val_lblpath=self.lblpath[self.val_num-1]
49
        # train path
50
        self.train_T1path=[temp for temp in self.T1path if temp not in [self.val_T1path]]
51
        self.train_IRpath=[temp for temp in self.IRpath if temp not in [self.val_IRpath]]
52
        self.train_T2path=[temp for temp in self.T2path if temp not in [self.val_T2path]]
53
        self.train_lblpath=[temp for temp in self.lblpath if temp not in [self.val_lblpath]]
54
        
55
        if self.is_val==False:
56
            print('training data')
57
            T1_nii=[to_uint8(readVol(path)) for path in self.train_T1path]
58
            IR_nii=[IR_to_uint8(readVol(path)) for path in self.train_IRpath]
59
            T2_nii=[to_uint8(readVol(path)) for path in self.train_T2path]
60
            lbl_nii=[readVol(path) for path in self.train_lblpath]
61
            
62
            if self.is_flip:
63
                vol_num=len(T1_nii)
64
                for nums in range(vol_num):
65
                    T1_nii.append(np.array([cv.flip(slice_,1) for slice_ in T1_nii[nums]]))
66
                    IR_nii.append(np.array([cv.flip(slice_,1) for slice_ in IR_nii[nums]]))
67
                    T2_nii.append(np.array([cv.flip(slice_,1) for slice_ in T2_nii[nums]]))
68
                    lbl_nii.append(np.array([cv.flip(slice_,1) for slice_ in lbl_nii[nums]]))
69
70
            if self.is_histeq:
71
                print('hist equalizing......')
72
                T1_nii=[histeq(vol) for vol in T1_nii]
73
                IR_nii=[vol for vol in IR_nii]
74
                T2_nii=[vol for vol in T2_nii]
75
76
            print('get stacking......')
77
            T1_stack_lists=[get_stacked(vol,self.forest) for vol in T1_nii]
78
            IR_stack_lists=[get_stacked(vol,self.forest) for vol in IR_nii]
79
            T2_stack_lists=[get_stacked(vol,self.forest) for vol in T2_nii]
80
            lbl_stack_lists=[get_stacked(vol,self.forest) for vol in lbl_nii]
81
82
            if self.is_rotate:
83
                print('rotating......')
84
                angle_list=[5,-5,10,-10,15,-15]
85
                sample_num=len(T1_stack_lists)
86
                for angle in angle_list:
87
                    for sample_index in range(sample_num):
88
                        T1_stack_lists.append(rotate(T1_stack_lists[sample_index],angle,interp=cv.INTER_CUBIC).copy())
89
                        IR_stack_lists.append(rotate(IR_stack_lists[sample_index],angle,interp=cv.INTER_CUBIC).copy())
90
                        T2_stack_lists.append(rotate(T2_stack_lists[sample_index],angle,interp=cv.INTER_CUBIC).copy())
91
                        lbl_stack_lists.append(rotate(lbl_stack_lists[sample_index],angle,interp=cv.INTER_NEAREST).copy())
92
93
            if self.is_crop:
94
                print('cropping......')
95
                region_lists=[calc_max_region_list(calc_crop_region(T1_stack_list,50,5),self.forest) for T1_stack_list in T1_stack_lists]
96
                self.region_lists=region_lists
97
                T1_stack_lists=[crop(stack_list,region_lists[list_index]) for list_index,stack_list in enumerate(T1_stack_lists)]
98
                IR_stack_lists=[crop(stack_list,region_lists[list_index]) for list_index,stack_list in enumerate(IR_stack_lists)]
99
                T2_stack_lists=[crop(stack_list,region_lists[list_index]) for list_index,stack_list in enumerate(T2_stack_lists)]
100
                lbl_stack_lists=[crop(stack_list,region_lists[list_index]) for list_index,stack_list in enumerate(lbl_stack_lists)]
101
            '''
102
            print('len=',len(T1_stack_lists))
103
            T1_path_list=[]
104
            IR_path_list=[]
105
            T2_path_list=[]
106
            lbl_path_list=[]
107
            range_list=[]
108
            name=['1','4','5','7','14','070','148']
109
            f_n=['n','f']
110
            ang=['0','5','-5','10','-10','15','-15']
111
            save_path='../../../../data/'
112
            for sam_i,sample in enumerate(T1_stack_lists):
113
                for img_j,img in enumerate(sample):
114
                    T1_path_list.append('imgs/'+'T1/'+'{}_{}_{}_{}.png'.format(name[sam_i%7],f_n[(int(sam_i/7))%2],ang[int(sam_i/14)],img_j))
115
                    path=save_path+'imgs/'+'T1/'+'{}_{}_{}_{}.png'.format(name[sam_i%7],f_n[(int(sam_i/7))%2],ang[int(sam_i/14)],img_j)
116
                    cv.imwrite(path,img)
117
            for sam_i,sample in enumerate(IR_stack_lists):
118
                for img_j,img in enumerate(sample):
119
                    IR_path_list.append('imgs/'+'IR/'+'{}_{}_{}_{}.png'.format(name[sam_i%7],f_n[(int(sam_i/7))%2],ang[int(sam_i/14)],img_j))
120
                    path=save_path+'imgs/'+'IR/'+'{}_{}_{}_{}.png'.format(name[sam_i%7],f_n[(int(sam_i/7))%2],ang[int(sam_i/14)],img_j)
121
                    cv.imwrite(path,img)
122
            for sam_i,sample in enumerate(T2_stack_lists):
123
                for img_j,img in enumerate(sample):
124
                    T2_path_list.append('imgs/'+'T2/'+'{}_{}_{}_{}.png'.format(name[sam_i%7],f_n[(int(sam_i/7))%2],ang[int(sam_i/14)],img_j))
125
                    path=save_path+'imgs/'+'T2/'+'{}_{}_{}_{}.png'.format(name[sam_i%7],f_n[(int(sam_i/7))%2],ang[int(sam_i/14)],img_j)
126
                    cv.imwrite(path,img)
127
            for sam_i,sample in enumerate(lbl_stack_lists):
128
                for img_j,img in enumerate(sample):
129
                    lbl_path_list.append('lbls/'+'{}_{}_{}_{}.png'.format(name[sam_i%7],f_n[(int(sam_i/7))%2],ang[int(sam_i/14)],img_j))
130
                    path=save_path+'lbls/'+'{}_{}_{}_{}.png'.format(name[sam_i%7],f_n[(int(sam_i/7))%2],ang[int(sam_i/14)],img_j)
131
                    print(img.shape)
132
                    cv.imwrite(path,img)
133
            for sam_i,sample in enumerate(region_lists):
134
                for img_j,img in enumerate(sample):
135
                    range_list.append(img)
136
            range_array=np.array(range_list)
137
            y_min_list=range_array[:,0]
138
            y_max_list=range_array[:,1]
139
            x_min_list=range_array[:,2]
140
            x_max_list=range_array[:,3]
141
            df=pd.DataFrame({   'T1':T1_path_list,'IR':IR_path_list,'T2':T2_path_list,'lbl':lbl_path_list,
142
                                'y_min':y_min_list,'y_max':y_max_list,'x_min':x_min_list,'x_max':x_max_list})
143
            print(df)
144
            df.to_csv("index.csv")
145
            '''
146
            # get means
147
            T1mean,IRmean,T2mean=0.0,0.0,0.0
148
            for samples in T1_stack_lists:
149
                for stacks in samples:
150
                    T1mean=T1mean+np.mean(stacks)
151
            T1mean=T1mean/(len(T1_stack_lists)*len(T1_stack_lists[0]))
152
            print('T1 mean = ',T1mean)
153
            self.T1mean=T1mean
154
            for samples in IR_stack_lists:
155
                for stacks in samples:
156
                    IRmean=IRmean+np.mean(stacks)
157
            IRmean=IRmean/(len(IR_stack_lists)*len(IR_stack_lists[0]))
158
            print('IR mean = ',IRmean)
159
            self.IRmean=IRmean
160
            for samples in T2_stack_lists:
161
                for stacks in samples:
162
                    T2mean=T2mean+np.mean(stacks)
163
            T2mean=T2mean/(len(T2_stack_lists)*len(T2_stack_lists[0]))
164
            print('T2 mean = ',T2mean)
165
            self.T2mean=T2mean
166
167
            # get edegs
168
            print('getting edges')
169
            edge_stack_lists=[]
170
            for samples in lbl_stack_lists:
171
                edge_stack_lists.append(get_edge(samples))
172
173
            # transform
174
            if self.is_transform:
175
                print('transforming')
176
                for sample_index in range(len(T1_stack_lists)):
177
                    for stack_index in range(len(T1_stack_lists[0])):
178
                        T1_stack_lists[sample_index][stack_index],  \
179
                        IR_stack_lists[sample_index][stack_index],  \
180
                        T2_stack_lists[sample_index][stack_index],  \
181
                        lbl_stack_lists[sample_index][stack_index], \
182
                        edge_stack_lists[sample_index][stack_index]=\
183
                        self.transform(                             \
184
                        T1_stack_lists[sample_index][stack_index],  \
185
                        IR_stack_lists[sample_index][stack_index],  \
186
                        T2_stack_lists[sample_index][stack_index],  \
187
                        lbl_stack_lists[sample_index][stack_index], \
188
                        edge_stack_lists[sample_index][stack_index])
189
        
190
        else:
191
            print('validating data')
192
            T1_nii=to_uint8(readVol(self.val_T1path))
193
            IR_nii=IR_to_uint8(readVol(self.val_IRpath))
194
            T2_nii=to_uint8(readVol(self.val_T2path))
195
            lbl_nii=readVol(self.val_lblpath)
196
197
            if self.is_histeq:
198
                print('hist equalizing......')
199
                T1_nii=histeq(T1_nii)
200
                IR_nii=IR_nii
201
                T1_nii=T1_nii
202
203
            print('get stacking......')
204
            T1_stack_lists=get_stacked(T1_nii,self.forest)
205
            IR_stack_lists=get_stacked(IR_nii,self.forest)
206
            T2_stack_lists=get_stacked(T2_nii,self.forest)
207
            lbl_stack_lists=get_stacked(lbl_nii,self.forest)
208
209
            if self.is_crop:
210
                print('cropping......')
211
                region_lists=calc_max_region_list(calc_crop_region(T1_stack_lists,50,5),self.forest)
212
                self.region_lists=region_lists
213
                T1_stack_lists=crop(T1_stack_lists,region_lists)
214
                IR_stack_lists=crop(IR_stack_lists,region_lists)
215
                T2_stack_lists=crop(T2_stack_lists,region_lists)
216
                lbl_stack_lists=crop(lbl_stack_lists,region_lists)
217
218
            # get means
219
            T1mean,IRmean,T2mean=0.0,0.0,0.0
220
            for stacks in T1_stack_lists:
221
                T1mean=T1mean+np.mean(stacks)
222
            T1mean=T1mean/(len(T1_stack_lists))
223
            print('T1 mean = ',T1mean)
224
            self.T1mean=T1mean
225
            for stacks in IR_stack_lists:
226
                IRmean=IRmean+np.mean(stacks)
227
            IRmean=IRmean/(len(IR_stack_lists))
228
            print('IR mean = ',IRmean)
229
            self.IRmean=IRmean
230
            for stacks in T2_stack_lists:
231
                T2mean=T2mean+np.mean(stacks)
232
            T2mean=T2mean/(len(T2_stack_lists))
233
            print('T2 mean = ',T2mean)
234
            self.T2mean=T2mean
235
236
            # get edges
237
            print('getting edges')
238
            edge_stack_lists=get_edge(lbl_stack_lists)
239
240
            # transform
241
            if self.is_transform:
242
                print('transforming')
243
                for stack_index in range(len(T1_stack_lists)):
244
                    T1_stack_lists[stack_index],  \
245
                    IR_stack_lists[stack_index],  \
246
                    T2_stack_lists[stack_index],  \
247
                    lbl_stack_lists[stack_index], \
248
                    edge_stack_lists[stack_index]=\
249
                    self.transform(               \
250
                    T1_stack_lists[stack_index],  \
251
                    IR_stack_lists[stack_index],  \
252
                    T2_stack_lists[stack_index],  \
253
                    lbl_stack_lists[stack_index], \
254
                    edge_stack_lists[stack_index])
255
256
        # data ready
257
        self.T1_stack_lists=T1_stack_lists
258
        self.IR_stack_lists=IR_stack_lists
259
        self.T2_stack_lists=T2_stack_lists
260
        self.lbl_stack_lists=lbl_stack_lists
261
        self.edge_stack_lists=edge_stack_lists
262
263
264
    def __len__(self):
265
        return (self.is_val)and(48)or(48*6*7*2)
266
    def __getitem__(self,index):
267
        # get train or validation data
268
        if self.is_val==False:
269
            set_index=range(len(self.T1_stack_lists))
270
            img_index=range(len(self.T1_stack_lists[0]))
271
            return  \
272
                self.region_lists[set_index[int(index/48)]][img_index[int(index%48)]],  \
273
                self.T1_stack_lists[set_index[int(index/48)]][img_index[int(index%48)]],\
274
                self.IR_stack_lists[set_index[int(index/48)]][img_index[int(index%48)]],\
275
                self.T2_stack_lists[set_index[int(index/48)]][img_index[int(index%48)]],\
276
                self.lbl_stack_lists[set_index[int(index/48)]][img_index[int(index%48)]]
277
                #self.edge_stack_lists[set_index[int(index/48)]][img_index[int(index%48)]]
278
279
        else:
280
            img_index=range(len(self.T1_stack_lists))
281
            return  \
282
                self.region_lists[img_index[int(index)]],   \
283
                self.T1_stack_lists[img_index[int(index)]], \
284
                self.IR_stack_lists[img_index[int(index)]], \
285
                self.T2_stack_lists[img_index[int(index)]], \
286
                self.lbl_stack_lists[img_index[int(index)]]
287
                #self.edge_stack_lists[img_index[int(index)]]
288
289
    
290
    
291
    
292
    def transform(self,imgT1,imgIR,imgT2,lbl,edge):
293
        imgT1=torch.from_numpy((imgT1.transpose(2,0,1).astype(np.float)-self.T1mean)/255.0).float()
294
        imgIR=torch.from_numpy((imgIR.transpose(2,0,1).astype(np.float)-self.IRmean)/255.0).float()
295
        imgT2=torch.from_numpy((imgT2.transpose(2,0,1).astype(np.float)-self.T2mean)/255.0).float()
296
        lbl=torch.from_numpy(lbl.transpose(2,0,1)).long()
297
        edge=torch.from_numpy(edge.transpose(2,0,1)/255).float()
298
        return imgT1,imgIR,imgT2,lbl,edge
299
    def decode_segmap(self,label_mask):
300
        r,g,b=label_mask.copy(),label_mask.copy(),label_mask.copy()
301
        for ll in range(0,self.n_classes):
302
            r[label_mask==ll]=self.color[ll,2]
303
            g[label_mask==ll]=self.color[ll,1]
304
            b[label_mask==ll]=self.color[ll,0]
305
        rgb=np.zeros((label_mask.shape[0],label_mask.shape[1],3))
306
        rgb[:,:,0],rgb[:,:,1],rgb[:,:,2]=r,g,b
307
        return rgb
308
    def lbl_totest(self,pred):
309
        pred_test=np.zeros((pred.shape[0],pred.shape[1]),np.uint8)
310
        for ll in range(9):
311
            pred_test[pred==ll]=self.label_test[ll]
312
        return pred_test
313
314
if __name__=='__main__':
315
    path='../../../../data/'
316
    MRloader=MR18loader_CV(root=path,val_num=7,is_val=False,is_transform=True,is_flip=True,is_rotate=True,is_crop=True,is_histeq=True,forest=3)
317
    loader=data.DataLoader(MRloader, batch_size=1, num_workers=1, shuffle=True)
318
    for i,(regions,T1s,IRs,T2s,lbls) in enumerate(MRloader):
319
        print(i)
320
        #print(T1s.shape)
321
        #print(regions)
322
        #print(lbls.min())
323
        #print(lbls.max())
324
        #cv.imwrite(str(i)+'.png',T1s[:,:,1])
325
        #print(region)
326
        #print(imgT1.shape)
327
        #print(imgIR.shape)
328
        #print(imgT2.shape)
329
        #print(lbl.shape)
330
331
        #print('[{},{},{},{}]'.format(imgT1[0,2,40,40],imgIR[0,2,40,40],imgT2[0,2,40,40],lbl[0,2,40,40]))
332
        
333
        #cv.imwrite('T1-'+str(i)+'.png',imgT1[2])
334
        #cv.imwrite('IR-'+str(i)+'.png',imgIR[2])
335
        #cv.imwrite('T2-'+str(i)+'.png',imgT2[2])
336
        #cv.imwrite('lbl-'+str(i)+'.png',MRloader.decode_segmap(lbl[2]))
337