--- a
+++ b/Segmenation.py
@@ -0,0 +1,220 @@
+import numpy as np
+import torch
+from scipy import ndimage as nd  
+import torch.nn.parallel
+from torch.autograd import Variable
+
+def Generator_multichannels(image, sizeofchunk, sizeofchunk_expand,numofchannels): 
+    sizeofimage = np.shape(image)[1:4]
+            
+    nb_chunks = (np.ceil(np.array(sizeofimage)/float(sizeofchunk))).astype(int)
+    
+    pad_image = np.zeros(([numofchannels,nb_chunks[0]*sizeofchunk,nb_chunks[1]*sizeofchunk,nb_chunks[2]*sizeofchunk]), dtype='float32')
+    pad_image[:,:sizeofimage[0], :sizeofimage[1], :sizeofimage[2]] = image
+            
+    width = int(np.ceil((sizeofchunk_expand-sizeofchunk)/2.0))
+            
+    size_pad_im = np.shape(pad_image)[1:4]
+    size_expand_im = np.array(size_pad_im) + 2 * width
+    expand_image = np.zeros(([numofchannels,size_expand_im[0],size_expand_im[1],size_expand_im[2]]), dtype='float32')
+    expand_image[:,width:-width, width:-width, width:-width] = pad_image
+  
+            
+    batchsize = np.prod(nb_chunks)
+    idx_chunk = 0
+    chunk_batch = np.zeros((batchsize,numofchannels,sizeofchunk_expand,sizeofchunk_expand,sizeofchunk_expand),dtype='float32')
+    idx_xyz = np.zeros((batchsize,3),dtype='uint16')
+    for x_idx in range(nb_chunks[0]):
+        for y_idx in range(nb_chunks[1]):
+            for z_idx in range(nb_chunks[2]):
+                
+                idx_xyz[idx_chunk,:] = [x_idx,y_idx,z_idx]
+                        
+                         
+
+                chunk_batch[idx_chunk,:,...] = expand_image[:,x_idx*sizeofchunk:x_idx*sizeofchunk+sizeofchunk_expand,\
+                           y_idx*sizeofchunk:y_idx*sizeofchunk+sizeofchunk_expand,\
+                           z_idx*sizeofchunk:z_idx*sizeofchunk+sizeofchunk_expand]             
+                        
+                idx_chunk += 1
+    
+    return chunk_batch, nb_chunks, idx_xyz, sizeofimage    
+
+def Chunks_Image(segment_chunks, nb_chunks, sizeofchunk, sizeofchunk_expand, idx_xyz, sizeofimage):
+    
+    batchsize = np.size(segment_chunks,0)
+
+    segment_image = np.zeros((nb_chunks[0]*sizeofchunk,nb_chunks[1]*sizeofchunk,nb_chunks[2]*sizeofchunk))
+    
+    for idx_chunk in range(batchsize):
+        
+        idx_low = idx_xyz[idx_chunk,:] * sizeofchunk
+        idx_upp = (idx_xyz[idx_chunk,:]+1) * sizeofchunk
+        
+        segment_image[idx_low[0]:idx_upp[0],idx_low[1]:idx_upp[1],idx_low[2]:idx_upp[2]] = \
+        segment_chunks[idx_chunk,0,...]
+        
+
+    segment_image = segment_image[:sizeofimage[0], :sizeofimage[1], :sizeofimage[2]]
+    return segment_image
+
+
+
+def BreastSeg(image,scale_subject,model,opt):
+    modelpath = "Models/"       
+    modelname = modelpath+"model_breast.pth"  
+    checkpoint = torch.load(modelname)
+    model.load_state_dict(checkpoint)
+
+
+    numofseg =1
+
+    commonspacing = [1.5,1.5,1.5]
+
+    imageshape = image.shape
+
+    scale_subject = scale_subject[::-1]/commonspacing    
+    image = nd.interpolation.zoom(image,scale_subject,order=1)    
+    imagesize = image.shape
+    
+    sizeofchunk = 20
+    sizeofchunk_expand = 108    
+    if opt.cuda:
+        sizeofchunk = 132
+        sizeofchunk_expand = 220            
+    image_one = np.zeros((1,imagesize[0],imagesize[1],imagesize[2]),dtype='float32')
+    image_one[0,...] =image 
+    chunk_batch, nb_chunks, idx_xyz, sizeofimage = Generator_multichannels(image_one,sizeofchunk,sizeofchunk_expand,1)
+
+    seg_batch = np.zeros((np.size(chunk_batch,0),numofseg,sizeofchunk,sizeofchunk,sizeofchunk),dtype='float32')
+    for i_chunk in range(np.size(chunk_batch,0)):
+        input = Variable(torch.from_numpy(chunk_batch[i_chunk:i_chunk+1,...]),volatile=True)
+        model.eval()
+        if opt.cuda:
+            input = input.cuda()
+        prediction = model(input)   
+        
+        seg_batch[i_chunk,0,...] = (prediction.data).cpu().numpy()
+        
+    
+    for i_seg in range(numofseg):
+        prob_image = Chunks_Image(seg_batch[:,i_seg:i_seg+1,...], nb_chunks, sizeofchunk, sizeofchunk_expand, idx_xyz, sizeofimage)
+        up_image = nd.interpolation.zoom(prob_image,1/scale_subject,order=1)
+        up_image_norm = np.zeros(imageshape,dtype='float32')
+        temp_image = up_image[0:imageshape[0],0:imageshape[1],0:imageshape[2]]
+        shape_tempimage =np.shape(temp_image)
+        up_image_norm[0:shape_tempimage[0],0:shape_tempimage[1],0:shape_tempimage[2]] = temp_image
+        
+        threshold = 0.5
+        idx = up_image_norm > threshold
+        up_image_norm[idx] = 1
+        up_image_norm[~idx] = 0               
+        seg_img = up_image_norm.astype('uint8')
+    return seg_img
+
+
+
+
+def BreastTumor(image_sub,image_post,image_mask,scale_subject,model1st,model2nd,opt):
+    modelpath = "Models/"       
+    modelname = modelpath+"/model_tumor_1st.pth" 
+    checkpoint = torch.load(modelname)
+    model1st.load_state_dict(checkpoint)
+
+
+    numofseg =1
+
+    commonspacing = [0.7,0.7,0.7]
+    imageshape = image_sub.shape
+
+
+    scale_subject = scale_subject[::-1]/commonspacing
+    
+    
+    image_sub = nd.interpolation.zoom(image_sub,scale_subject,order=1)
+    image_post = nd.interpolation.zoom(image_post,scale_subject,order=1)
+    image_mask = nd.interpolation.zoom(image_mask,scale_subject,order=1)
+    
+    imagesize = np.shape(image_sub)
+    
+    image_one = np.zeros((3,imagesize[0],imagesize[1],imagesize[2]),dtype='float32')
+    
+    image_one[0,...] = 1.0*image_sub
+
+    image_one[1,...] = 1.0*image_post
+    image_one[2,...] = 1.0*image_mask
+
+
+
+    
+    sizeofchunk = 12
+    sizeofchunk_expand = 52    
+    if opt.cuda:
+#        sizeofchunk = 148
+#        sizeofchunk_expand = 188      
+        sizeofchunk = 108
+        sizeofchunk_expand = 148  
+               
+
+    chunk_batch, nb_chunks, idx_xyz, sizeofimage = Generator_multichannels(image_one,sizeofchunk,sizeofchunk_expand,3)
+
+    seg_batch = np.zeros((np.size(chunk_batch,0),numofseg,sizeofchunk,sizeofchunk,sizeofchunk),dtype='float32')
+    for i_chunk in range(np.size(chunk_batch,0)):
+        input = Variable(torch.from_numpy(chunk_batch[i_chunk:i_chunk+1,...]),volatile=True)
+        model1st.eval()
+        if opt.cuda:
+            input = input.cuda()
+        prediction = model1st(input)   
+        
+        seg_batch[i_chunk,0,...] = (prediction.data).cpu().numpy()
+        
+    
+    for i_seg in range(numofseg):
+        prob_image = Chunks_Image(seg_batch[:,i_seg:i_seg+1,...], nb_chunks, sizeofchunk, sizeofchunk_expand, idx_xyz, sizeofimage)
+        prob_image[prob_image<0.01] =0
+
+    image_one[2,...] = prob_image
+# Just for saving output of 1st stage             
+    up_image = nd.interpolation.zoom(prob_image,1/scale_subject,order=1)
+    up_image_norm = np.zeros(imageshape,dtype='float32')
+    temp_image = up_image[0:imageshape[0],0:imageshape[1],0:imageshape[2]]
+    shape_tempimage =np.shape(temp_image)
+    up_image_norm[0:shape_tempimage[0],0:shape_tempimage[1],0:shape_tempimage[2]] = temp_image             
+    prob_output =  up_image_norm
+#            
+             
+             
+    del model1st           
+    modelname = modelpath+"/model_tumor_2nd.pth"  
+    checkpoint = torch.load(modelname)
+    model2nd.load_state_dict(checkpoint)             
+             
+             
+    chunk_batch, nb_chunks, idx_xyz, sizeofimage = Generator_multichannels(image_one,sizeofchunk,sizeofchunk_expand,3)
+
+    seg_batch = np.zeros((np.size(chunk_batch,0),numofseg,sizeofchunk,sizeofchunk,sizeofchunk),dtype='float32')
+    for i_chunk in range(np.size(chunk_batch,0)):
+        input = Variable(torch.from_numpy(chunk_batch[i_chunk:i_chunk+1,...]),volatile=True)
+        model2nd.eval()
+        if opt.cuda:
+            input = input.cuda()
+        prediction = model2nd(input)   
+        
+        seg_batch[i_chunk,0,...] = (prediction.data).cpu().numpy()
+        
+    
+    for i_seg in range(numofseg):
+        prob_image = Chunks_Image(seg_batch[:,i_seg:i_seg+1,...], nb_chunks, sizeofchunk, sizeofchunk_expand, idx_xyz, sizeofimage)
+        up_image = nd.interpolation.zoom(prob_image,1/scale_subject,order=1)
+        up_image_norm = np.zeros(imageshape,dtype='float32')
+        temp_image = up_image[0:imageshape[0],0:imageshape[1],0:imageshape[2]]
+        shape_tempimage =np.shape(temp_image)
+        up_image_norm[0:shape_tempimage[0],0:shape_tempimage[1],0:shape_tempimage[2]] = temp_image
+        threshold = 0.5
+        idx = up_image_norm > threshold
+        up_image_norm[idx] = 1
+        up_image_norm[~idx] = 0               
+        seg_img = up_image_norm.astype('uint8')           
+    return prob_output,seg_img
+
+