|
a |
|
b/supplementary_files/performance_evaluation.py |
|
|
1 |
import argparse |
|
|
2 |
import matplotlib.pyplot as plt |
|
|
3 |
import numpy as np |
|
|
4 |
from sklearn.metrics import precision_recall_curve, average_precision_score |
|
|
5 |
from matplotlib_venn import venn2, venn3 |
|
|
6 |
|
|
|
7 |
##################################### Functions ############################################# |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
def plot_confusion_matrix(cm, target_names, cmap=None, normalize=False): |
|
|
11 |
"""Function that plots the confusion matrix given cm. Mattias Ohlsson's code extended.""" |
|
|
12 |
|
|
|
13 |
import itertools |
|
|
14 |
|
|
|
15 |
accuracy = np.trace(cm) / float(np.sum(cm)) |
|
|
16 |
misclass = 1 - accuracy |
|
|
17 |
|
|
|
18 |
if cmap is None: |
|
|
19 |
cmap = plt.get_cmap("Blues") |
|
|
20 |
|
|
|
21 |
fig = plt.figure(figsize=(4, 3)) |
|
|
22 |
plt.imshow(cm, interpolation="nearest", cmap=cmap) |
|
|
23 |
plt.colorbar() |
|
|
24 |
|
|
|
25 |
if target_names is not None: |
|
|
26 |
tick_marks = np.arange(len(target_names)) |
|
|
27 |
plt.xticks(tick_marks, target_names, rotation=0, fontsize=12) |
|
|
28 |
plt.yticks(tick_marks, target_names, fontsize=12) |
|
|
29 |
|
|
|
30 |
if normalize: |
|
|
31 |
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] |
|
|
32 |
|
|
|
33 |
thresh = cm.max() / 1.5 if normalize else cm.max() / 2 |
|
|
34 |
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): |
|
|
35 |
if normalize: |
|
|
36 |
plt.text( |
|
|
37 |
j, |
|
|
38 |
i, |
|
|
39 |
"{:0.4f}".format(cm[i, j]), |
|
|
40 |
horizontalalignment="center", |
|
|
41 |
color="white" if cm[i, j] > thresh else "black", |
|
|
42 |
fontsize=14, |
|
|
43 |
) |
|
|
44 |
else: |
|
|
45 |
plt.text( |
|
|
46 |
j, |
|
|
47 |
i, |
|
|
48 |
"{:,}".format(cm[i, j]), |
|
|
49 |
horizontalalignment="center", |
|
|
50 |
color="white" if cm[i, j] > thresh else "black", |
|
|
51 |
fontsize=14, |
|
|
52 |
) |
|
|
53 |
|
|
|
54 |
plt.tight_layout() |
|
|
55 |
plt.ylabel("True label", fontsize=14) |
|
|
56 |
plt.xlabel( |
|
|
57 |
"Predicted label\naccuracy={:0.4f}; misclass={:0.4f}".format( |
|
|
58 |
accuracy, misclass |
|
|
59 |
), |
|
|
60 |
fontsize=14, |
|
|
61 |
) |
|
|
62 |
|
|
|
63 |
return fig |
|
|
64 |
|
|
|
65 |
|
|
|
66 |
def classify_associations(target_file, assoc_tuples): |
|
|
67 |
self_assoc = 0 # Self associations |
|
|
68 |
found_assoc_dict = {} |
|
|
69 |
false_assoc_dict = {} |
|
|
70 |
tp_fp = np.array([[0, 0]]) |
|
|
71 |
with open(target_file, "r") as f: |
|
|
72 |
for line in f: |
|
|
73 |
if line[0] != "f": |
|
|
74 |
splitline = line.strip().split("\t") |
|
|
75 |
feat_a = splitline[2] |
|
|
76 |
feat_b = splitline[3] |
|
|
77 |
score = abs(float(splitline[5])) |
|
|
78 |
if feat_a == feat_b: # Self associations will not be counted |
|
|
79 |
self_assoc += 1 |
|
|
80 |
else: |
|
|
81 |
if (feat_a, feat_b) in assoc_tuples: |
|
|
82 |
found_assoc_dict[(feat_a, feat_b)] = score |
|
|
83 |
if ( |
|
|
84 |
feat_b, |
|
|
85 |
feat_a, |
|
|
86 |
) not in found_assoc_dict.keys(): # If we had not found it yet |
|
|
87 |
tp_fp = np.vstack((tp_fp, tp_fp[-1] + [0, 1])) |
|
|
88 |
elif (feat_a, feat_b) not in assoc_tuples: |
|
|
89 |
false_assoc_dict[(feat_a, feat_b)] = score |
|
|
90 |
if (feat_b, feat_a) not in false_assoc_dict.keys(): |
|
|
91 |
tp_fp = np.vstack((tp_fp, tp_fp[-1] + [1, 0])) |
|
|
92 |
|
|
|
93 |
# Remove duplicated associations: |
|
|
94 |
for i, j in list(found_assoc_dict.keys()): |
|
|
95 |
if (j, i) in found_assoc_dict.keys(): |
|
|
96 |
del found_assoc_dict[ |
|
|
97 |
(j, i) |
|
|
98 |
] # remove the weakest direction for the association |
|
|
99 |
|
|
|
100 |
for i, j in list(false_assoc_dict.keys()): |
|
|
101 |
if (j, i) in false_assoc_dict.keys(): |
|
|
102 |
del false_assoc_dict[(i, j)] |
|
|
103 |
|
|
|
104 |
return self_assoc, found_assoc_dict, false_assoc_dict, tp_fp |
|
|
105 |
|
|
|
106 |
|
|
|
107 |
def create_confusion_matrix(n_feat, associations, real_assoc, false_assoc): |
|
|
108 |
cm = np.empty((2, 2)) |
|
|
109 |
# TN: only counting the upper half matrix (non doubled associations) |
|
|
110 |
cm[0, 0] = (n_feat * n_feat - n_feat) / 2 - ( |
|
|
111 |
associations + false_assoc |
|
|
112 |
) # Diagonal is discarded |
|
|
113 |
cm[0, 1] = false_assoc |
|
|
114 |
cm[1, 0] = associations - real_assoc |
|
|
115 |
cm[1, 1] = real_assoc |
|
|
116 |
|
|
|
117 |
return cm |
|
|
118 |
|
|
|
119 |
|
|
|
120 |
def get_precision_recall(found_assoc_dict, false_assoc_dict, associations): |
|
|
121 |
y_true = [] |
|
|
122 |
y_pred = [] |
|
|
123 |
|
|
|
124 |
# True Positives |
|
|
125 |
for score in found_assoc_dict.values(): |
|
|
126 |
y_true.append(1) |
|
|
127 |
y_pred.append(score) |
|
|
128 |
# False Positives |
|
|
129 |
for score in false_assoc_dict.values(): |
|
|
130 |
y_true.append(0) |
|
|
131 |
y_pred.append(score) |
|
|
132 |
# False negatives |
|
|
133 |
for _ in range(associations - len(found_assoc_dict)): |
|
|
134 |
y_true.append(1) |
|
|
135 |
y_pred.append(0) |
|
|
136 |
|
|
|
137 |
precision, recall, thr = precision_recall_curve( |
|
|
138 |
y_true, y_pred |
|
|
139 |
) # thr will tell us score values |
|
|
140 |
avg_prec = average_precision_score(y_true, y_pred) |
|
|
141 |
|
|
|
142 |
return precision, recall, thr, avg_prec |
|
|
143 |
|
|
|
144 |
|
|
|
145 |
def plot_precision_recall(precision, recall, avg_prec, label, ax): |
|
|
146 |
ax.scatter( |
|
|
147 |
recall, |
|
|
148 |
precision, |
|
|
149 |
lw=0, |
|
|
150 |
marker=".", |
|
|
151 |
s=5, |
|
|
152 |
edgecolors="none", |
|
|
153 |
label=f"{label} - APS:{avg_prec:.2f}", |
|
|
154 |
) |
|
|
155 |
ax.legend() |
|
|
156 |
return ax |
|
|
157 |
|
|
|
158 |
|
|
|
159 |
def plot_thr_recall(thr, recall, label, ax): |
|
|
160 |
ax.scatter(recall[:-1], thr, lw=0, marker=".", s=5, edgecolors="none", label=label) |
|
|
161 |
ax.legend() |
|
|
162 |
return ax |
|
|
163 |
|
|
|
164 |
|
|
|
165 |
def plot_TP_vs_FP(tp_fp, label, ax): |
|
|
166 |
ax.scatter(tp_fp[:, 0], tp_fp[:, 1], s=2, label=label, edgecolors="none") |
|
|
167 |
ax.legend() |
|
|
168 |
return ax |
|
|
169 |
|
|
|
170 |
|
|
|
171 |
def plot_filling_order(order_list, last_rank=None): |
|
|
172 |
|
|
|
173 |
if last_rank is None: |
|
|
174 |
last_rank = len(order_list) |
|
|
175 |
fig = plt.figure() |
|
|
176 |
order_img = np.zeros((np.max(order_list), len(order_list))) |
|
|
177 |
for i, element in enumerate(order_list): |
|
|
178 |
order_img[element - 1, i:] = 1 |
|
|
179 |
|
|
|
180 |
plt.imshow(order_img[:last_rank, :], cmap="binary") |
|
|
181 |
plt.xlabel("Correct prediction number") |
|
|
182 |
plt.ylabel("Association ranking") |
|
|
183 |
plt.plot(np.arange(last_rank), np.arange(last_rank)) |
|
|
184 |
return fig |
|
|
185 |
|
|
|
186 |
|
|
|
187 |
def plot_effect_size_matching( |
|
|
188 |
assoc_tuples_dict, found_assoc_dict, label, ALGORITHM, ax |
|
|
189 |
): |
|
|
190 |
|
|
|
191 |
ground_truth_effects = [ |
|
|
192 |
assoc_tuples_dict[key] for key in list(found_assoc_dict.keys()) |
|
|
193 |
] |
|
|
194 |
predicted_effects = np.array(list(found_assoc_dict.values())) |
|
|
195 |
|
|
|
196 |
if ALGORITHM == "ttest": |
|
|
197 |
# Eq 15 on https://doi.org/10.1146/annurev-statistics-031017-100307 |
|
|
198 |
predicted_effects = [-np.log10(p) if p != 0 else -1 for p in predicted_effects] |
|
|
199 |
predicted_effects[predicted_effects == -1] = np.max( |
|
|
200 |
predicted_effects |
|
|
201 |
) # Change zeros for max likelihood, -1 as dummy value |
|
|
202 |
predicted_effects = np.array(predicted_effects) |
|
|
203 |
|
|
|
204 |
max, min = np.max(predicted_effects), np.min(predicted_effects) |
|
|
205 |
standarized_pred_effects = (predicted_effects - min) / (max - min) |
|
|
206 |
ax.scatter( |
|
|
207 |
ground_truth_effects, |
|
|
208 |
standarized_pred_effects, |
|
|
209 |
s=12, |
|
|
210 |
edgecolors="none", |
|
|
211 |
label=label, |
|
|
212 |
) |
|
|
213 |
ax.legend() |
|
|
214 |
return ax |
|
|
215 |
|
|
|
216 |
|
|
|
217 |
def plot_venn_diagram(venn, ax, mode="all", scale="log"): |
|
|
218 |
sets = [set(venn[key][mode]) for key in list(venn.keys())] |
|
|
219 |
labels = (key for key in list(venn.keys())) |
|
|
220 |
|
|
|
221 |
if len(venn) == 2: |
|
|
222 |
venn2(sets, labels, ax=ax) |
|
|
223 |
elif len(venn) == 3: |
|
|
224 |
venn3(sets, labels, ax=ax) |
|
|
225 |
else: |
|
|
226 |
raise ValueError("Unsupported number of input files.") |
|
|
227 |
|
|
|
228 |
|
|
|
229 |
def plot_upsetplot(venn, assoc_tuples): |
|
|
230 |
from upsetplot import UpSet |
|
|
231 |
import pandas as pd |
|
|
232 |
from matplotlib import cm |
|
|
233 |
|
|
|
234 |
all_assoc = set( |
|
|
235 |
[ |
|
|
236 |
association |
|
|
237 |
for ALGORITHM in venn.keys() |
|
|
238 |
for association in venn[ALGORITHM]["all"] |
|
|
239 |
] |
|
|
240 |
) |
|
|
241 |
columns = ["ground truth"] |
|
|
242 |
columns.extend([ALGORITHM for ALGORITHM in list(venn.keys())]) |
|
|
243 |
|
|
|
244 |
df = {} |
|
|
245 |
for association in all_assoc: |
|
|
246 |
df[association] = [] |
|
|
247 |
|
|
|
248 |
if association in assoc_tuples: |
|
|
249 |
df[association].append("TP") |
|
|
250 |
else: |
|
|
251 |
df[association].append("FP") |
|
|
252 |
|
|
|
253 |
for ALGORITHM in list(venn.keys()): |
|
|
254 |
if association in venn[ALGORITHM]["all"]: |
|
|
255 |
df[association].append(1) |
|
|
256 |
else: |
|
|
257 |
df[association].append(0) |
|
|
258 |
|
|
|
259 |
df = pd.DataFrame.from_dict(df, orient="index", columns=columns) |
|
|
260 |
df = df.set_index([pd.Index(df[ALGORITHM] == 1) for ALGORITHM in list(venn.keys())]) |
|
|
261 |
upset = UpSet(df, intersection_plot_elements=0, show_counts=True) |
|
|
262 |
|
|
|
263 |
upset.add_stacked_bars( |
|
|
264 |
by="ground truth", |
|
|
265 |
colors=cm.Pastel1, |
|
|
266 |
title="Count by ground truth value", |
|
|
267 |
elements=10, |
|
|
268 |
) |
|
|
269 |
|
|
|
270 |
return upset |
|
|
271 |
|
|
|
272 |
|
|
|
273 |
###################################### Main code ################################################ |
|
|
274 |
|
|
|
275 |
parser = argparse.ArgumentParser( |
|
|
276 |
description="Read two files with ground truth associations and predicted associations." |
|
|
277 |
) |
|
|
278 |
parser.add_argument( |
|
|
279 |
"-p", |
|
|
280 |
"--perturbed", |
|
|
281 |
metavar="pert", |
|
|
282 |
type=str, |
|
|
283 |
required=True, |
|
|
284 |
help="perturbed feature names", |
|
|
285 |
) |
|
|
286 |
parser.add_argument( |
|
|
287 |
"-n", |
|
|
288 |
"--features", |
|
|
289 |
metavar="feat", |
|
|
290 |
type=int, |
|
|
291 |
required=True, |
|
|
292 |
help=" total number of features", |
|
|
293 |
) |
|
|
294 |
parser.add_argument( |
|
|
295 |
"-r", |
|
|
296 |
"--reference", |
|
|
297 |
metavar="ref", |
|
|
298 |
type=str, |
|
|
299 |
required=True, |
|
|
300 |
help="path to the ground truth associations file", |
|
|
301 |
) |
|
|
302 |
parser.add_argument( |
|
|
303 |
"-t", |
|
|
304 |
"--targets", |
|
|
305 |
metavar="tar", |
|
|
306 |
type=str, |
|
|
307 |
required=True, |
|
|
308 |
nargs="+", |
|
|
309 |
help="path to the predicted associations files", |
|
|
310 |
) |
|
|
311 |
args = parser.parse_args() |
|
|
312 |
|
|
|
313 |
|
|
|
314 |
# Defining main performance evaluation figures: |
|
|
315 |
fig_0, ax_0 = plt.subplots(figsize=(7, 7)) |
|
|
316 |
fig_1, ax_1 = plt.subplots(figsize=(7, 7)) |
|
|
317 |
fig_2, ax_2 = plt.subplots() |
|
|
318 |
fig_3, ax_3 = plt.subplots() |
|
|
319 |
|
|
|
320 |
assoc_tuples_dict = {} |
|
|
321 |
|
|
|
322 |
# Reading the file with the ground truth changes: |
|
|
323 |
with open(args.reference, "r") as f: |
|
|
324 |
for line in f: |
|
|
325 |
if line[0] != "f" and line[0] != "n": |
|
|
326 |
splitline = line.strip().split("\t") |
|
|
327 |
feat_a = splitline[2] |
|
|
328 |
feat_b = splitline[3] |
|
|
329 |
assoc_strength = abs(float(splitline[4])) |
|
|
330 |
# Only can detect associations with perturbed features |
|
|
331 |
if args.perturbed in feat_a or args.perturbed in feat_b: |
|
|
332 |
assoc_tuples_dict[(feat_a, feat_b)] = assoc_strength |
|
|
333 |
assoc_tuples_dict[(feat_b, feat_a)] = assoc_strength |
|
|
334 |
|
|
|
335 |
associations = int(len(assoc_tuples_dict) / 2) |
|
|
336 |
venn = {} |
|
|
337 |
# Count and save found associations |
|
|
338 |
for target_file in args.targets: |
|
|
339 |
|
|
|
340 |
ALGORITHM = target_file.split("/")[-1].split("_")[3][:-4] |
|
|
341 |
self_assoc, found_assoc_dict, false_assoc_dict, tp_fp = classify_associations( |
|
|
342 |
target_file, list(assoc_tuples_dict.keys()) |
|
|
343 |
) |
|
|
344 |
real_assoc = len(found_assoc_dict) # True predicted associations |
|
|
345 |
false_assoc = len(false_assoc_dict) # False predicted associations |
|
|
346 |
total_assoc = real_assoc + false_assoc |
|
|
347 |
|
|
|
348 |
venn[ALGORITHM] = {} |
|
|
349 |
venn[ALGORITHM]["correct"] = list(found_assoc_dict.keys()) |
|
|
350 |
venn[ALGORITHM]["all"] = list(found_assoc_dict.keys()) + list( |
|
|
351 |
false_assoc_dict.keys() |
|
|
352 |
) |
|
|
353 |
|
|
|
354 |
# Assess ranking of associations (they are doubled in assoc_tuples): |
|
|
355 |
order_list = [ |
|
|
356 |
list(assoc_tuples_dict.keys()).index((feat_a, feat_b)) // 2 |
|
|
357 |
for (feat_a, feat_b) in list(found_assoc_dict.keys()) |
|
|
358 |
] |
|
|
359 |
fig = plot_filling_order(order_list) |
|
|
360 |
fig.savefig(f"Order_image_{ALGORITHM}.png", dpi=200) |
|
|
361 |
|
|
|
362 |
ax_0 = plot_effect_size_matching( |
|
|
363 |
assoc_tuples_dict, found_assoc_dict, ALGORITHM, ALGORITHM, ax_0 |
|
|
364 |
) |
|
|
365 |
|
|
|
366 |
# Plot confusion matrix: |
|
|
367 |
cm = create_confusion_matrix(args.features, associations, real_assoc, false_assoc) |
|
|
368 |
fig = plot_confusion_matrix( |
|
|
369 |
cm, ["No assoc", "Association"], cmap=None, normalize=False |
|
|
370 |
) |
|
|
371 |
|
|
|
372 |
fig.savefig(f"Confusion_matrix_{ALGORITHM}.png", dpi=100, bbox_inches="tight") |
|
|
373 |
|
|
|
374 |
# Plot precision-recall and TP-FP curves |
|
|
375 |
precision, recall, thr, avg_prec = get_precision_recall( |
|
|
376 |
found_assoc_dict, false_assoc_dict, associations |
|
|
377 |
) |
|
|
378 |
|
|
|
379 |
ax_1 = plot_precision_recall(precision, recall, avg_prec, ALGORITHM, ax_1) |
|
|
380 |
ax_2 = plot_TP_vs_FP(tp_fp, ALGORITHM, ax_2) |
|
|
381 |
ax_3 = plot_thr_recall(thr, recall, ALGORITHM, ax_3) |
|
|
382 |
|
|
|
383 |
# Write results: |
|
|
384 |
with open("Performance_evaluation_summary_results.txt", "a") as f: |
|
|
385 |
f.write(f" File: {target_file}\n") |
|
|
386 |
f.write( |
|
|
387 |
f"Ground truth detectable associations (i.e. involving perturbed feature,{args.perturbed}):{associations}\n" |
|
|
388 |
) |
|
|
389 |
f.write( |
|
|
390 |
f"{total_assoc} unique associations found\n{self_assoc} self-associations were found before filtering\n{real_assoc} were real associations\n{false_assoc} were either false or below the significance threshold\n" |
|
|
391 |
) |
|
|
392 |
# print("Correct associations:\n", found_assoc_tuples, "\n") |
|
|
393 |
f.write( |
|
|
394 |
f"Sensitivity:{real_assoc}/{associations} = {real_assoc/associations}\n" |
|
|
395 |
) |
|
|
396 |
f.write(f"Precision:{real_assoc}/{total_assoc} = {(real_assoc)/total_assoc}\n") |
|
|
397 |
f.write(f"Order list:{order_list}\n\n") |
|
|
398 |
f.write("______________________________________________________\n") |
|
|
399 |
|
|
|
400 |
|
|
|
401 |
# Edit figures: layout |
|
|
402 |
ax_0.set_xlabel("Real effect") |
|
|
403 |
ax_0.set_ylabel("Predicted effect") |
|
|
404 |
ax_0.set_ylim((-0.02, 1.02)) |
|
|
405 |
ax_0.set_xlim((0, 1.02)) |
|
|
406 |
ax_0.legend( |
|
|
407 |
loc="upper center", bbox_to_anchor=(0.5, 1.1), ncol=3, fancybox=True, shadow=True |
|
|
408 |
) |
|
|
409 |
|
|
|
410 |
ax_1.set_xlabel("Recall") |
|
|
411 |
ax_1.set_ylabel("Precision") |
|
|
412 |
ax_1.legend() |
|
|
413 |
ax_1.set_ylim((0, 1.05)) |
|
|
414 |
ax_1.set_xlim((0, 1.05)) |
|
|
415 |
|
|
|
416 |
ax_2.set_xlabel("False Positives") |
|
|
417 |
ax_2.set_ylabel("True Positives") |
|
|
418 |
ax_2.set_aspect("equal") |
|
|
419 |
|
|
|
420 |
ax_3.set_ylabel("Threshold") |
|
|
421 |
ax_3.set_xlabel("Recall") |
|
|
422 |
|
|
|
423 |
|
|
|
424 |
# Save main figures: |
|
|
425 |
fig_0.savefig("Effect_size_matchin.png", dpi=200) |
|
|
426 |
fig_1.savefig("Precision_recall.png", dpi=200) |
|
|
427 |
fig_2.savefig("TP_vs_FP.png", dpi=200) |
|
|
428 |
fig_3.savefig("thr_vs_recall.png", dpi=200) |
|
|
429 |
|
|
|
430 |
# Plotting venn diagram: |
|
|
431 |
if len(venn) == 2 or len(venn) == 3: |
|
|
432 |
fig_v, ax_v = plt.subplots() |
|
|
433 |
ax_v = plot_venn_diagram(venn, ax_v, mode="correct") |
|
|
434 |
fig_v.savefig("Venn_diagram.png", dpi=200) |
|
|
435 |
|
|
|
436 |
# Plotting UpSet plot |
|
|
437 |
upset = plot_upsetplot(venn, list(assoc_tuples_dict.keys())) |
|
|
438 |
upset.plot() |
|
|
439 |
plt.savefig("UpSet.png", dpi=200) |