|
a |
|
b/coderpp/test/confusion_matrix.py |
|
|
1 |
import pickle |
|
|
2 |
import numpy as np |
|
|
3 |
import itertools |
|
|
4 |
from tqdm import tqdm |
|
|
5 |
from load_umls import UMLS |
|
|
6 |
import ahocorasick |
|
|
7 |
from prettytable import PrettyTable |
|
|
8 |
import matplotlib.pyplot as plt |
|
|
9 |
import os |
|
|
10 |
import argparse |
|
|
11 |
import random |
|
|
12 |
|
|
|
13 |
def eval_clustering(umls, similarity, indices, threshold, args): |
|
|
14 |
with open(args.idx2phrase_path, 'rb') as f: |
|
|
15 |
idx2phrase = pickle.load(f) |
|
|
16 |
gt_clustering = list(umls.cui2str.values()) |
|
|
17 |
confusion_matrix = np.zeros((2, 2)) |
|
|
18 |
# in a group |
|
|
19 |
TP = ahocorasick.Automaton() |
|
|
20 |
similarity_list_actual_p = [] |
|
|
21 |
similarity_list_actual_n = [] |
|
|
22 |
T = ahocorasick.Automaton() |
|
|
23 |
# with open(args.output_dir+'FN_finetune.txt', 'w') as f: |
|
|
24 |
if True: |
|
|
25 |
for group in tqdm(gt_clustering): |
|
|
26 |
if len(group) < 2: |
|
|
27 |
continue |
|
|
28 |
group = list(group) |
|
|
29 |
for pair in itertools.combinations(group, r=2): |
|
|
30 |
if str(pair[0])+str(pair[1])+idx2phrase[pair[0]]+idx2phrase[pair[1]] in T or str(pair[1])+str(pair[0])+idx2phrase[pair[1]]+idx2phrase[pair[0]] in T: |
|
|
31 |
continue |
|
|
32 |
T.add_word(str(pair[0])+str(pair[1])+idx2phrase[pair[0]]+idx2phrase[pair[1]], '') |
|
|
33 |
T.add_word(str(pair[1])+str(pair[0])+idx2phrase[pair[1]]+idx2phrase[pair[0]], '') |
|
|
34 |
if pair[1] in indices[pair[0]]: |
|
|
35 |
index = np.where(indices[pair[0]] == pair[1]) |
|
|
36 |
similarity_list_actual_p.append(min(similarity[pair[0], index][0][0], 1)) |
|
|
37 |
if similarity[pair[0], index] > threshold: |
|
|
38 |
confusion_matrix[0, 0] += 1 |
|
|
39 |
TP.add_word(str(pair[0])+str(pair[1])+idx2phrase[pair[0]]+idx2phrase[pair[1]], '') |
|
|
40 |
TP.add_word(str(pair[1])+str(pair[0])+idx2phrase[pair[1]]+idx2phrase[pair[0]], '') |
|
|
41 |
else: |
|
|
42 |
confusion_matrix[0, 1] += 1 |
|
|
43 |
# f.write(idx2phrase[pair[0]]+', '+idx2phrase[pair[1]]+', '+'\t'+str(similarity[pair[0], index][0][0])+'\n') |
|
|
44 |
elif pair[0] in indices[pair[1]]: |
|
|
45 |
index = np.where(indices[pair[1]] == pair[0]) |
|
|
46 |
similarity_list_actual_p.append(min(similarity[pair[1], index][0][0], 1)) |
|
|
47 |
if similarity[pair[1], index] > threshold: |
|
|
48 |
confusion_matrix[0, 0] += 1 |
|
|
49 |
TP.add_word(str(pair[0])+str(pair[1])+idx2phrase[pair[0]]+idx2phrase[pair[1]], '') |
|
|
50 |
TP.add_word(str(pair[1])+str(pair[0])+idx2phrase[pair[1]]+idx2phrase[pair[0]], '') |
|
|
51 |
else: |
|
|
52 |
confusion_matrix[0, 1] += 1 |
|
|
53 |
# f.write(idx2phrase[pair[0]]+', '+idx2phrase[pair[1]]+', '+'\t'+str(similarity[pair[1], index][0][0])+'\n') |
|
|
54 |
else: |
|
|
55 |
confusion_matrix[0, 1] += 1 |
|
|
56 |
# f.write(idx2phrase[pair[0]]+', '+idx2phrase[pair[1]]+'\n') |
|
|
57 |
|
|
|
58 |
|
|
|
59 |
# not in a group |
|
|
60 |
predicted_p = 0 |
|
|
61 |
fp = 0 |
|
|
62 |
fp_list = [] |
|
|
63 |
A = ahocorasick.Automaton() |
|
|
64 |
for string in tqdm(umls.stridx_list): |
|
|
65 |
A.add_word(str(string), str(string)) |
|
|
66 |
# with open(args.output_dir+'FP_finetune.txt', 'w') as f: |
|
|
67 |
if True: |
|
|
68 |
for i in tqdm(umls.stridx_list): |
|
|
69 |
for j in range(1, indices.shape[1]): |
|
|
70 |
if idx2phrase[i] != idx2phrase[indices[i][j]] and str(indices[i][j]) in A and str(i)+str(indices[i][j])+idx2phrase[i]+idx2phrase[indices[i][j]] not in T and similarity[i][j] > 0: |
|
|
71 |
similarity_list_actual_n.append(min(similarity[i][j], 1)) |
|
|
72 |
if i in indices[indices[i][j]]: |
|
|
73 |
index = np.where(indices[indices[i][j]] == i) |
|
|
74 |
similarity[indices[i][j]][index] = 0 |
|
|
75 |
|
|
|
76 |
if similarity[i][j] > threshold and idx2phrase[i] != idx2phrase[indices[i][j]] and str(indices[i][j]) in A: |
|
|
77 |
predicted_p += 1 |
|
|
78 |
if str(i)+str(indices[i][j])+idx2phrase[i]+idx2phrase[indices[i][j]] not in TP: |
|
|
79 |
fp += 1 |
|
|
80 |
# print((idx2phrase[i], idx2phrase[indices[i][j]])) |
|
|
81 |
# f.write(idx2phrase[i]+'\t'+idx2phrase[indices[i][j]]+'\t'+str(similarity[i][j])+'\n') |
|
|
82 |
fp_list.append(idx2phrase[i]+'\t'+idx2phrase[indices[i][j]]+'\t'+str(similarity[i][j])+'\n') |
|
|
83 |
if i in indices[indices[i][j]]: |
|
|
84 |
index = np.where(indices[indices[i][j]] == i) |
|
|
85 |
similarity[indices[i][j]][index] = 0 |
|
|
86 |
with open(args.output_dir+'fp.txt', 'w') as f: |
|
|
87 |
for string in random.sample(fp_list, 20): |
|
|
88 |
f.write(string) |
|
|
89 |
confusion_matrix[1, 0] += predicted_p - confusion_matrix[0, 0] |
|
|
90 |
print(confusion_matrix[1, 0] - fp) |
|
|
91 |
length = len(umls.stridx_list) |
|
|
92 |
confusion_matrix[1, 1] += (length * (length - 1) / 2 - confusion_matrix[0, 0] - confusion_matrix[0, 1] - confusion_matrix[1, 0]) |
|
|
93 |
print('threshold:', threshold) |
|
|
94 |
print(confusion_matrix) |
|
|
95 |
return confusion_matrix, similarity_list_actual_p, similarity_list_actual_n |
|
|
96 |
|
|
|
97 |
def print_result(threshold_list, accuracy_list, recall_list, precision_list, args): |
|
|
98 |
table = PrettyTable() |
|
|
99 |
column_names = ["Threshold", "Accuracy", "Recall", "Precision", "F1"] |
|
|
100 |
table.add_column(column_names[0], threshold_list) |
|
|
101 |
table.add_column(column_names[1], [format(accuracy, '.3f') for accuracy in accuracy_list]) |
|
|
102 |
table.add_column(column_names[2], [format(recall, '.3f') for recall in recall_list]) |
|
|
103 |
table.add_column(column_names[3], [format(precision, '.3f') for precision in precision_list]) |
|
|
104 |
table.add_column(column_names[4], [format(2*precision*recall/(precision+recall), '.3f') for (precision, recall) in zip(precision_list, recall_list)]) |
|
|
105 |
print(table) |
|
|
106 |
table = table.get_string() |
|
|
107 |
with open(args.output_dir+args.title+'.txt', 'w') as f: |
|
|
108 |
f.write(table) |
|
|
109 |
|
|
|
110 |
def plot_histogram(listp, listn, name): |
|
|
111 |
plt.figure() |
|
|
112 |
plt.hist(listp, bins=50, range=(min(listn), 1), density=False, alpha=0.5, label='Pair with same Cui') |
|
|
113 |
plt.hist(listn, bins=50, range=(min(listn), 1), density=False, alpha=0.5, label='Pair with different Cui') |
|
|
114 |
plt.xlabel('Similarity score') |
|
|
115 |
plt.ylabel('Frequency') |
|
|
116 |
plt.title('Similarity score for pairs(Frequency)') |
|
|
117 |
plt.legend(loc='upper left') |
|
|
118 |
plt.savefig('frequency_' + name + '.png') |
|
|
119 |
plt.show() |
|
|
120 |
|
|
|
121 |
plt.figure() |
|
|
122 |
plt.hist(listp, bins=50, range=(min(listn), 1), density=True, alpha=0.5, label='Pair with same Cui') |
|
|
123 |
plt.hist(listn, bins=50, range=(min(listn), 1), density=True, alpha=0.5, label='Pair with different Cui') |
|
|
124 |
plt.xlabel('Similarity score') |
|
|
125 |
plt.ylabel('Density') |
|
|
126 |
plt.title('Similarity score for pairs(Density)') |
|
|
127 |
plt.legend(loc='upper left') |
|
|
128 |
plt.savefig('density_' + name + '.png') |
|
|
129 |
plt.show() |
|
|
130 |
|
|
|
131 |
def run(args): |
|
|
132 |
accuracy_list = [] |
|
|
133 |
recall_list = [] |
|
|
134 |
precision_list = [] |
|
|
135 |
thre_list = [] |
|
|
136 |
for threshold in threshold_list: |
|
|
137 |
similarity = np.load(args.similarity_path) |
|
|
138 |
indices = np.load(args.indices_path) |
|
|
139 |
confusion_matrix, similarity_list_actual_p, similarity_list_actual_n = eval_clustering(umls, similarity, indices, threshold, args) |
|
|
140 |
accuracy_list.append((confusion_matrix[0, 0] + confusion_matrix[1, 1]) / confusion_matrix.sum()) |
|
|
141 |
recall_list.append(confusion_matrix[0, 0] / confusion_matrix[0].sum()) |
|
|
142 |
precision_list.append(confusion_matrix[0, 0] / confusion_matrix[:, 0].sum()) |
|
|
143 |
thre_list.append(threshold) |
|
|
144 |
print_result(thre_list, accuracy_list, recall_list, precision_list, args) |
|
|
145 |
# plot_histogram(similarity_list_actual_p, similarity_list_actual_n, args.title) |
|
|
146 |
|
|
|
147 |
if __name__ == '__main__': |
|
|
148 |
parser = argparse.ArgumentParser() |
|
|
149 |
parser.add_argument( |
|
|
150 |
"--umls_dir", |
|
|
151 |
default="/media/sda1/GanjinZero/UMLSBert/umls", |
|
|
152 |
type=str, |
|
|
153 |
help="Directory of umls data" |
|
|
154 |
) |
|
|
155 |
parser.add_argument( |
|
|
156 |
"--output_dir", |
|
|
157 |
default="output/", |
|
|
158 |
type=str, |
|
|
159 |
help="Directory to save results" |
|
|
160 |
) |
|
|
161 |
parser.add_argument( |
|
|
162 |
"--use_data_dir", |
|
|
163 |
default="use_data/", |
|
|
164 |
type=str, |
|
|
165 |
help="Directory of faiss index, idx2phrase and other use data" |
|
|
166 |
) |
|
|
167 |
parser.add_argument( |
|
|
168 |
"--title", |
|
|
169 |
type=str, |
|
|
170 |
help="Title of the task" |
|
|
171 |
) |
|
|
172 |
args = parser.parse_args() |
|
|
173 |
|
|
|
174 |
args.indices_path = os.path.join(args.use_data_dir, "indices.npy") |
|
|
175 |
args.similarity_path = os.path.join(args.use_data_dir, "similarity.npy") |
|
|
176 |
args.phrase2idx_path = os.path.join(args.use_data_dir, "phrase2idx.pkl") |
|
|
177 |
args.idx2phrase_path = os.path.join(args.use_data_dir, "idx2phrase.pkl") |
|
|
178 |
umls = UMLS(umls_path=args.umls_dir, phrase2idx_path=args.phrase2idx_path, idx2phrase_path=args.idx2phrase_path) |
|
|
179 |
threshold_list = [0.98, 0.96, 0.94, 0.92, 0.90, 0.88, 0.86, 0.84, 0.82, 0.80, 0.78, 0.76, 0.74, 0.72, 0.70, 0.68, 0.66, 0.64, 0.62, 0.60] |
|
|
180 |
# threshold_list = [.8] |
|
|
181 |
run(args) |
|
|
182 |
|