Diff of /Brain_pipeline.py [000000] .. [271336]

Switch to unified view

a b/Brain_pipeline.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Fri Oct 14 01:47:13 2016
4
5
@author: seeker105
6
"""
7
import os.path
8
import random
9
import pylab
10
import numpy as np
11
from glob import glob
12
import SimpleITK as sitk
13
from Nyul import IntensityRangeStandardization
14
import sys
15
import timeit
16
from sklearn.feature_extraction.image import extract_patches_2d
17
from skimage import color
18
19
class Pipeline(object):
20
    ''' The Pipeline for loading images for all patients and all modalities
21
        1)find_training_patches: finds the training patches for a particular class
22
        
23
        INPUT: 
24
            1) The filepath 'path': Directory of the image database. It contains the slices of training image slices
25
            
26
    '''
27
    
28
    def __init__(self, path_train = '', path_test = '' , mx_train = 1000000, mx_tst = 1000000):
29
        self.path_train = path_train
30
        self.path_test = path_test
31
        self.scans_train, self.scans_test, self.train_im, self.test_im = self.read_scans(mx_train, mx_tst)
32
        
33
        
34
    def read_scans(self, mx_train, mx_test):
35
       
36
        scans_train = glob(self.path_train + r'/*.mha')
37
        scans_test = glob(self.path_test + r'/*.mha')
38
        train_im = [sitk.GetArrayFromImage(sitk.ReadImage(scans_train[i])) for i in xrange(min(len(scans_train), mx_train))]
39
        test_im = [sitk.GetArrayFromImage(sitk.ReadImage(scans_test[i])) for i in xrange(min(len(scans_test), mx_test))]
40
        return scans_train, scans_test, np.array(train_im), np.array(test_im)
41
    
42
    
43
    def n4itk(self, img):
44
        img = sitk.Cast(img, sitk.sitkFloat32)
45
        img_mask = sitk.BinaryNot(sitk.BinaryThreshold(img, 0, 0))   ## Create a mask spanning the part containing the brain, as we want to apply the filter to the brain image
46
        corrected_img = sitk.N4BiasFieldCorrection(img, img_mask)
47
        return corrected_img
48
    
49
    '''def find_all_train(self, classes, d = 4, h = 33, w = 33):
50
        mn = 300000000000000
51
        #load all labels
52
        im_ar = []
53
        for i in self.pathnames_train:
54
            im_ar.append([sitk.GetArrayFromImage(sitk.ReadImage(idx)) for idx in i])
55
        im_ar = np.array(im_ar)
56
        for i in xrange(classes):
57
            mn = min(mn, len(np.argwhere(im_ar[i]==i)))
58
    '''
59
    def sample_training_patches(self, num_patches, class_nm, d = 4, h = 33, w = 33):
60
        ''' Creates the input patches and their labels for training CNN. The patches are 4x33x33
61
        and the label for a patch equals to the label for the central pixel of the patch.
62
        
63
        INPUT:
64
            1) num_patches: The number of patches required of the class.
65
            
66
            2) class_nm:    The index of class label for which we are finding patches.
67
            
68
            3) d, h, w:     number of channels, height and width of patch
69
        
70
        OUTPUT:
71
            1) patches:     The list of all patches of dimensions d, h, w. 
72
            
73
            2) labels:      The list of labels for each patch. Label for a patch corresponds to the label
74
                            of the central pixel of that patch.
75
                            
76
        '''
77
        
78
        #find patches for training
79
        patches, labels = [], np.full(num_patches, fill_value = class_nm,  dtype = np.int32)
80
        count = 0
81
        # convert gt_im to 1D and save shape
82
        gt_im = np.swapaxes(self.train_im, 0, 1)[4]   #swap axes to make axis 0 represent the modality and axis 1 represent the slice. take the ground truth
83
        #take flair image as mask
84
        msk = np.swapaxes(self.train_im, 0, 1)[0]
85
        tmp_shp = gt_im.shape
86
        gt_im = gt_im.reshape(-1)
87
        msk = msk.reshape(-1)
88
        # maintain list of 1D indices where label = class_nm
89
        indices = np.squeeze(np.argwhere((gt_im == class_nm) & (msk != 0.)))
90
        # shuffle the list of indices of the class
91
        st = timeit.default_timer()
92
        np.random.shuffle(indices)
93
        print 'shuffling of label {} took :'.format(class_nm), timeit.default_timer()-st
94
        #reshape gt_im
95
        gt_im = gt_im.reshape(tmp_shp)
96
        st = timeit.default_timer()
97
        #find the patches from the images
98
        i = 0
99
        pix = len(indices)
100
        while (count<num_patches) and (pix>i):
101
            #print (count, ' cl:' ,class_nm)
102
            #sys.stdout.flush()
103
            #randomly choose an index
104
            ind = indices[i]
105
            i+= 1
106
            #reshape ind to 3D index
107
            ind = np.unravel_index(ind, tmp_shp)
108
            #print ind
109
            #sys.stdout.flush()
110
            #find the slice index to choose from
111
            slice_idx = ind[0]
112
            #load the slice from the label
113
            l = gt_im[slice_idx]
114
            # the centre pixel and its coordinates
115
            p = ind[1:]
116
            #construct the patch by defining the coordinates
117
            p_x = (p[0] - h/2, p[0] + (h+1)/2)
118
            p_y = (p[1] - w/2, p[1] + (w+1)/2)
119
            #check if the pixels are in range
120
            if p_x[0]<0 or p_x[1]>l.shape[0] or p_y[0]<0 or p_y[1]>l.shape[1]:
121
                continue
122
            #take patches from all modalities and group them together
123
            tmp = self.train_im[slice_idx][0:4, p_x[0]:p_x[1], p_y[0]:p_y[1]]
124
            patches.append(tmp)
125
            count+=1
126
        print 'finding patches of label {} took :'.format(class_nm), timeit.default_timer()-st
127
        patches = np.array(patches)
128
        return patches, labels
129
        
130
131
    def training_patches(self, num_patches, classes = 5, d = 4, h = 33, w = 33):
132
        '''Creates the input patches and their labels for training CNN. The patches are 4x33x33
133
    and the label for a patch corresponds to the label for the central voxel of the patch. The 
134
    data will be balanced, with the number of patches being the same for each class
135
            
136
            INPUT:
137
                    1) classes:  number of all classes in the segmentation
138
                    2) num_patches: number of patches for each class
139
                    3) d, h, w : channels, height and width of the patches
140
            OUTPUT:
141
                    1) all_patches: numpy array of all class patches of the shape 4x33x33
142
                    2) all_labels : numpy array of the all_patches labels
143
        '''
144
        
145
        patches, labels, mu, sigma = [], [], [], []
146
        for idx in xrange(classes):
147
            p, l = self.sample_training_patches(num_patches[idx], idx, d, h, w)
148
            patches.append(p)
149
            labels.append(l)
150
        patches = np.vstack(np.array(patches)) 
151
        patches_by_channel = np.swapaxes(patches, 0, 1)
152
        for seq, i in zip(patches_by_channel, xrange(d)):
153
            avg = np.mean(seq)
154
            std = np.std(seq)
155
            patches_by_channel[i] = (patches_by_channel[i] - avg)/std
156
            mu.append(avg)
157
            sigma.append(std)
158
        patches = np.swapaxes(patches_by_channel, 0, 1)
159
        return patches, np.array(labels).reshape(-1), np.array(mu), np.array(sigma)
160
     
161
def test_patches(img , mu, sigma, d = 4, h = 33, w = 33):
162
    ''' Creates patches of image. Returns a numpy array of dimension number_of_patches x d.
163
    
164
            INPUT:
165
                    1)img: a 3D array containing the all modalities of a 2D image. 
166
                    2)d, h, w: see above
167
            OUTPUT:
168
                    tst_arr: ndarray of all patches of all modalities. Of the form number of patches x modalities
169
    '''
170
    
171
    #list of patches
172
    p = []
173
    msk = (img[0]+img[1]+img[2]+img[3])!=0.   #mask using FLAIR channel
174
    msk = msk[16:-16, 16:-16]      #crop the mask to conform to the rebuilt image after prediction
175
    msk = msk.reshape(-1)
176
    for i in xrange(len(img)):
177
        plist = extract_patches_2d(img[i], (h, w))
178
        plist = (plist - mu[i])/sigma[i]
179
        p.append(plist[msk])              #only take patches with brain mask!=0
180
    return np.array(p).swapaxes(0,1)
181
    
182
183
def reconstruct_labels(msk, pred_list):
184
    im = np.full((208, 208), 0.)
185
    msk = msk[16:-16, 16:-16]
186
    im[msk] = np.array(pred_list)
187
    im = np.pad(im, (16, 16), mode='edge')
188
    return im