Diff of /utils/metrics.py [000000] .. [4cda31]

Switch to unified view

a b/utils/metrics.py
1
# Manuel A. Morales (moralesq@mit.edu)
2
# Harvard-MIT Department of Health Sciences & Technology  
3
# Athinoula A. Martinos Center for Biomedical Imaging
4
5
import numpy as np
6
import pandas as pd
7
from medpy.metric.binary import hd, dc
8
9
def get_geometric_metrics(M_gt, M_pred, voxelspacing, 
10
                          tissue_labels=[1, 2, 3], tissue_label_names=['RV','LVM','LV'], phase=0):
11
    """Calculate the Dice Similarity Coefficient and Hausdorff distance. 
12
    """
13
14
    Dice        = []
15
    Hausdorff   = []
16
    TissueClass = []
17
    for label in tissue_labels:
18
        TissueClass += [tissue_label_names[label-1]]
19
        
20
        gt_label = np.copy(M_gt)
21
        gt_label[gt_label != label] = 0
22
23
        pred_label = np.copy(M_pred)
24
        pred_label[pred_label != label] = 0
25
26
        gt_label   = np.clip(gt_label, 0, 1)
27
        pred_label = np.clip(pred_label, 0, 1)
28
29
        dice      = dc(gt_label, pred_label)
30
        hausdorff = hd(gt_label, pred_label, voxelspacing=voxelspacing)
31
        
32
        Dice.append(dice)
33
        Hausdorff.append(hausdorff)
34
        
35
    output = {'DSC':Dice,'HD':Hausdorff,'TissueClass':TissueClass,'Phase':[phase]*len(tissue_labels)} 
36
    return pd.DataFrame(output)
37
        
38
def get_volume_ml(M, voxel_spacing_mm, tissue_label=1):
39
40
    voxel_vol_cm3 = np.prod(voxel_spacing_mm) / 1000 
41
    volume_ml = (M==tissue_label).sum()*voxel_vol_cm3
42
43
    return volume_ml
44
45
def get_mass_g(M, voxel_spacing_mm, tissue_label=2, tissue_density_g_per_ml=1.05):
46
    volume_ml = get_volume_ml(M, voxel_spacing_mm, tissue_label=tissue_label)
47
    mass_g    = volume_ml * tissue_density_g_per_ml
48
    return mass_g
49
50
def get_volumes_ml_and_ef(M_ed, M_es, voxel_spacing_mm, tissue_label=1):
51
    EDV_ml = get_volume_ml(M_ed, voxel_spacing_mm, tissue_label=tissue_label) 
52
    ESV_ml = get_volume_ml(M_es, voxel_spacing_mm, tissue_label=tissue_label)
53
    EF     = (EDV_ml-ESV_ml)/EDV_ml
54
    return EDV_ml, ESV_ml, EF*100
55
56
def get_clinical_parameters_rv(M_ed, M_es, voxel_spacing_mm):
57
    RV_EDV_ml, RV_ESV_ml, RV_EF = get_volumes_ml_and_ef(M_ed, M_es, voxel_spacing_mm, tissue_label=1)
58
    return RV_EDV_ml, RV_ESV_ml, RV_EF
59
60
def get_clinical_parameters_lv(M_ed, M_es, voxel_spacing_mm):
61
    LV_EDV_ml, LV_ESV_ml, LV_EF = get_volumes_ml_and_ef(M_ed, M_es, voxel_spacing_mm, tissue_label=3)
62
    LV_mass_g = get_mass_g(M_ed, voxel_spacing_mm, tissue_label=2)
63
    return LV_EDV_ml, LV_ESV_ml, LV_EF, LV_mass_g
64
65
def get_clinical_parameters(M_ed, M_es, voxel_spacing_mm):
66
    """Generate left- and right-ventricular parameters using a mask of the myocardium. Mask values should be: 
67
68
    0 background 
69
    1 right-ventricular blood pool 
70
    2 left-ventricular myocardium 
71
    3 left-ventricular blood pool 
72
73
    Input
74
    -----
75
    M_ed             : 3D array containing binary labels for the myocardium at end-diastole. 
76
    M_es             : 3D array containing binary labels for the myocardium at end-systole. 
77
    voxel_spacing_mm : tuple containing spatial resolution of 3D volume in mm, i.e., (dx, dy, dz)
78
79
    Output
80
    ------
81
    clinical_parameters : dictionary of cardiac parameters
82
83
    Example
84
    -------
85
    import nibabel as nib
86
    from aux import metrics 
87
88
    >>> # load cine image and corresonding segmentation, each of shape (nx, ny, nz, nt)
89
    >>> V_nifti = nib.load('sample.nii.gz')
90
    >>> M_nifti = nib.load('sample_segmentation.nii.gz')
91
92
    >>> # get the 4D array of size (nx,ny,nz,nt)
93
    >>> V = V_nifti.get_fdata()
94
    >>> M = M_nifti.get_fdata()
95
96
    >>> # get the spatial resolution 
97
    >>> resolution = M_nifti.header.get_zooms()[:3]
98
    >>> print(M.shape, resolution)
99
    (256, 256, 7, 30) (1.40625, 1.40625, 8.0)
100
    >>> # Assume diastole is at t=0, systole at t=10
101
    >>> print(M[...,0].shape)
102
    (256, 256, 7)
103
    >>> params = metrics.get_clinical_parameters(M_ed=M[...,0], M_es=M[...,10], voxel_spacing_mm=resolution)
104
    >>> print(params)
105
    {'RV_EDV_ml': 69.1189453125, 'RV_ESV_ml': 17.370703125, 'RV_EF': 74.868390936141, 
106
     'LV_EDV_ml': 51.526757812499994, 'LV_ESV_ml': 25.5181640625, 'LV_EF': 50.47589806570464, 'LV_mass_g': 73.903798828125}
107
    >>> print(params['LV_EDV_ml'])
108
    51.526757812499994
109
    """
110
    #print(M_ed.shape)
111
    RV_EDV_ml, RV_ESV_ml, RV_EF = get_clinical_parameters_rv(M_ed, M_es, voxel_spacing_mm)
112
    LV_EDV_ml, LV_ESV_ml, LV_EF, LV_mass_g = get_clinical_parameters_lv(M_ed, M_es, voxel_spacing_mm)
113
    clinical_parameters = {'RV_EDV_ml':RV_EDV_ml, 'RV_ESV_ml':RV_ESV_ml, 'RV_EF':RV_EF, 
114
                           'LV_EDV_ml':LV_EDV_ml, 'LV_ESV_ml':LV_ESV_ml, 'LV_EF':LV_EF, 'LV_mass_g':LV_mass_g}
115
    return clinical_parameters
116
117
## Stats
118
119
def clinical_metrics_statistics(x, y):
120
    """Calculate correlation (corr), bias, standard deviation (std), mean absolute error between x and y measures. 
121
    
122
    Bias: The bias between the two tests is measured by the mean of the differences. 
123
    std : The standard deviation (also known as limits of agreement) between the two tests are defined by a 95% 
124
          prediction interval of a particular value of the difference.
125
    
126
    See: https://ncss-wpengine.netdna-ssl.com/wp-content/themes/ncss/pdf/Procedures/NCSS/Bland-Altman_Plot_and_Analysis.pdf
127
128
    """
129
    dk   = x-y
130
    bias = np.mean(dk)
131
    std  = np.sqrt(np.sum((dk-bias)**2)/(len(x)-1))
132
    mae  = np.mean(np.abs(dk))
133
    return bias, std, mae, x.corrwith(y)
134
135
def get_clinical_metrics_on_dataloader(loading_fn, listSIDs, ED_ids=0, ES_ids=1):
136
    """Calculate clinical metrics on data loader function `loading_fn` for subjects in `listSIDs`.
137
    Assumes end-diastole and end-systole time frame = `end_diastolic_frame_id`, `end_systolic_frame_id`.
138
    """
139
    Clinical_params_pred = pd.DataFrame({'RV_EDV_ml':[], 'RV_ESV_ml':[], 'RV_EF':[], 
140
                                         'LV_EDV_ml':[], 'LV_ESV_ml':[], 'LV_EF':[], 'LV_mass_g':[]})
141
    if type(ED_ids) == int: ED_ids = [ED_ids] * len(listSIDs)
142
    if type(ES_ids) == int: ES_ids = [ES_ids] * len(listSIDs)    
143
    for subject_id, ED_id, ES_id in zip(listSIDs, ED_ids, ES_ids):
144
        V, M_pred_ed, affine, zooms = loading_fn(subject_id, ED_id)
145
        V, M_pred_es, affine, zooms = loading_fn(subject_id, ES_id)
146
147
        clinical_params_pred = get_clinical_parameters(np.argmax(M_pred_ed,-1), 
148
                                                               np.argmax(M_pred_es,-1), 
149
                                                               voxel_spacing_mm=zooms[:3])
150
151
        Clinical_params_pred = Clinical_params_pred.append(clinical_params_pred,ignore_index=True)
152
153
    Clinical_params_pred.index = pd.Index(listSIDs, name='SubjectID') 
154
155
    return Clinical_params_pred
156
157
def compare_clinical_metrics_on_dataloader(loading_fn, listSIDs, ED_ids=0, ES_ids=1):
158
    """Calculate clinical metrics on data loader function `loading_fn` for subjects in `listSIDs`.
159
    Assumes end-diastole and end-systole time frame = `end_diastolic_frame_id`, `end_systolic_frame_id`.
160
    """
161
    Clinical_params_gt   = pd.DataFrame({'RV_EDV_ml':[], 'RV_ESV_ml':[], 'RV_EF':[], 
162
                                         'LV_EDV_ml':[], 'LV_ESV_ml':[], 'LV_EF':[], 'LV_mass_g':[]})
163
    Clinical_params_pred = Clinical_params_gt.copy()
164
    
165
    if type(ED_ids) == int: ED_ids = [ED_ids] * len(listSIDs)
166
    if type(ES_ids) == int: ES_ids = [ES_ids] * len(listSIDs)    
167
    for subject_id, ED_id, ES_id in zip(listSIDs, ED_ids, ES_ids):
168
        V, M_ed, M_pred_ed, affine, zooms = loading_fn(subject_id, ED_id)
169
        V, M_es, M_pred_es, affine, zooms = loading_fn(subject_id, ES_id)
170
        
171
        clinical_params_gt   = get_clinical_parameters(np.argmax(M_ed,-1), 
172
                                                               np.argmax(M_es,-1), 
173
                                                               voxel_spacing_mm=zooms[:3])
174
        clinical_params_pred = get_clinical_parameters(np.argmax(M_pred_ed,-1), 
175
                                                               np.argmax(M_pred_es,-1), 
176
                                                               voxel_spacing_mm=zooms[:3])
177
178
        Clinical_params_gt   = Clinical_params_gt.append(clinical_params_gt,ignore_index=True)
179
        Clinical_params_pred = Clinical_params_pred.append(clinical_params_pred,ignore_index=True)
180
181
    Clinical_params_gt.index   = pd.Index(listSIDs, name='SubjectID') 
182
    Clinical_params_pred.index = pd.Index(listSIDs, name='SubjectID') 
183
    
184
    stats_df = clinical_metrics_statistics(Clinical_params_gt,Clinical_params_pred)
185
    stats_df = pd.DataFrame(stats_df,index=['bias','std','MAE','corr']).T[['corr','bias','std','MAE']]
186
187
    return Clinical_params_gt, Clinical_params_pred, stats_df
188
189
190
def compare_geometric_metrics_on_dataloader(loading_fn, listSIDs, listTimeFrames,
191
                                    tissue_labels=[1, 2, 3], tissue_label_names=['RV','LVM','LV']):
192
    """Calculate geometric metrics on data loader function `loading_fn` for subjects in `listSIDs`.
193
    Metrics are calculated for all frames in `listTimeFrames`.
194
    """
195
    Geometric_params = pd.DataFrame({'DSC':[],'HD':[],'TissueClass':[], 'Phase':[]})    
196
    for subject_id in listSIDs:
197
        for time_frame in listTimeFrames:
198
            V, M, M_pred, affine, zooms = loading_fn(subject_id, time_frame)
199
        
200
            # GEOMETRIC METRICS
201
            geometric_metrics = get_geometric_metrics(np.argmax(M,-1), np.argmax(M_pred,-1), 
202
                                                      voxelspacing=zooms[:3], phase=time_frame,
203
                                                      tissue_labels=tissue_labels,tissue_label_names=tissue_label_names)
204
205
            Geometric_params = Geometric_params.append(geometric_metrics, ignore_index=True)
206
        
207
    Geometric_params.index = pd.Index(np.repeat(listSIDs, len(tissue_labels)*len(listTimeFrames)), name='SubjectID') 
208
    return Geometric_params
209
210
211