a b/Segmenation.py
1
import numpy as np
2
import torch
3
from scipy import ndimage as nd  
4
import torch.nn.parallel
5
from torch.autograd import Variable
6
7
def Generator_multichannels(image, sizeofchunk, sizeofchunk_expand,numofchannels): 
8
    sizeofimage = np.shape(image)[1:4]
9
            
10
    nb_chunks = (np.ceil(np.array(sizeofimage)/float(sizeofchunk))).astype(int)
11
    
12
    pad_image = np.zeros(([numofchannels,nb_chunks[0]*sizeofchunk,nb_chunks[1]*sizeofchunk,nb_chunks[2]*sizeofchunk]), dtype='float32')
13
    pad_image[:,:sizeofimage[0], :sizeofimage[1], :sizeofimage[2]] = image
14
            
15
    width = int(np.ceil((sizeofchunk_expand-sizeofchunk)/2.0))
16
            
17
    size_pad_im = np.shape(pad_image)[1:4]
18
    size_expand_im = np.array(size_pad_im) + 2 * width
19
    expand_image = np.zeros(([numofchannels,size_expand_im[0],size_expand_im[1],size_expand_im[2]]), dtype='float32')
20
    expand_image[:,width:-width, width:-width, width:-width] = pad_image
21
  
22
            
23
    batchsize = np.prod(nb_chunks)
24
    idx_chunk = 0
25
    chunk_batch = np.zeros((batchsize,numofchannels,sizeofchunk_expand,sizeofchunk_expand,sizeofchunk_expand),dtype='float32')
26
    idx_xyz = np.zeros((batchsize,3),dtype='uint16')
27
    for x_idx in range(nb_chunks[0]):
28
        for y_idx in range(nb_chunks[1]):
29
            for z_idx in range(nb_chunks[2]):
30
                
31
                idx_xyz[idx_chunk,:] = [x_idx,y_idx,z_idx]
32
                        
33
                         
34
35
                chunk_batch[idx_chunk,:,...] = expand_image[:,x_idx*sizeofchunk:x_idx*sizeofchunk+sizeofchunk_expand,\
36
                           y_idx*sizeofchunk:y_idx*sizeofchunk+sizeofchunk_expand,\
37
                           z_idx*sizeofchunk:z_idx*sizeofchunk+sizeofchunk_expand]             
38
                        
39
                idx_chunk += 1
40
    
41
    return chunk_batch, nb_chunks, idx_xyz, sizeofimage    
42
43
def Chunks_Image(segment_chunks, nb_chunks, sizeofchunk, sizeofchunk_expand, idx_xyz, sizeofimage):
44
    
45
    batchsize = np.size(segment_chunks,0)
46
47
    segment_image = np.zeros((nb_chunks[0]*sizeofchunk,nb_chunks[1]*sizeofchunk,nb_chunks[2]*sizeofchunk))
48
    
49
    for idx_chunk in range(batchsize):
50
        
51
        idx_low = idx_xyz[idx_chunk,:] * sizeofchunk
52
        idx_upp = (idx_xyz[idx_chunk,:]+1) * sizeofchunk
53
        
54
        segment_image[idx_low[0]:idx_upp[0],idx_low[1]:idx_upp[1],idx_low[2]:idx_upp[2]] = \
55
        segment_chunks[idx_chunk,0,...]
56
        
57
58
    segment_image = segment_image[:sizeofimage[0], :sizeofimage[1], :sizeofimage[2]]
59
    return segment_image
60
61
62
63
def BreastSeg(image,scale_subject,model,opt):
64
    modelpath = "Models/"       
65
    modelname = modelpath+"model_breast.pth"  
66
    checkpoint = torch.load(modelname)
67
    model.load_state_dict(checkpoint)
68
69
70
    numofseg =1
71
72
    commonspacing = [1.5,1.5,1.5]
73
74
    imageshape = image.shape
75
76
    scale_subject = scale_subject[::-1]/commonspacing    
77
    image = nd.interpolation.zoom(image,scale_subject,order=1)    
78
    imagesize = image.shape
79
    
80
    sizeofchunk = 20
81
    sizeofchunk_expand = 108    
82
    if opt.cuda:
83
        sizeofchunk = 132
84
        sizeofchunk_expand = 220            
85
    image_one = np.zeros((1,imagesize[0],imagesize[1],imagesize[2]),dtype='float32')
86
    image_one[0,...] =image 
87
    chunk_batch, nb_chunks, idx_xyz, sizeofimage = Generator_multichannels(image_one,sizeofchunk,sizeofchunk_expand,1)
88
89
    seg_batch = np.zeros((np.size(chunk_batch,0),numofseg,sizeofchunk,sizeofchunk,sizeofchunk),dtype='float32')
90
    for i_chunk in range(np.size(chunk_batch,0)):
91
        input = Variable(torch.from_numpy(chunk_batch[i_chunk:i_chunk+1,...]),volatile=True)
92
        model.eval()
93
        if opt.cuda:
94
            input = input.cuda()
95
        prediction = model(input)   
96
        
97
        seg_batch[i_chunk,0,...] = (prediction.data).cpu().numpy()
98
        
99
    
100
    for i_seg in range(numofseg):
101
        prob_image = Chunks_Image(seg_batch[:,i_seg:i_seg+1,...], nb_chunks, sizeofchunk, sizeofchunk_expand, idx_xyz, sizeofimage)
102
        up_image = nd.interpolation.zoom(prob_image,1/scale_subject,order=1)
103
        up_image_norm = np.zeros(imageshape,dtype='float32')
104
        temp_image = up_image[0:imageshape[0],0:imageshape[1],0:imageshape[2]]
105
        shape_tempimage =np.shape(temp_image)
106
        up_image_norm[0:shape_tempimage[0],0:shape_tempimage[1],0:shape_tempimage[2]] = temp_image
107
        
108
        threshold = 0.5
109
        idx = up_image_norm > threshold
110
        up_image_norm[idx] = 1
111
        up_image_norm[~idx] = 0               
112
        seg_img = up_image_norm.astype('uint8')
113
    return seg_img
114
115
116
117
118
def BreastTumor(image_sub,image_post,image_mask,scale_subject,model1st,model2nd,opt):
119
    modelpath = "Models/"       
120
    modelname = modelpath+"/model_tumor_1st.pth" 
121
    checkpoint = torch.load(modelname)
122
    model1st.load_state_dict(checkpoint)
123
124
125
    numofseg =1
126
127
    commonspacing = [0.7,0.7,0.7]
128
    imageshape = image_sub.shape
129
130
131
    scale_subject = scale_subject[::-1]/commonspacing
132
    
133
    
134
    image_sub = nd.interpolation.zoom(image_sub,scale_subject,order=1)
135
    image_post = nd.interpolation.zoom(image_post,scale_subject,order=1)
136
    image_mask = nd.interpolation.zoom(image_mask,scale_subject,order=1)
137
    
138
    imagesize = np.shape(image_sub)
139
    
140
    image_one = np.zeros((3,imagesize[0],imagesize[1],imagesize[2]),dtype='float32')
141
    
142
    image_one[0,...] = 1.0*image_sub
143
144
    image_one[1,...] = 1.0*image_post
145
    image_one[2,...] = 1.0*image_mask
146
147
148
149
    
150
    sizeofchunk = 12
151
    sizeofchunk_expand = 52    
152
    if opt.cuda:
153
#        sizeofchunk = 148
154
#        sizeofchunk_expand = 188      
155
        sizeofchunk = 108
156
        sizeofchunk_expand = 148  
157
               
158
159
    chunk_batch, nb_chunks, idx_xyz, sizeofimage = Generator_multichannels(image_one,sizeofchunk,sizeofchunk_expand,3)
160
161
    seg_batch = np.zeros((np.size(chunk_batch,0),numofseg,sizeofchunk,sizeofchunk,sizeofchunk),dtype='float32')
162
    for i_chunk in range(np.size(chunk_batch,0)):
163
        input = Variable(torch.from_numpy(chunk_batch[i_chunk:i_chunk+1,...]),volatile=True)
164
        model1st.eval()
165
        if opt.cuda:
166
            input = input.cuda()
167
        prediction = model1st(input)   
168
        
169
        seg_batch[i_chunk,0,...] = (prediction.data).cpu().numpy()
170
        
171
    
172
    for i_seg in range(numofseg):
173
        prob_image = Chunks_Image(seg_batch[:,i_seg:i_seg+1,...], nb_chunks, sizeofchunk, sizeofchunk_expand, idx_xyz, sizeofimage)
174
        prob_image[prob_image<0.01] =0
175
176
    image_one[2,...] = prob_image
177
# Just for saving output of 1st stage             
178
    up_image = nd.interpolation.zoom(prob_image,1/scale_subject,order=1)
179
    up_image_norm = np.zeros(imageshape,dtype='float32')
180
    temp_image = up_image[0:imageshape[0],0:imageshape[1],0:imageshape[2]]
181
    shape_tempimage =np.shape(temp_image)
182
    up_image_norm[0:shape_tempimage[0],0:shape_tempimage[1],0:shape_tempimage[2]] = temp_image             
183
    prob_output =  up_image_norm
184
#            
185
             
186
             
187
    del model1st           
188
    modelname = modelpath+"/model_tumor_2nd.pth"  
189
    checkpoint = torch.load(modelname)
190
    model2nd.load_state_dict(checkpoint)             
191
             
192
             
193
    chunk_batch, nb_chunks, idx_xyz, sizeofimage = Generator_multichannels(image_one,sizeofchunk,sizeofchunk_expand,3)
194
195
    seg_batch = np.zeros((np.size(chunk_batch,0),numofseg,sizeofchunk,sizeofchunk,sizeofchunk),dtype='float32')
196
    for i_chunk in range(np.size(chunk_batch,0)):
197
        input = Variable(torch.from_numpy(chunk_batch[i_chunk:i_chunk+1,...]),volatile=True)
198
        model2nd.eval()
199
        if opt.cuda:
200
            input = input.cuda()
201
        prediction = model2nd(input)   
202
        
203
        seg_batch[i_chunk,0,...] = (prediction.data).cpu().numpy()
204
        
205
    
206
    for i_seg in range(numofseg):
207
        prob_image = Chunks_Image(seg_batch[:,i_seg:i_seg+1,...], nb_chunks, sizeofchunk, sizeofchunk_expand, idx_xyz, sizeofimage)
208
        up_image = nd.interpolation.zoom(prob_image,1/scale_subject,order=1)
209
        up_image_norm = np.zeros(imageshape,dtype='float32')
210
        temp_image = up_image[0:imageshape[0],0:imageshape[1],0:imageshape[2]]
211
        shape_tempimage =np.shape(temp_image)
212
        up_image_norm[0:shape_tempimage[0],0:shape_tempimage[1],0:shape_tempimage[2]] = temp_image
213
        threshold = 0.5
214
        idx = up_image_norm > threshold
215
        up_image_norm[idx] = 1
216
        up_image_norm[~idx] = 0               
217
        seg_img = up_image_norm.astype('uint8')           
218
    return prob_output,seg_img
219
220