--- a +++ b/Roberta+LLM/eval_file.py @@ -0,0 +1,418 @@ +# from eval_file import * + +import argparse +from collections import defaultdict +from itertools import chain +from math import pow +from pathlib import Path + +# from common_utils.common_io import load_bio_file_into_sents +# from common_utils.common_log import create_logger +# -*- coding: utf-8 -*- + +# -*- coding: utf-8 -*- + +import json +import pickle as pkl + + +def read_from_file(ifn): + with open(ifn, "r") as f: + text = f.read() + return text + + +def write_to_file(text, ofn): + with open(ofn, "w") as f: + f.write(text) + return True + + +def pkl_load(ifn): + with open(ifn, "rb") as f: + pdata = pkl.load(f) + return pdata + + +def pkl_dump(pdata, ofn): + with open(ofn, "wb") as f: + pkl.dump(pdata, f) + return True + + +def json_load(ifn): + with open(ifn, "r") as f: + jdata = json.load(f) + return jdata + + +def json_dump(jdata, ofn): + with open(ofn, "w") as f: + json.dump(jdata, f) + return True + + +def load_bio_file_into_sents(bio_file, word_sep=" ", do_lower=False): + bio_text = read_from_file(bio_file) + bio_text = bio_text.strip() + if do_lower: + bio_text = bio_text.lower() + + new_sents = [] + sents = bio_text.split("\n\n") + + for sent in sents: + new_sent = [] + words = sent.split("\n") + for word in words: + new_word = word.split(word_sep) + new_sent.append(new_word) + new_sents.append(new_sent) + + return new_sents + + +def output_bio(bio_data, output_file, sep=" "): + with open(output_file, "w") as f: + for sent in bio_data: + for word in sent: + line = sep.join(word) + f.write(line) + f.write("\n") + f.write("\n") + + +class PRF: + def __init__(self): + self.true = 0 + self.false = 0 + + def add_true_case(self): + self.true += 1 + + def add_false_case(self): + self.false += 1 + + def get_true_false_counts(self): + return self.true, self.false + + def __str__(self): + return str(self.__dict__) + + +class BioEval: + def __init__(self): + self.acc = PRF() + # prediction + self.all_strict = PRF() + self.all_relax = PRF() + self.cat_strict = defaultdict(PRF) + self.cat_relax = defaultdict(PRF) + # gold standard + self.gs_all = 0 + self.gs_cat = defaultdict(int) + self.performance = dict() + self.counts = dict() + self.beta = 1 + self.label_not_for_eval = {'o'} + + def reset(self): + self.acc = PRF() + self.all_strict = PRF() + self.all_relax = PRF() + self.cat_strict = defaultdict(PRF) + self.cat_relax = defaultdict(PRF) + self.gs_all = 0 + self.gs_cat = defaultdict(int) + self.performance = dict() + self.counts = dict() + + def set_beta_for_f_score(self, beta): + print("Using beta={} for calculating F-score".format(beta)) + self.beta = beta + + # def set_logger(self, logger): + # self.logger = logger + + def add_labels_not_for_eval(self, *labels): + for each in labels: + self.label_not_for_eval.add(each.lower()) + + def __calc_prf(self, tp, fp, tp_tn): + """ + Using this function to calculate F-beta score, beta=1 is f_score-score, set beta=2 favor recall, and set beta=0.5 favor precision. + Using set_beta_for_f_score function to change beta value. + """ + tp_fp = tp + fp + pre = 1.0 * tp / tp_fp if tp_fp > 0 else 0.0 + rec = 1.0 * tp / tp_tn if tp_tn > 0 else 0.0 + beta2 = pow(self.beta, 2) + f_beta = (1 + beta2) * pre * rec / (beta2 * pre + rec) if (pre + rec) > 0 else 0.0 + return pre, rec, f_beta + + def __measure_performance(self): + self.performance['overall'] = dict() + + acc_true_num, acc_false_num = self.acc.get_true_false_counts() + total_acc_num = acc_true_num + acc_false_num + # calc acc + overall_acc = round(1.0 * acc_true_num / total_acc_num, 4) if total_acc_num > 0 else 0.0 + self.performance['overall']['acc'] = overall_acc + + strict_true_counts, strict_false_counts = self.all_strict.get_true_false_counts() + strict_pre, strict_rec, strict_f_score = self.__calc_prf(strict_true_counts, strict_false_counts, self.gs_all) + self.performance['overall']['strict'] = dict() + self.performance['overall']['strict']['precision'] = strict_pre + self.performance['overall']['strict']['recall'] = strict_rec + self.performance['overall']['strict']['f_score'] = strict_f_score + + relax_true_counts, relax_false_counts = self.all_relax.get_true_false_counts() + relax_pre, relax_rec, relax_f_score = self.__calc_prf(relax_true_counts, relax_false_counts, self.gs_all) + self.performance['overall']['relax'] = dict() + self.performance['overall']['relax']['precision'] = relax_pre + self.performance['overall']['relax']['recall'] = relax_rec + self.performance['overall']['relax']['f_score'] = relax_f_score + + self.performance['category'] = dict() + self.performance['category']['strict'] = dict() + for k, v in self.cat_strict.items(): + self.performance['category']['strict'][k] = dict() + stc, sfc = v.get_true_false_counts() + p, r, f = self.__calc_prf(stc, sfc, self.gs_cat[k]) + self.performance['category']['strict'][k]['precision'] = p + self.performance['category']['strict'][k]['recall'] = r + self.performance['category']['strict'][k]['f_score'] = f + + self.performance['category']['relax'] = dict() + for k, v in self.cat_relax.items(): + self.performance['category']['relax'][k] = dict() + rtc, rfc = v.get_true_false_counts() + p, r, f = self.__calc_prf(rtc, rfc, self.gs_cat[k]) + self.performance['category']['relax'][k]['precision'] = p + self.performance['category']['relax'][k]['recall'] = r + self.performance['category']['relax'][k]['f_score'] = f + + def __measure_counts(self): + # gold standard + self.counts['expect'] = dict() + self.counts['expect']['overall'] = self.gs_all + for k, v in self.gs_cat.items(): + self.counts['expect'][k] = v + # prediction + self.counts['prediction'] = {'strict': dict(), 'relax': dict()} + # strict + strict_true_counts, strict_false_counts = self.all_strict.get_true_false_counts() + self.counts['prediction']['strict']['overall'] = dict() + self.counts['prediction']['strict']['overall']['total'] = strict_true_counts + strict_false_counts + self.counts['prediction']['strict']['overall']['true'] = strict_true_counts + self.counts['prediction']['strict']['overall']['false'] = strict_false_counts + for k, v in self.cat_strict.items(): + t, f = v.get_true_false_counts() + self.counts['prediction']['strict'][k] = dict() + self.counts['prediction']['strict'][k]['total'] = t + f + self.counts['prediction']['strict'][k]['true'] = t + self.counts['prediction']['strict'][k]['false'] = f + # relax + relax_true_counts, relax_false_counts = self.all_relax.get_true_false_counts() + self.counts['prediction']['relax']['overall'] = dict() + self.counts['prediction']['relax']['overall']['total'] = relax_true_counts + relax_false_counts + self.counts['prediction']['relax']['overall']['true'] = relax_true_counts + self.counts['prediction']['relax']['overall']['false'] = relax_false_counts + for k, v in self.cat_relax.items(): + t, f = v.get_true_false_counts() + self.counts['prediction']['relax'][k] = dict() + self.counts['prediction']['relax'][k]['total'] = t + f + self.counts['prediction']['relax'][k]['true'] = t + self.counts['prediction']['relax'][k]['false'] = f + + @staticmethod + def __strict_match(gs, pred, s_idx, e_idx, en_type): + if e_idx < len(gs) and gs[e_idx] == f"i-{en_type}": + # check token after end in GS is not continued entity token + return False + elif gs[s_idx] != f"b-{en_type}" or pred[s_idx] != f"b-{en_type}": + # force first token to be B- + return False + # check every token in span is the same + for idx in range(s_idx, e_idx): + if gs[idx] != pred[idx]: + return False + return True + + @staticmethod + def __relax_match(gs, pred, s_idx, e_idx, en_type): + # we adopt the partial match strategy which is very loose compare to right-left or approximate match + for idx in range(s_idx, e_idx): + gs_cate = gs[idx].split("-")[-1] + pred_bound, pred_cate = pred[idx].split("-") + if gs_cate == pred_cate == en_type: + return True + return False + + @staticmethod + def __check_evaluated_already(gs_dict, cate, start_idx, end_idx): + for k, v in gs_dict.items(): + c, s, e = k + if not (e < start_idx or s > end_idx) and c == cate: + if v == 0: + return True + else: + gs_dict[k] -= 1 + return False + return False + + def __process_bio(self, gs_bio, pred_bio): + # measure acc + for w_idx, (gs_word, pred_word) in enumerate(zip(gs_bio, pred_bio)): + # measure acc + if gs_word == pred_word: + self.acc.add_true_case() + else: + self.acc.add_false_case() + + # process gold standard + llen = len(gs_bio) + gs_dict = defaultdict(int) + cur_idx = 0 + while cur_idx < llen: + if gs_bio[cur_idx].strip() in self.label_not_for_eval: + cur_idx += 1 + else: + start_idx = cur_idx + end_idx = start_idx + 1 + _, cate = gs_bio[start_idx].strip().split('-') + while end_idx < llen and gs_bio[end_idx].strip() == f"i-{cate}": + end_idx += 1 + self.gs_all += 1 + self.gs_cat[cate] += 1 + gs_dict[(cate, start_idx, end_idx)] += 1 + cur_idx = end_idx + # process predictions + cur_idx = 0 + while cur_idx < llen: + if pred_bio[cur_idx].strip() in self.label_not_for_eval: + cur_idx += 1 + else: + start_idx = cur_idx + end_idx = start_idx + 1 + _, cate = pred_bio[start_idx].strip().split("-") + while end_idx < llen and pred_bio[end_idx].strip() == f"i-{cate}": + end_idx += 1 + if self.__strict_match(gs_bio, pred_bio, start_idx, end_idx, cate): + self.all_strict.add_true_case() + self.cat_strict[cate].add_true_case() + self.all_relax.add_true_case() + self.cat_relax[cate].add_true_case() + elif self.__relax_match(gs_bio, pred_bio, start_idx, end_idx, cate): + if self.__check_evaluated_already(gs_dict, cate, start_idx, end_idx): + cur_idx = end_idx + continue + self.all_strict.add_false_case() + self.cat_strict[cate].add_false_case() + self.all_relax.add_true_case() + self.cat_relax[cate].add_true_case() + else: + self.all_strict.add_false_case() + self.cat_strict[cate].add_false_case() + self.all_relax.add_false_case() + self.cat_relax[cate].add_false_case() + cur_idx = end_idx + + def eval_file(self, gs_file, pred_file): + print("processing gold standard file: {} and prediciton file: {}".format(gs_file, pred_file)) + pred_bio_sents = load_bio_file_into_sents(pred_file, do_lower=True) + gs_bio_sents = load_bio_file_into_sents(gs_file, do_lower=True) + # process bio data + # check two data have same amount of sents + assert len(gs_bio_sents) == len(pred_bio_sents), \ + "gold standard and prediction have different dimension: gs: {}; pred: {}".format(len(gs_bio_sents), len(pred_bio_sents)) + # measure performance + for s_idx, (gs_sent, pred_sent) in enumerate(zip(gs_bio_sents, pred_bio_sents)): + # check two sents have same No. of words + assert len(gs_sent) == len(pred_sent), \ + "In {}th sentence, the words counts are different; gs: {}; pred: {}".format(s_idx, gs_sent, pred_sent) + gs_sent = list(map(lambda x: x[-1], gs_sent)) + pred_sent = list(map(lambda x: x[-1], pred_sent)) + self.__process_bio(gs_sent, pred_sent) + # get the evaluation matrix + self.__measure_performance() + self.__measure_counts() + + def eval_mem(self, gs, pred, do_flat=False): + # flat sents to sent; we assume input sequences only have 1 dimension (only labels) + if do_flat: + print('Sentences have been flatten to 1 dim.') + gs = list(chain(*gs)) + pred = list(chain(*pred)) + gs = list(map(lambda x: x.lower(), gs)) + pred = list(map(lambda x: x.lower(), pred)) + self.__process_bio(gs, pred) + else: + for sidx, (gs_s, pred_s) in enumerate(zip(gs, pred)): + gs_s = list(map(lambda x: x.lower(), gs_s)) + pred_s = list(map(lambda x: x.lower(), pred_s)) + self.__process_bio(gs_s, pred_s) + + self.__measure_performance() + self.__measure_counts() + + def evaluate_annotations(self, gs, pred, do_lower=False): + for gs_sent, pred_sent in zip(gs, pred): + if do_lower: + gs_sent = list(map(lambda x: x.lower(), gs_sent)) + pred_sent = list(map(lambda x: x.lower(), pred_sent)) + self.__process_bio(gs_sent, pred_sent) + + self.__measure_performance() + self.__measure_counts() + + def get_performance(self): + return self.performance + + def get_counts(self): + return self.counts + + def save_evaluation(self, file): + with open(file, "w") as f: + json.dump(self.performance, f) + + def show_evaluation(self, digits=4): + if len(self.performance) == 0: + raise RuntimeError('call eval_mem() first to get the performance attribute') + + cate = self.performance['category']['strict'].keys() + + headers = ['precision', 'recall', 'f1'] + width = max(max([len(c) for c in cate]), len('overall'), digits) + head_fmt = '{:>{width}s} ' + ' {:>9}' * len(headers) + + report = head_fmt.format(u'', *headers, width=width) + report += '\n\nstrict\n' + + row_fmt = '{:>{width}s} ' + ' {:>9.{digits}f}' * 3 + '\n' + for c in cate: + precision = self.performance['category']['strict'][c]['precision'] + recall = self.performance['category']['strict'][c]['recall'] + f1 = self.performance['category']['strict'][c]['f_score'] + report += row_fmt.format(c, *[precision, recall, f1], width=width, digits=digits) + + report += '\nrelax\n' + + for c in cate: + precision = self.performance['category']['relax'][c]['precision'] + recall = self.performance['category']['relax'][c]['recall'] + f1 = self.performance['category']['relax'][c]['f_score'] + report += row_fmt.format(c, *[precision, recall, f1], width=width, digits=digits) + + report += '\n\noverall\n' + report += 'acc: ' + str(self.performance['overall']['acc']) + report += '\nstrict\n' + report += row_fmt.format('', *[self.performance['overall']['strict']['precision'], + self.performance['overall']['strict']['recall'], + self.performance['overall']['strict']['f_score']], width=width, digits=digits) + + report += '\nrelax\n' + report += row_fmt.format('', *[self.performance['overall']['relax']['precision'], + self.performance['overall']['relax']['recall'], + self.performance['overall']['relax']['f_score']], width=width, digits=digits) + return report