|
a |
|
b/aiagents4pharma/talk2aiagents4pharma/agents/main_agent.py |
|
|
1 |
#/usr/bin/env python3 |
|
|
2 |
|
|
|
3 |
''' |
|
|
4 |
This is the main agent file for the AIAgents4Pharma. |
|
|
5 |
''' |
|
|
6 |
|
|
|
7 |
import logging |
|
|
8 |
import hydra |
|
|
9 |
from langgraph_supervisor import create_supervisor |
|
|
10 |
from langchain_openai import ChatOpenAI |
|
|
11 |
from langchain_core.language_models.chat_models import BaseChatModel |
|
|
12 |
from langgraph.checkpoint.memory import MemorySaver |
|
|
13 |
from ...talk2biomodels.agents.t2b_agent import get_app as get_app_t2b |
|
|
14 |
from ...talk2knowledgegraphs.agents.t2kg_agent import get_app as get_app_t2kg |
|
|
15 |
from ..states.state_talk2aiagents4pharma import Talk2AIAgents4Pharma |
|
|
16 |
|
|
|
17 |
# Initialize logger |
|
|
18 |
logging.basicConfig(level=logging.INFO) |
|
|
19 |
logger = logging.getLogger(__name__) |
|
|
20 |
|
|
|
21 |
def get_app(uniq_id, llm_model: BaseChatModel): |
|
|
22 |
''' |
|
|
23 |
This function returns the langraph app. |
|
|
24 |
''' |
|
|
25 |
if hasattr(llm_model, 'model_name'): |
|
|
26 |
if llm_model.model_name == 'gpt-4o-mini': |
|
|
27 |
llm_model = ChatOpenAI(model='gpt-4o-mini', |
|
|
28 |
temperature=0, |
|
|
29 |
model_kwargs={"parallel_tool_calls": False}) |
|
|
30 |
# Load hydra configuration |
|
|
31 |
logger.log(logging.INFO, "Launching AIAgents4Pharma_Agent with thread_id %s", uniq_id) |
|
|
32 |
with hydra.initialize(version_base=None, config_path="../configs"): |
|
|
33 |
cfg = hydra.compose(config_name='config', |
|
|
34 |
overrides=['agents/main_agent=default']) |
|
|
35 |
cfg = cfg.agents.main_agent |
|
|
36 |
logger.log(logging.INFO, "System_prompt of T2AA4P: %s", cfg.system_prompt) |
|
|
37 |
with hydra.initialize(version_base=None, config_path="../../talk2biomodels/configs"): |
|
|
38 |
cfg_t2b = hydra.compose(config_name='config', |
|
|
39 |
overrides=['agents/t2b_agent=default']) |
|
|
40 |
cfg_t2b = cfg_t2b.agents.t2b_agent |
|
|
41 |
with hydra.initialize(version_base=None, config_path="../../talk2knowledgegraphs/configs"): |
|
|
42 |
cfg_t2kg = hydra.compose(config_name='config', |
|
|
43 |
overrides=['agents/t2kg_agent=default']) |
|
|
44 |
cfg_t2kg = cfg_t2kg.agents.t2kg_agent |
|
|
45 |
system_prompt = cfg.system_prompt |
|
|
46 |
system_prompt += "\n\nHere is the system prompt of T2B agent\n" |
|
|
47 |
system_prompt += cfg_t2b.state_modifier |
|
|
48 |
system_prompt += "\n\nHere is the system prompt of T2KG agent\n" |
|
|
49 |
system_prompt += cfg_t2kg.state_modifier |
|
|
50 |
# Create supervisor workflow |
|
|
51 |
workflow = create_supervisor( |
|
|
52 |
[ |
|
|
53 |
get_app_t2b(uniq_id, llm_model), # Talk2BioModels |
|
|
54 |
get_app_t2kg(uniq_id, llm_model) # Talk2KnowledgeGraphs |
|
|
55 |
], |
|
|
56 |
model=llm_model, |
|
|
57 |
state_schema=Talk2AIAgents4Pharma, |
|
|
58 |
# Full history is needed to extract |
|
|
59 |
# the tool artifacts |
|
|
60 |
output_mode="full_history", |
|
|
61 |
add_handoff_back_messages=True, |
|
|
62 |
prompt=system_prompt |
|
|
63 |
) |
|
|
64 |
|
|
|
65 |
# Compile and run |
|
|
66 |
app = workflow.compile(checkpointer=MemorySaver(), |
|
|
67 |
name="AIAgents4Pharma_Agent") |
|
|
68 |
|
|
|
69 |
return app |