a b/Roberta+LLM/eval_file.py
1
# from eval_file import *
2
3
import argparse
4
from collections import defaultdict
5
from itertools import chain
6
from math import pow
7
from pathlib import Path
8
9
# from common_utils.common_io import load_bio_file_into_sents
10
# from common_utils.common_log import create_logger
11
# -*- coding: utf-8 -*-
12
13
# -*- coding: utf-8 -*-
14
15
import json
16
import pickle as pkl
17
18
19
def read_from_file(ifn):
20
    with open(ifn, "r") as f:
21
        text = f.read()
22
    return text
23
24
25
def write_to_file(text, ofn):
26
    with open(ofn, "w") as f:
27
        f.write(text)
28
    return True
29
30
31
def pkl_load(ifn):
32
    with open(ifn, "rb") as f:
33
        pdata = pkl.load(f)
34
    return pdata
35
36
37
def pkl_dump(pdata, ofn):
38
    with open(ofn, "wb") as f:
39
        pkl.dump(pdata, f)
40
    return True
41
42
43
def json_load(ifn):
44
    with open(ifn, "r") as f:
45
        jdata = json.load(f)
46
    return jdata
47
48
49
def json_dump(jdata, ofn):
50
    with open(ofn, "w") as f:
51
        json.dump(jdata, f)
52
    return True
53
54
55
def load_bio_file_into_sents(bio_file, word_sep=" ", do_lower=False):
56
    bio_text = read_from_file(bio_file)
57
    bio_text = bio_text.strip()
58
    if do_lower:
59
        bio_text = bio_text.lower()
60
61
    new_sents = []
62
    sents = bio_text.split("\n\n")
63
64
    for sent in sents:
65
        new_sent = []
66
        words = sent.split("\n")
67
        for word in words:
68
            new_word = word.split(word_sep)
69
            new_sent.append(new_word)
70
        new_sents.append(new_sent)
71
72
    return new_sents
73
74
75
def output_bio(bio_data, output_file, sep=" "):
76
    with open(output_file, "w") as f:
77
        for sent in bio_data:
78
            for word in sent:
79
                line = sep.join(word)
80
                f.write(line)
81
                f.write("\n")
82
            f.write("\n")
83
84
85
class PRF:
86
    def __init__(self):
87
        self.true = 0
88
        self.false = 0
89
90
    def add_true_case(self):
91
        self.true += 1
92
93
    def add_false_case(self):
94
        self.false += 1
95
96
    def get_true_false_counts(self):
97
        return self.true, self.false
98
99
    def __str__(self):
100
        return str(self.__dict__)
101
102
103
class BioEval:
104
    def __init__(self):
105
        self.acc = PRF()
106
        # prediction
107
        self.all_strict = PRF()
108
        self.all_relax = PRF()
109
        self.cat_strict = defaultdict(PRF)
110
        self.cat_relax = defaultdict(PRF)
111
        # gold standard
112
        self.gs_all = 0
113
        self.gs_cat = defaultdict(int)
114
        self.performance = dict()
115
        self.counts = dict()
116
        self.beta = 1
117
        self.label_not_for_eval = {'o'}
118
119
    def reset(self):
120
        self.acc = PRF()
121
        self.all_strict = PRF()
122
        self.all_relax = PRF()
123
        self.cat_strict = defaultdict(PRF)
124
        self.cat_relax = defaultdict(PRF)
125
        self.gs_all = 0
126
        self.gs_cat = defaultdict(int)
127
        self.performance = dict()
128
        self.counts = dict()
129
130
    def set_beta_for_f_score(self, beta):
131
        print("Using beta={} for calculating F-score".format(beta))
132
        self.beta = beta
133
134
    # def set_logger(self, logger):
135
    #     self.logger = logger
136
137
    def add_labels_not_for_eval(self, *labels):
138
        for each in labels:
139
            self.label_not_for_eval.add(each.lower())
140
141
    def __calc_prf(self, tp, fp, tp_tn):
142
        """
143
        Using this function to calculate F-beta score, beta=1 is f_score-score, set beta=2 favor recall, and set beta=0.5 favor precision.
144
        Using set_beta_for_f_score function to change beta value.
145
        """
146
        tp_fp = tp + fp
147
        pre = 1.0 * tp / tp_fp if tp_fp > 0 else 0.0
148
        rec = 1.0 * tp / tp_tn if tp_tn > 0 else 0.0
149
        beta2 = pow(self.beta, 2)
150
        f_beta = (1 + beta2) * pre * rec / (beta2 * pre + rec) if (pre + rec) > 0 else 0.0
151
        return pre, rec, f_beta
152
153
    def __measure_performance(self):
154
        self.performance['overall'] = dict()
155
156
        acc_true_num, acc_false_num = self.acc.get_true_false_counts()
157
        total_acc_num = acc_true_num + acc_false_num
158
        # calc acc
159
        overall_acc = round(1.0 * acc_true_num / total_acc_num, 4) if total_acc_num > 0 else 0.0
160
        self.performance['overall']['acc'] = overall_acc
161
162
        strict_true_counts, strict_false_counts = self.all_strict.get_true_false_counts()
163
        strict_pre, strict_rec, strict_f_score = self.__calc_prf(strict_true_counts, strict_false_counts, self.gs_all)
164
        self.performance['overall']['strict'] = dict()
165
        self.performance['overall']['strict']['precision'] = strict_pre
166
        self.performance['overall']['strict']['recall'] = strict_rec
167
        self.performance['overall']['strict']['f_score'] = strict_f_score
168
169
        relax_true_counts, relax_false_counts = self.all_relax.get_true_false_counts()
170
        relax_pre, relax_rec, relax_f_score = self.__calc_prf(relax_true_counts, relax_false_counts, self.gs_all)
171
        self.performance['overall']['relax'] = dict()
172
        self.performance['overall']['relax']['precision'] = relax_pre
173
        self.performance['overall']['relax']['recall'] = relax_rec
174
        self.performance['overall']['relax']['f_score'] = relax_f_score
175
176
        self.performance['category'] = dict()
177
        self.performance['category']['strict'] = dict()
178
        for k, v in self.cat_strict.items():
179
            self.performance['category']['strict'][k] = dict()
180
            stc, sfc = v.get_true_false_counts()
181
            p, r, f = self.__calc_prf(stc, sfc, self.gs_cat[k])
182
            self.performance['category']['strict'][k]['precision'] = p
183
            self.performance['category']['strict'][k]['recall'] = r
184
            self.performance['category']['strict'][k]['f_score'] = f
185
186
        self.performance['category']['relax'] = dict()
187
        for k, v in self.cat_relax.items():
188
            self.performance['category']['relax'][k] = dict()
189
            rtc, rfc = v.get_true_false_counts()
190
            p, r, f = self.__calc_prf(rtc, rfc, self.gs_cat[k])
191
            self.performance['category']['relax'][k]['precision'] = p
192
            self.performance['category']['relax'][k]['recall'] = r
193
            self.performance['category']['relax'][k]['f_score'] = f
194
195
    def __measure_counts(self):
196
        # gold standard
197
        self.counts['expect'] = dict()
198
        self.counts['expect']['overall'] = self.gs_all
199
        for k, v in self.gs_cat.items():
200
            self.counts['expect'][k] = v
201
        # prediction
202
        self.counts['prediction'] = {'strict': dict(), 'relax': dict()}
203
        # strict
204
        strict_true_counts, strict_false_counts = self.all_strict.get_true_false_counts()
205
        self.counts['prediction']['strict']['overall'] = dict()
206
        self.counts['prediction']['strict']['overall']['total'] = strict_true_counts + strict_false_counts
207
        self.counts['prediction']['strict']['overall']['true'] = strict_true_counts
208
        self.counts['prediction']['strict']['overall']['false'] = strict_false_counts
209
        for k, v in self.cat_strict.items():
210
            t, f = v.get_true_false_counts()
211
            self.counts['prediction']['strict'][k] = dict()
212
            self.counts['prediction']['strict'][k]['total'] = t + f
213
            self.counts['prediction']['strict'][k]['true'] = t
214
            self.counts['prediction']['strict'][k]['false'] = f
215
        # relax
216
        relax_true_counts, relax_false_counts = self.all_relax.get_true_false_counts()
217
        self.counts['prediction']['relax']['overall'] = dict()
218
        self.counts['prediction']['relax']['overall']['total'] = relax_true_counts + relax_false_counts
219
        self.counts['prediction']['relax']['overall']['true'] = relax_true_counts
220
        self.counts['prediction']['relax']['overall']['false'] = relax_false_counts
221
        for k, v in self.cat_relax.items():
222
            t, f = v.get_true_false_counts()
223
            self.counts['prediction']['relax'][k] = dict()
224
            self.counts['prediction']['relax'][k]['total'] = t + f
225
            self.counts['prediction']['relax'][k]['true'] = t
226
            self.counts['prediction']['relax'][k]['false'] = f
227
228
    @staticmethod
229
    def __strict_match(gs, pred, s_idx, e_idx, en_type):
230
        if e_idx < len(gs) and gs[e_idx] == f"i-{en_type}":
231
            # check token after end in GS is not continued entity token
232
            return False
233
        elif gs[s_idx] != f"b-{en_type}" or pred[s_idx] != f"b-{en_type}":
234
            # force first token to be B-
235
            return False
236
        # check every token in span is the same
237
        for idx in range(s_idx, e_idx):
238
            if gs[idx] != pred[idx]:
239
                return False
240
        return True
241
242
    @staticmethod
243
    def __relax_match(gs, pred, s_idx, e_idx, en_type):
244
        # we adopt the partial match strategy which is very loose compare to right-left or approximate match
245
        for idx in range(s_idx, e_idx):
246
            gs_cate = gs[idx].split("-")[-1]
247
            pred_bound, pred_cate = pred[idx].split("-")
248
            if gs_cate == pred_cate == en_type:
249
                return True
250
        return False
251
252
    @staticmethod
253
    def __check_evaluated_already(gs_dict, cate, start_idx, end_idx):
254
        for k, v in gs_dict.items():
255
            c, s, e = k
256
            if not (e < start_idx or s > end_idx) and c == cate:
257
                if v == 0:
258
                    return True
259
                else:
260
                    gs_dict[k] -= 1
261
                    return False
262
        return False
263
264
    def __process_bio(self, gs_bio, pred_bio):
265
        # measure acc
266
        for w_idx, (gs_word, pred_word) in enumerate(zip(gs_bio, pred_bio)):
267
            # measure acc
268
            if gs_word == pred_word:
269
                self.acc.add_true_case()
270
            else:
271
                self.acc.add_false_case()
272
273
        # process gold standard
274
        llen = len(gs_bio)
275
        gs_dict = defaultdict(int)
276
        cur_idx = 0
277
        while cur_idx < llen:
278
            if gs_bio[cur_idx].strip() in self.label_not_for_eval:
279
                cur_idx += 1
280
            else:
281
                start_idx = cur_idx
282
                end_idx = start_idx + 1
283
                _, cate = gs_bio[start_idx].strip().split('-')
284
                while end_idx < llen and gs_bio[end_idx].strip() == f"i-{cate}":
285
                    end_idx += 1
286
                self.gs_all += 1
287
                self.gs_cat[cate] += 1
288
                gs_dict[(cate, start_idx, end_idx)] += 1
289
                cur_idx = end_idx
290
        # process predictions
291
        cur_idx = 0
292
        while cur_idx < llen:
293
            if pred_bio[cur_idx].strip() in self.label_not_for_eval:
294
                cur_idx += 1
295
            else:
296
                start_idx = cur_idx
297
                end_idx = start_idx + 1
298
                _, cate = pred_bio[start_idx].strip().split("-")
299
                while end_idx < llen and pred_bio[end_idx].strip() == f"i-{cate}":
300
                    end_idx += 1
301
                if self.__strict_match(gs_bio, pred_bio, start_idx, end_idx, cate):
302
                    self.all_strict.add_true_case()
303
                    self.cat_strict[cate].add_true_case()
304
                    self.all_relax.add_true_case()
305
                    self.cat_relax[cate].add_true_case()
306
                elif self.__relax_match(gs_bio, pred_bio, start_idx, end_idx, cate):
307
                    if self.__check_evaluated_already(gs_dict, cate, start_idx, end_idx):
308
                        cur_idx = end_idx
309
                        continue
310
                    self.all_strict.add_false_case()
311
                    self.cat_strict[cate].add_false_case()
312
                    self.all_relax.add_true_case()
313
                    self.cat_relax[cate].add_true_case()
314
                else:
315
                    self.all_strict.add_false_case()
316
                    self.cat_strict[cate].add_false_case()
317
                    self.all_relax.add_false_case()
318
                    self.cat_relax[cate].add_false_case()
319
                cur_idx = end_idx
320
321
    def eval_file(self, gs_file, pred_file):
322
        print("processing gold standard file: {} and prediciton file: {}".format(gs_file, pred_file))
323
        pred_bio_sents = load_bio_file_into_sents(pred_file, do_lower=True)
324
        gs_bio_sents = load_bio_file_into_sents(gs_file, do_lower=True)
325
        # process bio data
326
        # check two data have same amount of sents
327
        assert len(gs_bio_sents) == len(pred_bio_sents), \
328
            "gold standard and prediction have different dimension: gs: {}; pred: {}".format(len(gs_bio_sents), len(pred_bio_sents))
329
        # measure performance
330
        for s_idx, (gs_sent, pred_sent) in enumerate(zip(gs_bio_sents, pred_bio_sents)):
331
            # check two sents have same No. of words
332
            assert len(gs_sent) == len(pred_sent), \
333
                "In {}th sentence, the words counts are different; gs: {}; pred: {}".format(s_idx, gs_sent, pred_sent)
334
            gs_sent = list(map(lambda x: x[-1], gs_sent))
335
            pred_sent = list(map(lambda x: x[-1], pred_sent))
336
            self.__process_bio(gs_sent, pred_sent)
337
        # get the evaluation matrix
338
        self.__measure_performance()
339
        self.__measure_counts()
340
341
    def eval_mem(self, gs, pred, do_flat=False):
342
        # flat sents to sent; we assume input sequences only have 1 dimension (only labels)
343
        if do_flat:
344
            print('Sentences have been flatten to 1 dim.')
345
            gs = list(chain(*gs))
346
            pred = list(chain(*pred))
347
            gs = list(map(lambda x: x.lower(), gs))
348
            pred = list(map(lambda x: x.lower(), pred))
349
            self.__process_bio(gs, pred)
350
        else:
351
            for sidx, (gs_s, pred_s) in enumerate(zip(gs, pred)):
352
                gs_s = list(map(lambda x: x.lower(), gs_s))
353
                pred_s = list(map(lambda x: x.lower(), pred_s))
354
                self.__process_bio(gs_s, pred_s)
355
356
        self.__measure_performance()
357
        self.__measure_counts()
358
359
    def evaluate_annotations(self, gs, pred, do_lower=False):
360
        for gs_sent, pred_sent in zip(gs, pred):
361
            if do_lower:
362
              gs_sent = list(map(lambda x: x.lower(), gs_sent))
363
              pred_sent = list(map(lambda x: x.lower(), pred_sent))
364
            self.__process_bio(gs_sent, pred_sent)
365
366
        self.__measure_performance()
367
        self.__measure_counts()
368
369
    def get_performance(self):
370
        return self.performance
371
372
    def get_counts(self):
373
        return self.counts
374
375
    def save_evaluation(self, file):
376
        with open(file, "w") as f:
377
            json.dump(self.performance, f)
378
379
    def show_evaluation(self, digits=4):
380
        if len(self.performance) == 0:
381
            raise RuntimeError('call eval_mem() first to get the performance attribute')
382
383
        cate = self.performance['category']['strict'].keys()
384
385
        headers = ['precision', 'recall', 'f1']
386
        width = max(max([len(c) for c in cate]), len('overall'), digits)
387
        head_fmt = '{:>{width}s} ' + ' {:>9}' * len(headers)
388
389
        report = head_fmt.format(u'', *headers, width=width)
390
        report += '\n\nstrict\n'
391
392
        row_fmt = '{:>{width}s} ' + ' {:>9.{digits}f}' * 3 + '\n'
393
        for c in cate:
394
            precision = self.performance['category']['strict'][c]['precision']
395
            recall = self.performance['category']['strict'][c]['recall']
396
            f1 = self.performance['category']['strict'][c]['f_score']
397
            report += row_fmt.format(c, *[precision, recall, f1], width=width, digits=digits)
398
399
        report += '\nrelax\n'
400
401
        for c in cate:
402
            precision = self.performance['category']['relax'][c]['precision']
403
            recall = self.performance['category']['relax'][c]['recall']
404
            f1 = self.performance['category']['relax'][c]['f_score']
405
            report += row_fmt.format(c, *[precision, recall, f1], width=width, digits=digits)
406
407
        report += '\n\noverall\n'
408
        report += 'acc: ' + str(self.performance['overall']['acc'])
409
        report += '\nstrict\n'
410
        report += row_fmt.format('', *[self.performance['overall']['strict']['precision'],
411
                                       self.performance['overall']['strict']['recall'],
412
                                       self.performance['overall']['strict']['f_score']], width=width, digits=digits)
413
414
        report += '\nrelax\n'
415
        report += row_fmt.format('', *[self.performance['overall']['relax']['precision'],
416
                                       self.performance['overall']['relax']['recall'],
417
                                       self.performance['overall']['relax']['f_score']], width=width, digits=digits)
418
        return report