a b/supplementary_files/performance_evaluation.py
1
import argparse
2
import matplotlib.pyplot as plt
3
import numpy as np
4
from sklearn.metrics import precision_recall_curve, average_precision_score
5
from matplotlib_venn import venn2, venn3
6
7
##################################### Functions #############################################
8
9
10
def plot_confusion_matrix(cm, target_names, cmap=None, normalize=False):
11
    """Function that plots the confusion matrix given cm. Mattias Ohlsson's code extended."""
12
13
    import itertools
14
15
    accuracy = np.trace(cm) / float(np.sum(cm))
16
    misclass = 1 - accuracy
17
18
    if cmap is None:
19
        cmap = plt.get_cmap("Blues")
20
21
    fig = plt.figure(figsize=(4, 3))
22
    plt.imshow(cm, interpolation="nearest", cmap=cmap)
23
    plt.colorbar()
24
25
    if target_names is not None:
26
        tick_marks = np.arange(len(target_names))
27
        plt.xticks(tick_marks, target_names, rotation=0, fontsize=12)
28
        plt.yticks(tick_marks, target_names, fontsize=12)
29
30
    if normalize:
31
        cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
32
33
    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
34
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
35
        if normalize:
36
            plt.text(
37
                j,
38
                i,
39
                "{:0.4f}".format(cm[i, j]),
40
                horizontalalignment="center",
41
                color="white" if cm[i, j] > thresh else "black",
42
                fontsize=14,
43
            )
44
        else:
45
            plt.text(
46
                j,
47
                i,
48
                "{:,}".format(cm[i, j]),
49
                horizontalalignment="center",
50
                color="white" if cm[i, j] > thresh else "black",
51
                fontsize=14,
52
            )
53
54
    plt.tight_layout()
55
    plt.ylabel("True label", fontsize=14)
56
    plt.xlabel(
57
        "Predicted label\naccuracy={:0.4f}; misclass={:0.4f}".format(
58
            accuracy, misclass
59
        ),
60
        fontsize=14,
61
    )
62
63
    return fig
64
65
66
def classify_associations(target_file, assoc_tuples):
67
    self_assoc = 0  # Self associations
68
    found_assoc_dict = {}
69
    false_assoc_dict = {}
70
    tp_fp = np.array([[0, 0]])
71
    with open(target_file, "r") as f:
72
        for line in f:
73
            if line[0] != "f":
74
                splitline = line.strip().split("\t")
75
                feat_a = splitline[2]
76
                feat_b = splitline[3]
77
                score = abs(float(splitline[5]))
78
                if feat_a == feat_b:  # Self associations will not be counted
79
                    self_assoc += 1
80
                else:
81
                    if (feat_a, feat_b) in assoc_tuples:
82
                        found_assoc_dict[(feat_a, feat_b)] = score
83
                        if (
84
                            feat_b,
85
                            feat_a,
86
                        ) not in found_assoc_dict.keys():  # If we had not found it yet
87
                            tp_fp = np.vstack((tp_fp, tp_fp[-1] + [0, 1]))
88
                    elif (feat_a, feat_b) not in assoc_tuples:
89
                        false_assoc_dict[(feat_a, feat_b)] = score
90
                        if (feat_b, feat_a) not in false_assoc_dict.keys():
91
                            tp_fp = np.vstack((tp_fp, tp_fp[-1] + [1, 0]))
92
93
    # Remove duplicated associations:
94
    for i, j in list(found_assoc_dict.keys()):
95
        if (j, i) in found_assoc_dict.keys():
96
            del found_assoc_dict[
97
                (j, i)
98
            ]  # remove the weakest direction for the association
99
100
    for i, j in list(false_assoc_dict.keys()):
101
        if (j, i) in false_assoc_dict.keys():
102
            del false_assoc_dict[(i, j)]
103
104
    return self_assoc, found_assoc_dict, false_assoc_dict, tp_fp
105
106
107
def create_confusion_matrix(n_feat, associations, real_assoc, false_assoc):
108
    cm = np.empty((2, 2))
109
    # TN: only counting the upper half matrix (non doubled associations)
110
    cm[0, 0] = (n_feat * n_feat - n_feat) / 2 - (
111
        associations + false_assoc
112
    )  # Diagonal is discarded
113
    cm[0, 1] = false_assoc
114
    cm[1, 0] = associations - real_assoc
115
    cm[1, 1] = real_assoc
116
117
    return cm
118
119
120
def get_precision_recall(found_assoc_dict, false_assoc_dict, associations):
121
    y_true = []
122
    y_pred = []
123
124
    # True Positives
125
    for score in found_assoc_dict.values():
126
        y_true.append(1)
127
        y_pred.append(score)
128
    # False Positives
129
    for score in false_assoc_dict.values():
130
        y_true.append(0)
131
        y_pred.append(score)
132
    # False negatives
133
    for _ in range(associations - len(found_assoc_dict)):
134
        y_true.append(1)
135
        y_pred.append(0)
136
137
    precision, recall, thr = precision_recall_curve(
138
        y_true, y_pred
139
    )  # thr will tell us score values
140
    avg_prec = average_precision_score(y_true, y_pred)
141
142
    return precision, recall, thr, avg_prec
143
144
145
def plot_precision_recall(precision, recall, avg_prec, label, ax):
146
    ax.scatter(
147
        recall,
148
        precision,
149
        lw=0,
150
        marker=".",
151
        s=5,
152
        edgecolors="none",
153
        label=f"{label} - APS:{avg_prec:.2f}",
154
    )
155
    ax.legend()
156
    return ax
157
158
159
def plot_thr_recall(thr, recall, label, ax):
160
    ax.scatter(recall[:-1], thr, lw=0, marker=".", s=5, edgecolors="none", label=label)
161
    ax.legend()
162
    return ax
163
164
165
def plot_TP_vs_FP(tp_fp, label, ax):
166
    ax.scatter(tp_fp[:, 0], tp_fp[:, 1], s=2, label=label, edgecolors="none")
167
    ax.legend()
168
    return ax
169
170
171
def plot_filling_order(order_list, last_rank=None):
172
173
    if last_rank is None:
174
        last_rank = len(order_list)
175
    fig = plt.figure()
176
    order_img = np.zeros((np.max(order_list), len(order_list)))
177
    for i, element in enumerate(order_list):
178
        order_img[element - 1, i:] = 1
179
180
    plt.imshow(order_img[:last_rank, :], cmap="binary")
181
    plt.xlabel("Correct prediction number")
182
    plt.ylabel("Association ranking")
183
    plt.plot(np.arange(last_rank), np.arange(last_rank))
184
    return fig
185
186
187
def plot_effect_size_matching(
188
    assoc_tuples_dict, found_assoc_dict, label, ALGORITHM, ax
189
):
190
191
    ground_truth_effects = [
192
        assoc_tuples_dict[key] for key in list(found_assoc_dict.keys())
193
    ]
194
    predicted_effects = np.array(list(found_assoc_dict.values()))
195
196
    if ALGORITHM == "ttest":
197
        # Eq 15 on https://doi.org/10.1146/annurev-statistics-031017-100307
198
        predicted_effects = [-np.log10(p) if p != 0 else -1 for p in predicted_effects]
199
        predicted_effects[predicted_effects == -1] = np.max(
200
            predicted_effects
201
        )  # Change zeros for max likelihood, -1 as dummy value
202
        predicted_effects = np.array(predicted_effects)
203
204
    max, min = np.max(predicted_effects), np.min(predicted_effects)
205
    standarized_pred_effects = (predicted_effects - min) / (max - min)
206
    ax.scatter(
207
        ground_truth_effects,
208
        standarized_pred_effects,
209
        s=12,
210
        edgecolors="none",
211
        label=label,
212
    )
213
    ax.legend()
214
    return ax
215
216
217
def plot_venn_diagram(venn, ax, mode="all", scale="log"):
218
    sets = [set(venn[key][mode]) for key in list(venn.keys())]
219
    labels = (key for key in list(venn.keys()))
220
221
    if len(venn) == 2:
222
        venn2(sets, labels, ax=ax)
223
    elif len(venn) == 3:
224
        venn3(sets, labels, ax=ax)
225
    else:
226
        raise ValueError("Unsupported number of input files.")
227
228
229
def plot_upsetplot(venn, assoc_tuples):
230
    from upsetplot import UpSet
231
    import pandas as pd
232
    from matplotlib import cm
233
234
    all_assoc = set(
235
        [
236
            association
237
            for ALGORITHM in venn.keys()
238
            for association in venn[ALGORITHM]["all"]
239
        ]
240
    )
241
    columns = ["ground truth"]
242
    columns.extend([ALGORITHM for ALGORITHM in list(venn.keys())])
243
244
    df = {}
245
    for association in all_assoc:
246
        df[association] = []
247
248
        if association in assoc_tuples:
249
            df[association].append("TP")
250
        else:
251
            df[association].append("FP")
252
253
        for ALGORITHM in list(venn.keys()):
254
            if association in venn[ALGORITHM]["all"]:
255
                df[association].append(1)
256
            else:
257
                df[association].append(0)
258
259
    df = pd.DataFrame.from_dict(df, orient="index", columns=columns)
260
    df = df.set_index([pd.Index(df[ALGORITHM] == 1) for ALGORITHM in list(venn.keys())])
261
    upset = UpSet(df, intersection_plot_elements=0, show_counts=True)
262
263
    upset.add_stacked_bars(
264
        by="ground truth",
265
        colors=cm.Pastel1,
266
        title="Count by ground truth value",
267
        elements=10,
268
    )
269
270
    return upset
271
272
273
###################################### Main code ################################################
274
275
parser = argparse.ArgumentParser(
276
    description="Read two files with ground truth associations and predicted associations."
277
)
278
parser.add_argument(
279
    "-p",
280
    "--perturbed",
281
    metavar="pert",
282
    type=str,
283
    required=True,
284
    help="perturbed feature names",
285
)
286
parser.add_argument(
287
    "-n",
288
    "--features",
289
    metavar="feat",
290
    type=int,
291
    required=True,
292
    help=" total number of features",
293
)
294
parser.add_argument(
295
    "-r",
296
    "--reference",
297
    metavar="ref",
298
    type=str,
299
    required=True,
300
    help="path to the ground truth associations file",
301
)
302
parser.add_argument(
303
    "-t",
304
    "--targets",
305
    metavar="tar",
306
    type=str,
307
    required=True,
308
    nargs="+",
309
    help="path to the predicted associations files",
310
)
311
args = parser.parse_args()
312
313
314
# Defining main performance evaluation figures:
315
fig_0, ax_0 = plt.subplots(figsize=(7, 7))
316
fig_1, ax_1 = plt.subplots(figsize=(7, 7))
317
fig_2, ax_2 = plt.subplots()
318
fig_3, ax_3 = plt.subplots()
319
320
assoc_tuples_dict = {}
321
322
# Reading the file with the ground truth changes:
323
with open(args.reference, "r") as f:
324
    for line in f:
325
        if line[0] != "f" and line[0] != "n":
326
            splitline = line.strip().split("\t")
327
            feat_a = splitline[2]
328
            feat_b = splitline[3]
329
            assoc_strength = abs(float(splitline[4]))
330
            # Only can detect associations with perturbed features
331
            if args.perturbed in feat_a or args.perturbed in feat_b:
332
                assoc_tuples_dict[(feat_a, feat_b)] = assoc_strength
333
                assoc_tuples_dict[(feat_b, feat_a)] = assoc_strength
334
335
associations = int(len(assoc_tuples_dict) / 2)
336
venn = {}
337
# Count and save found associations
338
for target_file in args.targets:
339
340
    ALGORITHM = target_file.split("/")[-1].split("_")[3][:-4]
341
    self_assoc, found_assoc_dict, false_assoc_dict, tp_fp = classify_associations(
342
        target_file, list(assoc_tuples_dict.keys())
343
    )
344
    real_assoc = len(found_assoc_dict)  # True predicted associations
345
    false_assoc = len(false_assoc_dict)  # False predicted associations
346
    total_assoc = real_assoc + false_assoc
347
348
    venn[ALGORITHM] = {}
349
    venn[ALGORITHM]["correct"] = list(found_assoc_dict.keys())
350
    venn[ALGORITHM]["all"] = list(found_assoc_dict.keys()) + list(
351
        false_assoc_dict.keys()
352
    )
353
354
    # Assess ranking of associations (they are doubled in assoc_tuples):
355
    order_list = [
356
        list(assoc_tuples_dict.keys()).index((feat_a, feat_b)) // 2
357
        for (feat_a, feat_b) in list(found_assoc_dict.keys())
358
    ]
359
    fig = plot_filling_order(order_list)
360
    fig.savefig(f"Order_image_{ALGORITHM}.png", dpi=200)
361
362
    ax_0 = plot_effect_size_matching(
363
        assoc_tuples_dict, found_assoc_dict, ALGORITHM, ALGORITHM, ax_0
364
    )
365
366
    # Plot confusion matrix:
367
    cm = create_confusion_matrix(args.features, associations, real_assoc, false_assoc)
368
    fig = plot_confusion_matrix(
369
        cm, ["No assoc", "Association"], cmap=None, normalize=False
370
    )
371
372
    fig.savefig(f"Confusion_matrix_{ALGORITHM}.png", dpi=100, bbox_inches="tight")
373
374
    # Plot precision-recall and TP-FP curves
375
    precision, recall, thr, avg_prec = get_precision_recall(
376
        found_assoc_dict, false_assoc_dict, associations
377
    )
378
379
    ax_1 = plot_precision_recall(precision, recall, avg_prec, ALGORITHM, ax_1)
380
    ax_2 = plot_TP_vs_FP(tp_fp, ALGORITHM, ax_2)
381
    ax_3 = plot_thr_recall(thr, recall, ALGORITHM, ax_3)
382
383
    # Write results:
384
    with open("Performance_evaluation_summary_results.txt", "a") as f:
385
        f.write(f" File:  {target_file}\n")
386
        f.write(
387
            f"Ground truth detectable associations (i.e. involving perturbed feature,{args.perturbed}):{associations}\n"
388
        )
389
        f.write(
390
            f"{total_assoc} unique associations found\n{self_assoc} self-associations were found before filtering\n{real_assoc} were real associations\n{false_assoc} were either false or below the significance threshold\n"
391
        )
392
        # print("Correct associations:\n", found_assoc_tuples, "\n")
393
        f.write(
394
            f"Sensitivity:{real_assoc}/{associations} = {real_assoc/associations}\n"
395
        )
396
        f.write(f"Precision:{real_assoc}/{total_assoc} = {(real_assoc)/total_assoc}\n")
397
        f.write(f"Order list:{order_list}\n\n")
398
        f.write("______________________________________________________\n")
399
400
401
# Edit figures: layout
402
ax_0.set_xlabel("Real effect")
403
ax_0.set_ylabel("Predicted effect")
404
ax_0.set_ylim((-0.02, 1.02))
405
ax_0.set_xlim((0, 1.02))
406
ax_0.legend(
407
    loc="upper center", bbox_to_anchor=(0.5, 1.1), ncol=3, fancybox=True, shadow=True
408
)
409
410
ax_1.set_xlabel("Recall")
411
ax_1.set_ylabel("Precision")
412
ax_1.legend()
413
ax_1.set_ylim((0, 1.05))
414
ax_1.set_xlim((0, 1.05))
415
416
ax_2.set_xlabel("False Positives")
417
ax_2.set_ylabel("True Positives")
418
ax_2.set_aspect("equal")
419
420
ax_3.set_ylabel("Threshold")
421
ax_3.set_xlabel("Recall")
422
423
424
# Save main figures:
425
fig_0.savefig("Effect_size_matchin.png", dpi=200)
426
fig_1.savefig("Precision_recall.png", dpi=200)
427
fig_2.savefig("TP_vs_FP.png", dpi=200)
428
fig_3.savefig("thr_vs_recall.png", dpi=200)
429
430
# Plotting venn diagram:
431
if len(venn) == 2 or len(venn) == 3:
432
    fig_v, ax_v = plt.subplots()
433
    ax_v = plot_venn_diagram(venn, ax_v, mode="correct")
434
    fig_v.savefig("Venn_diagram.png", dpi=200)
435
436
# Plotting UpSet plot
437
upset = plot_upsetplot(venn, list(assoc_tuples_dict.keys()))
438
upset.plot()
439
plt.savefig("UpSet.png", dpi=200)