[9d3784]: / aiagents4pharma / talk2scholars / tools / s2 / utils / search_helper.py

Download this file

176 lines (152 with data), 6.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#!/usr/bin/env python3
"""
Utility for fetching recommendations based on a single paper.
"""
import logging
from typing import Any, Optional, Dict
import hydra
import requests
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SearchData:
"""Helper class to organize search-related data."""
def __init__(
self,
query: str,
limit: int,
year: Optional[str],
tool_call_id: str,
):
self.query = query
self.limit = limit
self.year = year
self.tool_call_id = tool_call_id
self.cfg = self._load_config()
self.endpoint = self.cfg.api_endpoint
self.params = self._create_params()
self.response = None
self.data = None
self.papers = []
self.filtered_papers = {}
self.content = ""
def _load_config(self) -> Any:
"""Load hydra configuration."""
with hydra.initialize(version_base=None, config_path="../../../configs"):
cfg = hydra.compose(
config_name="config", overrides=["tools/search=default"]
)
logger.info("Loaded configuration for search tool")
return cfg.tools.search
def _create_params(self) -> Dict[str, Any]:
"""Create parameters for the API request."""
params = {
"query": self.query,
"limit": min(self.limit, 100),
"fields": ",".join(self.cfg.api_fields),
}
if self.year:
params["year"] = self.year
return params
def _fetch_papers(self) -> None:
"""Fetch papers from Semantic Scholar API."""
logger.info("Searching for papers on %s", self.query)
# Wrap API call in try/except to catch connectivity issues
for attempt in range(10):
try:
self.response = requests.get(
self.endpoint, params=self.params, timeout=10
)
self.response.raise_for_status() # Raises HTTPError for bad responses
break # Exit loop if request is successful
except requests.exceptions.RequestException as e:
logger.error(
"Attempt %d: Failed to connect to Semantic Scholar API: %s",
attempt + 1,
e,
)
if attempt == 9: # Last attempt
raise RuntimeError(
"Failed to connect to Semantic Scholar API after 10 attempts."
"Please retry the same query."
) from e
if self.response is None:
raise RuntimeError(
"Failed to obtain a response from the Semantic Scholar API."
)
self.data = self.response.json()
# Check for expected data format
if "data" not in self.data:
logger.error("Unexpected API response format: %s", self.data)
raise RuntimeError(
"Unexpected response from Semantic Scholar API. The results could not be "
"retrieved due to an unexpected format. "
"Please modify your search query and try again."
)
self.papers = self.data.get("data", [])
if not self.papers:
logger.error(
"No papers returned from Semantic Scholar API for query: %s", self.query
)
raise RuntimeError(
"No papers were found for your query. Consider refining your search "
"by using more specific keywords or different terms."
)
def _filter_papers(self) -> None:
"""Filter and format papers."""
self.filtered_papers = {
paper["paperId"]: {
"semantic_scholar_paper_id": paper["paperId"],
"Title": paper.get("title", "N/A"),
"Abstract": paper.get("abstract", "N/A"),
"Year": paper.get("year", "N/A"),
"Publication Date": paper.get("publicationDate", "N/A"),
"Venue": paper.get("venue", "N/A"),
"Journal Name": (paper.get("journal") or {}).get("name", "N/A"),
"Citation Count": paper.get("citationCount", "N/A"),
"Authors": [
f"{author.get('name', 'N/A')} (ID: {author.get('authorId', 'N/A')})"
for author in paper.get("authors", [])
],
"URL": paper.get("url", "N/A"),
"arxiv_id": paper.get("externalIds", {}).get("ArXiv", "N/A"),
}
for paper in self.papers
if paper.get("title") and paper.get("authors")
}
logger.info("Filtered %d papers", len(self.filtered_papers))
def _create_content(self) -> None:
"""Create the content message for the response."""
top_papers = list(self.filtered_papers.values())[:3]
top_papers_info = "\n".join(
[
f"{i+1}. {paper['Title']} ({paper['Year']}; "
f"semantic_scholar_paper_id: {paper['semantic_scholar_paper_id']}; "
f"arXiv ID: {paper['arxiv_id']})"
for i, paper in enumerate(top_papers)
]
)
logger.info("-----------Filtered %d papers", self.get_paper_count())
self.content = (
"Search was successful. Papers are attached as an artifact. "
"Here is a summary of the search results:\n"
)
self.content += f"Number of papers found: {self.get_paper_count()}\n"
self.content += f"Query: {self.query}\n"
self.content += f"Year: {self.year}\n" if self.year else ""
self.content += "Top 3 papers:\n" + top_papers_info
def process_search(self) -> Dict[str, Any]:
"""Process the search request and return results."""
self._fetch_papers()
self._filter_papers()
self._create_content()
return {
"papers": self.filtered_papers,
"content": self.content,
}
def get_paper_count(self) -> int:
"""Get the number of papers found in the search.
Returns:
int: The number of papers in the filtered papers dictionary.
"""
return len(self.filtered_papers)