--- a +++ b/ehragent/evaluate.py @@ -0,0 +1,73 @@ +import os +import json + +def judge(pred, ans): + old_flag = True + if not ans in pred: + old_flag = False + if "True" in pred: + pred = pred.replace("True", "1") + else: + pred = pred.replace("False", "0") + if ans == "False" or ans == "false": + ans = "0" + if ans == "True" or ans == "true": + ans = "1" + if ans == "None" or ans == "none": + ans = "0" + if ", " in ans: + ans = ans.split(', ') + if ans[-2:] == ".0": + ans = ans[:-2] + if not type(ans) == list: + ans = [ans] + new_flag = True + for i in range(len(ans)): + if not ans[i] in pred: + new_flag = False + break + return (old_flag or new_flag) + +logs_path = "<YOUR_LOGS_PATH>" +files = os.listdir(logs_path) + +# read the files +answer_book = "<YOUR_DATASET_PATH>" +with open(answer_book, 'r') as f: + contents = json.load(f) +answers = {} +for i in range(len(contents)): + answers[contents[i]['id']] = contents[i]['answer'] + +stats = {"total_num": 0, "correct": 0, "unfinished": 0, "incorrect": 0} + +for file in files: + if not file.split('.')[0] in answers.keys(): + continue + with open(logs_path+file, 'r') as f: + logs = f.read() + split_logs = logs.split('\n----------------------------------------------------------\n') + question = split_logs[0] + answer = answers[file.split('.')[0]] + if type(answer) == list: + answer = ', '.join(answer) + stats["total_num"] += 1 + if not "TERMINATE" in logs: + stats["unfinished"] += 1 + else: + if '"cell": "' in logs: + last_code_start = logs.rfind('"cell": "') + last_code_end = logs.rfind('"\n}') + last_code = logs[last_code_start+9:last_code_end] + else: + last_code_end = logs.rfind('Solution:') + prediction_end = logs.rfind('TERMINATE') + prediction = logs[last_code_end:prediction_end] + logs = logs.split('TERMINATE')[0] + result = judge(prediction, answer) + if result: + stats["correct"] += 1 + else: + stats["incorrect"] += 1 + +print(stats)