[3af7d7]: / aiagents4pharma / talk2scholars / tests / test_paper_download_agent.py

Download this file

143 lines (119 with data), 5.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""Unit tests for the paper download agent in Talk2Scholars."""
from unittest import mock
import pytest
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.language_models.chat_models import BaseChatModel
from ..agents.paper_download_agent import get_app
from ..state.state_talk2scholars import Talk2Scholars
@pytest.fixture(autouse=True)
def mock_hydra_fixture():
"""Mocks Hydra configuration for tests."""
with mock.patch("hydra.initialize"), mock.patch("hydra.compose") as mock_compose:
cfg_mock = mock.MagicMock()
cfg_mock.agents.talk2scholars.s2_agent.temperature = 0
cfg_mock.agents.talk2scholars.paper_download_agent.prompt = "Test prompt"
mock_compose.return_value = cfg_mock
yield mock_compose
@pytest.fixture
def mock_tools_fixture():
"""Mocks paper download tools to prevent real HTTP calls."""
with (
mock.patch(
"aiagents4pharma.talk2scholars.tools.paper_download."
"download_arxiv_input.download_arxiv_paper"
) as mock_download_arxiv_paper,
mock.patch(
"aiagents4pharma.talk2scholars.tools.s2.query_results.query_results"
) as mock_query_results,
):
mock_download_arxiv_paper.return_value = {
"pdf_data": {"dummy_key": "dummy_value"}
}
mock_query_results.return_value = {
"result": "Mocked Query Result"
}
yield [mock_download_arxiv_paper, mock_query_results]
@pytest.mark.usefixtures("mock_hydra_fixture")
def test_paper_download_agent_initialization():
"""Ensures the paper download agent initializes properly with a prompt."""
thread_id = "test_thread_paper_dl"
llm_mock = mock.Mock(spec=BaseChatModel) # Mock LLM
with mock.patch(
"aiagents4pharma.talk2scholars.agents.paper_download_agent.create_react_agent"
) as mock_create_agent:
mock_create_agent.return_value = mock.Mock()
app = get_app(thread_id, llm_mock)
assert app is not None, "The agent app should be successfully created."
assert mock_create_agent.called
def test_paper_download_agent_invocation():
"""Verifies agent processes queries and updates state correctly."""
_ = mock_tools_fixture # Prevents unused-argument warning
thread_id = "test_thread_paper_dl"
mock_state = Talk2Scholars(
messages=[HumanMessage(content="Download paper 1234.5678")]
)
llm_mock = mock.Mock(spec=BaseChatModel)
with mock.patch(
"aiagents4pharma.talk2scholars.agents.paper_download_agent.create_react_agent"
) as mock_create_agent:
mock_agent = mock.Mock()
mock_create_agent.return_value = mock_agent
mock_agent.invoke.return_value = {
"messages": [AIMessage(content="Here is the paper")],
"pdf_data": {"file_bytes": b"FAKE_PDF_CONTENTS"},
}
app = get_app(thread_id, llm_mock)
result = app.invoke(
mock_state,
config={
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": "test_ns",
"checkpoint_id": "test_checkpoint",
}
},
)
assert "messages" in result
assert "pdf_data" in result
def test_paper_download_agent_tools_assignment(request): # Keep fixture name
"""Checks correct tool assignment (download_arxiv_paper, query_results)."""
thread_id = "test_thread_paper_dl"
mock_tools = request.getfixturevalue("mock_tools_fixture")
llm_mock = mock.Mock(spec=BaseChatModel)
with (
mock.patch(
"aiagents4pharma.talk2scholars.agents.paper_download_agent.create_react_agent"
) as mock_create_agent,
mock.patch(
"aiagents4pharma.talk2scholars.agents.paper_download_agent.ToolNode"
) as mock_toolnode,
):
mock_agent = mock.Mock()
mock_create_agent.return_value = mock_agent
mock_tool_instance = mock.Mock()
mock_tool_instance.tools = mock_tools
mock_toolnode.return_value= mock_tool_instance
get_app(thread_id, llm_mock)
assert mock_toolnode.called
assert len(mock_tool_instance.tools) == 2
def test_paper_download_agent_hydra_failure():
"""Confirms the agent gracefully handles exceptions if Hydra fails."""
thread_id = "test_thread_paper_dl"
llm_mock = mock.Mock(spec=BaseChatModel)
with mock.patch("hydra.initialize", side_effect=Exception("Mock Hydra failure")):
with pytest.raises(Exception) as exc_info:
get_app(thread_id, llm_mock)
assert "Mock Hydra failure" in str(exc_info.value)
def test_paper_download_agent_model_failure():
"""Ensures agent handles model-related failures gracefully."""
thread_id = "test_thread_paper_dl"
llm_mock = mock.Mock(spec=BaseChatModel)
with mock.patch(
"aiagents4pharma.talk2scholars.agents.paper_download_agent.create_react_agent",
side_effect=Exception("Mock model failure"),
):
with pytest.raises(Exception) as exc_info:
get_app(thread_id, llm_mock)
assert "Mock model failure" in str(exc_info.value), (
"Model initialization failure should raise an exception."
)