[9d3784]: / aiagents4pharma / talk2scholars / tests / test_s2_display.py

Download this file

75 lines (62 with data), 2.2 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
Unit tests for S2 tools functionality.
"""
# pylint: disable=redefined-outer-name
import pytest
from langgraph.types import Command
from ..tools.s2.display_results import (
display_results,
NoPapersFoundError as raised_error,
)
@pytest.fixture
def initial_state():
"""Provides an empty initial state for tests."""
return {"papers": {}, "multi_papers": {}}
# Fixed test data for deterministic results
MOCK_SEARCH_RESPONSE = {
"data": [
{
"paperId": "123",
"title": "Machine Learning Basics",
"abstract": "An introduction to ML",
"year": 2023,
"citationCount": 100,
"url": "https://example.com/paper1",
"authors": [{"name": "Test Author"}],
}
]
}
MOCK_STATE_PAPER = {
"123": {
"Title": "Machine Learning Basics",
"Abstract": "An introduction to ML",
"Year": 2023,
"Citation Count": 100,
"URL": "https://example.com/paper1",
}
}
class TestS2Tools:
"""Unit tests for individual S2 tools"""
def test_display_results_empty_state(self, initial_state):
"""Verifies display_results tool behavior when state is empty and raises an exception"""
with pytest.raises(
raised_error,
match="No papers found. A search/rec needs to be performed first.",
):
display_results.invoke({"state": initial_state, "tool_call_id": "test123"})
def test_display_results_shows_papers(self, initial_state):
"""Verifies display_results tool correctly returns papers from state"""
state = initial_state.copy()
state["last_displayed_papers"] = "papers"
state["papers"] = MOCK_STATE_PAPER
result = display_results.invoke(
input={"state": state, "tool_call_id": "test123"}
)
assert isinstance(result, Command) # Expect a Command object
assert isinstance(result.update, dict) # Ensure update is a dictionary
assert "messages" in result.update
assert len(result.update["messages"]) == 1
assert (
"1 papers found. Papers are attached as an artifact."
in result.update["messages"][0].content
)