a b/aiagents4pharma/talk2cells/agents/scp_agent.py
1
#/usr/bin/env python3
2
3
'''
4
This is the agent file for the Talk2Cells graph.
5
'''
6
7
import logging
8
import os
9
from langchain_openai import ChatOpenAI
10
from langgraph.checkpoint.memory import MemorySaver
11
from langgraph.graph import START, StateGraph
12
from langgraph.prebuilt import create_react_agent, ToolNode
13
from ..tools.scp_agent.search_studies import search_studies
14
from ..tools.scp_agent.display_studies import display_studies
15
from ..states.state_talk2cells import Talk2Cells
16
17
# Initialize logger
18
logging.basicConfig(level=logging.INFO)
19
logger = logging.getLogger(__name__)
20
21
def get_app(uniq_id):
22
    '''
23
    This function returns the langraph app.
24
    '''
25
    def agent_scp_node(state: Talk2Cells):
26
        '''
27
        This function calls the model.
28
        '''
29
        logger.log(logging.INFO, "Creating SCP_Agent node with thread_id %s", uniq_id)
30
        # Get the messages from the state
31
        # messages = state['messages']
32
        # Call the model
33
        # inputs = {'messages': messages}
34
        response = model.invoke(state, {"configurable": {"thread_id": uniq_id}})
35
        # The response is a list of messages and may contain `tool calls`
36
        # We return a list, because this will get added to the existing list
37
        # return {"messages": [response]}
38
        return response
39
40
    # Define the tools
41
    # tools = [search_studies, display_studies]
42
    tools = ToolNode([search_studies, display_studies])
43
44
    # Create the LLM
45
    # And bind the tools to it
46
    # model = ChatOpenAI(model="gpt-4o-mini", temperature=0).bind_tools(tools)
47
48
    # Create an environment variable to store the LLM model
49
    # Check if the environment variable AIAGENTS4PHARMA_LLM_MODEL is set
50
    # If not, set it to 'gpt-4o-mini'
51
    llm_model = os.getenv('AIAGENTS4PHARMA_LLM_MODEL', 'gpt-4o-mini')
52
    # print (f'LLM model: {llm_model}')
53
    # llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
54
    llm = ChatOpenAI(model=llm_model, temperature=0)
55
    model = create_react_agent(
56
                            llm,
57
                            tools=tools,
58
                            state_schema=Talk2Cells,
59
                            state_modifier=(
60
                                            "You are Talk2Cells agent."
61
                                            ),
62
                            checkpointer=MemorySaver()
63
                        )
64
65
    # Define a new graph
66
    workflow = StateGraph(Talk2Cells)
67
68
    # Define the two nodes we will cycle between
69
    workflow.add_node("agent_scp", agent_scp_node)
70
71
    # Set the entrypoint as `agent`
72
    # This means that this node is the first one called
73
    workflow.add_edge(START, "agent_scp")
74
75
    # Initialize memory to persist state between graph runs
76
    checkpointer = MemorySaver()
77
78
    # Finally, we compile it!
79
    # This compiles it into a LangChain Runnable,
80
    # meaning you can use it as you would any other runnable.
81
    # Note that we're (optionally) passing the memory when compiling the graph
82
    app = workflow.compile(checkpointer=checkpointer)
83
    logger.log(logging.INFO, "Compiled the graph")
84
85
    return app