Switch to unified view

a b/segmentation/generate_lesion_measures.py
1
#%%
2
'''
3
Copyright (c) Microsoft Corporation. All rights reserved.
4
Licensed under the MIT License.
5
'''
6
import pandas as pd 
7
import numpy as np
8
import SimpleITK as sitk 
9
import os 
10
from glob import glob
11
import sys
12
import argparse
13
config_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
14
sys.path.append(config_dir)
15
from config import RESULTS_FOLDER
16
from metrics.metrics import *
17
18
def get_spacing_from_niftipath(path):
19
    spacing = sitk.ReadImage(path).GetSpacing()
20
    return spacing
21
22
23
def main(args):
24
    fold = args.fold
25
    network = args.network_name
26
    inputsize = args.input_patch_size
27
    experiment_code = f"{network}_fold{fold}_randcrop{inputsize}"
28
    preddir = os.path.join(RESULTS_FOLDER, 'predictions', f'fold{fold}', network, experiment_code)
29
    predpaths = sorted(glob(os.path.join(preddir, '*.nii.gz')))
30
    gtpaths = sorted(list(pd.read_csv('./../data_split/test_filepaths.csv')['GTPATH']))
31
    ptpaths = sorted(list(pd.read_csv('./../data_split/test_filepaths.csv')['PTPATH'])) # PET image paths (ptpaths) for calculating the detection metrics using criterion3 
32
    
33
    imageids = [os.path.basename(path)[:-7] for path in gtpaths]
34
    DSC = [] 
35
    SUVmean_orig, SUVmean_pred = [], []
36
    SUVmax_orig, SUVmax_pred = [], [] 
37
    LesionCount_orig, LesionCount_pred = [], [] 
38
    TMTV_orig, TMTV_pred = [], []
39
    TLG_orig, TLG_pred = [], []
40
    Dmax_orig, Dmax_pred = [], []
41
    
42
    for i in range(len(gtpaths)):
43
        ptpath = ptpaths[i]
44
        gtpath = gtpaths[i]
45
        predpath = predpaths[i]
46
        
47
        ptarray = get_3darray_from_niftipath(ptpath)
48
        gtarray = get_3darray_from_niftipath(gtpath)
49
        predarray = get_3darray_from_niftipath(predpath)
50
        spacing = get_spacing_from_niftipath(gtpath)
51
52
        # Dice score between mask gt and pred
53
        dsc = calculate_patient_level_dice_score(gtarray, predarray)
54
        # Lesion SUVmean
55
        suvmean_orig = calculate_patient_level_lesion_suvmean_suvmax(ptarray, gtarray, marker='SUVmean')
56
        suvmean_pred = calculate_patient_level_lesion_suvmean_suvmax(ptarray, predarray, marker='SUVmean')
57
        # Lesion SUVmax
58
        suvmax_orig = calculate_patient_level_lesion_suvmean_suvmax(ptarray, gtarray, marker='SUVmax')
59
        suvmax_pred = calculate_patient_level_lesion_suvmean_suvmax(ptarray, predarray, marker='SUVmax')
60
        # Lesion Count 
61
        lesioncount_orig = calculate_patient_level_lesion_count(gtarray)
62
        lesioncount_pred = calculate_patient_level_lesion_count(predarray)
63
        # TMTV
64
        tmtv_orig = calculate_patient_level_tmtv(gtarray, spacing)
65
        tmtv_pred = calculate_patient_level_tmtv(predarray, spacing)
66
        # TLG
67
        tlg_orig = calculate_patient_level_tlg(ptarray, gtarray, spacing)
68
        tlg_pred = calculate_patient_level_tlg(ptarray, predarray, spacing)
69
        # Dmax
70
        dmax_orig = calculate_patient_level_dissemination(gtarray, spacing)
71
        dmax_pred = calculate_patient_level_dissemination(predarray, spacing)
72
        
73
        DSC.append(dsc)
74
        SUVmean_orig.append(suvmean_orig)
75
        SUVmean_pred.append(suvmean_pred)
76
        SUVmax_orig.append(suvmax_orig)
77
        SUVmax_pred.append(suvmax_pred)
78
        LesionCount_orig.append(lesioncount_orig)
79
        LesionCount_pred.append(lesioncount_pred)
80
        TMTV_orig.append(tmtv_orig)
81
        TMTV_pred.append(tmtv_pred)
82
        TLG_orig.append(tlg_orig)
83
        TLG_pred.append(tlg_pred)
84
        Dmax_orig.append(dmax_orig)
85
        Dmax_pred.append(dmax_pred)
86
        
87
        
88
        print(f"{i}: {imageids[i]}")
89
        print(f"Dice Score: {round(dsc,4)}")
90
        print(f"SUVmean: GT: {suvmean_orig}, Pred: {suvmean_pred}")
91
        print(f"SUVmax: GT: {suvmax_orig}, Pred: {suvmax_pred}")
92
        print(f"LesionCount: GT: {lesioncount_orig}, Pred: {lesioncount_pred}")
93
        print(f"TMTV: GT: {tmtv_orig} ml, Pred: {tmtv_pred} ml")
94
        print(f"TLG: GT: {tlg_orig} ml, Pred: {tlg_pred} ml")
95
        print(f"Dmax: GT: {dmax_orig} cm, Pred: {dmax_pred} cm")
96
        print("\n")
97
98
    save_lesionmeasures_dir = os.path.join(RESULTS_FOLDER, f'test_lesion_measures', 'fold'+str(fold), network, experiment_code)
99
    os.makedirs(save_lesionmeasures_dir, exist_ok=True)
100
    filepath = os.path.join(save_lesionmeasures_dir, f'testlesionmeasures.csv')
101
    
102
    data = np.column_stack(
103
            [
104
                imageids,
105
                DSC,
106
                SUVmean_orig,
107
                SUVmean_pred,
108
                SUVmax_orig,
109
                SUVmax_pred,
110
                LesionCount_orig,
111
                LesionCount_pred,
112
                TMTV_orig,
113
                TMTV_pred,
114
                TLG_orig,
115
                TLG_pred,
116
                Dmax_orig,
117
                Dmax_pred
118
            ]
119
        )
120
121
    data_df = pd.DataFrame(
122
        data=data,
123
        columns=[
124
            'PatientID',
125
            'DSC',
126
            'SUVmean_orig',
127
            'SUVmean_pred',
128
            'SUVmax_orig',
129
            'SUVmax_pred',
130
            'LesionCount_orig',
131
            'LesionCount_pred',
132
            'TMTV_orig',
133
            'TMTV_pred',
134
            'TLG_orig',
135
            'TLG_pred',
136
            'Dmax_orig',
137
            'Dmax_pred'
138
        ]
139
    )
140
    data_df.to_csv(filepath, index=False)
141
        
142
143
if __name__ == "__main__":  
144
    parser = argparse.ArgumentParser(description='Lymphoma PET/CT lesion segmentation using MONAI-PyTorch')
145
    parser.add_argument('--fold', type=int, default=0, metavar='fold',
146
                        help='validation fold (default: 0), remaining folds will be used for training')
147
    parser.add_argument('--network-name', type=str, default='unet', metavar='netname',
148
                        help='network name for training (default: unet)')
149
    parser.add_argument('--input-patch-size', type=int, default=192, metavar='inputsize',
150
                        help='size of cropped input patch for training (default: 192)')
151
    args = parser.parse_args()
152
    main(args)
153