|
a |
|
b/aiagents4pharma/talk2biomodels/tools/search_models.py |
|
|
1 |
#!/usr/bin/env python3 |
|
|
2 |
|
|
|
3 |
""" |
|
|
4 |
Tool for searching models based on search query. |
|
|
5 |
""" |
|
|
6 |
|
|
|
7 |
from typing import Type, Annotated |
|
|
8 |
import logging |
|
|
9 |
from pydantic import BaseModel, Field |
|
|
10 |
import pandas as pd |
|
|
11 |
from basico import biomodels |
|
|
12 |
from langgraph.types import Command |
|
|
13 |
from langchain_core.tools import BaseTool |
|
|
14 |
from langchain_core.messages import ToolMessage |
|
|
15 |
from langchain_core.tools.base import InjectedToolCallId |
|
|
16 |
|
|
|
17 |
# Initialize logger |
|
|
18 |
logging.basicConfig(level=logging.INFO) |
|
|
19 |
logger = logging.getLogger(__name__) |
|
|
20 |
|
|
|
21 |
class SearchModelsInput(BaseModel): |
|
|
22 |
""" |
|
|
23 |
Input schema for the search models tool. |
|
|
24 |
""" |
|
|
25 |
query: str = Field(description="Search models query", default=None) |
|
|
26 |
num_query: int = Field(description="Top number of models to search", |
|
|
27 |
default=10, |
|
|
28 |
le=100) |
|
|
29 |
tool_call_id: Annotated[str, InjectedToolCallId] |
|
|
30 |
|
|
|
31 |
# Note: It's important that every field has type hints. BaseTool is a |
|
|
32 |
# Pydantic class and not having type hints can lead to unexpected behavior. |
|
|
33 |
class SearchModelsTool(BaseTool): |
|
|
34 |
""" |
|
|
35 |
Tool for returning the search results based on the search query. |
|
|
36 |
""" |
|
|
37 |
name: str = "search_models" |
|
|
38 |
description: str = "Search for only manually curated models in " |
|
|
39 |
"the BioMmodels database based on keywords." |
|
|
40 |
args_schema: Type[BaseModel] = SearchModelsInput |
|
|
41 |
return_direct: bool = False |
|
|
42 |
|
|
|
43 |
def _run(self, |
|
|
44 |
tool_call_id: Annotated[str, InjectedToolCallId], |
|
|
45 |
query: str = None, |
|
|
46 |
num_query: int = 10) -> dict: |
|
|
47 |
""" |
|
|
48 |
Run the tool. |
|
|
49 |
|
|
|
50 |
Args: |
|
|
51 |
query (str): The search query. |
|
|
52 |
num_query (int): The number of models to search. |
|
|
53 |
tool_call_id (str): The tool call ID. |
|
|
54 |
|
|
|
55 |
Returns: |
|
|
56 |
dict: The answer to the question in the form of a dictionary. |
|
|
57 |
""" |
|
|
58 |
logger.log(logging.INFO, "Searching models with the query and number %s, %s", |
|
|
59 |
query, num_query) |
|
|
60 |
# Search for models based on the query |
|
|
61 |
search_results = biomodels.search_for_model(query, num_results=num_query) |
|
|
62 |
# Convert the search results to a pandas DataFrame |
|
|
63 |
df = pd.DataFrame(search_results) |
|
|
64 |
# Prepare a message to return |
|
|
65 |
first_n = min(3, len(search_results)) |
|
|
66 |
content = f"Found {len(search_results)} manually curated models" |
|
|
67 |
content += f" for the query: {query}." |
|
|
68 |
# Pass the first 3 models to the LLM |
|
|
69 |
# to avoid hallucinations |
|
|
70 |
content += f" Here is the summary of the first {first_n} models:" |
|
|
71 |
for i in range(first_n): |
|
|
72 |
content += f"\nModel {i+1}: {search_results[i]['name']} (ID: {search_results[i]['id']})" |
|
|
73 |
# Return the updated state of the tool |
|
|
74 |
return Command( |
|
|
75 |
update={ |
|
|
76 |
# update the message history |
|
|
77 |
"messages": [ |
|
|
78 |
ToolMessage( |
|
|
79 |
content=content, |
|
|
80 |
tool_call_id=tool_call_id, |
|
|
81 |
artifact={'dic_data': df.to_dict(orient='records')} |
|
|
82 |
) |
|
|
83 |
], |
|
|
84 |
} |
|
|
85 |
) |