Switch to unified view

a b/aiagents4pharma/talk2biomodels/tools/search_models.py
1
#!/usr/bin/env python3
2
3
"""
4
Tool for searching models based on search query.
5
"""
6
7
from typing import Type, Annotated
8
import logging
9
from pydantic import BaseModel, Field
10
import pandas as pd
11
from basico import biomodels
12
from langgraph.types import Command
13
from langchain_core.tools import BaseTool
14
from langchain_core.messages import ToolMessage
15
from langchain_core.tools.base import InjectedToolCallId
16
17
# Initialize logger
18
logging.basicConfig(level=logging.INFO)
19
logger = logging.getLogger(__name__)
20
21
class SearchModelsInput(BaseModel):
22
    """
23
    Input schema for the search models tool.
24
    """
25
    query: str = Field(description="Search models query", default=None)
26
    num_query: int = Field(description="Top number of models to search",
27
                           default=10,
28
                           le=100)
29
    tool_call_id: Annotated[str, InjectedToolCallId]
30
31
# Note: It's important that every field has type hints. BaseTool is a
32
# Pydantic class and not having type hints can lead to unexpected behavior.
33
class SearchModelsTool(BaseTool):
34
    """
35
    Tool for returning the search results based on the search query.
36
    """
37
    name: str = "search_models"
38
    description: str = "Search for only manually curated models in "
39
    "the BioMmodels database based on keywords."
40
    args_schema: Type[BaseModel] = SearchModelsInput
41
    return_direct: bool = False
42
43
    def _run(self,
44
             tool_call_id: Annotated[str, InjectedToolCallId],
45
             query: str = None,
46
             num_query: int = 10) -> dict:
47
        """
48
        Run the tool.
49
50
        Args:
51
            query (str): The search query.
52
            num_query (int): The number of models to search.
53
            tool_call_id (str): The tool call ID.
54
55
        Returns:
56
            dict: The answer to the question in the form of a dictionary.
57
        """
58
        logger.log(logging.INFO, "Searching models with the query and number %s, %s",
59
                   query, num_query)
60
        # Search for models based on the query
61
        search_results = biomodels.search_for_model(query, num_results=num_query)
62
        # Convert the search results to a pandas DataFrame
63
        df = pd.DataFrame(search_results)
64
        # Prepare a message to return
65
        first_n = min(3, len(search_results))
66
        content = f"Found {len(search_results)} manually curated models"
67
        content += f" for the query: {query}."
68
        # Pass the first 3 models to the LLM
69
        # to avoid hallucinations
70
        content += f" Here is the summary of the first {first_n} models:"
71
        for i in range(first_n):
72
            content += f"\nModel {i+1}: {search_results[i]['name']} (ID: {search_results[i]['id']})"
73
        # Return the updated state of the tool
74
        return Command(
75
                update={
76
                    # update the message history
77
                    "messages": [
78
                        ToolMessage(
79
                            content=content,
80
                            tool_call_id=tool_call_id,
81
                            artifact={'dic_data': df.to_dict(orient='records')}
82
                            )
83
                        ],
84
                    }
85
            )