|
a |
|
b/segmentation/calculate_test_metrics.py |
|
|
1 |
#%% |
|
|
2 |
''' |
|
|
3 |
Copyright (c) Microsoft Corporation. All rights reserved. |
|
|
4 |
Licensed under the MIT License. |
|
|
5 |
''' |
|
|
6 |
import numpy as np |
|
|
7 |
import pandas as pd |
|
|
8 |
import SimpleITK as sitk |
|
|
9 |
import os |
|
|
10 |
from glob import glob |
|
|
11 |
import sys |
|
|
12 |
import argparse |
|
|
13 |
config_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") |
|
|
14 |
sys.path.append(config_dir) |
|
|
15 |
from config import RESULTS_FOLDER |
|
|
16 |
from metrics.metrics import ( |
|
|
17 |
get_3darray_from_niftipath, |
|
|
18 |
calculate_patient_level_dice_score, |
|
|
19 |
calculate_patient_level_false_positive_volume, |
|
|
20 |
calculate_patient_level_false_negative_volume, |
|
|
21 |
calculate_patient_level_tp_fp_fn |
|
|
22 |
) |
|
|
23 |
|
|
|
24 |
def get_spacing_from_niftipath(path): |
|
|
25 |
image = sitk.ReadImage(path) |
|
|
26 |
return image.GetSpacing() |
|
|
27 |
|
|
|
28 |
def get_column_statistics(col): |
|
|
29 |
mean = col.mean() |
|
|
30 |
std = col.std() |
|
|
31 |
median = col.median() |
|
|
32 |
quantile25 = col.quantile(q=0.25) |
|
|
33 |
quantile75 = col.quantile(q=0.75) |
|
|
34 |
return (mean, std, median, quantile25, quantile75) |
|
|
35 |
|
|
|
36 |
def get_prediction_statistics(data_df): |
|
|
37 |
dsc_stats = get_column_statistics(data_df['DSC'].astype(float)) |
|
|
38 |
fpv_stats = get_column_statistics(data_df['FPV'].astype(float)) |
|
|
39 |
fnv_stats = get_column_statistics(data_df['FNV'].astype(float)) |
|
|
40 |
|
|
|
41 |
c1_sensitivity = data_df[f'TP_C1']/(data_df[f'TP_C1'] + data_df[f'FN_C1']) |
|
|
42 |
c2_sensitivity = data_df[f'TP_C2']/(data_df[f'TP_C2'] + data_df[f'FN_C2']) |
|
|
43 |
c3_sensitivity = data_df[f'TP_C3']/(data_df[f'TP_C3'] + data_df[f'FN_C3']) |
|
|
44 |
sens_c1_stats = get_column_statistics(c1_sensitivity) |
|
|
45 |
sens_c2_stats = get_column_statistics(c2_sensitivity) |
|
|
46 |
sens_c3_stats = get_column_statistics(c3_sensitivity) |
|
|
47 |
|
|
|
48 |
fp_c1_stats = get_column_statistics(data_df['FP_M1'].astype(float)) |
|
|
49 |
fp_c2_stats = get_column_statistics(data_df['FP_M2'].astype(float)) |
|
|
50 |
fp_c3_stats = get_column_statistics(data_df['FP_M3'].astype(float)) |
|
|
51 |
|
|
|
52 |
dsc_stats = [round(d, 2) for d in dsc_stats] |
|
|
53 |
fpv_stats = [round(d, 2) for d in fpv_stats] |
|
|
54 |
fnv_stats = [round(d, 2) for d in fnv_stats] |
|
|
55 |
sens_c1_stats = [round(d, 2) for d in sens_c1_stats] |
|
|
56 |
sens_c2_stats = [round(d, 2) for d in sens_c2_stats] |
|
|
57 |
sens_c3_stats = [round(d, 2) for d in sens_c3_stats] |
|
|
58 |
fp_c1_stats = [round(d, 0) for d in fp_c1_stats] |
|
|
59 |
fp_c2_stats = [round(d, 0) for d in fp_c2_stats] |
|
|
60 |
fp_c3_stats = [round(d, 0) for d in fp_c3_stats] |
|
|
61 |
|
|
|
62 |
print(f"DSC (Mean): {dsc_stats[0]} +/- {dsc_stats[1]}") |
|
|
63 |
print(f"DSC (Median): {dsc_stats[2]} [{dsc_stats[3]}, {dsc_stats[4]}]") |
|
|
64 |
print(f"FPV (Median): {fpv_stats[2]} [{fpv_stats[3]}, {fpv_stats[4]}]") |
|
|
65 |
print(f"FNV (Median): {fnv_stats[2]} [{fnv_stats[3]}, {fnv_stats[4]}]") |
|
|
66 |
print(f"Sensitivity - Criterion1 (Median): {sens_c1_stats[2]} [{sens_c1_stats[3]}, {sens_c1_stats[4]}]") |
|
|
67 |
print(f"FP - Criterion1 (Median): {fp_c1_stats[2]} [{fp_c1_stats[3]}, {fp_c1_stats[4]}]") |
|
|
68 |
print(f"Sensitivity - Criterion2 (Median): {sens_c2_stats[2]} [{sens_c2_stats[3]}, {sens_c2_stats[4]}]") |
|
|
69 |
print(f"FP - Criterion1 (Median): {fp_c2_stats[2]} [{fp_c2_stats[3]}, {fp_c2_stats[4]}]") |
|
|
70 |
print(f"Sensitivity - Criterion3 (Median): {sens_c3_stats[2]} [{sens_c3_stats[3]}, {sens_c3_stats[4]}]") |
|
|
71 |
print(f"FP - Criterion3 (Median): {fp_c3_stats[2]} [{fp_c3_stats[3]}, {fp_c3_stats[4]}]") |
|
|
72 |
print('\n') |
|
|
73 |
|
|
|
74 |
#%% |
|
|
75 |
def main(args): |
|
|
76 |
fold = args.fold |
|
|
77 |
network = args.network_name |
|
|
78 |
inputsize = args.input_patch_size |
|
|
79 |
experiment_code = f"{network}_fold{fold}_randcrop{inputsize}" |
|
|
80 |
preddir = os.path.join(RESULTS_FOLDER, 'predictions', f'fold{fold}', network, experiment_code) |
|
|
81 |
predpaths = sorted(glob(os.path.join(preddir, '*.nii.gz'))) |
|
|
82 |
gtpaths = sorted(list(pd.read_csv('./../data_split/test_filepaths.csv')['GTPATH'])) |
|
|
83 |
ptpaths = sorted(list(pd.read_csv('./../data_split/test_filepaths.csv')['PTPATH'])) # PET image paths (ptpaths) for calculating the detection metrics using criterion3 |
|
|
84 |
|
|
|
85 |
imageids = [os.path.basename(path)[:-7] for path in gtpaths] |
|
|
86 |
TEST_DSCs, TEST_FPVs, TEST_FNVs = [], [], [] |
|
|
87 |
TEST_TP_criterion1, TEST_FP_criterion1, TEST_FN_criterion1 = [], [], [] |
|
|
88 |
TEST_TP_criterion2, TEST_FP_criterion2, TEST_FN_criterion2 = [], [], [] |
|
|
89 |
TEST_TP_criterion3, TEST_FP_criterion3, TEST_FN_criterion3 = [], [], [] |
|
|
90 |
|
|
|
91 |
|
|
|
92 |
for i in range(len(gtpaths)): |
|
|
93 |
gtpath = gtpaths[i] |
|
|
94 |
ptpath = ptpaths[i] |
|
|
95 |
predpath = predpaths[i] |
|
|
96 |
|
|
|
97 |
gtarray = get_3darray_from_niftipath(gtpath) |
|
|
98 |
ptarray = get_3darray_from_niftipath(ptpath) |
|
|
99 |
predarray = get_3darray_from_niftipath(predpath) |
|
|
100 |
spacing = get_spacing_from_niftipath(gtpath) |
|
|
101 |
|
|
|
102 |
dsc = calculate_patient_level_dice_score(gtarray, predarray) |
|
|
103 |
fpv = calculate_patient_level_false_positive_volume(gtarray, predarray, spacing) |
|
|
104 |
fnv = calculate_patient_level_false_negative_volume(gtarray, predarray, spacing) |
|
|
105 |
tp_c1, fp_c1, fn_c1 = calculate_patient_level_tp_fp_fn(gtarray, predarray, criterion='criterion1') |
|
|
106 |
tp_c2, fp_c2, fn_c2 = calculate_patient_level_tp_fp_fn(gtarray, predarray, criterion='criterion2', threshold=0.5) |
|
|
107 |
tp_c3, fp_c3, fn_c3 = calculate_patient_level_tp_fp_fn(gtarray, predarray, criterion='criterion3', ptarray=ptarray) |
|
|
108 |
|
|
|
109 |
TEST_DSCs.append(dsc) |
|
|
110 |
TEST_FPVs.append(fpv) |
|
|
111 |
TEST_FNVs.append(fnv) |
|
|
112 |
TEST_TP_criterion1.append(tp_c1) |
|
|
113 |
TEST_FP_criterion1.append(fp_c1) |
|
|
114 |
TEST_FN_criterion1.append(fn_c1) |
|
|
115 |
|
|
|
116 |
TEST_TP_criterion2.append(tp_c2) |
|
|
117 |
TEST_FP_criterion2.append(fp_c2) |
|
|
118 |
TEST_FN_criterion2.append(fn_c2) |
|
|
119 |
|
|
|
120 |
TEST_TP_criterion3.append(tp_c3) |
|
|
121 |
TEST_FP_criterion3.append(fp_c3) |
|
|
122 |
TEST_FN_criterion3.append(fn_c3) |
|
|
123 |
print(f"{imageids[i]}: DSC = {round(dsc, 4)}\nFPV = {round(fpv, 4)} ml\nFNV = {round(fnv, 4)} ml") |
|
|
124 |
|
|
|
125 |
save_testmetrics_dir = os.path.join(RESULTS_FOLDER, 'test_metrics', 'fold'+str(fold), network, experiment_code) |
|
|
126 |
os.makedirs(save_testmetrics_dir, exist_ok=True) |
|
|
127 |
save_testmetrics_fpath = os.path.join(save_testmetrics_dir, 'testmetrics.csv') |
|
|
128 |
|
|
|
129 |
data = np.column_stack( |
|
|
130 |
( |
|
|
131 |
imageids, TEST_DSCs, TEST_FPVs, TEST_FNVs, |
|
|
132 |
TEST_TP_criterion1, TEST_FP_criterion1, TEST_FN_criterion1, |
|
|
133 |
TEST_TP_criterion2, TEST_FP_criterion2, TEST_FN_criterion2, |
|
|
134 |
TEST_TP_criterion3, TEST_FP_criterion3, TEST_FN_criterion3 |
|
|
135 |
) |
|
|
136 |
) |
|
|
137 |
column_names = [ |
|
|
138 |
'PatientID', 'DSC', 'FPV', 'FNV', |
|
|
139 |
'TP_C1', 'FP_C1', 'FN_C1', |
|
|
140 |
'TP_C2', 'FP_C2', 'FN_C2', |
|
|
141 |
'TP_C3', 'FP_C3', 'FN_C3', |
|
|
142 |
] |
|
|
143 |
data_df = pd.DataFrame(data=data, columns=column_names) |
|
|
144 |
data_df.to_csv(save_testmetrics_fpath, index=False) |
|
|
145 |
|
|
|
146 |
|
|
|
147 |
|
|
|
148 |
|
|
|
149 |
if __name__ == "__main__": |
|
|
150 |
parser = argparse.ArgumentParser(description='Lymphoma PET/CT lesion segmentation using MONAI-PyTorch') |
|
|
151 |
parser.add_argument('--fold', type=int, default=0, metavar='fold', |
|
|
152 |
help='validation fold (default: 0), remaining folds will be used for training') |
|
|
153 |
parser.add_argument('--network-name', type=str, default='unet', metavar='netname', |
|
|
154 |
help='network name for training (default: unet)') |
|
|
155 |
parser.add_argument('--input-patch-size', type=int, default=192, metavar='inputsize', |
|
|
156 |
help='size of cropped input patch for training (default: 192)') |
|
|
157 |
args = parser.parse_args() |
|
|
158 |
main(args) |
|
|
159 |
|
|
|
160 |
# %% |