Switch to unified view

a b/segmentation/calculate_test_metrics.py
1
#%%
2
'''
3
Copyright (c) Microsoft Corporation. All rights reserved.
4
Licensed under the MIT License.
5
'''
6
import numpy as np 
7
import pandas as pd 
8
import SimpleITK as sitk 
9
import os 
10
from glob import glob 
11
import sys 
12
import argparse
13
config_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
14
sys.path.append(config_dir)
15
from config import RESULTS_FOLDER
16
from metrics.metrics import (
17
    get_3darray_from_niftipath,
18
    calculate_patient_level_dice_score,
19
    calculate_patient_level_false_positive_volume,
20
    calculate_patient_level_false_negative_volume,
21
    calculate_patient_level_tp_fp_fn
22
)
23
24
def get_spacing_from_niftipath(path):
25
    image = sitk.ReadImage(path)
26
    return image.GetSpacing()
27
28
def get_column_statistics(col):
29
    mean = col.mean()
30
    std = col.std()
31
    median = col.median()
32
    quantile25 = col.quantile(q=0.25)
33
    quantile75 = col.quantile(q=0.75)
34
    return (mean, std, median, quantile25, quantile75)
35
36
def get_prediction_statistics(data_df):
37
    dsc_stats = get_column_statistics(data_df['DSC'].astype(float))
38
    fpv_stats = get_column_statistics(data_df['FPV'].astype(float))
39
    fnv_stats = get_column_statistics(data_df['FNV'].astype(float))
40
    
41
    c1_sensitivity = data_df[f'TP_C1']/(data_df[f'TP_C1'] + data_df[f'FN_C1'])
42
    c2_sensitivity = data_df[f'TP_C2']/(data_df[f'TP_C2'] + data_df[f'FN_C2'])
43
    c3_sensitivity = data_df[f'TP_C3']/(data_df[f'TP_C3'] + data_df[f'FN_C3'])
44
    sens_c1_stats = get_column_statistics(c1_sensitivity)
45
    sens_c2_stats = get_column_statistics(c2_sensitivity)
46
    sens_c3_stats = get_column_statistics(c3_sensitivity)
47
    
48
    fp_c1_stats = get_column_statistics(data_df['FP_M1'].astype(float))
49
    fp_c2_stats = get_column_statistics(data_df['FP_M2'].astype(float))
50
    fp_c3_stats = get_column_statistics(data_df['FP_M3'].astype(float))
51
    
52
    dsc_stats = [round(d, 2) for d in dsc_stats]
53
    fpv_stats = [round(d, 2) for d in fpv_stats]
54
    fnv_stats = [round(d, 2) for d in fnv_stats]
55
    sens_c1_stats = [round(d, 2) for d in sens_c1_stats]
56
    sens_c2_stats = [round(d, 2) for d in sens_c2_stats]
57
    sens_c3_stats = [round(d, 2) for d in sens_c3_stats]
58
    fp_c1_stats = [round(d, 0) for d in fp_c1_stats]
59
    fp_c2_stats = [round(d, 0) for d in fp_c2_stats]
60
    fp_c3_stats = [round(d, 0) for d in fp_c3_stats]
61
62
    print(f"DSC (Mean): {dsc_stats[0]} +/- {dsc_stats[1]}")
63
    print(f"DSC (Median): {dsc_stats[2]} [{dsc_stats[3]}, {dsc_stats[4]}]")
64
    print(f"FPV (Median): {fpv_stats[2]} [{fpv_stats[3]}, {fpv_stats[4]}]")
65
    print(f"FNV (Median): {fnv_stats[2]} [{fnv_stats[3]}, {fnv_stats[4]}]")
66
    print(f"Sensitivity - Criterion1 (Median): {sens_c1_stats[2]} [{sens_c1_stats[3]}, {sens_c1_stats[4]}]")
67
    print(f"FP - Criterion1 (Median): {fp_c1_stats[2]} [{fp_c1_stats[3]}, {fp_c1_stats[4]}]")
68
    print(f"Sensitivity - Criterion2 (Median): {sens_c2_stats[2]} [{sens_c2_stats[3]}, {sens_c2_stats[4]}]")
69
    print(f"FP - Criterion1 (Median): {fp_c2_stats[2]} [{fp_c2_stats[3]}, {fp_c2_stats[4]}]")
70
    print(f"Sensitivity - Criterion3 (Median): {sens_c3_stats[2]} [{sens_c3_stats[3]}, {sens_c3_stats[4]}]")
71
    print(f"FP - Criterion3 (Median): {fp_c3_stats[2]} [{fp_c3_stats[3]}, {fp_c3_stats[4]}]")
72
    print('\n')
73
    
74
#%%
75
def main(args):
76
    fold = args.fold
77
    network = args.network_name
78
    inputsize = args.input_patch_size
79
    experiment_code = f"{network}_fold{fold}_randcrop{inputsize}"
80
    preddir = os.path.join(RESULTS_FOLDER, 'predictions', f'fold{fold}', network, experiment_code)
81
    predpaths = sorted(glob(os.path.join(preddir, '*.nii.gz')))
82
    gtpaths = sorted(list(pd.read_csv('./../data_split/test_filepaths.csv')['GTPATH']))
83
    ptpaths = sorted(list(pd.read_csv('./../data_split/test_filepaths.csv')['PTPATH'])) # PET image paths (ptpaths) for calculating the detection metrics using criterion3 
84
    
85
    imageids = [os.path.basename(path)[:-7] for path in gtpaths]
86
    TEST_DSCs, TEST_FPVs, TEST_FNVs = [], [], []
87
    TEST_TP_criterion1, TEST_FP_criterion1, TEST_FN_criterion1 = [], [], []
88
    TEST_TP_criterion2, TEST_FP_criterion2, TEST_FN_criterion2 = [], [], []
89
    TEST_TP_criterion3, TEST_FP_criterion3, TEST_FN_criterion3 = [], [], []
90
91
        
92
    for i in range(len(gtpaths)):
93
        gtpath = gtpaths[i]
94
        ptpath = ptpaths[i]
95
        predpath = predpaths[i]
96
97
        gtarray = get_3darray_from_niftipath(gtpath)
98
        ptarray = get_3darray_from_niftipath(ptpath)
99
        predarray = get_3darray_from_niftipath(predpath)
100
        spacing = get_spacing_from_niftipath(gtpath)
101
102
        dsc = calculate_patient_level_dice_score(gtarray, predarray)
103
        fpv = calculate_patient_level_false_positive_volume(gtarray, predarray, spacing)
104
        fnv = calculate_patient_level_false_negative_volume(gtarray, predarray, spacing)
105
        tp_c1, fp_c1, fn_c1 = calculate_patient_level_tp_fp_fn(gtarray, predarray, criterion='criterion1')
106
        tp_c2, fp_c2, fn_c2 = calculate_patient_level_tp_fp_fn(gtarray, predarray, criterion='criterion2', threshold=0.5)
107
        tp_c3, fp_c3, fn_c3 = calculate_patient_level_tp_fp_fn(gtarray, predarray, criterion='criterion3', ptarray=ptarray)
108
        
109
        TEST_DSCs.append(dsc)
110
        TEST_FPVs.append(fpv)
111
        TEST_FNVs.append(fnv)
112
        TEST_TP_criterion1.append(tp_c1)
113
        TEST_FP_criterion1.append(fp_c1)
114
        TEST_FN_criterion1.append(fn_c1)
115
        
116
        TEST_TP_criterion2.append(tp_c2)
117
        TEST_FP_criterion2.append(fp_c2)
118
        TEST_FN_criterion2.append(fn_c2)
119
        
120
        TEST_TP_criterion3.append(tp_c3)
121
        TEST_FP_criterion3.append(fp_c3)
122
        TEST_FN_criterion3.append(fn_c3)
123
        print(f"{imageids[i]}: DSC = {round(dsc, 4)}\nFPV = {round(fpv, 4)} ml\nFNV = {round(fnv, 4)} ml")
124
125
    save_testmetrics_dir = os.path.join(RESULTS_FOLDER, 'test_metrics', 'fold'+str(fold), network, experiment_code)
126
    os.makedirs(save_testmetrics_dir, exist_ok=True)
127
    save_testmetrics_fpath = os.path.join(save_testmetrics_dir, 'testmetrics.csv')
128
129
    data = np.column_stack(
130
        (
131
            imageids, TEST_DSCs, TEST_FPVs, TEST_FNVs,
132
            TEST_TP_criterion1, TEST_FP_criterion1, TEST_FN_criterion1,
133
            TEST_TP_criterion2, TEST_FP_criterion2, TEST_FN_criterion2,
134
            TEST_TP_criterion3, TEST_FP_criterion3, TEST_FN_criterion3
135
        )
136
    )
137
    column_names = [
138
        'PatientID', 'DSC', 'FPV', 'FNV',
139
        'TP_C1', 'FP_C1', 'FN_C1',
140
        'TP_C2', 'FP_C2', 'FN_C2',
141
        'TP_C3', 'FP_C3', 'FN_C3',
142
    ]
143
    data_df = pd.DataFrame(data=data, columns=column_names)
144
    data_df.to_csv(save_testmetrics_fpath, index=False)
145
    
146
    
147
148
    
149
if __name__ == "__main__":  
150
    parser = argparse.ArgumentParser(description='Lymphoma PET/CT lesion segmentation using MONAI-PyTorch')
151
    parser.add_argument('--fold', type=int, default=0, metavar='fold',
152
                        help='validation fold (default: 0), remaining folds will be used for training')
153
    parser.add_argument('--network-name', type=str, default='unet', metavar='netname',
154
                        help='network name for training (default: unet)')
155
    parser.add_argument('--input-patch-size', type=int, default=192, metavar='inputsize',
156
                        help='size of cropped input patch for training (default: 192)')
157
    args = parser.parse_args()
158
    main(args)
159
    
160
# %%