|
a |
|
b/ehragent/evaluate.py |
|
|
1 |
import os |
|
|
2 |
import json |
|
|
3 |
|
|
|
4 |
def judge(pred, ans): |
|
|
5 |
old_flag = True |
|
|
6 |
if not ans in pred: |
|
|
7 |
old_flag = False |
|
|
8 |
if "True" in pred: |
|
|
9 |
pred = pred.replace("True", "1") |
|
|
10 |
else: |
|
|
11 |
pred = pred.replace("False", "0") |
|
|
12 |
if ans == "False" or ans == "false": |
|
|
13 |
ans = "0" |
|
|
14 |
if ans == "True" or ans == "true": |
|
|
15 |
ans = "1" |
|
|
16 |
if ans == "None" or ans == "none": |
|
|
17 |
ans = "0" |
|
|
18 |
if ", " in ans: |
|
|
19 |
ans = ans.split(', ') |
|
|
20 |
if ans[-2:] == ".0": |
|
|
21 |
ans = ans[:-2] |
|
|
22 |
if not type(ans) == list: |
|
|
23 |
ans = [ans] |
|
|
24 |
new_flag = True |
|
|
25 |
for i in range(len(ans)): |
|
|
26 |
if not ans[i] in pred: |
|
|
27 |
new_flag = False |
|
|
28 |
break |
|
|
29 |
return (old_flag or new_flag) |
|
|
30 |
|
|
|
31 |
logs_path = "<YOUR_LOGS_PATH>" |
|
|
32 |
files = os.listdir(logs_path) |
|
|
33 |
|
|
|
34 |
# read the files |
|
|
35 |
answer_book = "<YOUR_DATASET_PATH>" |
|
|
36 |
with open(answer_book, 'r') as f: |
|
|
37 |
contents = json.load(f) |
|
|
38 |
answers = {} |
|
|
39 |
for i in range(len(contents)): |
|
|
40 |
answers[contents[i]['id']] = contents[i]['answer'] |
|
|
41 |
|
|
|
42 |
stats = {"total_num": 0, "correct": 0, "unfinished": 0, "incorrect": 0} |
|
|
43 |
|
|
|
44 |
for file in files: |
|
|
45 |
if not file.split('.')[0] in answers.keys(): |
|
|
46 |
continue |
|
|
47 |
with open(logs_path+file, 'r') as f: |
|
|
48 |
logs = f.read() |
|
|
49 |
split_logs = logs.split('\n----------------------------------------------------------\n') |
|
|
50 |
question = split_logs[0] |
|
|
51 |
answer = answers[file.split('.')[0]] |
|
|
52 |
if type(answer) == list: |
|
|
53 |
answer = ', '.join(answer) |
|
|
54 |
stats["total_num"] += 1 |
|
|
55 |
if not "TERMINATE" in logs: |
|
|
56 |
stats["unfinished"] += 1 |
|
|
57 |
else: |
|
|
58 |
if '"cell": "' in logs: |
|
|
59 |
last_code_start = logs.rfind('"cell": "') |
|
|
60 |
last_code_end = logs.rfind('"\n}') |
|
|
61 |
last_code = logs[last_code_start+9:last_code_end] |
|
|
62 |
else: |
|
|
63 |
last_code_end = logs.rfind('Solution:') |
|
|
64 |
prediction_end = logs.rfind('TERMINATE') |
|
|
65 |
prediction = logs[last_code_end:prediction_end] |
|
|
66 |
logs = logs.split('TERMINATE')[0] |
|
|
67 |
result = judge(prediction, answer) |
|
|
68 |
if result: |
|
|
69 |
stats["correct"] += 1 |
|
|
70 |
else: |
|
|
71 |
stats["incorrect"] += 1 |
|
|
72 |
|
|
|
73 |
print(stats) |