a b/data_loader/preprocess.py
1
import os
2
import numpy as np
3
import math
4
import random
5
import cv2 as cv
6
import nibabel as nib
7
import torch
8
9
# in: volume path
10
# out: volume data in array
11
def readVol(volpath):
12
    return nib.load(volpath).get_data()
13
14
# in: volume array
15
# out: comprise to uint8, put 0 where number<0
16
def to_uint8(vol):
17
    vol=vol.astype(np.float)
18
    vol[vol<0]=0
19
    return ((vol-vol.min())*255.0/vol.max()).astype(np.uint8)
20
21
# in: volume array
22
# out: comprise to uint8, put 0 where number<800
23
def IR_to_uint8(vol):
24
    vol=vol.astype(np.float)
25
    vol[vol<0]=0
26
    return ((vol-800)*255.0/vol.max()).astype(np.uint8)
27
28
# in: volume array
29
# out: hist equalized volume arrray
30
def histeq(vol):
31
    for slice_index in range(vol.shape[2]):
32
        vol[:,:,slice_index]=cv.equalizeHist(vol[:,:,slice_index])
33
    return vol
34
35
# in: volume array
36
# out: preprocessed array
37
def preprocessed(vol):
38
    for slice_index in range(vol.shape[2]):
39
        cur_slice=vol[:,:,slice_index]
40
        sob_x=cv.Sobel(cur_slice,cv.CV_16S,1,0)
41
        sob_y=cv.Sobel(cur_slice,cv.CV_16S,0,1)
42
        absX=cv.convertScaleAbs(sob_x)
43
        absY=cv.convertScaleAbs(sob_y)
44
        sob=cv.addWeighted(absX,0.5,absY,0.5,0)
45
        dst=cur_slice+0.5*sob
46
        vol[:,:,slice_index]=dst
47
    return vol
48
49
# in: index of slice, stack number, slice number
50
# out: which slice should be stacked
51
def get_stackindex(slice_index, stack_num, slice_num):
52
    assert stack_num%2==1, 'stack numbers must be odd!'
53
    query_list=[0]*stack_num
54
    for stack_index in range(stack_num):
55
        query_list[stack_index]=(slice_index+(stack_index-int(stack_num/2)))%slice_num
56
    return query_list
57
58
# in: volume array, stack number
59
# out: stacked img in list
60
def get_stacked(vol,stack_num):
61
    stack_list=[]
62
    stacked_slice=np.zeros((vol.shape[0],vol.shape[1],stack_num),np.uint8)
63
    for slice_index in range(vol.shape[2]):
64
        query_list=get_stackindex(slice_index,stack_num,vol.shape[2])
65
        for index_query_list,query_list_content in enumerate(query_list):
66
            stacked_slice[:,:,index_query_list]=vol[:,:,query_list_content].transpose()
67
        stack_list.append(stacked_slice.copy())
68
    return stack_list
69
70
# in: stacked img, rotate angle
71
# out: rotated imgs
72
def rotate(stack_list,angle,interp):
73
    for stack_list_index,stacked in enumerate(stack_list):
74
        raws,cols=stacked.shape[0:2]
75
        M=cv.getRotationMatrix2D(((cols-1)/2.0,(raws-1)/2.0),angle,1)
76
        stack_list[stack_list_index]=cv.warpAffine(stacked,M,(cols,raws),flags=interp)
77
    return stack_list
78
79
# in: T1 volume, foreground threshold, margin pixel numbers
80
# out: which region should be cropped
81
def calc_crop_region(stack_list_T1,thre,pix):
82
    crop_region=[]
83
    for stack_list_index,stacked in enumerate(stack_list_T1):
84
        _,threimg=cv.threshold(stacked[:,:,int(stacked.shape[2]/2)].copy(),thre,255,cv.THRESH_TOZERO)
85
        pix_index=np.where(threimg>0)
86
        if not pix_index[0].size==0:
87
            y_min,y_max=min(pix_index[0]),max(pix_index[0])
88
            x_min,x_max=min(pix_index[1]),max(pix_index[1])
89
        else:
90
            y_min,y_max=pix,pix
91
            x_min,x_max=pix,pix
92
        y_min=(y_min<=pix)and(0)or(y_min)
93
        y_max=(y_max>=stacked.shape[0]-1-pix)and(stacked.shape[0]-1)or(y_max)
94
        x_min=(x_min<=pix)and(0)or(x_min)
95
        x_max=(x_max>=stacked.shape[1]-1-pix)and(stacked.shape[1]-1)or(x_max)
96
        crop_region.append([y_min,y_max,x_min,x_max])
97
    return crop_region
98
99
# in: crop region for each slice, how many slices in a stack
100
# out: max region in a stacked img
101
def calc_max_region_list(region_list,stack_num):
102
    max_region_list=[]
103
    for region_list_index in range(len(region_list)):
104
        y_min_list,y_max_list,x_min_list,x_max_list=[],[],[],[]
105
        for stack_index in range(stack_num):
106
            query_list=get_stackindex(region_list_index,stack_num,len(region_list))
107
            region=region_list[query_list[stack_index]]
108
            y_min_list.append(region[0])
109
            y_max_list.append(region[1])
110
            x_min_list.append(region[2])
111
            x_max_list.append(region[3])
112
        max_region_list.append([min(y_min_list),max(y_max_list),min(x_min_list),max(x_max_list)])
113
    return max_region_list
114
115
# in: size, devider
116
# out: padded size which can be devide by devider
117
def calc_ceil_pad(x,devider):
118
    return math.ceil(x/float(devider))*devider
119
120
# in: stack img list, maxed region list
121
# out: cropped img list
122
def crop(stack_list,region_list):
123
    cropped_list=[]
124
    for stack_list_index,stacked in enumerate(stack_list):
125
        y_min,y_max,x_min,x_max=region_list[stack_list_index]
126
        cropped=np.zeros((calc_ceil_pad(y_max-y_min,16),calc_ceil_pad(x_max-x_min,16),stacked.shape[2]),np.uint8)
127
        cropped[0:y_max-y_min,0:x_max-x_min,:]=stacked[y_min:y_max,x_min:x_max,:]
128
        cropped_list.append(cropped.copy())
129
    return cropped_list
130
131
# in: stack lbl list, dilate iteration
132
# out: stack edge list
133
def get_edge(stack_list,kernel_size=(3,3),sigmaX=0):
134
    edge_list=[]
135
    for stacked in stack_list:
136
        edges=np.zeros((stacked.shape[0],stacked.shape[1],stacked.shape[2]),np.uint8)
137
        for slice_index in range(stacked.shape[2]):
138
            edges[:,:,slice_index]=cv.Canny(stacked[:,:,slice_index],1,1)
139
            edges[:,:,slice_index]=cv.GaussianBlur(edges[:,:,slice_index],kernel_size,sigmaX)
140
        edge_list.append(edges)
141
    return edge_list
142
143
144
145
146
147
if __name__=='__main__':
148
    T1_path='../../data/training/1/pre/reg_T1.nii.gz'
149
    vol=to_uint8(readVol(T1_path))
150
    print(vol.shape)
151
    print('vol[100,100,20]= ', vol[100,100,20])
152
    histeqed=histeq(vol)
153
    print('vol[100,100,20]= ', vol[100,100,20])
154
    print('query list: ', get_stackindex(1,5,histeqed.shape[2]))
155
    stack_list=get_stacked(histeqed,5)
156
    print(len(stack_list))
157
    print(stack_list[0].shape)
158
    angle=random.uniform(-15,15)
159
    print('angle= ', angle)
160
    rotated=rotate(stack_list,angle)
161
    print(len(rotated))
162
    region=calc_crop_region(rotated,50,5)
163
    max_region=calc_max_region_list(region,5)
164
    print(region)
165
    print(max_region)
166
    cropped=crop(rotated,max_region)
167
    for i in range(48):
168
        print(cropped[i].shape)