Switch to unified view

a b/BraTs18Challege/inference_Brats.py
1
from __future__ import print_function, division
2
import os
3
4
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
5
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
6
from tensorflow.python.client import device_lib
7
8
print(device_lib.list_local_devices())
9
10
from Vnet.model_vnet3d_multilabel import Vnet3dModuleMultiLabel
11
from Vnet.util import getLargestConnectedCompont
12
from dataprocess.utils import calcu_dice, file_name_path
13
from dataprocess.data3dprepare import normalize
14
import numpy as np
15
import SimpleITK as sitk
16
17
18
def inference():
19
    """
20
    Vnet network segmentation brats fine segmatation
21
    :return:
22
    """
23
    channel = 4
24
    numclass = 4
25
    flair_name = "_flair.nii.gz"
26
    t1_name = "_t1.nii.gz"
27
    t1ce_name = "_t1ce.nii.gz"
28
    t2_name = "_t2.nii.gz"
29
    mask_name = "_seg.nii.gz"
30
    out_mask_name = "_outseg.nii.gz"
31
    # step1 init vnet model
32
    depth_z = 48
33
    Vnet3d = Vnet3dModuleMultiLabel(240, 240, depth_z, channels=channel, numclass=numclass,
34
                                    costname=("categorical_dice",), inference=True,
35
                                    model_path="log\segmeation2mm\weighted_categorical_crossentropy\model\Vnet3d.pd-10000")
36
    brats_path = "D:\Data\\brats18\\test"
37
    # step2 get all test image path
38
    dice_values0 = []
39
    dice_values1 = []
40
    dice_values2 = []
41
    dice_values3 = []
42
    path_list = file_name_path(brats_path)
43
    # step3 get test image(4 model) and mask
44
    for subsetindex in range(len(path_list)):
45
        # step4 load test image(4 model) and mask as ndarray
46
        brats_subset_path = brats_path + "/" + str(path_list[subsetindex]) + "/"
47
        flair_image = brats_subset_path + str(path_list[subsetindex]) + flair_name
48
        t1_image = brats_subset_path + str(path_list[subsetindex]) + t1_name
49
        t1ce_image = brats_subset_path + str(path_list[subsetindex]) + t1ce_name
50
        t2_image = brats_subset_path + str(path_list[subsetindex]) + t2_name
51
        mask_image = brats_subset_path + str(path_list[subsetindex]) + mask_name
52
        flair_src = sitk.ReadImage(flair_image, sitk.sitkInt16)
53
        t1_src = sitk.ReadImage(t1_image, sitk.sitkInt16)
54
        t1ce_src = sitk.ReadImage(t1ce_image, sitk.sitkInt16)
55
        t2_src = sitk.ReadImage(t2_image, sitk.sitkInt16)
56
        mask = sitk.ReadImage(mask_image, sitk.sitkUInt8)
57
        flair_array = sitk.GetArrayFromImage(flair_src)
58
        t1_array = sitk.GetArrayFromImage(t1_src)
59
        t1ce_array = sitk.GetArrayFromImage(t1ce_src)
60
        t2_array = sitk.GetArrayFromImage(t2_src)
61
        label = sitk.GetArrayFromImage(mask)
62
        # step5 mormazalation test image(4 model) and merage to 4 channels ndarray
63
        flair_array = normalize(flair_array)
64
        t1_array = normalize(t1_array)
65
        t1ce_array = normalize(t1ce_array)
66
        t2_array = normalize(t2_array)
67
68
        imagez, height, width = np.shape(flair_array)[0], np.shape(flair_array)[1], np.shape(flair_array)[2]
69
        fourmodelimagearray = np.zeros((imagez, height, width, channel), np.float)
70
        fourmodelimagearray[:, :, :, 0] = flair_array
71
        fourmodelimagearray[:, :, :, 1] = t1_array
72
        fourmodelimagearray[:, :, :, 2] = t1ce_array
73
        fourmodelimagearray[:, :, :, 3] = t2_array
74
        ys_pd_array = np.zeros((imagez, height, width), np.uint8)
75
        # step6 predict test image(4 model)
76
        last_depth = 0
77
        for depth in range(0, imagez // depth_z, 1):
78
            patch_xs = fourmodelimagearray[depth * depth_z:(depth + 1) * depth_z, :, :, :]
79
            pathc_pd = Vnet3d.prediction(patch_xs)
80
            ys_pd_array[depth * depth_z:(depth + 1) * depth_z, :, :] = pathc_pd
81
            last_depth = depth
82
        if imagez != depth_z * last_depth:
83
            patch_xs = fourmodelimagearray[(imagez - depth_z):imagez, :, :, :]
84
            pathc_pd = Vnet3d.prediction(patch_xs)
85
            ys_pd_array[(imagez - depth_z):imagez, :, :] = pathc_pd
86
87
        ys_pd_array = np.clip(ys_pd_array, 0, 255).astype('uint8')
88
        all_ys_pd_array = ys_pd_array.copy()
89
        all_ys_pd_array[ys_pd_array != 0] = 1
90
        outmask = getLargestConnectedCompont(sitk.GetImageFromArray(all_ys_pd_array))
91
        ys_pd_array[outmask == 0] = 0
92
        # step7 calcu test mask and predict mask dice value
93
        batch_ys = label.copy()
94
        batch_ys[label == 4] = 3
95
        dice_value0 = 0
96
        dice_value1 = 0
97
        dice_value2 = 0
98
        dice_value3 = 0
99
        for num_class in range(4):
100
            ys_pd_array_tmp = ys_pd_array.copy()
101
            batch_ys_tmp = batch_ys.copy()
102
            ys_pd_array_tmp[ys_pd_array == num_class] = 1
103
            batch_ys_tmp[label == num_class] = 1
104
            if num_class == 0:
105
                dice_value0 = calcu_dice(ys_pd_array_tmp, batch_ys_tmp, 1)
106
            if num_class == 1:
107
                dice_value1 = calcu_dice(ys_pd_array_tmp, batch_ys_tmp, 1)
108
            if num_class == 2:
109
                dice_value2 = calcu_dice(ys_pd_array_tmp, batch_ys_tmp, 1)
110
            if num_class == 3:
111
                dice_value3 = calcu_dice(ys_pd_array_tmp, batch_ys_tmp, 1)
112
        print("index,dice:", (subsetindex, dice_value0, dice_value1, dice_value2, dice_value3))
113
        dice_values0.append(dice_value0)
114
        dice_values1.append(dice_value1)
115
        dice_values2.append(dice_value2)
116
        dice_values3.append(dice_value3)
117
        # step8 out put predict mask
118
        ys_pd_array = ys_pd_array.astype('float')
119
        outputmask = np.zeros((imagez, height, width), np.uint8)
120
        outputmask[ys_pd_array == 1] = 1
121
        outputmask[ys_pd_array == 2] = 2
122
        outputmask[ys_pd_array == 3] = 4
123
        ys_pd_itk = sitk.GetImageFromArray(outputmask)
124
        ys_pd_itk.SetSpacing(mask.GetSpacing())
125
        ys_pd_itk.SetOrigin(mask.GetOrigin())
126
        ys_pd_itk.SetDirection(mask.GetDirection())
127
        out_mask_image = brats_subset_path + str(path_list[subsetindex]) + out_mask_name
128
        sitk.WriteImage(ys_pd_itk, out_mask_image)
129
    average0 = sum(dice_values0) / len(dice_values0)
130
    average1 = sum(dice_values1) / len(dice_values1)
131
    average2 = sum(dice_values2) / len(dice_values2)
132
    average3 = sum(dice_values3) / len(dice_values3)
133
    print("average dice:", (average0, average1, average2, average3))
134
135
136
inference()