|
a |
|
b/BraTs18Challege/inference_Brats.py |
|
|
1 |
from __future__ import print_function, division |
|
|
2 |
import os |
|
|
3 |
|
|
|
4 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
|
5 |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
6 |
from tensorflow.python.client import device_lib |
|
|
7 |
|
|
|
8 |
print(device_lib.list_local_devices()) |
|
|
9 |
|
|
|
10 |
from Vnet.model_vnet3d_multilabel import Vnet3dModuleMultiLabel |
|
|
11 |
from Vnet.util import getLargestConnectedCompont |
|
|
12 |
from dataprocess.utils import calcu_dice, file_name_path |
|
|
13 |
from dataprocess.data3dprepare import normalize |
|
|
14 |
import numpy as np |
|
|
15 |
import SimpleITK as sitk |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
def inference(): |
|
|
19 |
""" |
|
|
20 |
Vnet network segmentation brats fine segmatation |
|
|
21 |
:return: |
|
|
22 |
""" |
|
|
23 |
channel = 4 |
|
|
24 |
numclass = 4 |
|
|
25 |
flair_name = "_flair.nii.gz" |
|
|
26 |
t1_name = "_t1.nii.gz" |
|
|
27 |
t1ce_name = "_t1ce.nii.gz" |
|
|
28 |
t2_name = "_t2.nii.gz" |
|
|
29 |
mask_name = "_seg.nii.gz" |
|
|
30 |
out_mask_name = "_outseg.nii.gz" |
|
|
31 |
# step1 init vnet model |
|
|
32 |
depth_z = 48 |
|
|
33 |
Vnet3d = Vnet3dModuleMultiLabel(240, 240, depth_z, channels=channel, numclass=numclass, |
|
|
34 |
costname=("categorical_dice",), inference=True, |
|
|
35 |
model_path="log\segmeation2mm\weighted_categorical_crossentropy\model\Vnet3d.pd-10000") |
|
|
36 |
brats_path = "D:\Data\\brats18\\test" |
|
|
37 |
# step2 get all test image path |
|
|
38 |
dice_values0 = [] |
|
|
39 |
dice_values1 = [] |
|
|
40 |
dice_values2 = [] |
|
|
41 |
dice_values3 = [] |
|
|
42 |
path_list = file_name_path(brats_path) |
|
|
43 |
# step3 get test image(4 model) and mask |
|
|
44 |
for subsetindex in range(len(path_list)): |
|
|
45 |
# step4 load test image(4 model) and mask as ndarray |
|
|
46 |
brats_subset_path = brats_path + "/" + str(path_list[subsetindex]) + "/" |
|
|
47 |
flair_image = brats_subset_path + str(path_list[subsetindex]) + flair_name |
|
|
48 |
t1_image = brats_subset_path + str(path_list[subsetindex]) + t1_name |
|
|
49 |
t1ce_image = brats_subset_path + str(path_list[subsetindex]) + t1ce_name |
|
|
50 |
t2_image = brats_subset_path + str(path_list[subsetindex]) + t2_name |
|
|
51 |
mask_image = brats_subset_path + str(path_list[subsetindex]) + mask_name |
|
|
52 |
flair_src = sitk.ReadImage(flair_image, sitk.sitkInt16) |
|
|
53 |
t1_src = sitk.ReadImage(t1_image, sitk.sitkInt16) |
|
|
54 |
t1ce_src = sitk.ReadImage(t1ce_image, sitk.sitkInt16) |
|
|
55 |
t2_src = sitk.ReadImage(t2_image, sitk.sitkInt16) |
|
|
56 |
mask = sitk.ReadImage(mask_image, sitk.sitkUInt8) |
|
|
57 |
flair_array = sitk.GetArrayFromImage(flair_src) |
|
|
58 |
t1_array = sitk.GetArrayFromImage(t1_src) |
|
|
59 |
t1ce_array = sitk.GetArrayFromImage(t1ce_src) |
|
|
60 |
t2_array = sitk.GetArrayFromImage(t2_src) |
|
|
61 |
label = sitk.GetArrayFromImage(mask) |
|
|
62 |
# step5 mormazalation test image(4 model) and merage to 4 channels ndarray |
|
|
63 |
flair_array = normalize(flair_array) |
|
|
64 |
t1_array = normalize(t1_array) |
|
|
65 |
t1ce_array = normalize(t1ce_array) |
|
|
66 |
t2_array = normalize(t2_array) |
|
|
67 |
|
|
|
68 |
imagez, height, width = np.shape(flair_array)[0], np.shape(flair_array)[1], np.shape(flair_array)[2] |
|
|
69 |
fourmodelimagearray = np.zeros((imagez, height, width, channel), np.float) |
|
|
70 |
fourmodelimagearray[:, :, :, 0] = flair_array |
|
|
71 |
fourmodelimagearray[:, :, :, 1] = t1_array |
|
|
72 |
fourmodelimagearray[:, :, :, 2] = t1ce_array |
|
|
73 |
fourmodelimagearray[:, :, :, 3] = t2_array |
|
|
74 |
ys_pd_array = np.zeros((imagez, height, width), np.uint8) |
|
|
75 |
# step6 predict test image(4 model) |
|
|
76 |
last_depth = 0 |
|
|
77 |
for depth in range(0, imagez // depth_z, 1): |
|
|
78 |
patch_xs = fourmodelimagearray[depth * depth_z:(depth + 1) * depth_z, :, :, :] |
|
|
79 |
pathc_pd = Vnet3d.prediction(patch_xs) |
|
|
80 |
ys_pd_array[depth * depth_z:(depth + 1) * depth_z, :, :] = pathc_pd |
|
|
81 |
last_depth = depth |
|
|
82 |
if imagez != depth_z * last_depth: |
|
|
83 |
patch_xs = fourmodelimagearray[(imagez - depth_z):imagez, :, :, :] |
|
|
84 |
pathc_pd = Vnet3d.prediction(patch_xs) |
|
|
85 |
ys_pd_array[(imagez - depth_z):imagez, :, :] = pathc_pd |
|
|
86 |
|
|
|
87 |
ys_pd_array = np.clip(ys_pd_array, 0, 255).astype('uint8') |
|
|
88 |
all_ys_pd_array = ys_pd_array.copy() |
|
|
89 |
all_ys_pd_array[ys_pd_array != 0] = 1 |
|
|
90 |
outmask = getLargestConnectedCompont(sitk.GetImageFromArray(all_ys_pd_array)) |
|
|
91 |
ys_pd_array[outmask == 0] = 0 |
|
|
92 |
# step7 calcu test mask and predict mask dice value |
|
|
93 |
batch_ys = label.copy() |
|
|
94 |
batch_ys[label == 4] = 3 |
|
|
95 |
dice_value0 = 0 |
|
|
96 |
dice_value1 = 0 |
|
|
97 |
dice_value2 = 0 |
|
|
98 |
dice_value3 = 0 |
|
|
99 |
for num_class in range(4): |
|
|
100 |
ys_pd_array_tmp = ys_pd_array.copy() |
|
|
101 |
batch_ys_tmp = batch_ys.copy() |
|
|
102 |
ys_pd_array_tmp[ys_pd_array == num_class] = 1 |
|
|
103 |
batch_ys_tmp[label == num_class] = 1 |
|
|
104 |
if num_class == 0: |
|
|
105 |
dice_value0 = calcu_dice(ys_pd_array_tmp, batch_ys_tmp, 1) |
|
|
106 |
if num_class == 1: |
|
|
107 |
dice_value1 = calcu_dice(ys_pd_array_tmp, batch_ys_tmp, 1) |
|
|
108 |
if num_class == 2: |
|
|
109 |
dice_value2 = calcu_dice(ys_pd_array_tmp, batch_ys_tmp, 1) |
|
|
110 |
if num_class == 3: |
|
|
111 |
dice_value3 = calcu_dice(ys_pd_array_tmp, batch_ys_tmp, 1) |
|
|
112 |
print("index,dice:", (subsetindex, dice_value0, dice_value1, dice_value2, dice_value3)) |
|
|
113 |
dice_values0.append(dice_value0) |
|
|
114 |
dice_values1.append(dice_value1) |
|
|
115 |
dice_values2.append(dice_value2) |
|
|
116 |
dice_values3.append(dice_value3) |
|
|
117 |
# step8 out put predict mask |
|
|
118 |
ys_pd_array = ys_pd_array.astype('float') |
|
|
119 |
outputmask = np.zeros((imagez, height, width), np.uint8) |
|
|
120 |
outputmask[ys_pd_array == 1] = 1 |
|
|
121 |
outputmask[ys_pd_array == 2] = 2 |
|
|
122 |
outputmask[ys_pd_array == 3] = 4 |
|
|
123 |
ys_pd_itk = sitk.GetImageFromArray(outputmask) |
|
|
124 |
ys_pd_itk.SetSpacing(mask.GetSpacing()) |
|
|
125 |
ys_pd_itk.SetOrigin(mask.GetOrigin()) |
|
|
126 |
ys_pd_itk.SetDirection(mask.GetDirection()) |
|
|
127 |
out_mask_image = brats_subset_path + str(path_list[subsetindex]) + out_mask_name |
|
|
128 |
sitk.WriteImage(ys_pd_itk, out_mask_image) |
|
|
129 |
average0 = sum(dice_values0) / len(dice_values0) |
|
|
130 |
average1 = sum(dice_values1) / len(dice_values1) |
|
|
131 |
average2 = sum(dice_values2) / len(dice_values2) |
|
|
132 |
average3 = sum(dice_values3) / len(dice_values3) |
|
|
133 |
print("average dice:", (average0, average1, average2, average3)) |
|
|
134 |
|
|
|
135 |
|
|
|
136 |
inference() |