Switch to unified view

a b/aiagents4pharma/talk2biomodels/agents/t2b_agent.py
1
#/usr/bin/env python3
2
3
'''
4
This is the agent file for the Talk2BioModels agent.
5
'''
6
7
import logging
8
from typing import Annotated
9
import hydra
10
from langchain_core.language_models.chat_models import BaseChatModel
11
from langgraph.checkpoint.memory import MemorySaver
12
from langgraph.graph import START, StateGraph
13
from langgraph.prebuilt import create_react_agent, ToolNode, InjectedState
14
from ..tools.search_models import SearchModelsTool
15
from ..tools.get_modelinfo import GetModelInfoTool
16
from ..tools.simulate_model import SimulateModelTool
17
from ..tools.custom_plotter import CustomPlotterTool
18
from ..tools.get_annotation import GetAnnotationTool
19
from ..tools.ask_question import AskQuestionTool
20
from ..tools.parameter_scan import ParameterScanTool
21
from ..tools.steady_state import SteadyStateTool
22
from ..tools.query_article import QueryArticle
23
from ..states.state_talk2biomodels import Talk2Biomodels
24
25
# Initialize logger
26
logging.basicConfig(level=logging.INFO)
27
logger = logging.getLogger(__name__)
28
29
def get_app(uniq_id,
30
            llm_model: BaseChatModel):
31
    '''
32
    This function returns the langraph app.
33
    '''
34
    def agent_t2b_node(state: Annotated[dict, InjectedState]):
35
        '''
36
        This function calls the model.
37
        '''
38
        logger.log(logging.INFO, "Calling t2b_agent node with thread_id %s", uniq_id)
39
        response = model.invoke(state, {"configurable": {"thread_id": uniq_id}})
40
        return response
41
42
    # Define the tools
43
    tools = ToolNode([
44
                    SimulateModelTool(),
45
                    AskQuestionTool(),
46
                    CustomPlotterTool(),
47
                    SearchModelsTool(),
48
                    GetModelInfoTool(),
49
                    SteadyStateTool(),
50
                    ParameterScanTool(),
51
                    GetAnnotationTool(),
52
                    QueryArticle()
53
                ])
54
55
    # Load hydra configuration
56
    logger.log(logging.INFO, "Load Hydra configuration for Talk2BioModels agent.")
57
    with hydra.initialize(version_base=None, config_path="../configs"):
58
        cfg = hydra.compose(config_name='config',
59
                            overrides=['agents/t2b_agent=default'])
60
        cfg = cfg.agents.t2b_agent
61
    logger.log(logging.INFO, "state_modifier: %s", cfg.state_modifier)
62
    # Create the agent
63
    model = create_react_agent(
64
                llm_model,
65
                tools=tools,
66
                state_schema=Talk2Biomodels,
67
                prompt=cfg.state_modifier,
68
                version='v2',
69
                checkpointer=MemorySaver()
70
            )
71
72
    # Define a new graph
73
    workflow = StateGraph(Talk2Biomodels)
74
75
    # Define the two nodes we will cycle between
76
    workflow.add_node("agent_t2b", agent_t2b_node)
77
78
    # Set the entrypoint as the first node
79
    # This means that this node is the first one called
80
    workflow.add_edge(START, "agent_t2b")
81
82
    # Initialize memory to persist state between graph runs
83
    checkpointer = MemorySaver()
84
85
    # Finally, we compile it!
86
    # This compiles it into a LangChain Runnable,
87
    # meaning you can use it as you would any other runnable.
88
    # Note that we're (optionally) passing the memory
89
    # when compiling the graph
90
    app = workflow.compile(checkpointer=checkpointer,
91
                           name="T2B_Agent")
92
    logger.log(logging.INFO,
93
               "Compiled the graph with thread_id %s and llm_model %s",
94
               uniq_id,
95
               llm_model)
96
97
    return app