Diff of /Brain_pipeline.py [000000] .. [271336]

Switch to side-by-side view

--- a
+++ b/Brain_pipeline.py
@@ -0,0 +1,188 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Fri Oct 14 01:47:13 2016
+
+@author: seeker105
+"""
+import os.path
+import random
+import pylab
+import numpy as np
+from glob import glob
+import SimpleITK as sitk
+from Nyul import IntensityRangeStandardization
+import sys
+import timeit
+from sklearn.feature_extraction.image import extract_patches_2d
+from skimage import color
+
+class Pipeline(object):
+    ''' The Pipeline for loading images for all patients and all modalities
+        1)find_training_patches: finds the training patches for a particular class
+        
+        INPUT: 
+            1) The filepath 'path': Directory of the image database. It contains the slices of training image slices
+            
+    '''
+    
+    def __init__(self, path_train = '', path_test = '' , mx_train = 1000000, mx_tst = 1000000):
+        self.path_train = path_train
+        self.path_test = path_test
+        self.scans_train, self.scans_test, self.train_im, self.test_im = self.read_scans(mx_train, mx_tst)
+        
+        
+    def read_scans(self, mx_train, mx_test):
+       
+        scans_train = glob(self.path_train + r'/*.mha')
+        scans_test = glob(self.path_test + r'/*.mha')
+        train_im = [sitk.GetArrayFromImage(sitk.ReadImage(scans_train[i])) for i in xrange(min(len(scans_train), mx_train))]
+        test_im = [sitk.GetArrayFromImage(sitk.ReadImage(scans_test[i])) for i in xrange(min(len(scans_test), mx_test))]
+        return scans_train, scans_test, np.array(train_im), np.array(test_im)
+    
+    
+    def n4itk(self, img):
+        img = sitk.Cast(img, sitk.sitkFloat32)
+        img_mask = sitk.BinaryNot(sitk.BinaryThreshold(img, 0, 0))   ## Create a mask spanning the part containing the brain, as we want to apply the filter to the brain image
+        corrected_img = sitk.N4BiasFieldCorrection(img, img_mask)
+        return corrected_img
+    
+    '''def find_all_train(self, classes, d = 4, h = 33, w = 33):
+        mn = 300000000000000
+        #load all labels
+        im_ar = []
+        for i in self.pathnames_train:
+            im_ar.append([sitk.GetArrayFromImage(sitk.ReadImage(idx)) for idx in i])
+        im_ar = np.array(im_ar)
+        for i in xrange(classes):
+            mn = min(mn, len(np.argwhere(im_ar[i]==i)))
+    '''
+    def sample_training_patches(self, num_patches, class_nm, d = 4, h = 33, w = 33):
+        ''' Creates the input patches and their labels for training CNN. The patches are 4x33x33
+        and the label for a patch equals to the label for the central pixel of the patch.
+        
+        INPUT:
+            1) num_patches: The number of patches required of the class.
+            
+            2) class_nm:    The index of class label for which we are finding patches.
+            
+            3) d, h, w:     number of channels, height and width of patch
+        
+        OUTPUT:
+            1) patches:     The list of all patches of dimensions d, h, w. 
+            
+            2) labels:      The list of labels for each patch. Label for a patch corresponds to the label
+                            of the central pixel of that patch.
+                            
+        '''
+        
+        #find patches for training
+        patches, labels = [], np.full(num_patches, fill_value = class_nm,  dtype = np.int32)
+        count = 0
+        # convert gt_im to 1D and save shape
+        gt_im = np.swapaxes(self.train_im, 0, 1)[4]   #swap axes to make axis 0 represent the modality and axis 1 represent the slice. take the ground truth
+        #take flair image as mask
+        msk = np.swapaxes(self.train_im, 0, 1)[0]
+        tmp_shp = gt_im.shape
+        gt_im = gt_im.reshape(-1)
+        msk = msk.reshape(-1)
+        # maintain list of 1D indices where label = class_nm
+        indices = np.squeeze(np.argwhere((gt_im == class_nm) & (msk != 0.)))
+        # shuffle the list of indices of the class
+        st = timeit.default_timer()
+        np.random.shuffle(indices)
+        print 'shuffling of label {} took :'.format(class_nm), timeit.default_timer()-st
+        #reshape gt_im
+        gt_im = gt_im.reshape(tmp_shp)
+        st = timeit.default_timer()
+        #find the patches from the images
+        i = 0
+        pix = len(indices)
+        while (count<num_patches) and (pix>i):
+            #print (count, ' cl:' ,class_nm)
+            #sys.stdout.flush()
+            #randomly choose an index
+            ind = indices[i]
+            i+= 1
+            #reshape ind to 3D index
+            ind = np.unravel_index(ind, tmp_shp)
+            #print ind
+            #sys.stdout.flush()
+            #find the slice index to choose from
+            slice_idx = ind[0]
+            #load the slice from the label
+            l = gt_im[slice_idx]
+            # the centre pixel and its coordinates
+            p = ind[1:]
+            #construct the patch by defining the coordinates
+            p_x = (p[0] - h/2, p[0] + (h+1)/2)
+            p_y = (p[1] - w/2, p[1] + (w+1)/2)
+            #check if the pixels are in range
+            if p_x[0]<0 or p_x[1]>l.shape[0] or p_y[0]<0 or p_y[1]>l.shape[1]:
+                continue
+            #take patches from all modalities and group them together
+            tmp = self.train_im[slice_idx][0:4, p_x[0]:p_x[1], p_y[0]:p_y[1]]
+            patches.append(tmp)
+            count+=1
+        print 'finding patches of label {} took :'.format(class_nm), timeit.default_timer()-st
+        patches = np.array(patches)
+        return patches, labels
+        
+
+    def training_patches(self, num_patches, classes = 5, d = 4, h = 33, w = 33):
+        '''Creates the input patches and their labels for training CNN. The patches are 4x33x33
+    and the label for a patch corresponds to the label for the central voxel of the patch. The 
+    data will be balanced, with the number of patches being the same for each class
+            
+            INPUT:
+                    1) classes:  number of all classes in the segmentation
+                    2) num_patches: number of patches for each class
+                    3) d, h, w : channels, height and width of the patches
+            OUTPUT:
+                    1) all_patches: numpy array of all class patches of the shape 4x33x33
+                    2) all_labels : numpy array of the all_patches labels
+        '''
+        
+        patches, labels, mu, sigma = [], [], [], []
+        for idx in xrange(classes):
+            p, l = self.sample_training_patches(num_patches[idx], idx, d, h, w)
+            patches.append(p)
+            labels.append(l)
+        patches = np.vstack(np.array(patches)) 
+        patches_by_channel = np.swapaxes(patches, 0, 1)
+        for seq, i in zip(patches_by_channel, xrange(d)):
+            avg = np.mean(seq)
+            std = np.std(seq)
+            patches_by_channel[i] = (patches_by_channel[i] - avg)/std
+            mu.append(avg)
+            sigma.append(std)
+        patches = np.swapaxes(patches_by_channel, 0, 1)
+        return patches, np.array(labels).reshape(-1), np.array(mu), np.array(sigma)
+     
+def test_patches(img , mu, sigma, d = 4, h = 33, w = 33):
+    ''' Creates patches of image. Returns a numpy array of dimension number_of_patches x d.
+    
+            INPUT:
+                    1)img: a 3D array containing the all modalities of a 2D image. 
+                    2)d, h, w: see above
+            OUTPUT:
+                    tst_arr: ndarray of all patches of all modalities. Of the form number of patches x modalities
+    '''
+    
+    #list of patches
+    p = []
+    msk = (img[0]+img[1]+img[2]+img[3])!=0.   #mask using FLAIR channel
+    msk = msk[16:-16, 16:-16]      #crop the mask to conform to the rebuilt image after prediction
+    msk = msk.reshape(-1)
+    for i in xrange(len(img)):
+        plist = extract_patches_2d(img[i], (h, w))
+        plist = (plist - mu[i])/sigma[i]
+        p.append(plist[msk])              #only take patches with brain mask!=0
+    return np.array(p).swapaxes(0,1)
+    
+
+def reconstruct_labels(msk, pred_list):
+    im = np.full((208, 208), 0.)
+    msk = msk[16:-16, 16:-16]
+    im[msk] = np.array(pred_list)
+    im = np.pad(im, (16, 16), mode='edge')
+    return im    
\ No newline at end of file