Diff of /ehragent/evaluate.py [000000] .. [6cf5c7]

Switch to unified view

a b/ehragent/evaluate.py
1
import os
2
import json
3
4
def judge(pred, ans):
5
    old_flag = True
6
    if not ans in pred:
7
        old_flag = False
8
    if "True" in pred:
9
        pred = pred.replace("True", "1")
10
    else:
11
        pred = pred.replace("False", "0")
12
    if ans == "False" or ans == "false":
13
        ans = "0"
14
    if ans == "True" or ans == "true":
15
        ans = "1"
16
    if ans == "None" or ans == "none":
17
        ans = "0"
18
    if ", " in ans:
19
        ans = ans.split(', ')
20
    if ans[-2:] == ".0":
21
        ans = ans[:-2]
22
    if not type(ans) == list:
23
        ans = [ans]
24
    new_flag = True
25
    for i in range(len(ans)):
26
        if not ans[i] in pred:
27
            new_flag = False
28
            break
29
    return (old_flag or new_flag)
30
31
logs_path = "<YOUR_LOGS_PATH>"
32
files = os.listdir(logs_path)
33
34
# read the files 
35
answer_book = "<YOUR_DATASET_PATH>"
36
with open(answer_book, 'r') as f:
37
    contents = json.load(f)
38
answers = {}
39
for i in range(len(contents)):
40
    answers[contents[i]['id']] = contents[i]['answer']
41
42
stats = {"total_num": 0, "correct": 0, "unfinished": 0, "incorrect": 0}
43
44
for file in files:
45
    if not file.split('.')[0] in answers.keys():
46
        continue
47
    with open(logs_path+file, 'r') as f:
48
        logs = f.read()
49
    split_logs = logs.split('\n----------------------------------------------------------\n')
50
    question = split_logs[0]
51
    answer = answers[file.split('.')[0]]
52
    if type(answer) == list:
53
        answer = ', '.join(answer)
54
    stats["total_num"] += 1
55
    if not "TERMINATE" in logs:
56
        stats["unfinished"] += 1
57
    else:
58
        if '"cell": "' in logs:
59
            last_code_start = logs.rfind('"cell": "')
60
            last_code_end = logs.rfind('"\n}')
61
            last_code = logs[last_code_start+9:last_code_end]
62
        else:
63
            last_code_end = logs.rfind('Solution:')
64
        prediction_end = logs.rfind('TERMINATE')
65
        prediction = logs[last_code_end:prediction_end]
66
        logs = logs.split('TERMINATE')[0]
67
        result = judge(prediction, answer)
68
        if result:
69
            stats["correct"] += 1
70
        else:
71
            stats["incorrect"] += 1
72
73
print(stats)