Diff of /ehragent/main.py [000000] .. [6cf5c7]

Switch to unified view

a b/ehragent/main.py
1
import os
2
import json
3
import random
4
import numpy as np
5
import argparse
6
import autogen
7
from toolset_high import *
8
from medagent import MedAgent
9
from config import openai_config, llm_config_list
10
import time
11
12
def judge(pred, ans):
13
    old_flag = True
14
    if not ans in pred:
15
        old_flag = False
16
    if "True" in pred:
17
        pred = pred.replace("True", "1")
18
    else:
19
        pred = pred.replace("False", "0")
20
    if ans == "False" or ans == "false":
21
        ans = "0"
22
    if ans == "True" or ans == "true":
23
        ans = "1"
24
    if ans == "No" or ans == "no":
25
        ans = "0"
26
    if ans == "Yes" or ans == "yes":
27
        ans = "1"
28
    if ans == "None" or ans == "none":
29
        ans = "0"
30
    if ", " in ans:
31
        ans = ans.split(', ')
32
    if ans[-2:] == ".0":
33
        ans = ans[:-2]
34
    if not type(ans) == list:
35
        ans = [ans]
36
    new_flag = True
37
    for i in range(len(ans)):
38
        if not ans[i] in pred:
39
            new_flag = False
40
            break
41
    return (old_flag or new_flag)
42
43
def set_seed(seed):
44
    random.seed(seed)
45
    np.random.seed(seed)
46
47
def main():
48
    parser = argparse.ArgumentParser()
49
    parser.add_argument("--llm", type=str, default="<YOUR_LLM_NAME>")
50
    parser.add_argument("--num_questions", type=int, default=1)
51
    parser.add_argument("--dataset", type=str, default="mimic_iii")
52
    parser.add_argument("--data_path", type=str, default="<YOUR_DATASET_PATH>")
53
    parser.add_argument("--logs_path", type=str, default="<YOUR_LOGS_PATH>")
54
    parser.add_argument("--seed", type=int, default=42)
55
    parser.add_argument("--debug", action="store_true")
56
    parser.add_argument("--debug_id", type=str, default="521fd2885f51641a963f8d3e")
57
    parser.add_argument("--start_id", type=int, default=0)
58
    parser.add_argument("--num_shots", type=int, default=4)
59
    args = parser.parse_args()
60
    set_seed(args.seed)
61
    if args.dataset == 'mimic_iii':
62
        from prompts_mimic import EHRAgent_4Shots_Knowledge
63
    else:
64
        from prompts_eicu import EHRAgent_4Shots_Knowledge
65
66
    config_list = [openai_config(args.llm)]
67
    llm_config = llm_config_list(args.seed, config_list)
68
69
    chatbot = autogen.agentchat.AssistantAgent(
70
        name="chatbot",
71
        system_message="For coding tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done. Save the answers to the questions in the variable 'answer'. Please only generate the code.",
72
        llm_config=llm_config,
73
    )
74
75
    user_proxy = MedAgent(
76
        name="user_proxy",
77
        is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
78
        human_input_mode="NEVER",
79
        max_consecutive_auto_reply=10,
80
        code_execution_config={"work_dir": "coding"},
81
        config_list=config_list,
82
    )
83
84
    # register the functions
85
    user_proxy.register_function(
86
        function_map={
87
            "python": run_code
88
        }
89
    )
90
91
    user_proxy.register_dataset(args.dataset)
92
93
    file_path = args.data_path
94
    # read from json file
95
    with open(file_path, 'r') as f:
96
        contents = json.load(f)
97
98
    # random shuffle
99
    import random
100
    random.shuffle(contents)
101
    file_path = "{}/{}/".format(args.logs_path, args.num_shots) + "{id}.txt"
102
103
    start_time = time.time()
104
    if args.num_questions == -1:
105
        args.num_questions = len(contents)
106
    long_term_memory = []
107
    init_memory = EHRAgent_4Shots_Knowledge
108
    init_memory = init_memory.split('\n\n')
109
    for i in range(len(init_memory)):
110
        item = init_memory[i]
111
        item = item.split('Question:')[-1]
112
        question = item.split('\nKnowledge:\n')[0]
113
        item = item.split('\nKnowledge:\n')[-1]
114
        knowledge = item.split('\nSolution:')[0]
115
        code = item.split('\nSolution:')[-1]
116
        new_item = {"question": question, "knowledge": knowledge, "code": code}
117
        long_term_memory.append(new_item)
118
119
    for i in range(args.start_id, args.num_questions):
120
        if args.debug and contents[i]['id'] != args.debug_id:
121
            continue
122
        question = contents[i]['template']
123
        answer = contents[i]['answer']
124
        try:
125
            user_proxy.update_memory(args.num_shots, long_term_memory)
126
            user_proxy.initiate_chat(
127
                chatbot,
128
                message=question,
129
            )
130
            logs = user_proxy._oai_messages
131
132
            logs_string = []
133
            logs_string.append(str(question))
134
            logs_string.append(str(answer))
135
            for agent in list(logs.keys()):
136
                for j in range(len(logs[agent])):
137
                    if logs[agent][j]['content'] != None:
138
                        logs_string.append(logs[agent][j]['content'])
139
                    else:
140
                        argums = logs[agent][j]['function_call']['arguments']
141
                        if type(argums) == dict and 'cell' in argums.keys():
142
                            logs_string.append(argums['cell'])
143
                        else:
144
                            logs_string.append(argums)
145
        except Exception as e:
146
            logs_string = [str(e)]
147
        print(logs_string)
148
        file_directory = file_path.format(id=contents[i]['id'])
149
        # f = open(file_directory, 'w')
150
        if type(answer) == list:
151
            answer = ', '.join(answer)
152
        logs_string.append("Ground-Truth Answer ---> "+answer)
153
        with open(file_directory, 'w') as f:
154
            f.write('\n----------------------------------------------------------\n'.join(logs_string))
155
        logs_string = '\n----------------------------------------------------------\n'.join(logs_string)
156
        if '"cell": "' in logs_string:
157
            last_code_start = logs_string.rfind('"cell": "')
158
            last_code_end = logs_string.rfind('"\n}')
159
            last_code = logs_string[last_code_start+9:last_code_end]
160
        else:
161
            last_code_end = logs_string.rfind('Solution:')
162
        prediction_end = logs_string.rfind('TERMINATE')
163
        prediction = logs_string[last_code_end:prediction_end]
164
        result = judge(prediction, answer)
165
        if result:
166
            new_item = {"question": question, "knowledge": user_proxy.knowledge, "code": user_proxy.code}
167
            long_term_memory.append(new_item)
168
    end_time = time.time()
169
    print("Time elapsed: ", end_time - start_time)
170
171
if __name__ == "__main__":
172
    main()