|
a |
|
b/src/agent.py |
|
|
1 |
#!/usr/bin/env python |
|
|
2 |
# -*- coding: UTF-8 -*- |
|
|
3 |
''' |
|
|
4 |
@Project :Auto-BioinfoGPT |
|
|
5 |
@File :agent.py |
|
|
6 |
@Author :Juexiao Zhou |
|
|
7 |
@Contact : juexiao.zhou@gmail.com |
|
|
8 |
@Date :2023/5/3 13:24 |
|
|
9 |
''' |
|
|
10 |
import os.path |
|
|
11 |
import os |
|
|
12 |
import torch.cuda |
|
|
13 |
from src.prompt import PromptGenerator |
|
|
14 |
from src.spinner import Spinner |
|
|
15 |
from src.executor import CodeExecutor |
|
|
16 |
from src.build_RAG_private import preload_retriever |
|
|
17 |
from src.local_llm import api_preload, api_generator, api_preload_deepseek, api_generator_deepseek |
|
|
18 |
import openai |
|
|
19 |
from openai import OpenAI |
|
|
20 |
import time |
|
|
21 |
import json |
|
|
22 |
|
|
|
23 |
class Agent: |
|
|
24 |
def __init__(self, |
|
|
25 |
initial_data_list, |
|
|
26 |
output_dir, |
|
|
27 |
initial_goal_description, |
|
|
28 |
model_engine, |
|
|
29 |
openai_api, |
|
|
30 |
execute = True, |
|
|
31 |
blacklist='', |
|
|
32 |
gui_mode=False, |
|
|
33 |
cpu=False, |
|
|
34 |
rag=False): |
|
|
35 |
self.initial_data_list = initial_data_list |
|
|
36 |
self.initial_goal_description = initial_goal_description |
|
|
37 |
self.tasks = [] |
|
|
38 |
self.update_data_lists = [_ for _ in initial_data_list] |
|
|
39 |
self.output_dir = output_dir |
|
|
40 |
self.update_data_lists.append(f'{output_dir}: all outputs should be stored under this dir') |
|
|
41 |
self.model_engine = model_engine |
|
|
42 |
self.rag = rag |
|
|
43 |
self.local_model_engines = ['codellama-7bi', |
|
|
44 |
'codellama-13bi', |
|
|
45 |
'codellama-34bi', |
|
|
46 |
'llama2-7bc', |
|
|
47 |
'llama2-13bc', |
|
|
48 |
'llama2-70bc', |
|
|
49 |
'deepseek-6.7bi', |
|
|
50 |
'deepseek-7bi', |
|
|
51 |
'deepseek-33bi', |
|
|
52 |
'deepseek-67bc'] |
|
|
53 |
self.gpt_model_engines = ['gpt-3.5-turbo', |
|
|
54 |
'gpt-4-turbo', |
|
|
55 |
'gpt-4', |
|
|
56 |
'gpt-4o', |
|
|
57 |
'gpt-4o-mini', |
|
|
58 |
'gpt-3.5-turbo-1106', |
|
|
59 |
'gpt-4-0613', |
|
|
60 |
'gpt-4-32k-0613', |
|
|
61 |
'gpt-4-1106-preview'] |
|
|
62 |
self.ollama_engines = ['ollama_llama3.1'] |
|
|
63 |
self.valid_model_engines = self.local_model_engines + self.gpt_model_engines |
|
|
64 |
self.openai_api = openai_api |
|
|
65 |
|
|
|
66 |
if self.rag: |
|
|
67 |
if self.model_engine in self.gpt_model_engines: |
|
|
68 |
self.retriever = preload_retriever(False, |
|
|
69 |
self.openai_api, |
|
|
70 |
PERSIST_DIR = "./softwares_database_RAG_openai", |
|
|
71 |
SOURCE_DIR = "./softwares_database" |
|
|
72 |
) |
|
|
73 |
else: |
|
|
74 |
self.retriever = preload_retriever(True, |
|
|
75 |
None, |
|
|
76 |
PERSIST_DIR="./softwares_database_RAG_local", |
|
|
77 |
SOURCE_DIR="./softwares_database" |
|
|
78 |
) |
|
|
79 |
else: |
|
|
80 |
self.retriever = None |
|
|
81 |
|
|
|
82 |
self.generator = PromptGenerator(blacklist=blacklist, |
|
|
83 |
engine = self.model_engine, |
|
|
84 |
rag = self.rag, |
|
|
85 |
retriever = self.retriever) |
|
|
86 |
self.global_round = 0 |
|
|
87 |
self.execute = execute |
|
|
88 |
self.execute_success = True |
|
|
89 |
self.execute_info = '' |
|
|
90 |
self.code_executor = CodeExecutor() |
|
|
91 |
self.gui_mode = gui_mode |
|
|
92 |
self.cpu = cpu |
|
|
93 |
|
|
|
94 |
if self.model_engine.startswith('ollama_'): |
|
|
95 |
print('[INFO] using ollama engine!') |
|
|
96 |
elif self.model_engine not in self.valid_model_engines: |
|
|
97 |
print('[ERROR] model invalid, please check the model engine selected!') |
|
|
98 |
exit() |
|
|
99 |
|
|
|
100 |
# use gpt model |
|
|
101 |
if self.model_engine in self.gpt_model_engines: |
|
|
102 |
self.openai_client = OpenAI( |
|
|
103 |
# This is the default and can be omitted |
|
|
104 |
api_key=self.openai_api, |
|
|
105 |
) |
|
|
106 |
|
|
|
107 |
# load local model with ollama |
|
|
108 |
if self.model_engine.startswith('ollama_'): |
|
|
109 |
from langchain_community.llms import Ollama |
|
|
110 |
self.local_llm_generator = Ollama(model=self.model_engine.split('ollama_')[-1]) |
|
|
111 |
|
|
|
112 |
# preload local model |
|
|
113 |
if 'llama' in self.model_engine and 'ollama_' not in self.model_engine: |
|
|
114 |
import torch.distributed as dist |
|
|
115 |
os.environ['MASTER_ADDR'] = 'localhost' |
|
|
116 |
os.environ['MASTER_PORT'] = '5678' |
|
|
117 |
if torch.cuda.is_available(): |
|
|
118 |
dist.init_process_group(backend='nccl', init_method='env://', rank=0, world_size=1) |
|
|
119 |
else: |
|
|
120 |
dist.init_process_group(backend='gloo', init_method='env://', rank=0, world_size=1) |
|
|
121 |
|
|
|
122 |
if self.model_engine == 'codellama-7bi': |
|
|
123 |
self.local_llm_generator = api_preload(ckpt_dir='src/codellama-main/CodeLlama-7b-Instruct/', |
|
|
124 |
tokenizer_path='src/codellama-main/CodeLlama-7b-Instruct/tokenizer.model', |
|
|
125 |
max_seq_len=4096) |
|
|
126 |
elif self.model_engine == 'codellama-13bi': |
|
|
127 |
self.local_llm_generator = api_preload(ckpt_dir='src/codellama-main/CodeLlama-13b-Instruct/one-gpu/', |
|
|
128 |
tokenizer_path='src/codellama-main/CodeLlama-13b-Instruct/tokenizer.model', |
|
|
129 |
max_seq_len=4096) |
|
|
130 |
elif self.model_engine == 'codellama-34bi': |
|
|
131 |
self.local_llm_generator = api_preload(ckpt_dir='src/codellama-main/CodeLlama-34b-Instruct/one-gpu/', |
|
|
132 |
tokenizer_path='src/codellama-main/CodeLlama-34b-Instruct/tokenizer.model', |
|
|
133 |
max_seq_len=4096) |
|
|
134 |
elif self.model_engine == 'llama2-7bc': |
|
|
135 |
self.local_llm_generator = api_preload(ckpt_dir='src/llama-main/llama-2-7b-chat/', |
|
|
136 |
tokenizer_path='src/llama-main/tokenizer.model', |
|
|
137 |
max_seq_len=4096) |
|
|
138 |
elif self.model_engine == 'llama2-13bc': |
|
|
139 |
self.local_llm_generator = api_preload(ckpt_dir='src/llama-main/llama-2-13b-chat/one-gpu/', |
|
|
140 |
tokenizer_path='src/llama-main/tokenizer.model', |
|
|
141 |
max_seq_len=4096) |
|
|
142 |
elif self.model_engine == 'llama2-70bc': |
|
|
143 |
self.local_llm_generator = api_preload(ckpt_dir='src/llama-main/llama-2-70b-chat/', |
|
|
144 |
tokenizer_path='src/llama-main/tokenizer.model', |
|
|
145 |
max_seq_len=4096) |
|
|
146 |
elif self.model_engine == 'deepseek-6.7bi': |
|
|
147 |
self.tokenizer, self.local_llm_generator = api_preload_deepseek( |
|
|
148 |
ckpt_dir='src/deepseek/deepseek-coder-6.7b-instruct/', |
|
|
149 |
tokenizer_path='src/deepseek/deepseek-coder-6.7b-instruct/', |
|
|
150 |
cpu = self.cpu |
|
|
151 |
) |
|
|
152 |
elif self.model_engine == 'deepseek-7bi': |
|
|
153 |
self.tokenizer, self.local_llm_generator = api_preload_deepseek( |
|
|
154 |
ckpt_dir='src/deepseek/deepseek-coder-7b-instruct-v1.5/', |
|
|
155 |
tokenizer_path='src/deepseek/deepseek-coder-7b-instruct-v1.5/', |
|
|
156 |
cpu = self.cpu |
|
|
157 |
) |
|
|
158 |
elif self.model_engine == 'deepseek-33bi': |
|
|
159 |
self.tokenizer, self.local_llm_generator = api_preload_deepseek( |
|
|
160 |
ckpt_dir='src/deepseek/deepseek-coder-33b-instruct/', |
|
|
161 |
tokenizer_path='src/deepseek/deepseek-coder-33b-instruct/', |
|
|
162 |
cpu = self.cpu |
|
|
163 |
) |
|
|
164 |
elif self.model_engine == 'deepseek-67bc': |
|
|
165 |
self.tokenizer, self.local_llm_generator = api_preload_deepseek( |
|
|
166 |
ckpt_dir='src/deepseek/deepseek-llm-67b-chat/', |
|
|
167 |
tokenizer_path='src/deepseek/deepseek-llm-67b-chat/', |
|
|
168 |
cpu = self.cpu |
|
|
169 |
) |
|
|
170 |
|
|
|
171 |
|
|
|
172 |
def get_single_response(self, prompt): |
|
|
173 |
|
|
|
174 |
# use openai |
|
|
175 |
if self.model_engine in self.gpt_model_engines: |
|
|
176 |
|
|
|
177 |
if self.model_engine in ['gpt-3.5-turbo-1106', 'gpt-4-1106-preview']: |
|
|
178 |
response = self.openai_client.chat.completions.create( |
|
|
179 |
model=self.model_engine, |
|
|
180 |
response_format={"type": "json_object"}, |
|
|
181 |
messages=[ |
|
|
182 |
{"role": "user", "content": str(prompt)}], |
|
|
183 |
max_tokens=1024, |
|
|
184 |
temperature=0, |
|
|
185 |
) |
|
|
186 |
else: |
|
|
187 |
response = self.openai_client.chat.completions.create( |
|
|
188 |
model=self.model_engine, |
|
|
189 |
messages=[ |
|
|
190 |
{"role": "user", "content": str(prompt)}], |
|
|
191 |
max_tokens=1024, |
|
|
192 |
temperature=0, |
|
|
193 |
) |
|
|
194 |
|
|
|
195 |
""" |
|
|
196 |
{ |
|
|
197 |
"choices": [ |
|
|
198 |
{ |
|
|
199 |
"finish_reason": "stop", |
|
|
200 |
"index": 0, |
|
|
201 |
"message": { |
|
|
202 |
"content": "Hello! As an AI language model, I don't have emotions, but I'm functioning well. I'm here to assist you with any questions or tasks you may have. How can I help you today?", |
|
|
203 |
"role": "assistant" |
|
|
204 |
} |
|
|
205 |
} |
|
|
206 |
], |
|
|
207 |
"created": 1683014436, |
|
|
208 |
"id": "chatcmpl-7BfE4AdTo5YlSIWyMDS6nL6CYv5is", |
|
|
209 |
"model": "gpt-3.5-turbo-0301", |
|
|
210 |
"object": "chat.completion", |
|
|
211 |
"usage": { |
|
|
212 |
"completion_tokens": 42, |
|
|
213 |
"prompt_tokens": 20, |
|
|
214 |
"total_tokens": 62 |
|
|
215 |
} |
|
|
216 |
} |
|
|
217 |
""" |
|
|
218 |
response_message = response.choices[0].message.content |
|
|
219 |
elif self.model_engine in self.local_model_engines: |
|
|
220 |
instructions = [ |
|
|
221 |
[ |
|
|
222 |
{ |
|
|
223 |
"role": "user", |
|
|
224 |
"content": str(prompt), |
|
|
225 |
} |
|
|
226 |
], |
|
|
227 |
] |
|
|
228 |
if 'deepseek' in self.model_engine: |
|
|
229 |
results = api_generator_deepseek(instructions = instructions, |
|
|
230 |
tokenizer = self.tokenizer, |
|
|
231 |
generator = self.local_llm_generator, |
|
|
232 |
max_new_tokens=4096, |
|
|
233 |
top_k=50, |
|
|
234 |
top_p=0.95) |
|
|
235 |
elif 'llama' in self.model_engine: |
|
|
236 |
results = api_generator(instructions=instructions, |
|
|
237 |
generator=self.local_llm_generator, |
|
|
238 |
temperature=0.6) |
|
|
239 |
response_message = results[0]['generation']['content'] |
|
|
240 |
elif self.model_engine.startswith('ollama_'): |
|
|
241 |
response_message = self.local_llm_generator.invoke(str(prompt)) |
|
|
242 |
return response_message |
|
|
243 |
|
|
|
244 |
def valid_json_response(self, response_message): |
|
|
245 |
if not os.path.isdir(f'{self.output_dir}'): |
|
|
246 |
os.makedirs(f'{self.output_dir}') |
|
|
247 |
try: |
|
|
248 |
with open(f'{self.output_dir}/{self.global_round}_response.json', 'w') as w: |
|
|
249 |
json.dump(json.loads(response_message), w) |
|
|
250 |
json.load(open(f'{self.output_dir}/{self.global_round}_response.json')) |
|
|
251 |
except: |
|
|
252 |
print('[INVALID RESSPONSE]\n', response_message) |
|
|
253 |
return False |
|
|
254 |
return True |
|
|
255 |
|
|
|
256 |
def valid_json_response_executor(self, response_message): |
|
|
257 |
if not os.path.isdir(f'{self.output_dir}'): |
|
|
258 |
os.makedirs(f'{self.output_dir}') |
|
|
259 |
try: |
|
|
260 |
with open(f'{self.output_dir}/executor_response.json', 'w') as w: |
|
|
261 |
json.dump(json.loads(response_message), w) |
|
|
262 |
tmp_data = json.load(open(f'{self.output_dir}/executor_response.json')) |
|
|
263 |
if str(tmp_data['stat']) not in ['0', '1']: |
|
|
264 |
return False |
|
|
265 |
except: |
|
|
266 |
print('[INVALID RESSPONSE]\n', response_message) |
|
|
267 |
return False |
|
|
268 |
return True |
|
|
269 |
|
|
|
270 |
def process_tasks(self, response_message): |
|
|
271 |
self.tasks = response_message['plan'] |
|
|
272 |
|
|
|
273 |
def find_json(self, response_message): |
|
|
274 |
if "```json\n" in response_message: |
|
|
275 |
start_index = response_message.find("{") |
|
|
276 |
end_index = response_message.rfind("}") + 1 |
|
|
277 |
# 提取 JSON 部分 |
|
|
278 |
return response_message[start_index:end_index] |
|
|
279 |
elif "```bash\n" in response_message: |
|
|
280 |
start_index = response_message.find("```bash\n") |
|
|
281 |
end_index = response_message.find("```\n") |
|
|
282 |
return str({'tool':'','code':response_message[start_index:end_index].lstrip('```bash\n')}) |
|
|
283 |
else: |
|
|
284 |
start_index = response_message.find("{") |
|
|
285 |
end_index = response_message.rfind("}") + 1 |
|
|
286 |
# 提取 JSON 部分 |
|
|
287 |
return response_message[start_index:end_index] |
|
|
288 |
|
|
|
289 |
def execute_code(self, response_message): |
|
|
290 |
if not os.path.isdir(f'{self.output_dir}'): |
|
|
291 |
os.makedirs(f'{self.output_dir}') |
|
|
292 |
try: |
|
|
293 |
with open(f'{self.output_dir}/{self.global_round}.sh', 'w') as w: |
|
|
294 |
w.write(response_message['code']) |
|
|
295 |
if self.execute: |
|
|
296 |
self.last_execute_code = response_message['code'] |
|
|
297 |
executor_info = self.code_executor.execute(bash_code_path=f'{self.output_dir}/{self.global_round}.sh') |
|
|
298 |
if len(executor_info) == 0: |
|
|
299 |
execute_statu, execute_info = True, 'No error message' |
|
|
300 |
else: |
|
|
301 |
executor_response_message = self.get_single_response(self.generator.get_executor_prompt(executor_info=executor_info)) |
|
|
302 |
print('[CHECKING EXECUTION RESULTS]\n') |
|
|
303 |
if 'llama' in self.model_engine or 'deepseek' in self.model_engine: |
|
|
304 |
executor_response_message = self.find_json(executor_response_message) |
|
|
305 |
max_tries = 10 |
|
|
306 |
n_tries = 0 |
|
|
307 |
while not self.valid_json_response_executor(executor_response_message): |
|
|
308 |
print(f'[EXECUTOR RESPONSE CHECKING TEST #{n_tries}/{max_tries}]') |
|
|
309 |
if 'gpt' in self.model_engine: |
|
|
310 |
time.sleep(20) |
|
|
311 |
executor_response_message = self.get_single_response( |
|
|
312 |
self.generator.get_executor_prompt(executor_info=executor_info)) |
|
|
313 |
print('[CHECKING EXECUTION RESULTS]\n') |
|
|
314 |
if 'llama' in self.model_engine or 'deepseek' in self.model_engine: |
|
|
315 |
executor_response_message = self.find_json(executor_response_message) |
|
|
316 |
if n_tries > max_tries: |
|
|
317 |
executor_response_message = {'stat':0, 'info':'None'} |
|
|
318 |
break |
|
|
319 |
n_tries += 1 |
|
|
320 |
executor_response_message = json.load(open(f'{self.output_dir}/executor_response.json')) |
|
|
321 |
execute_statu, execute_info = executor_response_message['stat'], executor_response_message['info'] |
|
|
322 |
#execute_statu, execute_info = executor_response_message['stat'], executor_info |
|
|
323 |
#os.system(f'bash {self.output_dir}/{self.global_round}.sh') |
|
|
324 |
return bool(int(execute_statu)), execute_info |
|
|
325 |
return True, 'Success without executing' |
|
|
326 |
except Exception as e: |
|
|
327 |
return False, e |
|
|
328 |
|
|
|
329 |
def run_plan_phase(self): |
|
|
330 |
# initial prompt |
|
|
331 |
init_prompt = self.generator.get_prompt( |
|
|
332 |
data_list=self.initial_data_list, |
|
|
333 |
goal_description=self.initial_goal_description, |
|
|
334 |
global_round=self.global_round, |
|
|
335 |
execute_success=self.execute_success, |
|
|
336 |
execute_info=self.execute_info |
|
|
337 |
) |
|
|
338 |
|
|
|
339 |
INFO_STR_USER = self.generator.format_user_prompt(init_prompt, self.global_round, self.gui_mode) |
|
|
340 |
if self.gui_mode: |
|
|
341 |
print('[AI Thinking...]') |
|
|
342 |
response_message = self.get_single_response(init_prompt) |
|
|
343 |
if 'llama' in self.model_engine or 'deepseek' in self.model_engine: |
|
|
344 |
response_message = self.find_json(response_message) |
|
|
345 |
while not self.valid_json_response(response_message): |
|
|
346 |
print(f'[Invalid Response, Waiting for 20s and Retrying...]') |
|
|
347 |
print(f'invalid response: {response_message}') |
|
|
348 |
if 'gpt' in self.model_engine: |
|
|
349 |
time.sleep(20) |
|
|
350 |
response_message = self.get_single_response(init_prompt) |
|
|
351 |
response_message = json.load(open(f'{self.output_dir}/{self.global_round}_response.json')) |
|
|
352 |
else: |
|
|
353 |
with Spinner(f'\033[32m[AI Thinking...]\033[0m'): |
|
|
354 |
response_message = self.get_single_response(init_prompt) |
|
|
355 |
if 'llama' in self.model_engine or 'deepseek' in self.model_engine: |
|
|
356 |
response_message = self.find_json(response_message) |
|
|
357 |
while not self.valid_json_response(response_message): |
|
|
358 |
print(f'\033[32m[Invalid Response, Waiting for 20s and Retrying...]\033[0m') |
|
|
359 |
print(f'invalid response: {response_message}') |
|
|
360 |
if 'gpt' in self.model_engine: |
|
|
361 |
time.sleep(20) |
|
|
362 |
response_message = self.get_single_response(init_prompt) |
|
|
363 |
response_message = json.load(open(f'{self.output_dir}/{self.global_round}_response.json')) |
|
|
364 |
INFO_STR_AI = self.generator.format_ai_response(response_message, self.gui_mode) |
|
|
365 |
|
|
|
366 |
# process tasks |
|
|
367 |
self.process_tasks(response_message) |
|
|
368 |
self.generator.set_tasks(self.tasks) |
|
|
369 |
self.generator.add_history(None, self.global_round, self.update_data_lists) |
|
|
370 |
self.global_round += 1 |
|
|
371 |
|
|
|
372 |
if self.execute == False: |
|
|
373 |
time.sleep(15) |
|
|
374 |
else: |
|
|
375 |
pass |
|
|
376 |
|
|
|
377 |
def run_code_generation_phase(self): |
|
|
378 |
# finish task one-by-one with code |
|
|
379 |
# print('[DEBUG] ', self.tasks) |
|
|
380 |
while len(self.tasks) > 0: |
|
|
381 |
task = self.tasks.pop(0) |
|
|
382 |
|
|
|
383 |
prompt = self.generator.get_prompt( |
|
|
384 |
data_list=self.update_data_lists, |
|
|
385 |
goal_description=task, |
|
|
386 |
global_round=self.global_round, |
|
|
387 |
execute_success=self.execute_success, |
|
|
388 |
execute_info=self.execute_info |
|
|
389 |
) |
|
|
390 |
|
|
|
391 |
self.first_prompt = True |
|
|
392 |
self.execute_success = False |
|
|
393 |
while self.execute_success == False: |
|
|
394 |
|
|
|
395 |
if self.first_prompt == False: |
|
|
396 |
prompt = self.generator.get_prompt( |
|
|
397 |
data_list=self.update_data_lists, |
|
|
398 |
goal_description=task, |
|
|
399 |
global_round=self.global_round, |
|
|
400 |
execute_success=self.execute_success, |
|
|
401 |
execute_info=self.execute_info, |
|
|
402 |
last_execute_code=self.last_execute_code |
|
|
403 |
) |
|
|
404 |
|
|
|
405 |
INFO_STR_USER = self.generator.format_user_prompt(prompt, self.global_round, self.gui_mode) |
|
|
406 |
if self.gui_mode: |
|
|
407 |
print('[AI Thinking...]') |
|
|
408 |
response_message = self.get_single_response(prompt) |
|
|
409 |
if 'llama' in self.model_engine or 'deepseek' in self.model_engine: |
|
|
410 |
response_message = self.find_json(response_message) |
|
|
411 |
while not self.valid_json_response(response_message): |
|
|
412 |
print(f'[Invalid Response, Waiting for 20s and Retrying...]') |
|
|
413 |
print(f'invalid response: {response_message}') |
|
|
414 |
if 'gpt' in self.model_engine: |
|
|
415 |
time.sleep(20) |
|
|
416 |
response_message = self.get_single_response(prompt) |
|
|
417 |
response_message = json.load(open(f'{self.output_dir}/{self.global_round}_response.json')) |
|
|
418 |
else: |
|
|
419 |
with Spinner(f'\033[32m[AI Thinking...]\033[0m'): |
|
|
420 |
response_message = self.get_single_response(prompt) |
|
|
421 |
if 'llama' in self.model_engine or 'deepseek' in self.model_engine: |
|
|
422 |
response_message = self.find_json(response_message) |
|
|
423 |
while not self.valid_json_response(response_message): |
|
|
424 |
print(f'\033[32m[Invalid Response, Waiting for 20s and Retrying...]\033[0m') |
|
|
425 |
print(f'invalid response: {response_message}') |
|
|
426 |
if 'gpt' in self.model_engine: |
|
|
427 |
time.sleep(20) |
|
|
428 |
response_message = self.get_single_response(prompt) |
|
|
429 |
response_message = json.load(open(f'{self.output_dir}/{self.global_round}_response.json')) |
|
|
430 |
INFO_STR_AI = self.generator.format_ai_response(response_message, self.gui_mode) |
|
|
431 |
|
|
|
432 |
# execute code |
|
|
433 |
if self.gui_mode: |
|
|
434 |
print('[AI Executing codes...]') |
|
|
435 |
print(f'[Execute Code Start]') |
|
|
436 |
execute_success, execute_info = self.execute_code(response_message) |
|
|
437 |
self.execute_success = execute_success |
|
|
438 |
self.execute_info = execute_info |
|
|
439 |
print('[Execute Code Finish]', self.execute_success, self.execute_info) |
|
|
440 |
|
|
|
441 |
if self.execute_success: |
|
|
442 |
print(f'[Execute Code Success!]') |
|
|
443 |
self.execute_info = '' |
|
|
444 |
else: |
|
|
445 |
print(f'[Execute Code Failed!]') |
|
|
446 |
self.first_prompt = False |
|
|
447 |
else: |
|
|
448 |
with Spinner(f'\033[32m[AI Executing codes...]\033[0m'): |
|
|
449 |
print(f'\033[32m[Execute Code Start]\033[0m') |
|
|
450 |
execute_success, execute_info = self.execute_code(response_message) |
|
|
451 |
self.execute_success = execute_success |
|
|
452 |
self.execute_info = execute_info |
|
|
453 |
print('\033[32m[Execute Code Finish]\033[0m', self.execute_success, self.execute_info) |
|
|
454 |
|
|
|
455 |
if self.execute_success: |
|
|
456 |
print(f'\033[32m[Execute Code Success!]\033[0m') |
|
|
457 |
self.execute_info = '' |
|
|
458 |
else: |
|
|
459 |
print(f'\033[31m[Execute Code Failed!]\033[0m') |
|
|
460 |
self.first_prompt = False |
|
|
461 |
|
|
|
462 |
self.generator.add_history(task, self.global_round, self.update_data_lists, code=response_message['code']) |
|
|
463 |
self.global_round += 1 |
|
|
464 |
if self.execute == False: |
|
|
465 |
time.sleep(15) |
|
|
466 |
|
|
|
467 |
def run(self): |
|
|
468 |
self.run_plan_phase() |
|
|
469 |
self.run_code_generation_phase() |
|
|
470 |
print(f'\033[31m[Job Finished! Cheers!]\033[0m') |