Diff of /src/agent.py [000000] .. [014e6e]

Switch to unified view

a b/src/agent.py
1
#!/usr/bin/env python
2
# -*- coding: UTF-8 -*-
3
'''
4
@Project :Auto-BioinfoGPT 
5
@File    :agent.py
6
@Author  :Juexiao Zhou
7
@Contact : juexiao.zhou@gmail.com
8
@Date    :2023/5/3 13:24 
9
'''
10
import os.path
11
import os
12
import torch.cuda
13
from src.prompt import PromptGenerator
14
from src.spinner import Spinner
15
from src.executor import CodeExecutor
16
from src.build_RAG_private import preload_retriever
17
from src.local_llm import api_preload, api_generator, api_preload_deepseek, api_generator_deepseek
18
import openai
19
from openai import OpenAI
20
import time
21
import json
22
23
class Agent:
24
    def __init__(self,
25
                 initial_data_list,
26
                 output_dir,
27
                 initial_goal_description,
28
                 model_engine,
29
                 openai_api,
30
                 execute = True,
31
                 blacklist='',
32
                 gui_mode=False,
33
                 cpu=False,
34
                 rag=False):
35
        self.initial_data_list = initial_data_list
36
        self.initial_goal_description = initial_goal_description
37
        self.tasks = []
38
        self.update_data_lists = [_ for _ in initial_data_list]
39
        self.output_dir = output_dir
40
        self.update_data_lists.append(f'{output_dir}: all outputs should be stored under this dir')
41
        self.model_engine = model_engine
42
        self.rag = rag
43
        self.local_model_engines = ['codellama-7bi',
44
                                    'codellama-13bi',
45
                                    'codellama-34bi',
46
                                    'llama2-7bc',
47
                                    'llama2-13bc',
48
                                    'llama2-70bc',
49
                                    'deepseek-6.7bi',
50
                                    'deepseek-7bi',
51
                                    'deepseek-33bi',
52
                                    'deepseek-67bc']
53
        self.gpt_model_engines = ['gpt-3.5-turbo',
54
                                  'gpt-4-turbo',
55
                                  'gpt-4',
56
                                  'gpt-4o',
57
                                  'gpt-4o-mini',
58
                                  'gpt-3.5-turbo-1106',
59
                                  'gpt-4-0613',
60
                                  'gpt-4-32k-0613',
61
                                  'gpt-4-1106-preview']
62
        self.ollama_engines = ['ollama_llama3.1']
63
        self.valid_model_engines = self.local_model_engines + self.gpt_model_engines
64
        self.openai_api = openai_api
65
66
        if self.rag:
67
            if self.model_engine in self.gpt_model_engines:
68
                self.retriever = preload_retriever(False,
69
                                                   self.openai_api,
70
                                                   PERSIST_DIR = "./softwares_database_RAG_openai",
71
                                                   SOURCE_DIR = "./softwares_database"
72
                                                   )
73
            else:
74
                self.retriever = preload_retriever(True,
75
                                                   None,
76
                                                   PERSIST_DIR="./softwares_database_RAG_local",
77
                                                   SOURCE_DIR="./softwares_database"
78
                                                   )
79
        else:
80
            self.retriever = None
81
82
        self.generator = PromptGenerator(blacklist=blacklist,
83
                                         engine = self.model_engine,
84
                                         rag = self.rag,
85
                                         retriever = self.retriever)
86
        self.global_round = 0
87
        self.execute = execute
88
        self.execute_success = True
89
        self.execute_info = ''
90
        self.code_executor = CodeExecutor()
91
        self.gui_mode = gui_mode
92
        self.cpu = cpu
93
94
        if self.model_engine.startswith('ollama_'):
95
            print('[INFO] using ollama engine!')
96
        elif self.model_engine not in self.valid_model_engines:
97
            print('[ERROR] model invalid, please check the model engine selected!')
98
            exit()
99
100
        # use gpt model
101
        if self.model_engine in self.gpt_model_engines:
102
            self.openai_client = OpenAI(
103
                # This is the default and can be omitted
104
                api_key=self.openai_api,
105
            )
106
107
        # load local model with ollama
108
        if self.model_engine.startswith('ollama_'):
109
            from langchain_community.llms import Ollama
110
            self.local_llm_generator = Ollama(model=self.model_engine.split('ollama_')[-1])
111
112
        # preload local model
113
        if 'llama' in self.model_engine and 'ollama_' not in self.model_engine:
114
            import torch.distributed as dist
115
            os.environ['MASTER_ADDR'] = 'localhost'
116
            os.environ['MASTER_PORT'] = '5678'
117
            if torch.cuda.is_available():
118
                dist.init_process_group(backend='nccl', init_method='env://', rank=0, world_size=1)
119
            else:
120
                dist.init_process_group(backend='gloo', init_method='env://', rank=0, world_size=1)
121
122
            if self.model_engine == 'codellama-7bi':
123
                self.local_llm_generator = api_preload(ckpt_dir='src/codellama-main/CodeLlama-7b-Instruct/',
124
                                        tokenizer_path='src/codellama-main/CodeLlama-7b-Instruct/tokenizer.model',
125
                                        max_seq_len=4096)
126
            elif self.model_engine == 'codellama-13bi':
127
                self.local_llm_generator = api_preload(ckpt_dir='src/codellama-main/CodeLlama-13b-Instruct/one-gpu/',
128
                                        tokenizer_path='src/codellama-main/CodeLlama-13b-Instruct/tokenizer.model',
129
                                        max_seq_len=4096)
130
            elif self.model_engine == 'codellama-34bi':
131
                self.local_llm_generator = api_preload(ckpt_dir='src/codellama-main/CodeLlama-34b-Instruct/one-gpu/',
132
                                        tokenizer_path='src/codellama-main/CodeLlama-34b-Instruct/tokenizer.model',
133
                                        max_seq_len=4096)
134
            elif self.model_engine == 'llama2-7bc':
135
                self.local_llm_generator = api_preload(ckpt_dir='src/llama-main/llama-2-7b-chat/',
136
                                        tokenizer_path='src/llama-main/tokenizer.model',
137
                                        max_seq_len=4096)
138
            elif self.model_engine == 'llama2-13bc':
139
                self.local_llm_generator = api_preload(ckpt_dir='src/llama-main/llama-2-13b-chat/one-gpu/',
140
                                        tokenizer_path='src/llama-main/tokenizer.model',
141
                                        max_seq_len=4096)
142
            elif self.model_engine == 'llama2-70bc':
143
                self.local_llm_generator = api_preload(ckpt_dir='src/llama-main/llama-2-70b-chat/',
144
                                        tokenizer_path='src/llama-main/tokenizer.model',
145
                                        max_seq_len=4096)
146
            elif self.model_engine == 'deepseek-6.7bi':
147
                self.tokenizer, self.local_llm_generator = api_preload_deepseek(
148
                    ckpt_dir='src/deepseek/deepseek-coder-6.7b-instruct/',
149
                    tokenizer_path='src/deepseek/deepseek-coder-6.7b-instruct/',
150
                    cpu = self.cpu
151
                )
152
            elif self.model_engine == 'deepseek-7bi':
153
                self.tokenizer, self.local_llm_generator = api_preload_deepseek(
154
                    ckpt_dir='src/deepseek/deepseek-coder-7b-instruct-v1.5/',
155
                    tokenizer_path='src/deepseek/deepseek-coder-7b-instruct-v1.5/',
156
                    cpu = self.cpu
157
                )
158
            elif self.model_engine == 'deepseek-33bi':
159
                self.tokenizer, self.local_llm_generator = api_preload_deepseek(
160
                    ckpt_dir='src/deepseek/deepseek-coder-33b-instruct/',
161
                    tokenizer_path='src/deepseek/deepseek-coder-33b-instruct/',
162
                    cpu = self.cpu
163
                )
164
            elif self.model_engine == 'deepseek-67bc':
165
                self.tokenizer, self.local_llm_generator = api_preload_deepseek(
166
                    ckpt_dir='src/deepseek/deepseek-llm-67b-chat/',
167
                    tokenizer_path='src/deepseek/deepseek-llm-67b-chat/',
168
                    cpu = self.cpu
169
                )
170
171
172
    def get_single_response(self, prompt):
173
174
        # use openai
175
        if self.model_engine in self.gpt_model_engines:
176
177
            if self.model_engine in ['gpt-3.5-turbo-1106', 'gpt-4-1106-preview']:
178
                response = self.openai_client.chat.completions.create(
179
                    model=self.model_engine,
180
                    response_format={"type": "json_object"},
181
                    messages=[
182
                        {"role": "user", "content": str(prompt)}],
183
                    max_tokens=1024,
184
                    temperature=0,
185
                )
186
            else:
187
                response = self.openai_client.chat.completions.create(
188
                    model=self.model_engine,
189
                      messages=[
190
                        {"role": "user", "content": str(prompt)}],
191
                    max_tokens=1024,
192
                    temperature=0,
193
                )
194
195
            """
196
            {
197
              "choices": [
198
                {
199
                  "finish_reason": "stop",
200
                  "index": 0,
201
                  "message": {
202
                    "content": "Hello! As an AI language model, I don't have emotions, but I'm functioning well. I'm here to assist you with any questions or tasks you may have. How can I help you today?",
203
                    "role": "assistant"
204
                  }
205
                }
206
              ],
207
              "created": 1683014436,
208
              "id": "chatcmpl-7BfE4AdTo5YlSIWyMDS6nL6CYv5is",
209
              "model": "gpt-3.5-turbo-0301",
210
              "object": "chat.completion",
211
              "usage": {
212
                "completion_tokens": 42,
213
                "prompt_tokens": 20,
214
                "total_tokens": 62
215
              }
216
            }
217
            """
218
            response_message = response.choices[0].message.content
219
        elif self.model_engine in self.local_model_engines:
220
            instructions = [
221
                [
222
                    {
223
                        "role": "user",
224
                        "content": str(prompt),
225
                    }
226
                ],
227
            ]
228
            if 'deepseek' in self.model_engine:
229
                results = api_generator_deepseek(instructions = instructions,
230
                                       tokenizer = self.tokenizer,
231
                                       generator = self.local_llm_generator,
232
                                       max_new_tokens=4096,
233
                                       top_k=50,
234
                                       top_p=0.95)
235
            elif 'llama' in self.model_engine:
236
                results = api_generator(instructions=instructions,
237
                                        generator=self.local_llm_generator,
238
                                        temperature=0.6)
239
            response_message = results[0]['generation']['content']
240
        elif self.model_engine.startswith('ollama_'):
241
            response_message = self.local_llm_generator.invoke(str(prompt))
242
        return response_message
243
244
    def valid_json_response(self, response_message):
245
        if not os.path.isdir(f'{self.output_dir}'):
246
            os.makedirs(f'{self.output_dir}')
247
        try:
248
            with open(f'{self.output_dir}/{self.global_round}_response.json', 'w') as w:
249
                json.dump(json.loads(response_message), w)
250
            json.load(open(f'{self.output_dir}/{self.global_round}_response.json'))
251
        except:
252
            print('[INVALID RESSPONSE]\n', response_message)
253
            return False
254
        return True
255
256
    def valid_json_response_executor(self, response_message):
257
        if not os.path.isdir(f'{self.output_dir}'):
258
            os.makedirs(f'{self.output_dir}')
259
        try:
260
            with open(f'{self.output_dir}/executor_response.json', 'w') as w:
261
                json.dump(json.loads(response_message), w)
262
            tmp_data = json.load(open(f'{self.output_dir}/executor_response.json'))
263
            if str(tmp_data['stat']) not in ['0', '1']:
264
                return False
265
        except:
266
            print('[INVALID RESSPONSE]\n', response_message)
267
            return False
268
        return True
269
270
    def process_tasks(self, response_message):
271
        self.tasks = response_message['plan']
272
273
    def find_json(self, response_message):
274
        if "```json\n" in response_message:
275
            start_index = response_message.find("{")
276
            end_index = response_message.rfind("}") + 1
277
            # 提取 JSON 部分
278
            return response_message[start_index:end_index]
279
        elif "```bash\n" in response_message:
280
            start_index = response_message.find("```bash\n")
281
            end_index = response_message.find("```\n")
282
            return str({'tool':'','code':response_message[start_index:end_index].lstrip('```bash\n')})
283
        else:
284
            start_index = response_message.find("{")
285
            end_index = response_message.rfind("}") + 1
286
            # 提取 JSON 部分
287
            return response_message[start_index:end_index]
288
289
    def execute_code(self, response_message):
290
        if not os.path.isdir(f'{self.output_dir}'):
291
            os.makedirs(f'{self.output_dir}')
292
        try:
293
            with open(f'{self.output_dir}/{self.global_round}.sh', 'w') as w:
294
                w.write(response_message['code'])
295
            if self.execute:
296
                self.last_execute_code = response_message['code']
297
                executor_info = self.code_executor.execute(bash_code_path=f'{self.output_dir}/{self.global_round}.sh')
298
                if len(executor_info) == 0:
299
                    execute_statu, execute_info = True, 'No error message'
300
                else:
301
                    executor_response_message = self.get_single_response(self.generator.get_executor_prompt(executor_info=executor_info))
302
                    print('[CHECKING EXECUTION RESULTS]\n')
303
                    if 'llama' in self.model_engine or 'deepseek' in self.model_engine:
304
                        executor_response_message = self.find_json(executor_response_message)
305
                    max_tries = 10
306
                    n_tries = 0
307
                    while not self.valid_json_response_executor(executor_response_message):
308
                        print(f'[EXECUTOR RESPONSE CHECKING TEST #{n_tries}/{max_tries}]')
309
                        if 'gpt' in self.model_engine:
310
                            time.sleep(20)
311
                        executor_response_message = self.get_single_response(
312
                            self.generator.get_executor_prompt(executor_info=executor_info))
313
                        print('[CHECKING EXECUTION RESULTS]\n')
314
                        if 'llama' in self.model_engine or 'deepseek' in self.model_engine:
315
                            executor_response_message = self.find_json(executor_response_message)
316
                        if n_tries > max_tries:
317
                            executor_response_message = {'stat':0, 'info':'None'}
318
                            break
319
                        n_tries += 1
320
                    executor_response_message = json.load(open(f'{self.output_dir}/executor_response.json'))
321
                    execute_statu, execute_info = executor_response_message['stat'], executor_response_message['info']
322
                    #execute_statu, execute_info = executor_response_message['stat'], executor_info
323
                #os.system(f'bash {self.output_dir}/{self.global_round}.sh')
324
                return bool(int(execute_statu)), execute_info
325
            return True, 'Success without executing'
326
        except Exception as e:
327
            return False, e
328
329
    def run_plan_phase(self):
330
        # initial prompt
331
        init_prompt = self.generator.get_prompt(
332
            data_list=self.initial_data_list,
333
            goal_description=self.initial_goal_description,
334
            global_round=self.global_round,
335
            execute_success=self.execute_success,
336
            execute_info=self.execute_info
337
        )
338
339
        INFO_STR_USER = self.generator.format_user_prompt(init_prompt, self.global_round, self.gui_mode)
340
        if self.gui_mode:
341
            print('[AI Thinking...]')
342
            response_message = self.get_single_response(init_prompt)
343
            if 'llama' in self.model_engine or 'deepseek' in self.model_engine:
344
                response_message = self.find_json(response_message)
345
            while not self.valid_json_response(response_message):
346
                print(f'[Invalid Response, Waiting for 20s and Retrying...]')
347
                print(f'invalid response: {response_message}')
348
                if 'gpt' in self.model_engine:
349
                    time.sleep(20)
350
                response_message = self.get_single_response(init_prompt)
351
            response_message = json.load(open(f'{self.output_dir}/{self.global_round}_response.json'))
352
        else:
353
            with Spinner(f'\033[32m[AI Thinking...]\033[0m'):
354
                response_message = self.get_single_response(init_prompt)
355
                if 'llama' in self.model_engine or 'deepseek' in self.model_engine:
356
                    response_message = self.find_json(response_message)
357
                while not self.valid_json_response(response_message):
358
                    print(f'\033[32m[Invalid Response, Waiting for 20s and Retrying...]\033[0m')
359
                    print(f'invalid response: {response_message}')
360
                    if 'gpt' in self.model_engine:
361
                        time.sleep(20)
362
                    response_message = self.get_single_response(init_prompt)
363
                response_message = json.load(open(f'{self.output_dir}/{self.global_round}_response.json'))
364
        INFO_STR_AI = self.generator.format_ai_response(response_message, self.gui_mode)
365
366
        # process tasks
367
        self.process_tasks(response_message)
368
        self.generator.set_tasks(self.tasks)
369
        self.generator.add_history(None, self.global_round, self.update_data_lists)
370
        self.global_round += 1
371
372
        if self.execute == False:
373
            time.sleep(15)
374
        else:
375
            pass
376
377
    def run_code_generation_phase(self):
378
        # finish task one-by-one with code
379
        # print('[DEBUG] ', self.tasks)
380
        while len(self.tasks) > 0:
381
            task = self.tasks.pop(0)
382
383
            prompt = self.generator.get_prompt(
384
                data_list=self.update_data_lists,
385
                goal_description=task,
386
                global_round=self.global_round,
387
                execute_success=self.execute_success,
388
                execute_info=self.execute_info
389
            )
390
391
            self.first_prompt = True
392
            self.execute_success = False
393
            while self.execute_success == False:
394
395
                if self.first_prompt == False:
396
                    prompt = self.generator.get_prompt(
397
                        data_list=self.update_data_lists,
398
                        goal_description=task,
399
                        global_round=self.global_round,
400
                        execute_success=self.execute_success,
401
                        execute_info=self.execute_info,
402
                        last_execute_code=self.last_execute_code
403
                    )
404
405
                INFO_STR_USER = self.generator.format_user_prompt(prompt, self.global_round, self.gui_mode)
406
                if self.gui_mode:
407
                    print('[AI Thinking...]')
408
                    response_message = self.get_single_response(prompt)
409
                    if 'llama' in self.model_engine or 'deepseek' in self.model_engine:
410
                        response_message = self.find_json(response_message)
411
                    while not self.valid_json_response(response_message):
412
                        print(f'[Invalid Response, Waiting for 20s and Retrying...]')
413
                        print(f'invalid response: {response_message}')
414
                        if 'gpt' in self.model_engine:
415
                            time.sleep(20)
416
                        response_message = self.get_single_response(prompt)
417
                    response_message = json.load(open(f'{self.output_dir}/{self.global_round}_response.json'))
418
                else:
419
                    with Spinner(f'\033[32m[AI Thinking...]\033[0m'):
420
                        response_message = self.get_single_response(prompt)
421
                        if 'llama' in self.model_engine or 'deepseek' in self.model_engine:
422
                            response_message = self.find_json(response_message)
423
                        while not self.valid_json_response(response_message):
424
                            print(f'\033[32m[Invalid Response, Waiting for 20s and Retrying...]\033[0m')
425
                            print(f'invalid response: {response_message}')
426
                            if 'gpt' in self.model_engine:
427
                                time.sleep(20)
428
                            response_message = self.get_single_response(prompt)
429
                        response_message = json.load(open(f'{self.output_dir}/{self.global_round}_response.json'))
430
                INFO_STR_AI = self.generator.format_ai_response(response_message, self.gui_mode)
431
432
                # execute code
433
                if self.gui_mode:
434
                    print('[AI Executing codes...]')
435
                    print(f'[Execute Code Start]')
436
                    execute_success, execute_info = self.execute_code(response_message)
437
                    self.execute_success = execute_success
438
                    self.execute_info = execute_info
439
                    print('[Execute Code Finish]', self.execute_success, self.execute_info)
440
441
                    if self.execute_success:
442
                        print(f'[Execute Code Success!]')
443
                        self.execute_info = ''
444
                    else:
445
                        print(f'[Execute Code Failed!]')
446
                        self.first_prompt = False
447
                else:
448
                    with Spinner(f'\033[32m[AI Executing codes...]\033[0m'):
449
                        print(f'\033[32m[Execute Code Start]\033[0m')
450
                        execute_success, execute_info = self.execute_code(response_message)
451
                        self.execute_success = execute_success
452
                        self.execute_info = execute_info
453
                        print('\033[32m[Execute Code Finish]\033[0m', self.execute_success, self.execute_info)
454
455
                        if self.execute_success:
456
                            print(f'\033[32m[Execute Code Success!]\033[0m')
457
                            self.execute_info = ''
458
                        else:
459
                            print(f'\033[31m[Execute Code Failed!]\033[0m')
460
                            self.first_prompt = False
461
462
            self.generator.add_history(task, self.global_round, self.update_data_lists, code=response_message['code'])
463
            self.global_round += 1
464
            if self.execute == False:
465
                time.sleep(15)
466
467
    def run(self):
468
        self.run_plan_phase()
469
        self.run_code_generation_phase()
470
        print(f'\033[31m[Job Finished! Cheers!]\033[0m')