a b/coderpp/test/confusion_matrix.py
1
import pickle
2
import numpy as np
3
import itertools
4
from tqdm import tqdm
5
from load_umls import UMLS
6
import ahocorasick
7
from prettytable import PrettyTable
8
import matplotlib.pyplot as plt
9
import os
10
import argparse
11
import random
12
13
def eval_clustering(umls, similarity, indices, threshold, args):
14
    with open(args.idx2phrase_path, 'rb') as f:
15
        idx2phrase = pickle.load(f)
16
    gt_clustering = list(umls.cui2str.values())
17
    confusion_matrix = np.zeros((2, 2))
18
    # in a group
19
    TP = ahocorasick.Automaton()
20
    similarity_list_actual_p = []
21
    similarity_list_actual_n = []
22
    T = ahocorasick.Automaton()
23
    # with open(args.output_dir+'FN_finetune.txt', 'w') as f:
24
    if True:
25
        for group in tqdm(gt_clustering):
26
            if len(group) < 2:
27
                continue
28
            group = list(group)
29
            for pair in itertools.combinations(group, r=2):
30
                if str(pair[0])+str(pair[1])+idx2phrase[pair[0]]+idx2phrase[pair[1]] in T or str(pair[1])+str(pair[0])+idx2phrase[pair[1]]+idx2phrase[pair[0]] in T:
31
                    continue
32
                T.add_word(str(pair[0])+str(pair[1])+idx2phrase[pair[0]]+idx2phrase[pair[1]], '')
33
                T.add_word(str(pair[1])+str(pair[0])+idx2phrase[pair[1]]+idx2phrase[pair[0]], '')
34
                if pair[1] in indices[pair[0]]:
35
                    index = np.where(indices[pair[0]] == pair[1])
36
                    similarity_list_actual_p.append(min(similarity[pair[0], index][0][0], 1))
37
                    if similarity[pair[0], index] > threshold:
38
                        confusion_matrix[0, 0] += 1
39
                        TP.add_word(str(pair[0])+str(pair[1])+idx2phrase[pair[0]]+idx2phrase[pair[1]], '')
40
                        TP.add_word(str(pair[1])+str(pair[0])+idx2phrase[pair[1]]+idx2phrase[pair[0]], '')
41
                    else:
42
                        confusion_matrix[0, 1] += 1
43
                        # f.write(idx2phrase[pair[0]]+', '+idx2phrase[pair[1]]+', '+'\t'+str(similarity[pair[0], index][0][0])+'\n')
44
                elif pair[0] in indices[pair[1]]:
45
                    index = np.where(indices[pair[1]] == pair[0])
46
                    similarity_list_actual_p.append(min(similarity[pair[1], index][0][0], 1))
47
                    if similarity[pair[1], index] > threshold:
48
                        confusion_matrix[0, 0] += 1
49
                        TP.add_word(str(pair[0])+str(pair[1])+idx2phrase[pair[0]]+idx2phrase[pair[1]], '')
50
                        TP.add_word(str(pair[1])+str(pair[0])+idx2phrase[pair[1]]+idx2phrase[pair[0]], '')
51
                    else:
52
                        confusion_matrix[0, 1] += 1
53
                        # f.write(idx2phrase[pair[0]]+', '+idx2phrase[pair[1]]+', '+'\t'+str(similarity[pair[1], index][0][0])+'\n')
54
                else:
55
                    confusion_matrix[0, 1] += 1
56
                    # f.write(idx2phrase[pair[0]]+', '+idx2phrase[pair[1]]+'\n')
57
58
59
    # not in a group
60
    predicted_p = 0
61
    fp = 0
62
    fp_list = []
63
    A = ahocorasick.Automaton()
64
    for string in tqdm(umls.stridx_list):
65
        A.add_word(str(string), str(string))
66
    # with open(args.output_dir+'FP_finetune.txt', 'w') as f:   
67
    if True:
68
        for i in tqdm(umls.stridx_list):
69
            for j in range(1, indices.shape[1]):
70
                if idx2phrase[i] != idx2phrase[indices[i][j]] and str(indices[i][j]) in A and str(i)+str(indices[i][j])+idx2phrase[i]+idx2phrase[indices[i][j]] not in T and similarity[i][j] > 0:
71
                    similarity_list_actual_n.append(min(similarity[i][j], 1))
72
                    if i in indices[indices[i][j]]:
73
                        index = np.where(indices[indices[i][j]] == i)
74
                        similarity[indices[i][j]][index] = 0
75
76
                if similarity[i][j] > threshold and idx2phrase[i] != idx2phrase[indices[i][j]] and str(indices[i][j]) in A:
77
                    predicted_p += 1
78
                    if str(i)+str(indices[i][j])+idx2phrase[i]+idx2phrase[indices[i][j]] not in TP:
79
                        fp += 1
80
                        # print((idx2phrase[i], idx2phrase[indices[i][j]]))
81
                        # f.write(idx2phrase[i]+'\t'+idx2phrase[indices[i][j]]+'\t'+str(similarity[i][j])+'\n')
82
                        fp_list.append(idx2phrase[i]+'\t'+idx2phrase[indices[i][j]]+'\t'+str(similarity[i][j])+'\n')
83
                    if i in indices[indices[i][j]]:
84
                        index = np.where(indices[indices[i][j]] == i)
85
                        similarity[indices[i][j]][index] = 0
86
    with open(args.output_dir+'fp.txt', 'w') as f:
87
        for string in random.sample(fp_list, 20):
88
            f.write(string)
89
    confusion_matrix[1, 0] += predicted_p - confusion_matrix[0, 0]
90
    print(confusion_matrix[1, 0] - fp)
91
    length = len(umls.stridx_list)
92
    confusion_matrix[1, 1] += (length * (length - 1) / 2 - confusion_matrix[0, 0] - confusion_matrix[0, 1] - confusion_matrix[1, 0])
93
    print('threshold:', threshold)
94
    print(confusion_matrix)
95
    return confusion_matrix, similarity_list_actual_p, similarity_list_actual_n
96
97
def print_result(threshold_list, accuracy_list, recall_list, precision_list, args):
98
    table = PrettyTable()
99
    column_names = ["Threshold", "Accuracy", "Recall", "Precision", "F1"]
100
    table.add_column(column_names[0], threshold_list)
101
    table.add_column(column_names[1], [format(accuracy, '.3f') for accuracy in accuracy_list])
102
    table.add_column(column_names[2], [format(recall, '.3f') for recall in recall_list])
103
    table.add_column(column_names[3], [format(precision, '.3f') for precision in precision_list])
104
    table.add_column(column_names[4], [format(2*precision*recall/(precision+recall), '.3f') for (precision, recall) in zip(precision_list, recall_list)])
105
    print(table)
106
    table = table.get_string()
107
    with open(args.output_dir+args.title+'.txt', 'w') as f:
108
        f.write(table)
109
110
def plot_histogram(listp, listn, name):
111
    plt.figure()
112
    plt.hist(listp, bins=50, range=(min(listn), 1), density=False, alpha=0.5, label='Pair with same Cui')
113
    plt.hist(listn, bins=50, range=(min(listn), 1), density=False, alpha=0.5, label='Pair with different Cui')
114
    plt.xlabel('Similarity score')
115
    plt.ylabel('Frequency')
116
    plt.title('Similarity score for pairs(Frequency)')
117
    plt.legend(loc='upper left')
118
    plt.savefig('frequency_' + name + '.png')
119
    plt.show()
120
121
    plt.figure()
122
    plt.hist(listp, bins=50, range=(min(listn), 1), density=True, alpha=0.5, label='Pair with same Cui')
123
    plt.hist(listn, bins=50, range=(min(listn), 1), density=True, alpha=0.5, label='Pair with different Cui')
124
    plt.xlabel('Similarity score')
125
    plt.ylabel('Density')
126
    plt.title('Similarity score for pairs(Density)')
127
    plt.legend(loc='upper left')
128
    plt.savefig('density_' + name + '.png')
129
    plt.show()    
130
131
def run(args):
132
    accuracy_list = []
133
    recall_list = []
134
    precision_list = []
135
    thre_list = []
136
    for threshold in threshold_list:
137
        similarity = np.load(args.similarity_path)
138
        indices = np.load(args.indices_path)
139
        confusion_matrix, similarity_list_actual_p, similarity_list_actual_n = eval_clustering(umls, similarity, indices, threshold, args)
140
        accuracy_list.append((confusion_matrix[0, 0] + confusion_matrix[1, 1]) / confusion_matrix.sum())
141
        recall_list.append(confusion_matrix[0, 0] / confusion_matrix[0].sum())
142
        precision_list.append(confusion_matrix[0, 0] / confusion_matrix[:, 0].sum())
143
        thre_list.append(threshold)
144
        print_result(thre_list, accuracy_list, recall_list, precision_list, args)
145
    # plot_histogram(similarity_list_actual_p, similarity_list_actual_n, args.title)
146
147
if __name__ == '__main__':
148
    parser = argparse.ArgumentParser()
149
    parser.add_argument(
150
        "--umls_dir",
151
        default="/media/sda1/GanjinZero/UMLSBert/umls",
152
        type=str,
153
        help="Directory of umls data"
154
    )
155
    parser.add_argument(
156
        "--output_dir",
157
        default="output/",
158
        type=str,
159
        help="Directory to save results"
160
    )
161
    parser.add_argument(
162
        "--use_data_dir",
163
        default="use_data/",
164
        type=str,
165
        help="Directory of faiss index, idx2phrase and other use data"
166
    )
167
    parser.add_argument(
168
        "--title",
169
        type=str,
170
        help="Title of the task"
171
    )
172
    args = parser.parse_args()
173
    
174
    args.indices_path = os.path.join(args.use_data_dir, "indices.npy")
175
    args.similarity_path = os.path.join(args.use_data_dir, "similarity.npy")
176
    args.phrase2idx_path = os.path.join(args.use_data_dir, "phrase2idx.pkl")
177
    args.idx2phrase_path = os.path.join(args.use_data_dir, "idx2phrase.pkl")
178
    umls = UMLS(umls_path=args.umls_dir, phrase2idx_path=args.phrase2idx_path, idx2phrase_path=args.idx2phrase_path)
179
    threshold_list = [0.98, 0.96, 0.94, 0.92, 0.90, 0.88, 0.86, 0.84, 0.82, 0.80, 0.78, 0.76, 0.74, 0.72, 0.70, 0.68, 0.66, 0.64, 0.62, 0.60]
180
    # threshold_list = [.8]
181
    run(args)
182