--- a +++ b/Segmenation.py @@ -0,0 +1,220 @@ +import numpy as np +import torch +from scipy import ndimage as nd +import torch.nn.parallel +from torch.autograd import Variable + +def Generator_multichannels(image, sizeofchunk, sizeofchunk_expand,numofchannels): + sizeofimage = np.shape(image)[1:4] + + nb_chunks = (np.ceil(np.array(sizeofimage)/float(sizeofchunk))).astype(int) + + pad_image = np.zeros(([numofchannels,nb_chunks[0]*sizeofchunk,nb_chunks[1]*sizeofchunk,nb_chunks[2]*sizeofchunk]), dtype='float32') + pad_image[:,:sizeofimage[0], :sizeofimage[1], :sizeofimage[2]] = image + + width = int(np.ceil((sizeofchunk_expand-sizeofchunk)/2.0)) + + size_pad_im = np.shape(pad_image)[1:4] + size_expand_im = np.array(size_pad_im) + 2 * width + expand_image = np.zeros(([numofchannels,size_expand_im[0],size_expand_im[1],size_expand_im[2]]), dtype='float32') + expand_image[:,width:-width, width:-width, width:-width] = pad_image + + + batchsize = np.prod(nb_chunks) + idx_chunk = 0 + chunk_batch = np.zeros((batchsize,numofchannels,sizeofchunk_expand,sizeofchunk_expand,sizeofchunk_expand),dtype='float32') + idx_xyz = np.zeros((batchsize,3),dtype='uint16') + for x_idx in range(nb_chunks[0]): + for y_idx in range(nb_chunks[1]): + for z_idx in range(nb_chunks[2]): + + idx_xyz[idx_chunk,:] = [x_idx,y_idx,z_idx] + + + + chunk_batch[idx_chunk,:,...] = expand_image[:,x_idx*sizeofchunk:x_idx*sizeofchunk+sizeofchunk_expand,\ + y_idx*sizeofchunk:y_idx*sizeofchunk+sizeofchunk_expand,\ + z_idx*sizeofchunk:z_idx*sizeofchunk+sizeofchunk_expand] + + idx_chunk += 1 + + return chunk_batch, nb_chunks, idx_xyz, sizeofimage + +def Chunks_Image(segment_chunks, nb_chunks, sizeofchunk, sizeofchunk_expand, idx_xyz, sizeofimage): + + batchsize = np.size(segment_chunks,0) + + segment_image = np.zeros((nb_chunks[0]*sizeofchunk,nb_chunks[1]*sizeofchunk,nb_chunks[2]*sizeofchunk)) + + for idx_chunk in range(batchsize): + + idx_low = idx_xyz[idx_chunk,:] * sizeofchunk + idx_upp = (idx_xyz[idx_chunk,:]+1) * sizeofchunk + + segment_image[idx_low[0]:idx_upp[0],idx_low[1]:idx_upp[1],idx_low[2]:idx_upp[2]] = \ + segment_chunks[idx_chunk,0,...] + + + segment_image = segment_image[:sizeofimage[0], :sizeofimage[1], :sizeofimage[2]] + return segment_image + + + +def BreastSeg(image,scale_subject,model,opt): + modelpath = "Models/" + modelname = modelpath+"model_breast.pth" + checkpoint = torch.load(modelname) + model.load_state_dict(checkpoint) + + + numofseg =1 + + commonspacing = [1.5,1.5,1.5] + + imageshape = image.shape + + scale_subject = scale_subject[::-1]/commonspacing + image = nd.interpolation.zoom(image,scale_subject,order=1) + imagesize = image.shape + + sizeofchunk = 20 + sizeofchunk_expand = 108 + if opt.cuda: + sizeofchunk = 132 + sizeofchunk_expand = 220 + image_one = np.zeros((1,imagesize[0],imagesize[1],imagesize[2]),dtype='float32') + image_one[0,...] =image + chunk_batch, nb_chunks, idx_xyz, sizeofimage = Generator_multichannels(image_one,sizeofchunk,sizeofchunk_expand,1) + + seg_batch = np.zeros((np.size(chunk_batch,0),numofseg,sizeofchunk,sizeofchunk,sizeofchunk),dtype='float32') + for i_chunk in range(np.size(chunk_batch,0)): + input = Variable(torch.from_numpy(chunk_batch[i_chunk:i_chunk+1,...]),volatile=True) + model.eval() + if opt.cuda: + input = input.cuda() + prediction = model(input) + + seg_batch[i_chunk,0,...] = (prediction.data).cpu().numpy() + + + for i_seg in range(numofseg): + prob_image = Chunks_Image(seg_batch[:,i_seg:i_seg+1,...], nb_chunks, sizeofchunk, sizeofchunk_expand, idx_xyz, sizeofimage) + up_image = nd.interpolation.zoom(prob_image,1/scale_subject,order=1) + up_image_norm = np.zeros(imageshape,dtype='float32') + temp_image = up_image[0:imageshape[0],0:imageshape[1],0:imageshape[2]] + shape_tempimage =np.shape(temp_image) + up_image_norm[0:shape_tempimage[0],0:shape_tempimage[1],0:shape_tempimage[2]] = temp_image + + threshold = 0.5 + idx = up_image_norm > threshold + up_image_norm[idx] = 1 + up_image_norm[~idx] = 0 + seg_img = up_image_norm.astype('uint8') + return seg_img + + + + +def BreastTumor(image_sub,image_post,image_mask,scale_subject,model1st,model2nd,opt): + modelpath = "Models/" + modelname = modelpath+"/model_tumor_1st.pth" + checkpoint = torch.load(modelname) + model1st.load_state_dict(checkpoint) + + + numofseg =1 + + commonspacing = [0.7,0.7,0.7] + imageshape = image_sub.shape + + + scale_subject = scale_subject[::-1]/commonspacing + + + image_sub = nd.interpolation.zoom(image_sub,scale_subject,order=1) + image_post = nd.interpolation.zoom(image_post,scale_subject,order=1) + image_mask = nd.interpolation.zoom(image_mask,scale_subject,order=1) + + imagesize = np.shape(image_sub) + + image_one = np.zeros((3,imagesize[0],imagesize[1],imagesize[2]),dtype='float32') + + image_one[0,...] = 1.0*image_sub + + image_one[1,...] = 1.0*image_post + image_one[2,...] = 1.0*image_mask + + + + + sizeofchunk = 12 + sizeofchunk_expand = 52 + if opt.cuda: +# sizeofchunk = 148 +# sizeofchunk_expand = 188 + sizeofchunk = 108 + sizeofchunk_expand = 148 + + + chunk_batch, nb_chunks, idx_xyz, sizeofimage = Generator_multichannels(image_one,sizeofchunk,sizeofchunk_expand,3) + + seg_batch = np.zeros((np.size(chunk_batch,0),numofseg,sizeofchunk,sizeofchunk,sizeofchunk),dtype='float32') + for i_chunk in range(np.size(chunk_batch,0)): + input = Variable(torch.from_numpy(chunk_batch[i_chunk:i_chunk+1,...]),volatile=True) + model1st.eval() + if opt.cuda: + input = input.cuda() + prediction = model1st(input) + + seg_batch[i_chunk,0,...] = (prediction.data).cpu().numpy() + + + for i_seg in range(numofseg): + prob_image = Chunks_Image(seg_batch[:,i_seg:i_seg+1,...], nb_chunks, sizeofchunk, sizeofchunk_expand, idx_xyz, sizeofimage) + prob_image[prob_image<0.01] =0 + + image_one[2,...] = prob_image +# Just for saving output of 1st stage + up_image = nd.interpolation.zoom(prob_image,1/scale_subject,order=1) + up_image_norm = np.zeros(imageshape,dtype='float32') + temp_image = up_image[0:imageshape[0],0:imageshape[1],0:imageshape[2]] + shape_tempimage =np.shape(temp_image) + up_image_norm[0:shape_tempimage[0],0:shape_tempimage[1],0:shape_tempimage[2]] = temp_image + prob_output = up_image_norm +# + + + del model1st + modelname = modelpath+"/model_tumor_2nd.pth" + checkpoint = torch.load(modelname) + model2nd.load_state_dict(checkpoint) + + + chunk_batch, nb_chunks, idx_xyz, sizeofimage = Generator_multichannels(image_one,sizeofchunk,sizeofchunk_expand,3) + + seg_batch = np.zeros((np.size(chunk_batch,0),numofseg,sizeofchunk,sizeofchunk,sizeofchunk),dtype='float32') + for i_chunk in range(np.size(chunk_batch,0)): + input = Variable(torch.from_numpy(chunk_batch[i_chunk:i_chunk+1,...]),volatile=True) + model2nd.eval() + if opt.cuda: + input = input.cuda() + prediction = model2nd(input) + + seg_batch[i_chunk,0,...] = (prediction.data).cpu().numpy() + + + for i_seg in range(numofseg): + prob_image = Chunks_Image(seg_batch[:,i_seg:i_seg+1,...], nb_chunks, sizeofchunk, sizeofchunk_expand, idx_xyz, sizeofimage) + up_image = nd.interpolation.zoom(prob_image,1/scale_subject,order=1) + up_image_norm = np.zeros(imageshape,dtype='float32') + temp_image = up_image[0:imageshape[0],0:imageshape[1],0:imageshape[2]] + shape_tempimage =np.shape(temp_image) + up_image_norm[0:shape_tempimage[0],0:shape_tempimage[1],0:shape_tempimage[2]] = temp_image + threshold = 0.5 + idx = up_image_norm > threshold + up_image_norm[idx] = 1 + up_image_norm[~idx] = 0 + seg_img = up_image_norm.astype('uint8') + return prob_output,seg_img + +