# from eval_file import *
import argparse
from collections import defaultdict
from itertools import chain
from math import pow
from pathlib import Path
# from common_utils.common_io import load_bio_file_into_sents
# from common_utils.common_log import create_logger
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import json
import pickle as pkl
def read_from_file(ifn):
with open(ifn, "r") as f:
text = f.read()
return text
def write_to_file(text, ofn):
with open(ofn, "w") as f:
f.write(text)
return True
def pkl_load(ifn):
with open(ifn, "rb") as f:
pdata = pkl.load(f)
return pdata
def pkl_dump(pdata, ofn):
with open(ofn, "wb") as f:
pkl.dump(pdata, f)
return True
def json_load(ifn):
with open(ifn, "r") as f:
jdata = json.load(f)
return jdata
def json_dump(jdata, ofn):
with open(ofn, "w") as f:
json.dump(jdata, f)
return True
def load_bio_file_into_sents(bio_file, word_sep=" ", do_lower=False):
bio_text = read_from_file(bio_file)
bio_text = bio_text.strip()
if do_lower:
bio_text = bio_text.lower()
new_sents = []
sents = bio_text.split("\n\n")
for sent in sents:
new_sent = []
words = sent.split("\n")
for word in words:
new_word = word.split(word_sep)
new_sent.append(new_word)
new_sents.append(new_sent)
return new_sents
def output_bio(bio_data, output_file, sep=" "):
with open(output_file, "w") as f:
for sent in bio_data:
for word in sent:
line = sep.join(word)
f.write(line)
f.write("\n")
f.write("\n")
class PRF:
def __init__(self):
self.true = 0
self.false = 0
def add_true_case(self):
self.true += 1
def add_false_case(self):
self.false += 1
def get_true_false_counts(self):
return self.true, self.false
def __str__(self):
return str(self.__dict__)
class BioEval:
def __init__(self):
self.acc = PRF()
# prediction
self.all_strict = PRF()
self.all_relax = PRF()
self.cat_strict = defaultdict(PRF)
self.cat_relax = defaultdict(PRF)
# gold standard
self.gs_all = 0
self.gs_cat = defaultdict(int)
self.performance = dict()
self.counts = dict()
self.beta = 1
self.label_not_for_eval = {'o'}
def reset(self):
self.acc = PRF()
self.all_strict = PRF()
self.all_relax = PRF()
self.cat_strict = defaultdict(PRF)
self.cat_relax = defaultdict(PRF)
self.gs_all = 0
self.gs_cat = defaultdict(int)
self.performance = dict()
self.counts = dict()
def set_beta_for_f_score(self, beta):
print("Using beta={} for calculating F-score".format(beta))
self.beta = beta
# def set_logger(self, logger):
# self.logger = logger
def add_labels_not_for_eval(self, *labels):
for each in labels:
self.label_not_for_eval.add(each.lower())
def __calc_prf(self, tp, fp, tp_tn):
"""
Using this function to calculate F-beta score, beta=1 is f_score-score, set beta=2 favor recall, and set beta=0.5 favor precision.
Using set_beta_for_f_score function to change beta value.
"""
tp_fp = tp + fp
pre = 1.0 * tp / tp_fp if tp_fp > 0 else 0.0
rec = 1.0 * tp / tp_tn if tp_tn > 0 else 0.0
beta2 = pow(self.beta, 2)
f_beta = (1 + beta2) * pre * rec / (beta2 * pre + rec) if (pre + rec) > 0 else 0.0
return pre, rec, f_beta
def __measure_performance(self):
self.performance['overall'] = dict()
acc_true_num, acc_false_num = self.acc.get_true_false_counts()
total_acc_num = acc_true_num + acc_false_num
# calc acc
overall_acc = round(1.0 * acc_true_num / total_acc_num, 4) if total_acc_num > 0 else 0.0
self.performance['overall']['acc'] = overall_acc
strict_true_counts, strict_false_counts = self.all_strict.get_true_false_counts()
strict_pre, strict_rec, strict_f_score = self.__calc_prf(strict_true_counts, strict_false_counts, self.gs_all)
self.performance['overall']['strict'] = dict()
self.performance['overall']['strict']['precision'] = strict_pre
self.performance['overall']['strict']['recall'] = strict_rec
self.performance['overall']['strict']['f_score'] = strict_f_score
relax_true_counts, relax_false_counts = self.all_relax.get_true_false_counts()
relax_pre, relax_rec, relax_f_score = self.__calc_prf(relax_true_counts, relax_false_counts, self.gs_all)
self.performance['overall']['relax'] = dict()
self.performance['overall']['relax']['precision'] = relax_pre
self.performance['overall']['relax']['recall'] = relax_rec
self.performance['overall']['relax']['f_score'] = relax_f_score
self.performance['category'] = dict()
self.performance['category']['strict'] = dict()
for k, v in self.cat_strict.items():
self.performance['category']['strict'][k] = dict()
stc, sfc = v.get_true_false_counts()
p, r, f = self.__calc_prf(stc, sfc, self.gs_cat[k])
self.performance['category']['strict'][k]['precision'] = p
self.performance['category']['strict'][k]['recall'] = r
self.performance['category']['strict'][k]['f_score'] = f
self.performance['category']['relax'] = dict()
for k, v in self.cat_relax.items():
self.performance['category']['relax'][k] = dict()
rtc, rfc = v.get_true_false_counts()
p, r, f = self.__calc_prf(rtc, rfc, self.gs_cat[k])
self.performance['category']['relax'][k]['precision'] = p
self.performance['category']['relax'][k]['recall'] = r
self.performance['category']['relax'][k]['f_score'] = f
def __measure_counts(self):
# gold standard
self.counts['expect'] = dict()
self.counts['expect']['overall'] = self.gs_all
for k, v in self.gs_cat.items():
self.counts['expect'][k] = v
# prediction
self.counts['prediction'] = {'strict': dict(), 'relax': dict()}
# strict
strict_true_counts, strict_false_counts = self.all_strict.get_true_false_counts()
self.counts['prediction']['strict']['overall'] = dict()
self.counts['prediction']['strict']['overall']['total'] = strict_true_counts + strict_false_counts
self.counts['prediction']['strict']['overall']['true'] = strict_true_counts
self.counts['prediction']['strict']['overall']['false'] = strict_false_counts
for k, v in self.cat_strict.items():
t, f = v.get_true_false_counts()
self.counts['prediction']['strict'][k] = dict()
self.counts['prediction']['strict'][k]['total'] = t + f
self.counts['prediction']['strict'][k]['true'] = t
self.counts['prediction']['strict'][k]['false'] = f
# relax
relax_true_counts, relax_false_counts = self.all_relax.get_true_false_counts()
self.counts['prediction']['relax']['overall'] = dict()
self.counts['prediction']['relax']['overall']['total'] = relax_true_counts + relax_false_counts
self.counts['prediction']['relax']['overall']['true'] = relax_true_counts
self.counts['prediction']['relax']['overall']['false'] = relax_false_counts
for k, v in self.cat_relax.items():
t, f = v.get_true_false_counts()
self.counts['prediction']['relax'][k] = dict()
self.counts['prediction']['relax'][k]['total'] = t + f
self.counts['prediction']['relax'][k]['true'] = t
self.counts['prediction']['relax'][k]['false'] = f
@staticmethod
def __strict_match(gs, pred, s_idx, e_idx, en_type):
if e_idx < len(gs) and gs[e_idx] == f"i-{en_type}":
# check token after end in GS is not continued entity token
return False
elif gs[s_idx] != f"b-{en_type}" or pred[s_idx] != f"b-{en_type}":
# force first token to be B-
return False
# check every token in span is the same
for idx in range(s_idx, e_idx):
if gs[idx] != pred[idx]:
return False
return True
@staticmethod
def __relax_match(gs, pred, s_idx, e_idx, en_type):
# we adopt the partial match strategy which is very loose compare to right-left or approximate match
for idx in range(s_idx, e_idx):
gs_cate = gs[idx].split("-")[-1]
pred_bound, pred_cate = pred[idx].split("-")
if gs_cate == pred_cate == en_type:
return True
return False
@staticmethod
def __check_evaluated_already(gs_dict, cate, start_idx, end_idx):
for k, v in gs_dict.items():
c, s, e = k
if not (e < start_idx or s > end_idx) and c == cate:
if v == 0:
return True
else:
gs_dict[k] -= 1
return False
return False
def __process_bio(self, gs_bio, pred_bio):
# measure acc
for w_idx, (gs_word, pred_word) in enumerate(zip(gs_bio, pred_bio)):
# measure acc
if gs_word == pred_word:
self.acc.add_true_case()
else:
self.acc.add_false_case()
# process gold standard
llen = len(gs_bio)
gs_dict = defaultdict(int)
cur_idx = 0
while cur_idx < llen:
if gs_bio[cur_idx].strip() in self.label_not_for_eval:
cur_idx += 1
else:
start_idx = cur_idx
end_idx = start_idx + 1
_, cate = gs_bio[start_idx].strip().split('-')
while end_idx < llen and gs_bio[end_idx].strip() == f"i-{cate}":
end_idx += 1
self.gs_all += 1
self.gs_cat[cate] += 1
gs_dict[(cate, start_idx, end_idx)] += 1
cur_idx = end_idx
# process predictions
cur_idx = 0
while cur_idx < llen:
if pred_bio[cur_idx].strip() in self.label_not_for_eval:
cur_idx += 1
else:
start_idx = cur_idx
end_idx = start_idx + 1
_, cate = pred_bio[start_idx].strip().split("-")
while end_idx < llen and pred_bio[end_idx].strip() == f"i-{cate}":
end_idx += 1
if self.__strict_match(gs_bio, pred_bio, start_idx, end_idx, cate):
self.all_strict.add_true_case()
self.cat_strict[cate].add_true_case()
self.all_relax.add_true_case()
self.cat_relax[cate].add_true_case()
elif self.__relax_match(gs_bio, pred_bio, start_idx, end_idx, cate):
if self.__check_evaluated_already(gs_dict, cate, start_idx, end_idx):
cur_idx = end_idx
continue
self.all_strict.add_false_case()
self.cat_strict[cate].add_false_case()
self.all_relax.add_true_case()
self.cat_relax[cate].add_true_case()
else:
self.all_strict.add_false_case()
self.cat_strict[cate].add_false_case()
self.all_relax.add_false_case()
self.cat_relax[cate].add_false_case()
cur_idx = end_idx
def eval_file(self, gs_file, pred_file):
print("processing gold standard file: {} and prediciton file: {}".format(gs_file, pred_file))
pred_bio_sents = load_bio_file_into_sents(pred_file, do_lower=True)
gs_bio_sents = load_bio_file_into_sents(gs_file, do_lower=True)
# process bio data
# check two data have same amount of sents
assert len(gs_bio_sents) == len(pred_bio_sents), \
"gold standard and prediction have different dimension: gs: {}; pred: {}".format(len(gs_bio_sents), len(pred_bio_sents))
# measure performance
for s_idx, (gs_sent, pred_sent) in enumerate(zip(gs_bio_sents, pred_bio_sents)):
# check two sents have same No. of words
assert len(gs_sent) == len(pred_sent), \
"In {}th sentence, the words counts are different; gs: {}; pred: {}".format(s_idx, gs_sent, pred_sent)
gs_sent = list(map(lambda x: x[-1], gs_sent))
pred_sent = list(map(lambda x: x[-1], pred_sent))
self.__process_bio(gs_sent, pred_sent)
# get the evaluation matrix
self.__measure_performance()
self.__measure_counts()
def eval_mem(self, gs, pred, do_flat=False):
# flat sents to sent; we assume input sequences only have 1 dimension (only labels)
if do_flat:
print('Sentences have been flatten to 1 dim.')
gs = list(chain(*gs))
pred = list(chain(*pred))
gs = list(map(lambda x: x.lower(), gs))
pred = list(map(lambda x: x.lower(), pred))
self.__process_bio(gs, pred)
else:
for sidx, (gs_s, pred_s) in enumerate(zip(gs, pred)):
gs_s = list(map(lambda x: x.lower(), gs_s))
pred_s = list(map(lambda x: x.lower(), pred_s))
self.__process_bio(gs_s, pred_s)
self.__measure_performance()
self.__measure_counts()
def evaluate_annotations(self, gs, pred, do_lower=False):
for gs_sent, pred_sent in zip(gs, pred):
if do_lower:
gs_sent = list(map(lambda x: x.lower(), gs_sent))
pred_sent = list(map(lambda x: x.lower(), pred_sent))
self.__process_bio(gs_sent, pred_sent)
self.__measure_performance()
self.__measure_counts()
def get_performance(self):
return self.performance
def get_counts(self):
return self.counts
def save_evaluation(self, file):
with open(file, "w") as f:
json.dump(self.performance, f)
def show_evaluation(self, digits=4):
if len(self.performance) == 0:
raise RuntimeError('call eval_mem() first to get the performance attribute')
cate = self.performance['category']['strict'].keys()
headers = ['precision', 'recall', 'f1']
width = max(max([len(c) for c in cate]), len('overall'), digits)
head_fmt = '{:>{width}s} ' + ' {:>9}' * len(headers)
report = head_fmt.format(u'', *headers, width=width)
report += '\n\nstrict\n'
row_fmt = '{:>{width}s} ' + ' {:>9.{digits}f}' * 3 + '\n'
for c in cate:
precision = self.performance['category']['strict'][c]['precision']
recall = self.performance['category']['strict'][c]['recall']
f1 = self.performance['category']['strict'][c]['f_score']
report += row_fmt.format(c, *[precision, recall, f1], width=width, digits=digits)
report += '\nrelax\n'
for c in cate:
precision = self.performance['category']['relax'][c]['precision']
recall = self.performance['category']['relax'][c]['recall']
f1 = self.performance['category']['relax'][c]['f_score']
report += row_fmt.format(c, *[precision, recall, f1], width=width, digits=digits)
report += '\n\noverall\n'
report += 'acc: ' + str(self.performance['overall']['acc'])
report += '\nstrict\n'
report += row_fmt.format('', *[self.performance['overall']['strict']['precision'],
self.performance['overall']['strict']['recall'],
self.performance['overall']['strict']['f_score']], width=width, digits=digits)
report += '\nrelax\n'
report += row_fmt.format('', *[self.performance['overall']['relax']['precision'],
self.performance['overall']['relax']['recall'],
self.performance['overall']['relax']['f_score']], width=width, digits=digits)
return report