|
a |
|
b/ehragent/main.py |
|
|
1 |
import os |
|
|
2 |
import json |
|
|
3 |
import random |
|
|
4 |
import numpy as np |
|
|
5 |
import argparse |
|
|
6 |
import autogen |
|
|
7 |
from toolset_high import * |
|
|
8 |
from medagent import MedAgent |
|
|
9 |
from config import openai_config, llm_config_list |
|
|
10 |
import time |
|
|
11 |
|
|
|
12 |
def judge(pred, ans): |
|
|
13 |
old_flag = True |
|
|
14 |
if not ans in pred: |
|
|
15 |
old_flag = False |
|
|
16 |
if "True" in pred: |
|
|
17 |
pred = pred.replace("True", "1") |
|
|
18 |
else: |
|
|
19 |
pred = pred.replace("False", "0") |
|
|
20 |
if ans == "False" or ans == "false": |
|
|
21 |
ans = "0" |
|
|
22 |
if ans == "True" or ans == "true": |
|
|
23 |
ans = "1" |
|
|
24 |
if ans == "No" or ans == "no": |
|
|
25 |
ans = "0" |
|
|
26 |
if ans == "Yes" or ans == "yes": |
|
|
27 |
ans = "1" |
|
|
28 |
if ans == "None" or ans == "none": |
|
|
29 |
ans = "0" |
|
|
30 |
if ", " in ans: |
|
|
31 |
ans = ans.split(', ') |
|
|
32 |
if ans[-2:] == ".0": |
|
|
33 |
ans = ans[:-2] |
|
|
34 |
if not type(ans) == list: |
|
|
35 |
ans = [ans] |
|
|
36 |
new_flag = True |
|
|
37 |
for i in range(len(ans)): |
|
|
38 |
if not ans[i] in pred: |
|
|
39 |
new_flag = False |
|
|
40 |
break |
|
|
41 |
return (old_flag or new_flag) |
|
|
42 |
|
|
|
43 |
def set_seed(seed): |
|
|
44 |
random.seed(seed) |
|
|
45 |
np.random.seed(seed) |
|
|
46 |
|
|
|
47 |
def main(): |
|
|
48 |
parser = argparse.ArgumentParser() |
|
|
49 |
parser.add_argument("--llm", type=str, default="<YOUR_LLM_NAME>") |
|
|
50 |
parser.add_argument("--num_questions", type=int, default=1) |
|
|
51 |
parser.add_argument("--dataset", type=str, default="mimic_iii") |
|
|
52 |
parser.add_argument("--data_path", type=str, default="<YOUR_DATASET_PATH>") |
|
|
53 |
parser.add_argument("--logs_path", type=str, default="<YOUR_LOGS_PATH>") |
|
|
54 |
parser.add_argument("--seed", type=int, default=42) |
|
|
55 |
parser.add_argument("--debug", action="store_true") |
|
|
56 |
parser.add_argument("--debug_id", type=str, default="521fd2885f51641a963f8d3e") |
|
|
57 |
parser.add_argument("--start_id", type=int, default=0) |
|
|
58 |
parser.add_argument("--num_shots", type=int, default=4) |
|
|
59 |
args = parser.parse_args() |
|
|
60 |
set_seed(args.seed) |
|
|
61 |
if args.dataset == 'mimic_iii': |
|
|
62 |
from prompts_mimic import EHRAgent_4Shots_Knowledge |
|
|
63 |
else: |
|
|
64 |
from prompts_eicu import EHRAgent_4Shots_Knowledge |
|
|
65 |
|
|
|
66 |
config_list = [openai_config(args.llm)] |
|
|
67 |
llm_config = llm_config_list(args.seed, config_list) |
|
|
68 |
|
|
|
69 |
chatbot = autogen.agentchat.AssistantAgent( |
|
|
70 |
name="chatbot", |
|
|
71 |
system_message="For coding tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done. Save the answers to the questions in the variable 'answer'. Please only generate the code.", |
|
|
72 |
llm_config=llm_config, |
|
|
73 |
) |
|
|
74 |
|
|
|
75 |
user_proxy = MedAgent( |
|
|
76 |
name="user_proxy", |
|
|
77 |
is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"), |
|
|
78 |
human_input_mode="NEVER", |
|
|
79 |
max_consecutive_auto_reply=10, |
|
|
80 |
code_execution_config={"work_dir": "coding"}, |
|
|
81 |
config_list=config_list, |
|
|
82 |
) |
|
|
83 |
|
|
|
84 |
# register the functions |
|
|
85 |
user_proxy.register_function( |
|
|
86 |
function_map={ |
|
|
87 |
"python": run_code |
|
|
88 |
} |
|
|
89 |
) |
|
|
90 |
|
|
|
91 |
user_proxy.register_dataset(args.dataset) |
|
|
92 |
|
|
|
93 |
file_path = args.data_path |
|
|
94 |
# read from json file |
|
|
95 |
with open(file_path, 'r') as f: |
|
|
96 |
contents = json.load(f) |
|
|
97 |
|
|
|
98 |
# random shuffle |
|
|
99 |
import random |
|
|
100 |
random.shuffle(contents) |
|
|
101 |
file_path = "{}/{}/".format(args.logs_path, args.num_shots) + "{id}.txt" |
|
|
102 |
|
|
|
103 |
start_time = time.time() |
|
|
104 |
if args.num_questions == -1: |
|
|
105 |
args.num_questions = len(contents) |
|
|
106 |
long_term_memory = [] |
|
|
107 |
init_memory = EHRAgent_4Shots_Knowledge |
|
|
108 |
init_memory = init_memory.split('\n\n') |
|
|
109 |
for i in range(len(init_memory)): |
|
|
110 |
item = init_memory[i] |
|
|
111 |
item = item.split('Question:')[-1] |
|
|
112 |
question = item.split('\nKnowledge:\n')[0] |
|
|
113 |
item = item.split('\nKnowledge:\n')[-1] |
|
|
114 |
knowledge = item.split('\nSolution:')[0] |
|
|
115 |
code = item.split('\nSolution:')[-1] |
|
|
116 |
new_item = {"question": question, "knowledge": knowledge, "code": code} |
|
|
117 |
long_term_memory.append(new_item) |
|
|
118 |
|
|
|
119 |
for i in range(args.start_id, args.num_questions): |
|
|
120 |
if args.debug and contents[i]['id'] != args.debug_id: |
|
|
121 |
continue |
|
|
122 |
question = contents[i]['template'] |
|
|
123 |
answer = contents[i]['answer'] |
|
|
124 |
try: |
|
|
125 |
user_proxy.update_memory(args.num_shots, long_term_memory) |
|
|
126 |
user_proxy.initiate_chat( |
|
|
127 |
chatbot, |
|
|
128 |
message=question, |
|
|
129 |
) |
|
|
130 |
logs = user_proxy._oai_messages |
|
|
131 |
|
|
|
132 |
logs_string = [] |
|
|
133 |
logs_string.append(str(question)) |
|
|
134 |
logs_string.append(str(answer)) |
|
|
135 |
for agent in list(logs.keys()): |
|
|
136 |
for j in range(len(logs[agent])): |
|
|
137 |
if logs[agent][j]['content'] != None: |
|
|
138 |
logs_string.append(logs[agent][j]['content']) |
|
|
139 |
else: |
|
|
140 |
argums = logs[agent][j]['function_call']['arguments'] |
|
|
141 |
if type(argums) == dict and 'cell' in argums.keys(): |
|
|
142 |
logs_string.append(argums['cell']) |
|
|
143 |
else: |
|
|
144 |
logs_string.append(argums) |
|
|
145 |
except Exception as e: |
|
|
146 |
logs_string = [str(e)] |
|
|
147 |
print(logs_string) |
|
|
148 |
file_directory = file_path.format(id=contents[i]['id']) |
|
|
149 |
# f = open(file_directory, 'w') |
|
|
150 |
if type(answer) == list: |
|
|
151 |
answer = ', '.join(answer) |
|
|
152 |
logs_string.append("Ground-Truth Answer ---> "+answer) |
|
|
153 |
with open(file_directory, 'w') as f: |
|
|
154 |
f.write('\n----------------------------------------------------------\n'.join(logs_string)) |
|
|
155 |
logs_string = '\n----------------------------------------------------------\n'.join(logs_string) |
|
|
156 |
if '"cell": "' in logs_string: |
|
|
157 |
last_code_start = logs_string.rfind('"cell": "') |
|
|
158 |
last_code_end = logs_string.rfind('"\n}') |
|
|
159 |
last_code = logs_string[last_code_start+9:last_code_end] |
|
|
160 |
else: |
|
|
161 |
last_code_end = logs_string.rfind('Solution:') |
|
|
162 |
prediction_end = logs_string.rfind('TERMINATE') |
|
|
163 |
prediction = logs_string[last_code_end:prediction_end] |
|
|
164 |
result = judge(prediction, answer) |
|
|
165 |
if result: |
|
|
166 |
new_item = {"question": question, "knowledge": user_proxy.knowledge, "code": user_proxy.code} |
|
|
167 |
long_term_memory.append(new_item) |
|
|
168 |
end_time = time.time() |
|
|
169 |
print("Time elapsed: ", end_time - start_time) |
|
|
170 |
|
|
|
171 |
if __name__ == "__main__": |
|
|
172 |
main() |