|
a |
|
b/aiagents4pharma/talk2scholars/agents/zotero_agent.py |
|
|
1 |
#!/usr/bin/env python3 |
|
|
2 |
|
|
|
3 |
""" |
|
|
4 |
Agent for interacting with Zotero with human-in-the-loop features |
|
|
5 |
""" |
|
|
6 |
|
|
|
7 |
import logging |
|
|
8 |
from typing import Any, Dict |
|
|
9 |
import hydra |
|
|
10 |
|
|
|
11 |
from langchain_core.language_models.chat_models import BaseChatModel |
|
|
12 |
from langgraph.graph import START, StateGraph |
|
|
13 |
from langgraph.prebuilt import create_react_agent, ToolNode |
|
|
14 |
from langgraph.checkpoint.memory import MemorySaver |
|
|
15 |
from ..state.state_talk2scholars import Talk2Scholars |
|
|
16 |
from ..tools.zotero.zotero_read import zotero_read |
|
|
17 |
from ..tools.zotero.zotero_review import zotero_review |
|
|
18 |
from ..tools.zotero.zotero_write import zotero_write |
|
|
19 |
from ..tools.s2.display_results import display_results as s2_display |
|
|
20 |
from ..tools.s2.query_results import query_results as s2_query_results |
|
|
21 |
from ..tools.s2.retrieve_semantic_scholar_paper_id import ( |
|
|
22 |
retrieve_semantic_scholar_paper_id, |
|
|
23 |
) |
|
|
24 |
|
|
|
25 |
# Initialize logger |
|
|
26 |
logging.basicConfig(level=logging.INFO) |
|
|
27 |
logger = logging.getLogger(__name__) |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
def get_app(uniq_id, llm_model: BaseChatModel): |
|
|
31 |
""" |
|
|
32 |
Initializes and returns the LangGraph application for the Zotero agent. |
|
|
33 |
|
|
|
34 |
This function sets up the Zotero agent, which integrates various tools to search, |
|
|
35 |
retrieve, and display research papers from Zotero. The agent follows the ReAct |
|
|
36 |
pattern for structured interaction and includes human-in-the-loop features. |
|
|
37 |
|
|
|
38 |
Args: |
|
|
39 |
uniq_id (str): Unique identifier for the current conversation session. |
|
|
40 |
llm_model (BaseChatModel, optional): The language model to be used by the agent. |
|
|
41 |
Defaults to `ChatOpenAI(model="gpt-4o-mini", temperature=0)`. |
|
|
42 |
|
|
|
43 |
Returns: |
|
|
44 |
StateGraph: A compiled LangGraph application that enables the Zotero agent to |
|
|
45 |
process user queries and retrieve research papers. |
|
|
46 |
|
|
|
47 |
Example: |
|
|
48 |
>>> app = get_app("thread_123") |
|
|
49 |
>>> result = app.invoke(initial_state) |
|
|
50 |
""" |
|
|
51 |
|
|
|
52 |
def agent_zotero_node(state: Talk2Scholars) -> Dict[str, Any]: |
|
|
53 |
""" |
|
|
54 |
Processes the user query and retrieves relevant research papers from Zotero. |
|
|
55 |
|
|
|
56 |
This function calls the language model using the configured `ReAct` agent to |
|
|
57 |
analyze the state and generate an appropriate response. The function then |
|
|
58 |
returns control to the main supervisor. |
|
|
59 |
|
|
|
60 |
Args: |
|
|
61 |
state (Talk2Scholars): The current conversation state, including messages exchanged |
|
|
62 |
and any previously retrieved research papers. |
|
|
63 |
|
|
|
64 |
Returns: |
|
|
65 |
Dict[str, Any]: A dictionary containing the updated conversation state. |
|
|
66 |
|
|
|
67 |
Example: |
|
|
68 |
>>> result = agent_zotero_node(current_state) |
|
|
69 |
>>> papers = result.get("papers", []) |
|
|
70 |
""" |
|
|
71 |
logger.log( |
|
|
72 |
logging.INFO, "Creating Agent_Zotero node with thread_id %s", uniq_id |
|
|
73 |
) |
|
|
74 |
result = model.invoke(state, {"configurable": {"thread_id": uniq_id}}) |
|
|
75 |
|
|
|
76 |
return result |
|
|
77 |
|
|
|
78 |
# Load hydra configuration |
|
|
79 |
logger.log(logging.INFO, "Load Hydra configuration for Talk2Scholars Zotero agent.") |
|
|
80 |
with hydra.initialize(version_base=None, config_path="../configs"): |
|
|
81 |
cfg = hydra.compose( |
|
|
82 |
config_name="config", |
|
|
83 |
overrides=["agents/talk2scholars/zotero_agent=default"], |
|
|
84 |
) |
|
|
85 |
cfg = cfg.agents.talk2scholars.zotero_agent |
|
|
86 |
logger.log(logging.INFO, "Loaded configuration for Zotero agent") |
|
|
87 |
|
|
|
88 |
# Define the tools |
|
|
89 |
tools = ToolNode( |
|
|
90 |
[ |
|
|
91 |
zotero_read, |
|
|
92 |
s2_display, |
|
|
93 |
s2_query_results, |
|
|
94 |
retrieve_semantic_scholar_paper_id, |
|
|
95 |
zotero_review, # First review |
|
|
96 |
zotero_write, # Then save with user confirmation |
|
|
97 |
] |
|
|
98 |
) |
|
|
99 |
|
|
|
100 |
# Define the model |
|
|
101 |
logger.log(logging.INFO, "Using model %s", llm_model) |
|
|
102 |
|
|
|
103 |
# Create the agent |
|
|
104 |
model = create_react_agent( |
|
|
105 |
llm_model, |
|
|
106 |
tools=tools, |
|
|
107 |
state_schema=Talk2Scholars, |
|
|
108 |
prompt=cfg.zotero_agent, |
|
|
109 |
checkpointer=MemorySaver(), # Required for interrupts to work |
|
|
110 |
) |
|
|
111 |
|
|
|
112 |
workflow = StateGraph(Talk2Scholars) |
|
|
113 |
workflow.add_node("agent_zotero", agent_zotero_node) |
|
|
114 |
workflow.add_edge(START, "agent_zotero") |
|
|
115 |
|
|
|
116 |
# Initialize memory to persist state between graph runs |
|
|
117 |
checkpointer = MemorySaver() |
|
|
118 |
|
|
|
119 |
# Compile the graph |
|
|
120 |
app = workflow.compile(checkpointer=checkpointer, name="agent_zotero") |
|
|
121 |
logger.log( |
|
|
122 |
logging.INFO, |
|
|
123 |
"Compiled the graph with thread_id %s and llm_model %s", |
|
|
124 |
uniq_id, |
|
|
125 |
llm_model, |
|
|
126 |
) |
|
|
127 |
|
|
|
128 |
return app |