[1bd6b2]: / BraTs18Challege / inference_Brats.py

Download this file

137 lines (127 with data), 6.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from __future__ import print_function, division
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())
from Vnet.model_vnet3d_multilabel import Vnet3dModuleMultiLabel
from Vnet.util import getLargestConnectedCompont
from dataprocess.utils import calcu_dice, file_name_path
from dataprocess.data3dprepare import normalize
import numpy as np
import SimpleITK as sitk
def inference():
"""
Vnet network segmentation brats fine segmatation
:return:
"""
channel = 4
numclass = 4
flair_name = "_flair.nii.gz"
t1_name = "_t1.nii.gz"
t1ce_name = "_t1ce.nii.gz"
t2_name = "_t2.nii.gz"
mask_name = "_seg.nii.gz"
out_mask_name = "_outseg.nii.gz"
# step1 init vnet model
depth_z = 48
Vnet3d = Vnet3dModuleMultiLabel(240, 240, depth_z, channels=channel, numclass=numclass,
costname=("categorical_dice",), inference=True,
model_path="log\segmeation2mm\weighted_categorical_crossentropy\model\Vnet3d.pd-10000")
brats_path = "D:\Data\\brats18\\test"
# step2 get all test image path
dice_values0 = []
dice_values1 = []
dice_values2 = []
dice_values3 = []
path_list = file_name_path(brats_path)
# step3 get test image(4 model) and mask
for subsetindex in range(len(path_list)):
# step4 load test image(4 model) and mask as ndarray
brats_subset_path = brats_path + "/" + str(path_list[subsetindex]) + "/"
flair_image = brats_subset_path + str(path_list[subsetindex]) + flair_name
t1_image = brats_subset_path + str(path_list[subsetindex]) + t1_name
t1ce_image = brats_subset_path + str(path_list[subsetindex]) + t1ce_name
t2_image = brats_subset_path + str(path_list[subsetindex]) + t2_name
mask_image = brats_subset_path + str(path_list[subsetindex]) + mask_name
flair_src = sitk.ReadImage(flair_image, sitk.sitkInt16)
t1_src = sitk.ReadImage(t1_image, sitk.sitkInt16)
t1ce_src = sitk.ReadImage(t1ce_image, sitk.sitkInt16)
t2_src = sitk.ReadImage(t2_image, sitk.sitkInt16)
mask = sitk.ReadImage(mask_image, sitk.sitkUInt8)
flair_array = sitk.GetArrayFromImage(flair_src)
t1_array = sitk.GetArrayFromImage(t1_src)
t1ce_array = sitk.GetArrayFromImage(t1ce_src)
t2_array = sitk.GetArrayFromImage(t2_src)
label = sitk.GetArrayFromImage(mask)
# step5 mormazalation test image(4 model) and merage to 4 channels ndarray
flair_array = normalize(flair_array)
t1_array = normalize(t1_array)
t1ce_array = normalize(t1ce_array)
t2_array = normalize(t2_array)
imagez, height, width = np.shape(flair_array)[0], np.shape(flair_array)[1], np.shape(flair_array)[2]
fourmodelimagearray = np.zeros((imagez, height, width, channel), np.float)
fourmodelimagearray[:, :, :, 0] = flair_array
fourmodelimagearray[:, :, :, 1] = t1_array
fourmodelimagearray[:, :, :, 2] = t1ce_array
fourmodelimagearray[:, :, :, 3] = t2_array
ys_pd_array = np.zeros((imagez, height, width), np.uint8)
# step6 predict test image(4 model)
last_depth = 0
for depth in range(0, imagez // depth_z, 1):
patch_xs = fourmodelimagearray[depth * depth_z:(depth + 1) * depth_z, :, :, :]
pathc_pd = Vnet3d.prediction(patch_xs)
ys_pd_array[depth * depth_z:(depth + 1) * depth_z, :, :] = pathc_pd
last_depth = depth
if imagez != depth_z * last_depth:
patch_xs = fourmodelimagearray[(imagez - depth_z):imagez, :, :, :]
pathc_pd = Vnet3d.prediction(patch_xs)
ys_pd_array[(imagez - depth_z):imagez, :, :] = pathc_pd
ys_pd_array = np.clip(ys_pd_array, 0, 255).astype('uint8')
all_ys_pd_array = ys_pd_array.copy()
all_ys_pd_array[ys_pd_array != 0] = 1
outmask = getLargestConnectedCompont(sitk.GetImageFromArray(all_ys_pd_array))
ys_pd_array[outmask == 0] = 0
# step7 calcu test mask and predict mask dice value
batch_ys = label.copy()
batch_ys[label == 4] = 3
dice_value0 = 0
dice_value1 = 0
dice_value2 = 0
dice_value3 = 0
for num_class in range(4):
ys_pd_array_tmp = ys_pd_array.copy()
batch_ys_tmp = batch_ys.copy()
ys_pd_array_tmp[ys_pd_array == num_class] = 1
batch_ys_tmp[label == num_class] = 1
if num_class == 0:
dice_value0 = calcu_dice(ys_pd_array_tmp, batch_ys_tmp, 1)
if num_class == 1:
dice_value1 = calcu_dice(ys_pd_array_tmp, batch_ys_tmp, 1)
if num_class == 2:
dice_value2 = calcu_dice(ys_pd_array_tmp, batch_ys_tmp, 1)
if num_class == 3:
dice_value3 = calcu_dice(ys_pd_array_tmp, batch_ys_tmp, 1)
print("index,dice:", (subsetindex, dice_value0, dice_value1, dice_value2, dice_value3))
dice_values0.append(dice_value0)
dice_values1.append(dice_value1)
dice_values2.append(dice_value2)
dice_values3.append(dice_value3)
# step8 out put predict mask
ys_pd_array = ys_pd_array.astype('float')
outputmask = np.zeros((imagez, height, width), np.uint8)
outputmask[ys_pd_array == 1] = 1
outputmask[ys_pd_array == 2] = 2
outputmask[ys_pd_array == 3] = 4
ys_pd_itk = sitk.GetImageFromArray(outputmask)
ys_pd_itk.SetSpacing(mask.GetSpacing())
ys_pd_itk.SetOrigin(mask.GetOrigin())
ys_pd_itk.SetDirection(mask.GetDirection())
out_mask_image = brats_subset_path + str(path_list[subsetindex]) + out_mask_name
sitk.WriteImage(ys_pd_itk, out_mask_image)
average0 = sum(dice_values0) / len(dice_values0)
average1 = sum(dice_values1) / len(dice_values1)
average2 = sum(dice_values2) / len(dice_values2)
average3 = sum(dice_values3) / len(dice_values3)
print("average dice:", (average0, average1, average2, average3))
inference()