|
a |
|
b/evaluate.py |
|
|
1 |
#evaluate.py |
|
|
2 |
#Copyright (c) 2020 Rachel Lea Ballantyne Draelos |
|
|
3 |
|
|
|
4 |
#MIT License |
|
|
5 |
|
|
|
6 |
#Permission is hereby granted, free of charge, to any person obtaining a copy |
|
|
7 |
#of this software and associated documentation files (the "Software"), to deal |
|
|
8 |
#in the Software without restriction, including without limitation the rights |
|
|
9 |
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|
|
10 |
#copies of the Software, and to permit persons to whom the Software is |
|
|
11 |
#furnished to do so, subject to the following conditions: |
|
|
12 |
|
|
|
13 |
#The above copyright notice and this permission notice shall be included in all |
|
|
14 |
#copies or substantial portions of the Software. |
|
|
15 |
|
|
|
16 |
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
|
17 |
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
|
18 |
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
|
19 |
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
|
20 |
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|
|
21 |
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|
|
22 |
#SOFTWARE |
|
|
23 |
|
|
|
24 |
#Imports |
|
|
25 |
import os |
|
|
26 |
import copy |
|
|
27 |
import time |
|
|
28 |
import bisect |
|
|
29 |
import shutil |
|
|
30 |
import operator |
|
|
31 |
import itertools |
|
|
32 |
import numpy as np |
|
|
33 |
import pandas as pd |
|
|
34 |
import sklearn.metrics |
|
|
35 |
from scipy import interp |
|
|
36 |
from itertools import cycle |
|
|
37 |
|
|
|
38 |
import matplotlib |
|
|
39 |
matplotlib.use('agg') #so that it does not attempt to display via SSH |
|
|
40 |
import seaborn |
|
|
41 |
import matplotlib.pyplot as plt |
|
|
42 |
plt.ioff() #turn interactive plotting off |
|
|
43 |
|
|
|
44 |
#suppress numpy warnings |
|
|
45 |
import warnings |
|
|
46 |
warnings.filterwarnings('ignore') |
|
|
47 |
|
|
|
48 |
####################### |
|
|
49 |
# Reporting Functions #--------------------------------------------------------- |
|
|
50 |
####################### |
|
|
51 |
def initialize_evaluation_dfs(all_labels, num_epochs): |
|
|
52 |
"""Create empty "eval_dfs_dict" |
|
|
53 |
Variables |
|
|
54 |
<all_labels>: a list of strings describing the labels in order |
|
|
55 |
<num_epochs>: int for total number of epochs""" |
|
|
56 |
if len(all_labels)==2: |
|
|
57 |
index = [all_labels[1]] |
|
|
58 |
numrows = 1 |
|
|
59 |
else: |
|
|
60 |
index = all_labels |
|
|
61 |
numrows = len(all_labels) |
|
|
62 |
#Initialize empty pandas dataframe to store evaluation results across epochs |
|
|
63 |
#for accuracy, AUROC, and AP |
|
|
64 |
result_df = pd.DataFrame(data=np.zeros((numrows, num_epochs)), |
|
|
65 |
index = index, |
|
|
66 |
columns = ['epoch_'+str(n) for n in range(0,num_epochs)]) |
|
|
67 |
#Initialize empty pandas dataframe to store evaluation results for top k |
|
|
68 |
top_k_result_df = pd.DataFrame(np.zeros((len(all_labels), num_epochs)), |
|
|
69 |
index=[x for x in range(1,len(all_labels)+1)], #e.g. 1,...,64 for len(all_labels)=64 |
|
|
70 |
columns = ['epoch_'+str(n) for n in range(0,num_epochs)]) |
|
|
71 |
|
|
|
72 |
#Make eval results dictionaries |
|
|
73 |
eval_results_valid = {'accuracy':copy.deepcopy(result_df), |
|
|
74 |
'auroc':copy.deepcopy(result_df), |
|
|
75 |
'avg_precision':copy.deepcopy(result_df), |
|
|
76 |
'top_k':top_k_result_df} |
|
|
77 |
eval_results_test = copy.deepcopy(eval_results_valid) |
|
|
78 |
return eval_results_valid, eval_results_test |
|
|
79 |
|
|
|
80 |
def save(eval_dfs_dict, results_dir, descriptor): |
|
|
81 |
"""Variables |
|
|
82 |
<eval_dfs_dict> is a dict of pandas dataframes |
|
|
83 |
<descriptor> is a string""" |
|
|
84 |
for k in eval_dfs_dict.keys(): |
|
|
85 |
eval_dfs_dict[k].to_csv(os.path.join(results_dir, descriptor+'_'+k+'_Table.csv')) |
|
|
86 |
|
|
|
87 |
def save_final_summary(eval_dfs_dict, best_valid_epoch, setname, results_dir): |
|
|
88 |
"""Save to overall df and print summary of best epoch.""" |
|
|
89 |
#final_descriptor is e.g. '2019-11-15-awesome-model_epoch15 |
|
|
90 |
final_descriptor = results_dir.replace('results/','')+'_epoch'+str(best_valid_epoch) |
|
|
91 |
if setname=='valid': print('***Summary for',setname,results_dir,'***') |
|
|
92 |
for metricname in list(eval_dfs_dict.keys()): |
|
|
93 |
#metricnames are accuracy, auroc, avg_precision, and top_k. |
|
|
94 |
#df holds a particular metric for the particular model we just ran. |
|
|
95 |
#for accuracy, auroc, and avg_precision, df index is diseases, columns are epochs. |
|
|
96 |
#for top_k, df index is the k value (an int) and columns are epochs. |
|
|
97 |
df = eval_dfs_dict[metricname] |
|
|
98 |
#all_df tracks results of all models in one giant table. |
|
|
99 |
#all_df has index of diseases or k value, and columns which are particular models. |
|
|
100 |
all_df_path = os.path.join('results',setname+'_'+metricname+'_all.csv') #e.g. valid_accuracy_all.csv |
|
|
101 |
if os.path.isfile(all_df_path): |
|
|
102 |
all_df = pd.read_csv(all_df_path,header=0,index_col=0) |
|
|
103 |
all_df[final_descriptor] = np.nan |
|
|
104 |
else: #all_df doesn't exist yet - create it. |
|
|
105 |
all_df = pd.DataFrame(np.empty((df.shape[0],1)), |
|
|
106 |
index = df.index.values.tolist(), |
|
|
107 |
columns = [final_descriptor]) |
|
|
108 |
#Print off and save results for best_valid_epoch |
|
|
109 |
if setname=='valid': print('\tEpoch',best_valid_epoch,metricname) |
|
|
110 |
for label in df.index.values: |
|
|
111 |
#print off to console |
|
|
112 |
value = df.at[label,'epoch_'+str(best_valid_epoch)] |
|
|
113 |
if setname=='valid': print('\t\t',label,':',str( round(value, 3) )) |
|
|
114 |
#save in all_df |
|
|
115 |
all_df.at[label,final_descriptor] = value |
|
|
116 |
all_df.to_csv(all_df_path,header=True,index=True) |
|
|
117 |
|
|
|
118 |
def clean_up_output_files(best_valid_epoch, results_dir): |
|
|
119 |
"""Delete output files that aren't from the best epoch""" |
|
|
120 |
#Delete all the backup parameters (they take a lot of space and you do not |
|
|
121 |
#need to have them) |
|
|
122 |
shutil.rmtree(os.path.join(results_dir,'backup')) |
|
|
123 |
#Delete all the extra output files: |
|
|
124 |
for subdir in ['heatmaps','curves','pred_probs']: |
|
|
125 |
#Clean up saved ROC and PR curves |
|
|
126 |
fullpath = os.path.join(results_dir,subdir) |
|
|
127 |
if os.path.exists(fullpath): #e.g. there may not be a heatmaps dir for a non-bottleneck model |
|
|
128 |
allfiles = os.listdir(fullpath) |
|
|
129 |
for filename in allfiles: |
|
|
130 |
if str(best_valid_epoch) not in filename: |
|
|
131 |
os.remove(os.path.join(fullpath,filename)) |
|
|
132 |
print('Output files all clean') |
|
|
133 |
|
|
|
134 |
######################### |
|
|
135 |
# Calculation Functions #------------------------------------------------------- |
|
|
136 |
######################### |
|
|
137 |
def evaluate_all(eval_dfs_dict, epoch, label_meanings, |
|
|
138 |
true_labels_array, pred_probs_array): |
|
|
139 |
"""Fill out the pandas dataframes in the dictionary <eval_dfs_dict> |
|
|
140 |
which is created in cnn.py. <epoch> and <which_label> are used to index into |
|
|
141 |
the dataframe for the metric. Metrics calculated for the provided vectors |
|
|
142 |
are: accuracy, AUC, partial AUC (threshold 0.2), and average precision. |
|
|
143 |
If <subjective> is set to True, additional metrics will be calculated |
|
|
144 |
(confusion matrix, sensitivity, specificity, PPV, NPV.) |
|
|
145 |
|
|
|
146 |
Variables: |
|
|
147 |
<all_eval_results> is a dictionary of pandas dataframes created in cnn.py |
|
|
148 |
<epoch> is an integer indicating which epoch it is, starting from epoch 1 |
|
|
149 |
<true_labels_array>: array of true labels. examples x labels |
|
|
150 |
<pred_probs_array>: array of predicted probabilities. examples x labels""" |
|
|
151 |
#Accuracy, AUROC, and AP (iter over labels) |
|
|
152 |
for label_number in range(len(label_meanings)): |
|
|
153 |
which_label = label_meanings[label_number] #descriptive string for the label |
|
|
154 |
true_labels = true_labels_array[:,label_number] |
|
|
155 |
pred_probs = pred_probs_array[:,label_number] |
|
|
156 |
pred_labels = (pred_probs>=0.5).astype(dtype='int') #decision threshold of 0.5 |
|
|
157 |
|
|
|
158 |
#Accuracy and confusion matrix (dependent on decision threshold) |
|
|
159 |
(eval_dfs_dict['accuracy']).at[which_label, 'epoch_'+str(epoch)] = compute_accuracy(true_labels, pred_labels) |
|
|
160 |
#confusion_matrix, sensitivity, specificity, ppv, npv = compute_confusion_matrix(true_labels, pred_labels) |
|
|
161 |
|
|
|
162 |
#AUROC and AP (sliding across multiple decision thresholds) |
|
|
163 |
fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_true = true_labels, |
|
|
164 |
y_score = pred_probs, |
|
|
165 |
pos_label = 1) |
|
|
166 |
(eval_dfs_dict['auroc']).at[which_label, 'epoch_'+str(epoch)] = sklearn.metrics.auc(fpr, tpr) |
|
|
167 |
(eval_dfs_dict['avg_precision']).at[which_label, 'epoch_'+str(epoch)] = sklearn.metrics.average_precision_score(true_labels, pred_probs) |
|
|
168 |
|
|
|
169 |
#Top k eval metrics (iter over examples) |
|
|
170 |
eval_dfs_dict['top_k'] = evaluate_top_k(eval_dfs_dict['top_k'], |
|
|
171 |
epoch, true_labels_array, pred_probs_array) |
|
|
172 |
return eval_dfs_dict |
|
|
173 |
|
|
|
174 |
################# |
|
|
175 |
# Top K Metrics #--------------------------------------------------------------- |
|
|
176 |
################# |
|
|
177 |
def evaluate_top_k(eval_top_k_df, epoch, true_labels_array, |
|
|
178 |
pred_probs_array): |
|
|
179 |
"""<eval_top_k_df> is a pandas dataframe with epoch number as columns and |
|
|
180 |
k values as rows, where k is an integer""" |
|
|
181 |
num_labels = true_labels_array.shape[1] #e.g. 64 |
|
|
182 |
total_examples = true_labels_array.shape[0] |
|
|
183 |
vals = [0 for x in range(1,num_labels+2)] #e.g. length 65 list but the index of the last element is 64 for num_labels=64 |
|
|
184 |
for example_number in range(total_examples): |
|
|
185 |
#iterate through individual examples (predictions for an individual CT) |
|
|
186 |
#rather than iterating through predicted labels |
|
|
187 |
true_labels = true_labels_array[example_number,:] |
|
|
188 |
pred_probs = pred_probs_array[example_number,:] |
|
|
189 |
for k in range(1,num_labels+1): #e.g. 1,...,64 |
|
|
190 |
previous_value = vals[k] |
|
|
191 |
incremental_update = calculate_top_k_accuracy(true_labels, pred_probs, k) |
|
|
192 |
new_value = previous_value + incremental_update |
|
|
193 |
vals[k] = new_value |
|
|
194 |
#Now update the dataframe. Should reach 100% performance by the end. |
|
|
195 |
for k in range(1,num_labels+1): |
|
|
196 |
eval_top_k_df.at[k,'epoch_'+str(epoch)] = vals[k]/total_examples |
|
|
197 |
|
|
|
198 |
##Now average over all the examples |
|
|
199 |
#eval_top_k_df.loc[:,'epoch_'+str(epoch)] = eval_top_k_df.loc[:,'epoch_'+str(epoch)] / total_examples |
|
|
200 |
return eval_top_k_df |
|
|
201 |
|
|
|
202 |
def calculate_top_k_accuracy(true_labels, pred_probs, k): |
|
|
203 |
k = min(k, len(true_labels)) #avoid accessing array elements that don't exist |
|
|
204 |
#argpartition described here: https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array |
|
|
205 |
#get the indices of the largest k probabilities |
|
|
206 |
ind = np.argpartition(pred_probs, -1*k)[-1*k:] |
|
|
207 |
#now figure out what percent of these top predictions were equal to 1 in the |
|
|
208 |
#true_labels. |
|
|
209 |
#Note that the denominator should not exceed the number of true labels, to |
|
|
210 |
#avoid penalizing the model inappropriately: |
|
|
211 |
denom = min(k, np.sum(true_labels)) |
|
|
212 |
if denom == 0: #because np.sum(true_labels) is 0 |
|
|
213 |
#super important! must return 1 to avoid dividing by 0 and producing nan |
|
|
214 |
#we don't return 0 because then the model can never get perfect performance |
|
|
215 |
#even at k=num_labels because it'll get 0 for anything that has no labels |
|
|
216 |
return 1 |
|
|
217 |
else: |
|
|
218 |
return float(np.sum(true_labels[ind]))/denom |
|
|
219 |
|
|
|
220 |
###################### |
|
|
221 |
# Accuracy and AUROC #---------------------------------------------------------- |
|
|
222 |
###################### |
|
|
223 |
def compute_accuracy(true_labels, labels_pred): |
|
|
224 |
"""Print and save the accuracy of the model on the dataset""" |
|
|
225 |
correct = (true_labels == labels_pred) |
|
|
226 |
correct_sum = correct.sum() |
|
|
227 |
return (float(correct_sum)/len(true_labels)) |
|
|
228 |
|
|
|
229 |
def compute_confusion_matrix(true_labels, labels_pred): |
|
|
230 |
"""Return the confusion matrix""" |
|
|
231 |
cm = sklearn.metrics.confusion_matrix(y_true=true_labels, |
|
|
232 |
y_pred=labels_pred) |
|
|
233 |
if cm.size < 4: #cm is too small to calculate anything |
|
|
234 |
return np.nan, np.nan, np.nan, np.nan, np.nan |
|
|
235 |
true_neg, false_pos, false_neg, true_pos = cm.ravel() |
|
|
236 |
sensitivity = float(true_pos)/(true_pos + false_neg) |
|
|
237 |
specificity = float(true_neg)/(true_neg + false_pos) |
|
|
238 |
ppv = float(true_pos)/(true_pos + false_pos) |
|
|
239 |
npv = float(true_neg)/(true_neg + false_neg) |
|
|
240 |
|
|
|
241 |
return((str(cm).replace("\n","_")), sensitivity, specificity, ppv, npv) |
|
|
242 |
|
|
|
243 |
def compute_partial_auroc(fpr, tpr, thresh = 0.2, trapezoid = False, verbose=False): |
|
|
244 |
fpr_thresh, tpr_thresh = get_fpr_tpr_for_thresh(fpr, tpr, thresh) |
|
|
245 |
if len(fpr_thresh) < 2:#can't calculate an AUC with only 1 data point |
|
|
246 |
return np.nan |
|
|
247 |
if verbose: |
|
|
248 |
print('fpr: '+str(fpr)) |
|
|
249 |
print('fpr_thresh: '+str(fpr_thresh)) |
|
|
250 |
print('tpr: '+str(tpr)) |
|
|
251 |
print('tpr_thresh: '+str(tpr_thresh)) |
|
|
252 |
return sklearn.metrics.auc(fpr_thresh, tpr_thresh) |
|
|
253 |
|
|
|
254 |
def get_fpr_tpr_for_thresh(fpr, tpr, thresh): |
|
|
255 |
"""The <fpr> and <tpr> are already sorted according to threshold (which is |
|
|
256 |
sorted from highest to lowest, and is NOT the same as <thresh>; threshold |
|
|
257 |
is the third output of sklearn.metrics.roc_curve and is a vector of the |
|
|
258 |
thresholds used to calculate FPR and TPR). This function figures out where |
|
|
259 |
to bisect the FPR so that the remaining elements are no greater than |
|
|
260 |
<thresh>. It bisects the TPR in the same place.""" |
|
|
261 |
p = (bisect.bisect_left(fpr, thresh)-1) #subtract one so that the FPR |
|
|
262 |
#of the remaining elements is NO GREATER THAN <thresh> |
|
|
263 |
return fpr[: p + 1], tpr[: p + 1] |
|
|
264 |
|
|
|
265 |
###################### |
|
|
266 |
# Plotting Functions #---------------------------------------------------------- |
|
|
267 |
###################### |
|
|
268 |
def plot_pr_and_roc_curves(results_dir, label_meanings, true_labels, pred_probs, |
|
|
269 |
epoch): |
|
|
270 |
#Plot Precision Recall Curve |
|
|
271 |
|
|
|
272 |
#Plot ROC Curve |
|
|
273 |
fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_true = true_labels, |
|
|
274 |
y_score = pred_probs, |
|
|
275 |
pos_label = 1) |
|
|
276 |
plot_roc_curve(fpr, tpr, epoch, outfilepath) |
|
|
277 |
|
|
|
278 |
|
|
|
279 |
def plot_roc_curve_multi_class(label_meanings, y_test, y_score, |
|
|
280 |
outdir, setname, epoch): |
|
|
281 |
"""<label_meanings>: list of strings, one for each label |
|
|
282 |
<y_test>: matrix of ground truth |
|
|
283 |
<y_score>: matrix of predicted probabilities |
|
|
284 |
<outdir>: directory to save output file |
|
|
285 |
<setname>: string e.g. 'train' 'valid' or 'test' |
|
|
286 |
<epoch>: int for epoch""" |
|
|
287 |
#Modified from https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html |
|
|
288 |
n_classes = len(label_meanings) |
|
|
289 |
lw = 2 |
|
|
290 |
|
|
|
291 |
# Compute ROC curve and ROC area for each class |
|
|
292 |
fpr = dict() |
|
|
293 |
tpr = dict() |
|
|
294 |
roc_auc = dict() |
|
|
295 |
for i in range(n_classes): |
|
|
296 |
fpr[i], tpr[i], _ = sklearn.metrics.roc_curve(y_test[:, i], y_score[:, i]) |
|
|
297 |
roc_auc[i] = sklearn.metrics.auc(fpr[i], tpr[i]) |
|
|
298 |
|
|
|
299 |
#make order df. (note that roc_auc is a dictionary with ints as keys |
|
|
300 |
#and AUCs as values. |
|
|
301 |
order = pd.DataFrame(np.zeros((n_classes,1)), index = [x for x in range(n_classes)], |
|
|
302 |
columns = ['roc_auc']) |
|
|
303 |
for i in range(n_classes): |
|
|
304 |
order.at[i,'roc_auc'] = roc_auc[i] |
|
|
305 |
order = order.sort_values(by='roc_auc',ascending=False) |
|
|
306 |
|
|
|
307 |
#Plot all ROC curves |
|
|
308 |
#Plot in order of the rainbow colors, from highest AUC to lowest AUC |
|
|
309 |
plt.figure() |
|
|
310 |
colors_list = ['palevioletred','darkorange','yellowgreen','olive','deepskyblue','royalblue','navy'] |
|
|
311 |
curves_plotted = 0 |
|
|
312 |
for i in order.index.values.tolist()[0:10]: #only plot the top ten so the plot is readable |
|
|
313 |
color_idx = curves_plotted%len(colors_list) #cycle through the colors list in order of colors |
|
|
314 |
color = colors_list[color_idx] |
|
|
315 |
plt.plot(fpr[i], tpr[i], color=color, lw=lw, |
|
|
316 |
label='{:5s} (area {:0.2f})'.format(label_meanings[i], roc_auc[i])) |
|
|
317 |
curves_plotted+=1 |
|
|
318 |
|
|
|
319 |
plt.plot([0, 1], [0, 1], 'k--', lw=lw) |
|
|
320 |
plt.xlim([0.0, 1.0]) |
|
|
321 |
plt.ylim([0.0, 1.05]) |
|
|
322 |
plt.xlabel('False Positive Rate') |
|
|
323 |
plt.ylabel('True Positive Rate') |
|
|
324 |
plt.title(setname.lower().capitalize()+' ROC Epoch '+str(epoch)) |
|
|
325 |
plt.legend(loc="lower right",prop={'size':6}) |
|
|
326 |
outfilepath = os.path.join(outdir,setname+'_ROC_ep'+str(epoch)+'.pdf') |
|
|
327 |
plt.savefig(outfilepath) |
|
|
328 |
plt.close() |
|
|
329 |
|
|
|
330 |
def plot_pr_curve_multi_class(label_meanings, y_test, y_score, |
|
|
331 |
outdir, setname, epoch): |
|
|
332 |
"""<label_meanings>: list of strings, one for each label |
|
|
333 |
<y_test>: matrix of ground truth |
|
|
334 |
<y_score>: matrix of predicted probabilities |
|
|
335 |
<outdir>: directory to save output file |
|
|
336 |
<setname>: string e.g. 'train' 'valid' or 'test' |
|
|
337 |
<epoch>: int for epoch""" |
|
|
338 |
#Modified from https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html |
|
|
339 |
#https://stackoverflow.com/questions/29656550/how-to-plot-pr-curve-over-10-folds-of-cross-validation-in-scikit-learn |
|
|
340 |
n_classes = len(label_meanings) |
|
|
341 |
lw = 2 |
|
|
342 |
|
|
|
343 |
#make order df. |
|
|
344 |
order = pd.DataFrame(np.zeros((n_classes,1)), index = [x for x in range(n_classes)], |
|
|
345 |
columns = ['prc']) |
|
|
346 |
for i in range(n_classes): |
|
|
347 |
order.at[i,'prc'] = sklearn.metrics.average_precision_score(y_test[:,i], y_score[:,i]) |
|
|
348 |
order = order.sort_values(by='prc',ascending=False) |
|
|
349 |
|
|
|
350 |
#Plot |
|
|
351 |
plt.figure() |
|
|
352 |
colors_list = ['palevioletred','darkorange','yellowgreen','olive','deepskyblue','royalblue','navy'] |
|
|
353 |
curves_plotted = 0 |
|
|
354 |
for i in order.index.values.tolist()[0:10]: #only plot the top ten so the plot is readable |
|
|
355 |
color_idx = curves_plotted%len(colors_list) #cycle through the colors list in order of colors |
|
|
356 |
color = colors_list[color_idx] |
|
|
357 |
average_precision = sklearn.metrics.average_precision_score(y_test[:,i], y_score[:,i]) |
|
|
358 |
precision, recall, _ = sklearn.metrics.precision_recall_curve(y_test[:,i], y_score[:,i]) |
|
|
359 |
plt.step(recall, precision, color=color, where='post', |
|
|
360 |
label='{:5s} (area {:0.2f})'.format(label_meanings[i], average_precision)) |
|
|
361 |
curves_plotted+=1 |
|
|
362 |
plt.xlabel('Recall') |
|
|
363 |
plt.ylabel('Precision') |
|
|
364 |
plt.ylim([0.0, 1.05]) |
|
|
365 |
plt.xlim([0.0, 1.0]) |
|
|
366 |
plt.title(setname.lower().capitalize()+' PRC Epoch '+str(epoch)) |
|
|
367 |
plt.legend(loc="lower right",prop={'size':6}) |
|
|
368 |
outfilepath = os.path.join(outdir,setname+'_PR_ep'+str(epoch)+'.pdf') |
|
|
369 |
plt.savefig(outfilepath) |
|
|
370 |
plt.close() |
|
|
371 |
|
|
|
372 |
def plot_pr_curve_single_class(true_labels, pred_probs, outfilepath): |
|
|
373 |
#http://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html |
|
|
374 |
average_precision = sklearn.metrics.average_precision_score(true_labels, pred_probs) |
|
|
375 |
precision, recall, _ = sklearn.metrics.precision_recall_curve(true_labels, pred_probs) |
|
|
376 |
plt.step(recall, precision, color='b', alpha=0.2, |
|
|
377 |
where='post') |
|
|
378 |
plt.fill_between(recall, precision, step='post', alpha=0.2, |
|
|
379 |
color='b') |
|
|
380 |
plt.xlabel('Recall') |
|
|
381 |
plt.ylabel('Precision') |
|
|
382 |
plt.ylim([0.0, 1.05]) |
|
|
383 |
plt.xlim([0.0, 1.0]) |
|
|
384 |
plt.title('2-class Precision-Recall curve: AP={0:0.2f}'.format( |
|
|
385 |
average_precision)) |
|
|
386 |
plt.savefig(outfilepath) |
|
|
387 |
plt.close() |
|
|
388 |
|
|
|
389 |
def plot_roc_curve_single_class(fpr, tpr, epoch, outfilepath): |
|
|
390 |
#http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html#sphx-glr-auto-examples-model-selection-plot-roc-py |
|
|
391 |
roc_auc = sklearn.metrics.auc(fpr, tpr) |
|
|
392 |
lw = 2 |
|
|
393 |
plt.plot(fpr, tpr, color='darkorange', |
|
|
394 |
lw=lw, label='ROC curve (area = %0.2f)' % roc_auc) |
|
|
395 |
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--') |
|
|
396 |
plt.xlim([0.0, 1.0]) |
|
|
397 |
plt.ylim([0.0, 1.05]) |
|
|
398 |
plt.xlabel('False Positive Rate') |
|
|
399 |
plt.ylabel('True Positive Rate') |
|
|
400 |
plt.title('Receiver Operating Characteristic') |
|
|
401 |
plt.legend(loc="lower right") |
|
|
402 |
plt.savefig(outputfilepath) |
|
|
403 |
plt.close() |
|
|
404 |
|
|
|
405 |
def plot_learning_curves(train_loss, valid_loss, results_dir, descriptor): |
|
|
406 |
"""Variables |
|
|
407 |
<train_loss> and <valid_loss> are numpy arrays with one numerical entry |
|
|
408 |
for each epoch quanitfying the loss for that epoch.""" |
|
|
409 |
x = np.arange(0,len(train_loss)) |
|
|
410 |
plt.plot(x, train_loss, color='blue', lw=2, label='train') |
|
|
411 |
plt.plot(x, valid_loss, color='green',lw = 2, label='valid') |
|
|
412 |
plt.xlabel('Epoch') |
|
|
413 |
plt.ylabel('Loss') |
|
|
414 |
plt.title('Training and Validation Loss') |
|
|
415 |
plt.legend(loc='lower right') |
|
|
416 |
plt.savefig(os.path.join(results_dir, descriptor+'_Learning_Curves.png')) |
|
|
417 |
plt.close() |
|
|
418 |
#save numpy arrays of the losses |
|
|
419 |
np.save(os.path.join(results_dir,'train_loss.npy'),train_loss) |
|
|
420 |
np.save(os.path.join(results_dir,'valid_loss.npy'),valid_loss) |
|
|
421 |
|
|
|
422 |
def plot_heatmap(outprefix, numeric_array, center, xticklabels, yticklabels): |
|
|
423 |
"""Save a heatmap based on numeric_array""" |
|
|
424 |
seaborn.set(font_scale=0.6) |
|
|
425 |
seaplt = (seaborn.heatmap(numeric_array, |
|
|
426 |
center=center, |
|
|
427 |
xticklabels=xticklabels, |
|
|
428 |
yticklabels=yticklabels)).get_figure() |
|
|
429 |
seaplt.savefig(outprefix+'.png') |
|
|
430 |
seaplt.clf() |