[3af7d7]: / aiagents4pharma / talk2scholars / tests / test_main_agent.py

Download this file

221 lines (164 with data), 6.8 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""
Unit tests for main agent functionality.
Tests the supervisor agent's routing logic and state management.
"""
# pylint: disable=redefined-outer-name,too-few-public-methods
from types import SimpleNamespace
import pytest
import hydra
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import ChatOpenAI
from pydantic import Field
from aiagents4pharma.talk2scholars.agents.main_agent import get_app
# --- Dummy LLM Implementation ---
class DummyLLM(BaseChatModel):
"""A dummy language model implementation for testing purposes."""
model_name: str = Field(...)
def _generate(self, prompt, stop=None):
"""Generate a response given a prompt."""
DummyLLM.called_prompt = prompt
return "dummy output"
@property
def _llm_type(self):
"""Return the type of the language model."""
return "dummy"
# --- Dummy Workflow and Sub-agent Functions ---
class DummyWorkflow:
"""A dummy workflow class that records arguments for verification."""
def __init__(self, supervisor_args=None):
"""Initialize the workflow with the given supervisor arguments."""
self.supervisor_args = supervisor_args or {}
self.checkpointer = None
self.name = None
def compile(self, checkpointer, name):
"""Compile the workflow with the given checkpointer and name."""
self.checkpointer = checkpointer
self.name = name
return self
def dummy_get_app_s2(uniq_id, llm_model):
"""Return a DummyWorkflow for the S2 agent."""
dummy_get_app_s2.called_uniq_id = uniq_id
dummy_get_app_s2.called_llm_model = llm_model
return DummyWorkflow(supervisor_args={"agent": "s2", "uniq_id": uniq_id})
def dummy_get_app_zotero(uniq_id, llm_model):
"""Return a DummyWorkflow for the Zotero agent."""
dummy_get_app_zotero.called_uniq_id = uniq_id
dummy_get_app_zotero.called_llm_model = llm_model
return DummyWorkflow(supervisor_args={"agent": "zotero", "uniq_id": uniq_id})
def dummy_get_app_pdf(uniq_id, llm_model):
"""Return a DummyWorkflow for the PDF agent."""
dummy_get_app_pdf.called_uniq_id = uniq_id
dummy_get_app_pdf.called_llm_model = llm_model
return DummyWorkflow(supervisor_args={"agent": "pdf", "uniq_id": uniq_id})
def dummy_create_supervisor(apps, model, state_schema, **kwargs):
"""Return a DummyWorkflow for the supervisor."""
dummy_create_supervisor.called_kwargs = kwargs
return DummyWorkflow(
supervisor_args={
"apps": apps,
"model": model,
"state_schema": state_schema,
**kwargs,
}
)
# --- Dummy Hydra Configuration Setup ---
class DummyHydraContext:
"""A dummy context manager for mocking Hydra's initialize and compose functions."""
def __enter__(self):
"""Return None when entering the context."""
return None
def __exit__(self, exc_type, exc_val, traceback):
"""Exit function that does nothing."""
return None
def dict_to_namespace(d):
"""Convert a dictionary to a SimpleNamespace object."""
return SimpleNamespace(
**{
key: dict_to_namespace(val) if isinstance(val, dict) else val
for key, val in d.items()
}
)
dummy_config = {
"agents": {
"talk2scholars": {"main_agent": {"system_prompt": "Dummy system prompt"}}
}
}
class DummyHydraCompose:
"""A dummy class that returns a namespace from a dummy config dictionary."""
def __init__(self, config):
"""Constructor that stores the dummy config."""
self.config = config
def __getattr__(self, item):
"""Return a namespace from the dummy config."""
return dict_to_namespace(self.config.get(item, {}))
# --- Pytest Fixtures to Patch Dependencies ---
@pytest.fixture(autouse=True)
def patch_hydra(monkeypatch):
"""Patch the hydra.initialize and hydra.compose functions to return dummy objects."""
monkeypatch.setattr(
hydra, "initialize", lambda version_base, config_path: DummyHydraContext()
)
monkeypatch.setattr(
hydra, "compose", lambda config_name, overrides: DummyHydraCompose(dummy_config)
)
def dummy_get_app_paper_download(uniq_id, llm_model):
"""Return a DummyWorkflow for the paper download agent."""
dummy_get_app_paper_download.called_uniq_id = uniq_id
dummy_get_app_paper_download.called_llm_model = llm_model
return DummyWorkflow(
supervisor_args={"agent": "paper_download", "uniq_id": uniq_id}
)
@pytest.fixture(autouse=True)
def patch_sub_agents_and_supervisor(monkeypatch):
"""Patch the sub-agents and supervisor creation functions."""
monkeypatch.setattr(
"aiagents4pharma.talk2scholars.agents.main_agent.get_app_s2", dummy_get_app_s2
)
monkeypatch.setattr(
"aiagents4pharma.talk2scholars.agents.main_agent.get_app_zotero",
dummy_get_app_zotero,
)
monkeypatch.setattr(
"aiagents4pharma.talk2scholars.agents.main_agent.get_app_pdf",
dummy_get_app_pdf,
)
monkeypatch.setattr(
"aiagents4pharma.talk2scholars.agents.main_agent.get_app_paper_download",
dummy_get_app_paper_download,
)
monkeypatch.setattr(
"aiagents4pharma.talk2scholars.agents.main_agent.create_supervisor",
dummy_create_supervisor,
)
# --- Test Cases ---
def test_dummy_llm_generate():
"""Test the dummy LLM's generate function."""
dummy = DummyLLM(model_name="test-model")
output = getattr(dummy, "_generate")("any prompt")
assert output == "dummy output"
def test_dummy_llm_llm_type():
"""Test the dummy LLM's _llm_type property."""
dummy = DummyLLM(model_name="test-model")
assert getattr(dummy, "_llm_type") == "dummy"
def test_get_app_with_gpt4o_mini():
"""
Test that get_app replaces a 'gpt-4o-mini' LLM with a new ChatOpenAI instance.
"""
uniq_id = "test_thread"
dummy_llm = DummyLLM(model_name="gpt-4o-mini")
app = get_app(uniq_id, dummy_llm)
supervisor_args = getattr(app, "supervisor_args", {})
assert isinstance(supervisor_args.get("model"), ChatOpenAI)
assert supervisor_args.get("prompt") == "Dummy system prompt"
assert getattr(app, "name", "") == "Talk2Scholars_MainAgent"
def test_get_app_with_other_model():
"""
Test that get_app does not replace the LLM if its model_name is not 'gpt-4o-mini'.
"""
uniq_id = "test_thread_2"
dummy_llm = DummyLLM(model_name="other-model")
app = get_app(uniq_id, dummy_llm)
supervisor_args = getattr(app, "supervisor_args", {})
assert supervisor_args.get("model") is dummy_llm
assert supervisor_args.get("prompt") == "Dummy system prompt"
assert getattr(app, "name", "") == "Talk2Scholars_MainAgent"