[96354c]: / src / compute_metric_results.py

Download this file

98 lines (71 with data), 3.6 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import csv
import sys
import os
import numpy as np
from src.config import BratsConfiguration
from src.dataset.utils import dataset, nifi_volume as nifi_utils
from src.dataset import brats_labels
from src.metrics import evaluation_metrics as eval
from tqdm import tqdm
def compute(volume_pred, volume_gt, roi_mask):
tp, fp, tn, fn = eval.get_confusion_matrix(volume_pred, volume_gt, roi_mask)
if len(np.unique(volume_pred)) == 1 and len(np.unique(volume_gt)) == 1 and np.unique(volume_pred)[0] == 0 and np.unique(volume_gt)[0] == 0:
print("There is no tumor for this region")
dc = 1.0
hd = 0.0
else:
dc = eval.dice(tp, fp, fn)
hd = eval.hausdorff(volume_pred, volume_gt)
recall = eval.recall(tp, fn)
precision = eval.precision(tp, fp)
acc = eval.accuracy(tp, fp, tn, fn)
f1 = eval.fscore(tp, fp, tn ,fn)
return dc, hd, recall, precision, acc, f1, (tp, fp, tn, fn)
def compute_wt_tc_et(prediction, reference, flair):
metrics = []
roi_mask = dataset.create_roi_mask(flair)
for typee in ["wt", "tc", "et"]:
if typee == "wt":
volume_gt = brats_labels.get_wt(reference)
volume_pred = brats_labels.get_wt(prediction)
elif typee == "tc":
volume_gt = brats_labels.get_tc(reference)
volume_pred = brats_labels.get_tc(prediction)
elif typee == "et":
volume_gt = brats_labels.get_et(reference)
volume_pred = brats_labels.get_et(prediction)
if len(np.unique(volume_gt)) == 1 and np.unique(volume_gt)[0] == 0:
print("there is no enchancing tumor region in the ground truth")
else:
continue
dc, hd, recall, precision, acc, f1, conf_matrix = compute(volume_pred, volume_gt, roi_mask)
metrics.extend([dc, hd, recall, precision, f1])
return metrics
if __name__ == "__main__":
config = BratsConfiguration(sys.argv[1])
model_config = config.get_model_config()
dataset_config = config.get_dataset_config()
basic_config = config.get_basic_config()
data_train, data_test = dataset.read_brats(dataset_config.get("train_csv"))
data = data_test
with open(f"results_test.csv", "w") as file:
writer = csv.writer(file)
writer.writerow(["subject_ID", "Grade", "Center", "Size",
"Dice WT", "HD WT", "Recall WT", "Precision WT", "F1 WT",
"Dice TC", "HD TC", "Recall TC", "Precision TC", "F1 TC",
"Dice ET", "HD ET", "Recall ET", "Precision ET", "F1 ET"
])
for patient in tqdm(data, total=len(data)):
patient_data = []
gt_path = os.path.join(patient.data_path, patient.patient, f"{patient.seg}")
data_path = os.path.join(patient.data_path, patient.patient, f"{patient.flair}")
prediction_path = os.path.join(patient.data_path, patient.patient, f"{patient.patient}_prediction.nii.gz")
if not os.path.exists(prediction_path):
print(f"{prediction_path} not found")
continue
patient_data.extend([patient.patient, patient.grade, patient.center, patient.size])
volume_gt_all, _ = nifi_utils.load_nifi_volume(gt_path)
volume_pred_all, _ = nifi_utils.load_nifi_volume(prediction_path)
volume, _ = nifi_utils.load_nifi_volume(data_path)
patient_data = compute_wt_tc_et(volume_pred_all, volume_gt_all, volume)
writer.writerow(patient_data)