a b/evaluate.py
1
#evaluate.py
2
#Copyright (c) 2020 Rachel Lea Ballantyne Draelos
3
4
#MIT License
5
6
#Permission is hereby granted, free of charge, to any person obtaining a copy
7
#of this software and associated documentation files (the "Software"), to deal
8
#in the Software without restriction, including without limitation the rights
9
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
#copies of the Software, and to permit persons to whom the Software is
11
#furnished to do so, subject to the following conditions:
12
13
#The above copyright notice and this permission notice shall be included in all
14
#copies or substantial portions of the Software.
15
16
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
#SOFTWARE
23
24
#Imports
25
import os
26
import copy
27
import time
28
import bisect
29
import shutil
30
import operator
31
import itertools
32
import numpy as np
33
import pandas as pd
34
import sklearn.metrics
35
from scipy import interp
36
from itertools import cycle
37
38
import matplotlib
39
matplotlib.use('agg') #so that it does not attempt to display via SSH
40
import seaborn
41
import matplotlib.pyplot as plt
42
plt.ioff() #turn interactive plotting off
43
44
#suppress numpy warnings
45
import warnings
46
warnings.filterwarnings('ignore')
47
48
#######################
49
# Reporting Functions #---------------------------------------------------------
50
#######################
51
def initialize_evaluation_dfs(all_labels, num_epochs):
52
    """Create empty "eval_dfs_dict"
53
    Variables
54
    <all_labels>: a list of strings describing the labels in order
55
    <num_epochs>: int for total number of epochs"""
56
    if len(all_labels)==2:
57
        index = [all_labels[1]]
58
        numrows = 1
59
    else:
60
        index = all_labels
61
        numrows = len(all_labels)
62
    #Initialize empty pandas dataframe to store evaluation results across epochs
63
    #for accuracy, AUROC, and AP
64
    result_df = pd.DataFrame(data=np.zeros((numrows, num_epochs)),
65
                            index = index,
66
                            columns = ['epoch_'+str(n) for n in range(0,num_epochs)])
67
    #Initialize empty pandas dataframe to store evaluation results for top k
68
    top_k_result_df = pd.DataFrame(np.zeros((len(all_labels), num_epochs)),
69
                                   index=[x for x in range(1,len(all_labels)+1)], #e.g. 1,...,64 for len(all_labels)=64
70
                                   columns = ['epoch_'+str(n) for n in range(0,num_epochs)])
71
    
72
    #Make eval results dictionaries
73
    eval_results_valid = {'accuracy':copy.deepcopy(result_df),
74
        'auroc':copy.deepcopy(result_df),
75
        'avg_precision':copy.deepcopy(result_df),
76
        'top_k':top_k_result_df}
77
    eval_results_test = copy.deepcopy(eval_results_valid)
78
    return eval_results_valid, eval_results_test
79
80
def save(eval_dfs_dict, results_dir, descriptor):
81
    """Variables
82
    <eval_dfs_dict> is a dict of pandas dataframes
83
    <descriptor> is a string"""
84
    for k in eval_dfs_dict.keys():
85
        eval_dfs_dict[k].to_csv(os.path.join(results_dir, descriptor+'_'+k+'_Table.csv'))
86
    
87
def save_final_summary(eval_dfs_dict, best_valid_epoch, setname, results_dir):
88
    """Save to overall df and print summary of best epoch."""
89
    #final_descriptor is e.g. '2019-11-15-awesome-model_epoch15
90
    final_descriptor = results_dir.replace('results/','')+'_epoch'+str(best_valid_epoch)
91
    if setname=='valid': print('***Summary for',setname,results_dir,'***')
92
    for metricname in list(eval_dfs_dict.keys()):
93
        #metricnames are accuracy, auroc, avg_precision, and top_k.
94
        #df holds a particular metric for the particular model we just ran.
95
        #for accuracy, auroc, and avg_precision, df index is diseases, columns are epochs.
96
        #for top_k, df index is the k value (an int) and columns are epochs.
97
        df = eval_dfs_dict[metricname]
98
        #all_df tracks results of all models in one giant table.
99
        #all_df has index of diseases or k value, and columns which are particular models.
100
        all_df_path = os.path.join('results',setname+'_'+metricname+'_all.csv') #e.g. valid_accuracy_all.csv
101
        if os.path.isfile(all_df_path):
102
            all_df = pd.read_csv(all_df_path,header=0,index_col=0)
103
            all_df[final_descriptor] = np.nan
104
        else: #all_df doesn't exist yet - create it.
105
            all_df = pd.DataFrame(np.empty((df.shape[0],1)),
106
                                  index = df.index.values.tolist(),
107
                                  columns = [final_descriptor])
108
        #Print off and save results for best_valid_epoch
109
        if setname=='valid': print('\tEpoch',best_valid_epoch,metricname)
110
        for label in df.index.values:
111
            #print off to console
112
            value = df.at[label,'epoch_'+str(best_valid_epoch)]
113
            if setname=='valid': print('\t\t',label,':',str( round(value, 3) ))
114
            #save in all_df
115
            all_df.at[label,final_descriptor] = value
116
        all_df.to_csv(all_df_path,header=True,index=True)
117
118
def clean_up_output_files(best_valid_epoch, results_dir):
119
    """Delete output files that aren't from the best epoch"""
120
    #Delete all the backup parameters (they take a lot of space and you do not
121
    #need to have them)
122
    shutil.rmtree(os.path.join(results_dir,'backup'))
123
    #Delete all the extra output files:
124
    for subdir in ['heatmaps','curves','pred_probs']:
125
        #Clean up saved ROC and PR curves
126
        fullpath = os.path.join(results_dir,subdir)
127
        if os.path.exists(fullpath): #e.g. there may not be a heatmaps dir for a non-bottleneck model
128
            allfiles = os.listdir(fullpath)
129
            for filename in allfiles:
130
                if str(best_valid_epoch) not in filename:
131
                    os.remove(os.path.join(fullpath,filename))
132
    print('Output files all clean')
133
    
134
#########################
135
# Calculation Functions #-------------------------------------------------------
136
#########################        
137
def evaluate_all(eval_dfs_dict, epoch, label_meanings,
138
                 true_labels_array, pred_probs_array):
139
    """Fill out the pandas dataframes in the dictionary <eval_dfs_dict>
140
    which is created in cnn.py. <epoch> and <which_label> are used to index into
141
    the dataframe for the metric. Metrics calculated for the provided vectors
142
    are: accuracy, AUC, partial AUC (threshold 0.2), and average precision.
143
    If <subjective> is set to True, additional metrics will be calculated
144
    (confusion matrix, sensitivity, specificity, PPV, NPV.)
145
    
146
    Variables:
147
    <all_eval_results> is a dictionary of pandas dataframes created in cnn.py
148
    <epoch> is an integer indicating which epoch it is, starting from epoch 1
149
    <true_labels_array>: array of true labels. examples x labels
150
    <pred_probs_array>: array of predicted probabilities. examples x labels"""
151
    #Accuracy, AUROC, and AP (iter over labels)
152
    for label_number in range(len(label_meanings)):
153
        which_label = label_meanings[label_number] #descriptive string for the label
154
        true_labels = true_labels_array[:,label_number]
155
        pred_probs = pred_probs_array[:,label_number]
156
        pred_labels = (pred_probs>=0.5).astype(dtype='int') #decision threshold of 0.5
157
        
158
        #Accuracy and confusion matrix (dependent on decision threshold)
159
        (eval_dfs_dict['accuracy']).at[which_label, 'epoch_'+str(epoch)] = compute_accuracy(true_labels, pred_labels)
160
        #confusion_matrix, sensitivity, specificity, ppv, npv = compute_confusion_matrix(true_labels, pred_labels)
161
        
162
        #AUROC and AP (sliding across multiple decision thresholds)
163
        fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_true = true_labels,
164
                                         y_score = pred_probs,
165
                                         pos_label = 1)
166
        (eval_dfs_dict['auroc']).at[which_label, 'epoch_'+str(epoch)] = sklearn.metrics.auc(fpr, tpr)
167
        (eval_dfs_dict['avg_precision']).at[which_label, 'epoch_'+str(epoch)] = sklearn.metrics.average_precision_score(true_labels, pred_probs)
168
    
169
    #Top k eval metrics (iter over examples)
170
    eval_dfs_dict['top_k'] = evaluate_top_k(eval_dfs_dict['top_k'],
171
                                    epoch, true_labels_array, pred_probs_array)
172
    return eval_dfs_dict
173
174
#################
175
# Top K Metrics #---------------------------------------------------------------
176
#################
177
def evaluate_top_k(eval_top_k_df, epoch, true_labels_array,
178
                   pred_probs_array):
179
    """<eval_top_k_df> is a pandas dataframe with epoch number as columns and
180
        k values as rows, where k is an integer"""
181
    num_labels = true_labels_array.shape[1] #e.g. 64
182
    total_examples = true_labels_array.shape[0]
183
    vals = [0 for x in range(1,num_labels+2)] #e.g. length 65 list but the index of the last element is 64 for num_labels=64
184
    for example_number in range(total_examples):
185
        #iterate through individual examples (predictions for an individual CT)
186
        #rather than iterating through predicted labels
187
        true_labels = true_labels_array[example_number,:]
188
        pred_probs = pred_probs_array[example_number,:]
189
        for k in range(1,num_labels+1): #e.g. 1,...,64
190
            previous_value = vals[k]
191
            incremental_update = calculate_top_k_accuracy(true_labels, pred_probs, k)
192
            new_value = previous_value + incremental_update
193
            vals[k] = new_value
194
    #Now update the dataframe. Should reach 100% performance by the end.
195
    for k in range(1,num_labels+1):
196
        eval_top_k_df.at[k,'epoch_'+str(epoch)] = vals[k]/total_examples
197
    
198
    ##Now average over all the examples
199
    #eval_top_k_df.loc[:,'epoch_'+str(epoch)] = eval_top_k_df.loc[:,'epoch_'+str(epoch)] / total_examples
200
    return eval_top_k_df
201
202
def calculate_top_k_accuracy(true_labels, pred_probs, k):
203
    k = min(k, len(true_labels)) #avoid accessing array elements that don't exist
204
    #argpartition described here: https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array
205
    #get the indices of the largest k probabilities
206
    ind = np.argpartition(pred_probs, -1*k)[-1*k:]
207
    #now figure out what percent of these top predictions were equal to 1 in the
208
    #true_labels.
209
    #Note that the denominator should not exceed the number of true labels, to
210
    #avoid penalizing the model inappropriately:
211
    denom = min(k, np.sum(true_labels))
212
    if denom == 0: #because np.sum(true_labels) is 0
213
        #super important! must return 1 to avoid dividing by 0 and producing nan
214
        #we don't return 0 because then the model can never get perfect performance
215
        #even at k=num_labels because it'll get 0 for anything that has no labels
216
        return 1 
217
    else:
218
        return float(np.sum(true_labels[ind]))/denom
219
220
######################
221
# Accuracy and AUROC #----------------------------------------------------------
222
######################
223
def compute_accuracy(true_labels, labels_pred):
224
    """Print and save the accuracy of the model on the dataset"""    
225
    correct = (true_labels == labels_pred)
226
    correct_sum = correct.sum()
227
    return (float(correct_sum)/len(true_labels))
228
229
def compute_confusion_matrix(true_labels, labels_pred):
230
    """Return the confusion matrix"""
231
    cm = sklearn.metrics.confusion_matrix(y_true=true_labels,
232
                          y_pred=labels_pred)
233
    if cm.size < 4: #cm is too small to calculate anything
234
        return np.nan, np.nan, np.nan, np.nan, np.nan
235
    true_neg, false_pos, false_neg, true_pos = cm.ravel()
236
    sensitivity = float(true_pos)/(true_pos + false_neg)
237
    specificity = float(true_neg)/(true_neg + false_pos)
238
    ppv = float(true_pos)/(true_pos + false_pos)
239
    npv = float(true_neg)/(true_neg + false_neg)
240
    
241
    return((str(cm).replace("\n","_")), sensitivity, specificity, ppv, npv)
242
243
def compute_partial_auroc(fpr, tpr, thresh = 0.2, trapezoid = False, verbose=False):
244
    fpr_thresh, tpr_thresh = get_fpr_tpr_for_thresh(fpr, tpr, thresh)
245
    if len(fpr_thresh) < 2:#can't calculate an AUC with only 1 data point
246
        return np.nan 
247
    if verbose:
248
        print('fpr: '+str(fpr))
249
        print('fpr_thresh: '+str(fpr_thresh))
250
        print('tpr: '+str(tpr))
251
        print('tpr_thresh: '+str(tpr_thresh))
252
    return sklearn.metrics.auc(fpr_thresh, tpr_thresh)
253
254
def get_fpr_tpr_for_thresh(fpr, tpr, thresh):
255
    """The <fpr> and <tpr> are already sorted according to threshold (which is
256
    sorted from highest to lowest, and is NOT the same as <thresh>; threshold
257
    is the third output of sklearn.metrics.roc_curve and is a vector of the
258
    thresholds used to calculate FPR and TPR). This function figures out where
259
    to bisect the FPR so that the remaining elements are no greater than
260
    <thresh>. It bisects the TPR in the same place."""
261
    p = (bisect.bisect_left(fpr, thresh)-1) #subtract one so that the FPR
262
    #of the remaining elements is NO GREATER THAN <thresh>
263
    return fpr[: p + 1], tpr[: p + 1]
264
265
######################
266
# Plotting Functions #----------------------------------------------------------
267
######################
268
def plot_pr_and_roc_curves(results_dir, label_meanings, true_labels, pred_probs,
269
                           epoch):
270
    #Plot Precision Recall Curve
271
    
272
    #Plot ROC Curve
273
    fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_true = true_labels,
274
                                         y_score = pred_probs,
275
                                         pos_label = 1)
276
    plot_roc_curve(fpr, tpr, epoch, outfilepath)
277
278
279
def plot_roc_curve_multi_class(label_meanings, y_test, y_score, 
280
                               outdir, setname, epoch):
281
    """<label_meanings>: list of strings, one for each label
282
    <y_test>: matrix of ground truth
283
    <y_score>: matrix of predicted probabilities
284
    <outdir>: directory to save output file
285
    <setname>: string e.g. 'train' 'valid' or 'test'
286
    <epoch>: int for epoch"""
287
    #Modified from https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
288
    n_classes = len(label_meanings)
289
    lw = 2
290
    
291
    # Compute ROC curve and ROC area for each class
292
    fpr = dict()
293
    tpr = dict()
294
    roc_auc = dict()
295
    for i in range(n_classes):
296
        fpr[i], tpr[i], _ = sklearn.metrics.roc_curve(y_test[:, i], y_score[:, i])
297
        roc_auc[i] = sklearn.metrics.auc(fpr[i], tpr[i])
298
    
299
    #make order df. (note that roc_auc is a dictionary with ints as keys
300
    #and AUCs as values. 
301
    order = pd.DataFrame(np.zeros((n_classes,1)), index = [x for x in range(n_classes)],
302
                         columns = ['roc_auc'])
303
    for i in range(n_classes):
304
        order.at[i,'roc_auc'] = roc_auc[i]
305
    order = order.sort_values(by='roc_auc',ascending=False)
306
    
307
    #Plot all ROC curves
308
    #Plot in order of the rainbow colors, from highest AUC to lowest AUC
309
    plt.figure()
310
    colors_list = ['palevioletred','darkorange','yellowgreen','olive','deepskyblue','royalblue','navy']
311
    curves_plotted = 0
312
    for i in order.index.values.tolist()[0:10]: #only plot the top ten so the plot is readable
313
        color_idx = curves_plotted%len(colors_list) #cycle through the colors list in order of colors
314
        color = colors_list[color_idx]
315
        plt.plot(fpr[i], tpr[i], color=color, lw=lw,
316
                 label='{:5s} (area {:0.2f})'.format(label_meanings[i], roc_auc[i]))
317
        curves_plotted+=1
318
    
319
    plt.plot([0, 1], [0, 1], 'k--', lw=lw)
320
    plt.xlim([0.0, 1.0])
321
    plt.ylim([0.0, 1.05])
322
    plt.xlabel('False Positive Rate')
323
    plt.ylabel('True Positive Rate')
324
    plt.title(setname.lower().capitalize()+' ROC Epoch '+str(epoch))
325
    plt.legend(loc="lower right",prop={'size':6})
326
    outfilepath = os.path.join(outdir,setname+'_ROC_ep'+str(epoch)+'.pdf')
327
    plt.savefig(outfilepath)
328
    plt.close()
329
330
def plot_pr_curve_multi_class(label_meanings, y_test, y_score, 
331
                              outdir, setname, epoch):
332
    """<label_meanings>: list of strings, one for each label
333
    <y_test>: matrix of ground truth
334
    <y_score>: matrix of predicted probabilities
335
    <outdir>: directory to save output file
336
    <setname>: string e.g. 'train' 'valid' or 'test'
337
    <epoch>: int for epoch"""
338
    #Modified from https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
339
    #https://stackoverflow.com/questions/29656550/how-to-plot-pr-curve-over-10-folds-of-cross-validation-in-scikit-learn
340
    n_classes = len(label_meanings)
341
    lw = 2
342
    
343
    #make order df.
344
    order = pd.DataFrame(np.zeros((n_classes,1)), index = [x for x in range(n_classes)],
345
                         columns = ['prc'])
346
    for i in range(n_classes):
347
        order.at[i,'prc'] = sklearn.metrics.average_precision_score(y_test[:,i], y_score[:,i])
348
    order = order.sort_values(by='prc',ascending=False)
349
    
350
    #Plot
351
    plt.figure()
352
    colors_list = ['palevioletred','darkorange','yellowgreen','olive','deepskyblue','royalblue','navy']
353
    curves_plotted = 0
354
    for i in order.index.values.tolist()[0:10]: #only plot the top ten so the plot is readable
355
        color_idx = curves_plotted%len(colors_list) #cycle through the colors list in order of colors
356
        color = colors_list[color_idx]
357
        average_precision = sklearn.metrics.average_precision_score(y_test[:,i], y_score[:,i])
358
        precision, recall, _ = sklearn.metrics.precision_recall_curve(y_test[:,i], y_score[:,i])
359
        plt.step(recall, precision, color=color, where='post',
360
                 label='{:5s} (area {:0.2f})'.format(label_meanings[i], average_precision))
361
        curves_plotted+=1
362
    plt.xlabel('Recall')
363
    plt.ylabel('Precision')
364
    plt.ylim([0.0, 1.05])
365
    plt.xlim([0.0, 1.0])
366
    plt.title(setname.lower().capitalize()+' PRC Epoch '+str(epoch))
367
    plt.legend(loc="lower right",prop={'size':6})
368
    outfilepath = os.path.join(outdir,setname+'_PR_ep'+str(epoch)+'.pdf')
369
    plt.savefig(outfilepath)
370
    plt.close()
371
372
def plot_pr_curve_single_class(true_labels, pred_probs, outfilepath):
373
    #http://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html
374
    average_precision = sklearn.metrics.average_precision_score(true_labels, pred_probs)
375
    precision, recall, _ = sklearn.metrics.precision_recall_curve(true_labels, pred_probs)
376
    plt.step(recall, precision, color='b', alpha=0.2,
377
             where='post')
378
    plt.fill_between(recall, precision, step='post', alpha=0.2,
379
                     color='b')
380
    plt.xlabel('Recall')
381
    plt.ylabel('Precision')
382
    plt.ylim([0.0, 1.05])
383
    plt.xlim([0.0, 1.0])
384
    plt.title('2-class Precision-Recall curve: AP={0:0.2f}'.format(
385
              average_precision))
386
    plt.savefig(outfilepath)
387
    plt.close()
388
389
def plot_roc_curve_single_class(fpr, tpr, epoch, outfilepath):
390
    #http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html#sphx-glr-auto-examples-model-selection-plot-roc-py
391
    roc_auc = sklearn.metrics.auc(fpr, tpr)
392
    lw = 2
393
    plt.plot(fpr, tpr, color='darkorange',
394
             lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
395
    plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
396
    plt.xlim([0.0, 1.0])
397
    plt.ylim([0.0, 1.05])
398
    plt.xlabel('False Positive Rate')
399
    plt.ylabel('True Positive Rate')
400
    plt.title('Receiver Operating Characteristic')
401
    plt.legend(loc="lower right")
402
    plt.savefig(outputfilepath)
403
    plt.close()
404
405
def plot_learning_curves(train_loss, valid_loss, results_dir, descriptor):
406
    """Variables
407
    <train_loss> and <valid_loss> are numpy arrays with one numerical entry
408
    for each epoch quanitfying the loss for that epoch."""
409
    x = np.arange(0,len(train_loss))
410
    plt.plot(x, train_loss, color='blue', lw=2, label='train')
411
    plt.plot(x, valid_loss, color='green',lw = 2, label='valid')
412
    plt.xlabel('Epoch')
413
    plt.ylabel('Loss')
414
    plt.title('Training and Validation Loss')
415
    plt.legend(loc='lower right')
416
    plt.savefig(os.path.join(results_dir, descriptor+'_Learning_Curves.png'))
417
    plt.close()
418
    #save numpy arrays of the losses
419
    np.save(os.path.join(results_dir,'train_loss.npy'),train_loss)
420
    np.save(os.path.join(results_dir,'valid_loss.npy'),valid_loss)
421
422
def plot_heatmap(outprefix, numeric_array, center, xticklabels, yticklabels):
423
    """Save a heatmap based on numeric_array"""
424
    seaborn.set(font_scale=0.6)
425
    seaplt = (seaborn.heatmap(numeric_array,
426
                           center=center,
427
                           xticklabels=xticklabels,
428
                           yticklabels=yticklabels)).get_figure()
429
    seaplt.savefig(outprefix+'.png')
430
    seaplt.clf()