[271336]: / Brain_pipeline.py

Download this file

188 lines (161 with data), 8.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# -*- coding: utf-8 -*-
"""
Created on Fri Oct 14 01:47:13 2016
@author: seeker105
"""
import os.path
import random
import pylab
import numpy as np
from glob import glob
import SimpleITK as sitk
from Nyul import IntensityRangeStandardization
import sys
import timeit
from sklearn.feature_extraction.image import extract_patches_2d
from skimage import color
class Pipeline(object):
''' The Pipeline for loading images for all patients and all modalities
1)find_training_patches: finds the training patches for a particular class
INPUT:
1) The filepath 'path': Directory of the image database. It contains the slices of training image slices
'''
def __init__(self, path_train = '', path_test = '' , mx_train = 1000000, mx_tst = 1000000):
self.path_train = path_train
self.path_test = path_test
self.scans_train, self.scans_test, self.train_im, self.test_im = self.read_scans(mx_train, mx_tst)
def read_scans(self, mx_train, mx_test):
scans_train = glob(self.path_train + r'/*.mha')
scans_test = glob(self.path_test + r'/*.mha')
train_im = [sitk.GetArrayFromImage(sitk.ReadImage(scans_train[i])) for i in xrange(min(len(scans_train), mx_train))]
test_im = [sitk.GetArrayFromImage(sitk.ReadImage(scans_test[i])) for i in xrange(min(len(scans_test), mx_test))]
return scans_train, scans_test, np.array(train_im), np.array(test_im)
def n4itk(self, img):
img = sitk.Cast(img, sitk.sitkFloat32)
img_mask = sitk.BinaryNot(sitk.BinaryThreshold(img, 0, 0)) ## Create a mask spanning the part containing the brain, as we want to apply the filter to the brain image
corrected_img = sitk.N4BiasFieldCorrection(img, img_mask)
return corrected_img
'''def find_all_train(self, classes, d = 4, h = 33, w = 33):
mn = 300000000000000
#load all labels
im_ar = []
for i in self.pathnames_train:
im_ar.append([sitk.GetArrayFromImage(sitk.ReadImage(idx)) for idx in i])
im_ar = np.array(im_ar)
for i in xrange(classes):
mn = min(mn, len(np.argwhere(im_ar[i]==i)))
'''
def sample_training_patches(self, num_patches, class_nm, d = 4, h = 33, w = 33):
''' Creates the input patches and their labels for training CNN. The patches are 4x33x33
and the label for a patch equals to the label for the central pixel of the patch.
INPUT:
1) num_patches: The number of patches required of the class.
2) class_nm: The index of class label for which we are finding patches.
3) d, h, w: number of channels, height and width of patch
OUTPUT:
1) patches: The list of all patches of dimensions d, h, w.
2) labels: The list of labels for each patch. Label for a patch corresponds to the label
of the central pixel of that patch.
'''
#find patches for training
patches, labels = [], np.full(num_patches, fill_value = class_nm, dtype = np.int32)
count = 0
# convert gt_im to 1D and save shape
gt_im = np.swapaxes(self.train_im, 0, 1)[4] #swap axes to make axis 0 represent the modality and axis 1 represent the slice. take the ground truth
#take flair image as mask
msk = np.swapaxes(self.train_im, 0, 1)[0]
tmp_shp = gt_im.shape
gt_im = gt_im.reshape(-1)
msk = msk.reshape(-1)
# maintain list of 1D indices where label = class_nm
indices = np.squeeze(np.argwhere((gt_im == class_nm) & (msk != 0.)))
# shuffle the list of indices of the class
st = timeit.default_timer()
np.random.shuffle(indices)
print 'shuffling of label {} took :'.format(class_nm), timeit.default_timer()-st
#reshape gt_im
gt_im = gt_im.reshape(tmp_shp)
st = timeit.default_timer()
#find the patches from the images
i = 0
pix = len(indices)
while (count<num_patches) and (pix>i):
#print (count, ' cl:' ,class_nm)
#sys.stdout.flush()
#randomly choose an index
ind = indices[i]
i+= 1
#reshape ind to 3D index
ind = np.unravel_index(ind, tmp_shp)
#print ind
#sys.stdout.flush()
#find the slice index to choose from
slice_idx = ind[0]
#load the slice from the label
l = gt_im[slice_idx]
# the centre pixel and its coordinates
p = ind[1:]
#construct the patch by defining the coordinates
p_x = (p[0] - h/2, p[0] + (h+1)/2)
p_y = (p[1] - w/2, p[1] + (w+1)/2)
#check if the pixels are in range
if p_x[0]<0 or p_x[1]>l.shape[0] or p_y[0]<0 or p_y[1]>l.shape[1]:
continue
#take patches from all modalities and group them together
tmp = self.train_im[slice_idx][0:4, p_x[0]:p_x[1], p_y[0]:p_y[1]]
patches.append(tmp)
count+=1
print 'finding patches of label {} took :'.format(class_nm), timeit.default_timer()-st
patches = np.array(patches)
return patches, labels
def training_patches(self, num_patches, classes = 5, d = 4, h = 33, w = 33):
'''Creates the input patches and their labels for training CNN. The patches are 4x33x33
and the label for a patch corresponds to the label for the central voxel of the patch. The
data will be balanced, with the number of patches being the same for each class
INPUT:
1) classes: number of all classes in the segmentation
2) num_patches: number of patches for each class
3) d, h, w : channels, height and width of the patches
OUTPUT:
1) all_patches: numpy array of all class patches of the shape 4x33x33
2) all_labels : numpy array of the all_patches labels
'''
patches, labels, mu, sigma = [], [], [], []
for idx in xrange(classes):
p, l = self.sample_training_patches(num_patches[idx], idx, d, h, w)
patches.append(p)
labels.append(l)
patches = np.vstack(np.array(patches))
patches_by_channel = np.swapaxes(patches, 0, 1)
for seq, i in zip(patches_by_channel, xrange(d)):
avg = np.mean(seq)
std = np.std(seq)
patches_by_channel[i] = (patches_by_channel[i] - avg)/std
mu.append(avg)
sigma.append(std)
patches = np.swapaxes(patches_by_channel, 0, 1)
return patches, np.array(labels).reshape(-1), np.array(mu), np.array(sigma)
def test_patches(img , mu, sigma, d = 4, h = 33, w = 33):
''' Creates patches of image. Returns a numpy array of dimension number_of_patches x d.
INPUT:
1)img: a 3D array containing the all modalities of a 2D image.
2)d, h, w: see above
OUTPUT:
tst_arr: ndarray of all patches of all modalities. Of the form number of patches x modalities
'''
#list of patches
p = []
msk = (img[0]+img[1]+img[2]+img[3])!=0. #mask using FLAIR channel
msk = msk[16:-16, 16:-16] #crop the mask to conform to the rebuilt image after prediction
msk = msk.reshape(-1)
for i in xrange(len(img)):
plist = extract_patches_2d(img[i], (h, w))
plist = (plist - mu[i])/sigma[i]
p.append(plist[msk]) #only take patches with brain mask!=0
return np.array(p).swapaxes(0,1)
def reconstruct_labels(msk, pred_list):
im = np.full((208, 208), 0.)
msk = msk[16:-16, 16:-16]
im[msk] = np.array(pred_list)
im = np.pad(im, (16, 16), mode='edge')
return im