Switch to unified view

a b/aiagents4pharma/talk2scholars/agents/zotero_agent.py
1
#!/usr/bin/env python3
2
3
"""
4
Agent for interacting with Zotero with human-in-the-loop features
5
"""
6
7
import logging
8
from typing import Any, Dict
9
import hydra
10
11
from langchain_core.language_models.chat_models import BaseChatModel
12
from langgraph.graph import START, StateGraph
13
from langgraph.prebuilt import create_react_agent, ToolNode
14
from langgraph.checkpoint.memory import MemorySaver
15
from ..state.state_talk2scholars import Talk2Scholars
16
from ..tools.zotero.zotero_read import zotero_read
17
from ..tools.zotero.zotero_review import zotero_review
18
from ..tools.zotero.zotero_write import zotero_write
19
from ..tools.s2.display_results import display_results as s2_display
20
from ..tools.s2.query_results import query_results as s2_query_results
21
from ..tools.s2.retrieve_semantic_scholar_paper_id import (
22
    retrieve_semantic_scholar_paper_id,
23
)
24
25
# Initialize logger
26
logging.basicConfig(level=logging.INFO)
27
logger = logging.getLogger(__name__)
28
29
30
def get_app(uniq_id, llm_model: BaseChatModel):
31
    """
32
    Initializes and returns the LangGraph application for the Zotero agent.
33
34
    This function sets up the Zotero agent, which integrates various tools to search,
35
    retrieve, and display research papers from Zotero. The agent follows the ReAct
36
    pattern for structured interaction and includes human-in-the-loop features.
37
38
    Args:
39
        uniq_id (str): Unique identifier for the current conversation session.
40
        llm_model (BaseChatModel, optional): The language model to be used by the agent.
41
            Defaults to `ChatOpenAI(model="gpt-4o-mini", temperature=0)`.
42
43
    Returns:
44
        StateGraph: A compiled LangGraph application that enables the Zotero agent to
45
            process user queries and retrieve research papers.
46
47
    Example:
48
        >>> app = get_app("thread_123")
49
        >>> result = app.invoke(initial_state)
50
    """
51
52
    def agent_zotero_node(state: Talk2Scholars) -> Dict[str, Any]:
53
        """
54
        Processes the user query and retrieves relevant research papers from Zotero.
55
56
        This function calls the language model using the configured `ReAct` agent to
57
        analyze the state and generate an appropriate response. The function then
58
        returns control to the main supervisor.
59
60
        Args:
61
            state (Talk2Scholars): The current conversation state, including messages exchanged
62
                and any previously retrieved research papers.
63
64
        Returns:
65
            Dict[str, Any]: A dictionary containing the updated conversation state.
66
67
        Example:
68
            >>> result = agent_zotero_node(current_state)
69
            >>> papers = result.get("papers", [])
70
        """
71
        logger.log(
72
            logging.INFO, "Creating Agent_Zotero node with thread_id %s", uniq_id
73
        )
74
        result = model.invoke(state, {"configurable": {"thread_id": uniq_id}})
75
76
        return result
77
78
    # Load hydra configuration
79
    logger.log(logging.INFO, "Load Hydra configuration for Talk2Scholars Zotero agent.")
80
    with hydra.initialize(version_base=None, config_path="../configs"):
81
        cfg = hydra.compose(
82
            config_name="config",
83
            overrides=["agents/talk2scholars/zotero_agent=default"],
84
        )
85
        cfg = cfg.agents.talk2scholars.zotero_agent
86
        logger.log(logging.INFO, "Loaded configuration for Zotero agent")
87
88
    # Define the tools
89
    tools = ToolNode(
90
        [
91
            zotero_read,
92
            s2_display,
93
            s2_query_results,
94
            retrieve_semantic_scholar_paper_id,
95
            zotero_review,  # First review
96
            zotero_write,  # Then save with user confirmation
97
        ]
98
    )
99
100
    # Define the model
101
    logger.log(logging.INFO, "Using model %s", llm_model)
102
103
    # Create the agent
104
    model = create_react_agent(
105
        llm_model,
106
        tools=tools,
107
        state_schema=Talk2Scholars,
108
        prompt=cfg.zotero_agent,
109
        checkpointer=MemorySaver(),  # Required for interrupts to work
110
    )
111
112
    workflow = StateGraph(Talk2Scholars)
113
    workflow.add_node("agent_zotero", agent_zotero_node)
114
    workflow.add_edge(START, "agent_zotero")
115
116
    # Initialize memory to persist state between graph runs
117
    checkpointer = MemorySaver()
118
119
    # Compile the graph
120
    app = workflow.compile(checkpointer=checkpointer, name="agent_zotero")
121
    logger.log(
122
        logging.INFO,
123
        "Compiled the graph with thread_id %s and llm_model %s",
124
        uniq_id,
125
        llm_model,
126
    )
127
128
    return app