--- a +++ b/BraTs18Challege/inference_Brats.py @@ -0,0 +1,136 @@ +from __future__ import print_function, division +import os + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +from tensorflow.python.client import device_lib + +print(device_lib.list_local_devices()) + +from Vnet.model_vnet3d_multilabel import Vnet3dModuleMultiLabel +from Vnet.util import getLargestConnectedCompont +from dataprocess.utils import calcu_dice, file_name_path +from dataprocess.data3dprepare import normalize +import numpy as np +import SimpleITK as sitk + + +def inference(): + """ + Vnet network segmentation brats fine segmatation + :return: + """ + channel = 4 + numclass = 4 + flair_name = "_flair.nii.gz" + t1_name = "_t1.nii.gz" + t1ce_name = "_t1ce.nii.gz" + t2_name = "_t2.nii.gz" + mask_name = "_seg.nii.gz" + out_mask_name = "_outseg.nii.gz" + # step1 init vnet model + depth_z = 48 + Vnet3d = Vnet3dModuleMultiLabel(240, 240, depth_z, channels=channel, numclass=numclass, + costname=("categorical_dice",), inference=True, + model_path="log\segmeation2mm\weighted_categorical_crossentropy\model\Vnet3d.pd-10000") + brats_path = "D:\Data\\brats18\\test" + # step2 get all test image path + dice_values0 = [] + dice_values1 = [] + dice_values2 = [] + dice_values3 = [] + path_list = file_name_path(brats_path) + # step3 get test image(4 model) and mask + for subsetindex in range(len(path_list)): + # step4 load test image(4 model) and mask as ndarray + brats_subset_path = brats_path + "/" + str(path_list[subsetindex]) + "/" + flair_image = brats_subset_path + str(path_list[subsetindex]) + flair_name + t1_image = brats_subset_path + str(path_list[subsetindex]) + t1_name + t1ce_image = brats_subset_path + str(path_list[subsetindex]) + t1ce_name + t2_image = brats_subset_path + str(path_list[subsetindex]) + t2_name + mask_image = brats_subset_path + str(path_list[subsetindex]) + mask_name + flair_src = sitk.ReadImage(flair_image, sitk.sitkInt16) + t1_src = sitk.ReadImage(t1_image, sitk.sitkInt16) + t1ce_src = sitk.ReadImage(t1ce_image, sitk.sitkInt16) + t2_src = sitk.ReadImage(t2_image, sitk.sitkInt16) + mask = sitk.ReadImage(mask_image, sitk.sitkUInt8) + flair_array = sitk.GetArrayFromImage(flair_src) + t1_array = sitk.GetArrayFromImage(t1_src) + t1ce_array = sitk.GetArrayFromImage(t1ce_src) + t2_array = sitk.GetArrayFromImage(t2_src) + label = sitk.GetArrayFromImage(mask) + # step5 mormazalation test image(4 model) and merage to 4 channels ndarray + flair_array = normalize(flair_array) + t1_array = normalize(t1_array) + t1ce_array = normalize(t1ce_array) + t2_array = normalize(t2_array) + + imagez, height, width = np.shape(flair_array)[0], np.shape(flair_array)[1], np.shape(flair_array)[2] + fourmodelimagearray = np.zeros((imagez, height, width, channel), np.float) + fourmodelimagearray[:, :, :, 0] = flair_array + fourmodelimagearray[:, :, :, 1] = t1_array + fourmodelimagearray[:, :, :, 2] = t1ce_array + fourmodelimagearray[:, :, :, 3] = t2_array + ys_pd_array = np.zeros((imagez, height, width), np.uint8) + # step6 predict test image(4 model) + last_depth = 0 + for depth in range(0, imagez // depth_z, 1): + patch_xs = fourmodelimagearray[depth * depth_z:(depth + 1) * depth_z, :, :, :] + pathc_pd = Vnet3d.prediction(patch_xs) + ys_pd_array[depth * depth_z:(depth + 1) * depth_z, :, :] = pathc_pd + last_depth = depth + if imagez != depth_z * last_depth: + patch_xs = fourmodelimagearray[(imagez - depth_z):imagez, :, :, :] + pathc_pd = Vnet3d.prediction(patch_xs) + ys_pd_array[(imagez - depth_z):imagez, :, :] = pathc_pd + + ys_pd_array = np.clip(ys_pd_array, 0, 255).astype('uint8') + all_ys_pd_array = ys_pd_array.copy() + all_ys_pd_array[ys_pd_array != 0] = 1 + outmask = getLargestConnectedCompont(sitk.GetImageFromArray(all_ys_pd_array)) + ys_pd_array[outmask == 0] = 0 + # step7 calcu test mask and predict mask dice value + batch_ys = label.copy() + batch_ys[label == 4] = 3 + dice_value0 = 0 + dice_value1 = 0 + dice_value2 = 0 + dice_value3 = 0 + for num_class in range(4): + ys_pd_array_tmp = ys_pd_array.copy() + batch_ys_tmp = batch_ys.copy() + ys_pd_array_tmp[ys_pd_array == num_class] = 1 + batch_ys_tmp[label == num_class] = 1 + if num_class == 0: + dice_value0 = calcu_dice(ys_pd_array_tmp, batch_ys_tmp, 1) + if num_class == 1: + dice_value1 = calcu_dice(ys_pd_array_tmp, batch_ys_tmp, 1) + if num_class == 2: + dice_value2 = calcu_dice(ys_pd_array_tmp, batch_ys_tmp, 1) + if num_class == 3: + dice_value3 = calcu_dice(ys_pd_array_tmp, batch_ys_tmp, 1) + print("index,dice:", (subsetindex, dice_value0, dice_value1, dice_value2, dice_value3)) + dice_values0.append(dice_value0) + dice_values1.append(dice_value1) + dice_values2.append(dice_value2) + dice_values3.append(dice_value3) + # step8 out put predict mask + ys_pd_array = ys_pd_array.astype('float') + outputmask = np.zeros((imagez, height, width), np.uint8) + outputmask[ys_pd_array == 1] = 1 + outputmask[ys_pd_array == 2] = 2 + outputmask[ys_pd_array == 3] = 4 + ys_pd_itk = sitk.GetImageFromArray(outputmask) + ys_pd_itk.SetSpacing(mask.GetSpacing()) + ys_pd_itk.SetOrigin(mask.GetOrigin()) + ys_pd_itk.SetDirection(mask.GetDirection()) + out_mask_image = brats_subset_path + str(path_list[subsetindex]) + out_mask_name + sitk.WriteImage(ys_pd_itk, out_mask_image) + average0 = sum(dice_values0) / len(dice_values0) + average1 = sum(dice_values1) / len(dice_values1) + average2 = sum(dice_values2) / len(dice_values2) + average3 = sum(dice_values3) / len(dice_values3) + print("average dice:", (average0, average1, average2, average3)) + + +inference()