--- a +++ b/Brain_pipeline.py @@ -0,0 +1,188 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Oct 14 01:47:13 2016 + +@author: seeker105 +""" +import os.path +import random +import pylab +import numpy as np +from glob import glob +import SimpleITK as sitk +from Nyul import IntensityRangeStandardization +import sys +import timeit +from sklearn.feature_extraction.image import extract_patches_2d +from skimage import color + +class Pipeline(object): + ''' The Pipeline for loading images for all patients and all modalities + 1)find_training_patches: finds the training patches for a particular class + + INPUT: + 1) The filepath 'path': Directory of the image database. It contains the slices of training image slices + + ''' + + def __init__(self, path_train = '', path_test = '' , mx_train = 1000000, mx_tst = 1000000): + self.path_train = path_train + self.path_test = path_test + self.scans_train, self.scans_test, self.train_im, self.test_im = self.read_scans(mx_train, mx_tst) + + + def read_scans(self, mx_train, mx_test): + + scans_train = glob(self.path_train + r'/*.mha') + scans_test = glob(self.path_test + r'/*.mha') + train_im = [sitk.GetArrayFromImage(sitk.ReadImage(scans_train[i])) for i in xrange(min(len(scans_train), mx_train))] + test_im = [sitk.GetArrayFromImage(sitk.ReadImage(scans_test[i])) for i in xrange(min(len(scans_test), mx_test))] + return scans_train, scans_test, np.array(train_im), np.array(test_im) + + + def n4itk(self, img): + img = sitk.Cast(img, sitk.sitkFloat32) + img_mask = sitk.BinaryNot(sitk.BinaryThreshold(img, 0, 0)) ## Create a mask spanning the part containing the brain, as we want to apply the filter to the brain image + corrected_img = sitk.N4BiasFieldCorrection(img, img_mask) + return corrected_img + + '''def find_all_train(self, classes, d = 4, h = 33, w = 33): + mn = 300000000000000 + #load all labels + im_ar = [] + for i in self.pathnames_train: + im_ar.append([sitk.GetArrayFromImage(sitk.ReadImage(idx)) for idx in i]) + im_ar = np.array(im_ar) + for i in xrange(classes): + mn = min(mn, len(np.argwhere(im_ar[i]==i))) + ''' + def sample_training_patches(self, num_patches, class_nm, d = 4, h = 33, w = 33): + ''' Creates the input patches and their labels for training CNN. The patches are 4x33x33 + and the label for a patch equals to the label for the central pixel of the patch. + + INPUT: + 1) num_patches: The number of patches required of the class. + + 2) class_nm: The index of class label for which we are finding patches. + + 3) d, h, w: number of channels, height and width of patch + + OUTPUT: + 1) patches: The list of all patches of dimensions d, h, w. + + 2) labels: The list of labels for each patch. Label for a patch corresponds to the label + of the central pixel of that patch. + + ''' + + #find patches for training + patches, labels = [], np.full(num_patches, fill_value = class_nm, dtype = np.int32) + count = 0 + # convert gt_im to 1D and save shape + gt_im = np.swapaxes(self.train_im, 0, 1)[4] #swap axes to make axis 0 represent the modality and axis 1 represent the slice. take the ground truth + #take flair image as mask + msk = np.swapaxes(self.train_im, 0, 1)[0] + tmp_shp = gt_im.shape + gt_im = gt_im.reshape(-1) + msk = msk.reshape(-1) + # maintain list of 1D indices where label = class_nm + indices = np.squeeze(np.argwhere((gt_im == class_nm) & (msk != 0.))) + # shuffle the list of indices of the class + st = timeit.default_timer() + np.random.shuffle(indices) + print 'shuffling of label {} took :'.format(class_nm), timeit.default_timer()-st + #reshape gt_im + gt_im = gt_im.reshape(tmp_shp) + st = timeit.default_timer() + #find the patches from the images + i = 0 + pix = len(indices) + while (count<num_patches) and (pix>i): + #print (count, ' cl:' ,class_nm) + #sys.stdout.flush() + #randomly choose an index + ind = indices[i] + i+= 1 + #reshape ind to 3D index + ind = np.unravel_index(ind, tmp_shp) + #print ind + #sys.stdout.flush() + #find the slice index to choose from + slice_idx = ind[0] + #load the slice from the label + l = gt_im[slice_idx] + # the centre pixel and its coordinates + p = ind[1:] + #construct the patch by defining the coordinates + p_x = (p[0] - h/2, p[0] + (h+1)/2) + p_y = (p[1] - w/2, p[1] + (w+1)/2) + #check if the pixels are in range + if p_x[0]<0 or p_x[1]>l.shape[0] or p_y[0]<0 or p_y[1]>l.shape[1]: + continue + #take patches from all modalities and group them together + tmp = self.train_im[slice_idx][0:4, p_x[0]:p_x[1], p_y[0]:p_y[1]] + patches.append(tmp) + count+=1 + print 'finding patches of label {} took :'.format(class_nm), timeit.default_timer()-st + patches = np.array(patches) + return patches, labels + + + def training_patches(self, num_patches, classes = 5, d = 4, h = 33, w = 33): + '''Creates the input patches and their labels for training CNN. The patches are 4x33x33 + and the label for a patch corresponds to the label for the central voxel of the patch. The + data will be balanced, with the number of patches being the same for each class + + INPUT: + 1) classes: number of all classes in the segmentation + 2) num_patches: number of patches for each class + 3) d, h, w : channels, height and width of the patches + OUTPUT: + 1) all_patches: numpy array of all class patches of the shape 4x33x33 + 2) all_labels : numpy array of the all_patches labels + ''' + + patches, labels, mu, sigma = [], [], [], [] + for idx in xrange(classes): + p, l = self.sample_training_patches(num_patches[idx], idx, d, h, w) + patches.append(p) + labels.append(l) + patches = np.vstack(np.array(patches)) + patches_by_channel = np.swapaxes(patches, 0, 1) + for seq, i in zip(patches_by_channel, xrange(d)): + avg = np.mean(seq) + std = np.std(seq) + patches_by_channel[i] = (patches_by_channel[i] - avg)/std + mu.append(avg) + sigma.append(std) + patches = np.swapaxes(patches_by_channel, 0, 1) + return patches, np.array(labels).reshape(-1), np.array(mu), np.array(sigma) + +def test_patches(img , mu, sigma, d = 4, h = 33, w = 33): + ''' Creates patches of image. Returns a numpy array of dimension number_of_patches x d. + + INPUT: + 1)img: a 3D array containing the all modalities of a 2D image. + 2)d, h, w: see above + OUTPUT: + tst_arr: ndarray of all patches of all modalities. Of the form number of patches x modalities + ''' + + #list of patches + p = [] + msk = (img[0]+img[1]+img[2]+img[3])!=0. #mask using FLAIR channel + msk = msk[16:-16, 16:-16] #crop the mask to conform to the rebuilt image after prediction + msk = msk.reshape(-1) + for i in xrange(len(img)): + plist = extract_patches_2d(img[i], (h, w)) + plist = (plist - mu[i])/sigma[i] + p.append(plist[msk]) #only take patches with brain mask!=0 + return np.array(p).swapaxes(0,1) + + +def reconstruct_labels(msk, pred_list): + im = np.full((208, 208), 0.) + msk = msk[16:-16, 16:-16] + im[msk] = np.array(pred_list) + im = np.pad(im, (16, 16), mode='edge') + return im \ No newline at end of file