--- a +++ b/2D/data_handling_2d_patch.py @@ -0,0 +1,250 @@ + +from __future__ import print_function + +# import packages +import time, os, random, cv2 +import numpy as np +import nibabel as nib +from sklearn.feature_extraction.image import extract_patches + +# import configurations +import configs + +# init configs +image_rows = configs.VOLUME_ROWS +image_cols = configs.VOLUME_COLS +image_depth = configs.VOLUME_DEPS +num_classes = configs.NUM_CLASSES + +# patch extraction parameters +patch_size = configs.PATCH_SIZE +extraction_step = configs.EXTRACTTION_STEP +extraction_step_csf_only = configs.EXTRACTTION_STEP_CSF + +# create npy data +def create_npy_data(train_imgs_path, is_extract_more_csf, is_train): + # empty matrix to hold patches + patches_training_imgs_2d=np.empty(shape=[0, patch_size,patch_size], dtype='int16') + patches_training_gtruth_2d=np.empty(shape=[0, patch_size,patch_size, num_classes], dtype='int16') + + images_train_dir = os.listdir(train_imgs_path) + start_time = time.time() + + j=0 + print('-'*30) + print('Creating training 3d_patches...') + print('-'*30) + + print('Hello') + + # for each volume do: + for img_dir_name in images_train_dir: + patches_training_imgs_2d_temp = np.empty(shape=[0,patch_size,patch_size], dtype='int16') + patches_training_gtruth_2d_temp = np.empty(shape=[0,patch_size,patch_size,num_classes], dtype='int16') + print('Processing: volume {0} / {1} volume images'.format(j+1, len(images_train_dir))) + + # volume + img_name = img_dir_name + '_hist.nii.gz' + img_name = os.path.join(train_imgs_path, img_dir_name, img_name) + + # groundtruth + img_seg_name = img_dir_name + '_seg.nii.gz' + img_seg_name = os.path.join(train_imgs_path, img_dir_name, img_seg_name) + + # mask + img_mask_name = img_dir_name + '_mask.nii.gz' + img_mask_name = os.path.join(train_imgs_path, img_dir_name, img_mask_name) + + # load volume, gt and mask + img = nib.load(img_name) + img_data = img.get_data() + img_data = np.squeeze(img_data) + + img_gtruth = nib.load(img_seg_name) + img_gtruth_data = img_gtruth.get_data() + img_gtruth_data = np.squeeze(img_gtruth_data) + + img_mask = nib.load(img_mask_name) + img_mask_data = img_gtruth.get_data() + img_mask_data = np.squeeze(img_mask_data) + + # for each slice do + for slice in range(img_gtruth_data.shape[2]): + print(slice) + patches_training_imgs_2d_slice_temp = np.empty(shape=[0,patch_size,patch_size], dtype='int16') + patches_training_gtruth_2d_slice_temp = np.empty(shape=[0,patch_size,patch_size,num_classes], dtype='int16') + if np.count_nonzero(img_gtruth_data[:,:,slice]) and np.count_nonzero(img_data[:,:,slice]): + + # extract patches of the jth volum image + imgs_patches, gt_patches = extract_2d_patches(img_data[:,:,slice], \ + img_gtruth_data[:,:,slice], \ + img_mask_data[:,:,slice], \ + is_extract_more_csf) + + # update database + patches_training_imgs_2d_slice_temp = np.append(patches_training_imgs_2d_slice_temp,imgs_patches, axis=0) + patches_training_gtruth_2d_slice_temp = np.append(patches_training_gtruth_2d_slice_temp,gt_patches, axis=0) + + patches_training_imgs_2d_temp = np.append(patches_training_imgs_2d_temp,patches_training_imgs_2d_slice_temp, axis=0) + patches_training_gtruth_2d_temp = np.append(patches_training_gtruth_2d_temp,patches_training_gtruth_2d_slice_temp, axis=0) + + patches_training_imgs_2d = np.append(patches_training_imgs_2d,patches_training_imgs_2d_temp, axis=0) + patches_training_gtruth_2d = np.append(patches_training_gtruth_2d,patches_training_gtruth_2d_temp, axis=0) + j += 1 + X = patches_training_imgs_2d.shape + Y = patches_training_gtruth_2d.shape + print('shape im: [{0} , {1} , {2}]'.format(X[0], X[1], X[2])) + print('shape gt: [{0} , {1} , {2}, {3}]'.format(Y[0], Y[1], Y[2], Y[3])) + + #convert to single precission + patches_training_imgs_2d = patches_training_imgs_2d.astype('float32') + patches_training_imgs_2d = np.expand_dims(patches_training_imgs_2d, axis=3) + + end_time = time.time() + print("Elapsed time was %g seconds" % (end_time - start_time)) + + X = patches_training_imgs_2d.shape + Y = patches_training_gtruth_2d.shape + + print('-'*30) + print('Training set detail...') + print('-'*30) + print('shape im: [{0} , {1} , {2}, {3}]'.format(X[0], X[1], X[2], X[3])) + print('shape gt: [{0} , {1} , {2}, {3}]'.format(Y[0], Y[1], Y[2], Y[3])) + + S = patches_training_imgs_2d.shape + print('Done: {0} 2d patches added from {1} volume images'.format(S[0], j)) + print('Loading done.') + + print('Saving to .npy files done.') + + # save train or validation + if is_train: + np.save('imdbs_2d_patch/patches_training_imgs_2d.npy', patches_training_imgs_2d) + np.save('imdbs_2d_patch/patches_training_gtruth_2d.npy', patches_training_gtruth_2d) + else: + np.save('imdbs_2d_patch/patches_val_imgs_2d.npy', patches_training_imgs_2d) + np.save('imdbs_2d_patch/patches_val_gtruth_2d.npy', patches_training_gtruth_2d) + print('Saving to .npy files done.') + +# extract 2d patches +def extract_2d_patches(img_data, gt_data, mask_data, is_extract_more_csf): + patch_shape =(patch_size,patch_size) + # empty matrix to hold patches + imgs_patches_per_slice=np.empty(shape=[0,patch_size,patch_size], dtype='int16') + gt_patches_per_slice=np.empty(shape=[0,patch_size,patch_size], dtype='int16') + mask_patches_per_slice=np.empty(shape=[0,patch_size,patch_size], dtype='int16') + + img_patches = extract_patches(img_data, patch_shape, extraction_step) + gt_patches = extract_patches(gt_data, patch_shape, extraction_step) + mask_patches = extract_patches(mask_data, patch_shape, extraction_step) + + # extract patches which has center pixel lying inside mask + rows = []; cols = [] + for i in range(0,mask_patches.shape[0]): + for j in range(0,mask_patches.shape[1]): + a1 = mask_patches.item((i,j,int(patch_size/2-1),int(patch_size/2-1))) + a2 = mask_patches.item((i,j,int(patch_size/2-1),int(patch_size/2))) + a3 = mask_patches.item((i,j,int(patch_size/2),int(patch_size/2-1))) + a4 = mask_patches.item((i,j,int(patch_size/2),int(patch_size/2))) + Sum = a1 + a2 + a3 + a4 + if Sum > 0: + rows.append(i) + cols.append(j) + + # number of n0m zero patches + N = len(rows) + + # select nonzeropatches index + selected_img_patches = img_patches[rows,cols,:,:] + selected_gt_patches = gt_patches [rows,cols,:,:] + + # update database + imgs_patches_per_slice = np.append(imgs_patches_per_slice,selected_img_patches, axis=0) + gt_patches_per_slice = np.append(gt_patches_per_slice,selected_gt_patches, axis=0) + + #extract more pathes for CSF + if is_extract_more_csf: + #creat CSF mask + extraction_step_csf = extraction_step_csf_only + img_patches_csf = extract_patches(img_data, patch_shape, extraction_step_csf) + gt_patches_csf = extract_patches(gt_data, patch_shape, extraction_step_csf) + + # extract CSF patches with small step + rows = []; cols = [] + for i in range(0,gt_patches_csf.shape[0]): + for j in range(0,gt_patches_csf.shape[1]): + a1 = gt_patches_csf.item((i,j,int(patch_size/2-1),int(patch_size/2-1))) + a2 = gt_patches_csf.item((i,j,int(patch_size/2-1),int(patch_size/2))) + a3 = gt_patches_csf.item((i,j,int(patch_size/2),int(patch_size/2-1))) + a4 = gt_patches_csf.item((i,j,int(patch_size/2),int(patch_size/2))) + Sum = (a1==1 or a2==1 or a3==1 or a4==1) + if Sum: + rows.append(i) + cols.append(j) + + N = len(rows) + if N is not 0: + csf_more_img_patches = img_patches_csf[rows,cols,:,:] + csf_more_gt_patches = gt_patches_csf[rows,cols,:,:] + + # update database + imgs_patches_per_slice = np.append(imgs_patches_per_slice,csf_more_img_patches, axis=0) + gt_patches_per_slice = np.append(gt_patches_per_slice,csf_more_gt_patches, axis=0) + + # convert to categorical + gt_patches_per_slice = separate_labels(gt_patches_per_slice) + return imgs_patches_per_slice, gt_patches_per_slice + + +# separate labels +def separate_labels(patch_3d_volume): + result =np.empty(shape=[0,patch_size,patch_size,num_classes], dtype='int16') + N = patch_3d_volume.shape[0] + # for each class do: + for V in range(N): + V_patch = patch_3d_volume[V , :, :] + U = np.unique(V_patch) + unique_values = list(U) + result_v =np.empty(shape=[patch_size,patch_size,0], dtype='int16') + if num_classes==3: + start_point = 1 + else: + start_point = 0 + for label in range(start_point,4): + if label in unique_values: + im_patch = V_patch == label + im_patch = im_patch*1 + else: + im_patch = np.zeros((V_patch.shape)) + + im_patch = np.expand_dims(im_patch, axis=2) + result_v = np.append(result_v,im_patch, axis=2) + result_v = np.expand_dims(result_v, axis=0) + result = np.append(result,result_v, axis=0) + return result + +# load train npy +def load_train_data(): + imgs_train = np.load('imdbs_2d_patch/patches_training_imgs_2d.npy') + imgs_gtruth_train = np.load('imdbs_2d_patch/patches_training_gtruth_2d.npy') + return imgs_train, imgs_gtruth_train + +# load validation npy +def load_validatation_data(): + imgs_validation = np.load('imdbs_2d_patch/patches_val_imgs_2d.npy') + gtruth_validation = np.load('imdbs_2d_patch/patches_val_gtruth_2d.npy') + return imgs_validation, gtruth_validation + +# main +if __name__ == '__main__': + if 'imdbs_2d_patch' not in os.listdir(os.curdir): + os.mkdir('imdbs_2d_patch') + train_imgs_path= '../data_new2/Training_Set' + val_imgs_path= '../data_new2/Validation_Set' + print(train_imgs_path) + print(val_imgs_path) + is_extract_more_csf = 0 + create_npy_data(train_imgs_path, is_extract_more_csf, 1) + is_extract_more_csf = 0 + create_npy_data(val_imgs_path, is_extract_more_csf, 0)