|
a |
|
b/aiagents4pharma/talk2biomodels/agents/t2b_agent.py |
|
|
1 |
#/usr/bin/env python3 |
|
|
2 |
|
|
|
3 |
''' |
|
|
4 |
This is the agent file for the Talk2BioModels agent. |
|
|
5 |
''' |
|
|
6 |
|
|
|
7 |
import logging |
|
|
8 |
from typing import Annotated |
|
|
9 |
import hydra |
|
|
10 |
from langchain_core.language_models.chat_models import BaseChatModel |
|
|
11 |
from langgraph.checkpoint.memory import MemorySaver |
|
|
12 |
from langgraph.graph import START, StateGraph |
|
|
13 |
from langgraph.prebuilt import create_react_agent, ToolNode, InjectedState |
|
|
14 |
from ..tools.search_models import SearchModelsTool |
|
|
15 |
from ..tools.get_modelinfo import GetModelInfoTool |
|
|
16 |
from ..tools.simulate_model import SimulateModelTool |
|
|
17 |
from ..tools.custom_plotter import CustomPlotterTool |
|
|
18 |
from ..tools.get_annotation import GetAnnotationTool |
|
|
19 |
from ..tools.ask_question import AskQuestionTool |
|
|
20 |
from ..tools.parameter_scan import ParameterScanTool |
|
|
21 |
from ..tools.steady_state import SteadyStateTool |
|
|
22 |
from ..tools.query_article import QueryArticle |
|
|
23 |
from ..states.state_talk2biomodels import Talk2Biomodels |
|
|
24 |
|
|
|
25 |
# Initialize logger |
|
|
26 |
logging.basicConfig(level=logging.INFO) |
|
|
27 |
logger = logging.getLogger(__name__) |
|
|
28 |
|
|
|
29 |
def get_app(uniq_id, |
|
|
30 |
llm_model: BaseChatModel): |
|
|
31 |
''' |
|
|
32 |
This function returns the langraph app. |
|
|
33 |
''' |
|
|
34 |
def agent_t2b_node(state: Annotated[dict, InjectedState]): |
|
|
35 |
''' |
|
|
36 |
This function calls the model. |
|
|
37 |
''' |
|
|
38 |
logger.log(logging.INFO, "Calling t2b_agent node with thread_id %s", uniq_id) |
|
|
39 |
response = model.invoke(state, {"configurable": {"thread_id": uniq_id}}) |
|
|
40 |
return response |
|
|
41 |
|
|
|
42 |
# Define the tools |
|
|
43 |
tools = ToolNode([ |
|
|
44 |
SimulateModelTool(), |
|
|
45 |
AskQuestionTool(), |
|
|
46 |
CustomPlotterTool(), |
|
|
47 |
SearchModelsTool(), |
|
|
48 |
GetModelInfoTool(), |
|
|
49 |
SteadyStateTool(), |
|
|
50 |
ParameterScanTool(), |
|
|
51 |
GetAnnotationTool(), |
|
|
52 |
QueryArticle() |
|
|
53 |
]) |
|
|
54 |
|
|
|
55 |
# Load hydra configuration |
|
|
56 |
logger.log(logging.INFO, "Load Hydra configuration for Talk2BioModels agent.") |
|
|
57 |
with hydra.initialize(version_base=None, config_path="../configs"): |
|
|
58 |
cfg = hydra.compose(config_name='config', |
|
|
59 |
overrides=['agents/t2b_agent=default']) |
|
|
60 |
cfg = cfg.agents.t2b_agent |
|
|
61 |
logger.log(logging.INFO, "state_modifier: %s", cfg.state_modifier) |
|
|
62 |
# Create the agent |
|
|
63 |
model = create_react_agent( |
|
|
64 |
llm_model, |
|
|
65 |
tools=tools, |
|
|
66 |
state_schema=Talk2Biomodels, |
|
|
67 |
prompt=cfg.state_modifier, |
|
|
68 |
version='v2', |
|
|
69 |
checkpointer=MemorySaver() |
|
|
70 |
) |
|
|
71 |
|
|
|
72 |
# Define a new graph |
|
|
73 |
workflow = StateGraph(Talk2Biomodels) |
|
|
74 |
|
|
|
75 |
# Define the two nodes we will cycle between |
|
|
76 |
workflow.add_node("agent_t2b", agent_t2b_node) |
|
|
77 |
|
|
|
78 |
# Set the entrypoint as the first node |
|
|
79 |
# This means that this node is the first one called |
|
|
80 |
workflow.add_edge(START, "agent_t2b") |
|
|
81 |
|
|
|
82 |
# Initialize memory to persist state between graph runs |
|
|
83 |
checkpointer = MemorySaver() |
|
|
84 |
|
|
|
85 |
# Finally, we compile it! |
|
|
86 |
# This compiles it into a LangChain Runnable, |
|
|
87 |
# meaning you can use it as you would any other runnable. |
|
|
88 |
# Note that we're (optionally) passing the memory |
|
|
89 |
# when compiling the graph |
|
|
90 |
app = workflow.compile(checkpointer=checkpointer, |
|
|
91 |
name="T2B_Agent") |
|
|
92 |
logger.log(logging.INFO, |
|
|
93 |
"Compiled the graph with thread_id %s and llm_model %s", |
|
|
94 |
uniq_id, |
|
|
95 |
llm_model) |
|
|
96 |
|
|
|
97 |
return app |