a b/aiagents4pharma/talk2biomodels/tools/ask_question.py
1
#!/usr/bin/env python3
2
3
"""
4
Tool for asking a question about the simulation results.
5
"""
6
7
import logging
8
from typing import Type, Annotated, Literal
9
import hydra
10
import basico
11
import pandas as pd
12
from pydantic import BaseModel, Field
13
from langchain_core.tools.base import BaseTool
14
from langchain_experimental.agents import create_pandas_dataframe_agent
15
from langgraph.prebuilt import InjectedState
16
17
# Initialize logger
18
logging.basicConfig(level=logging.INFO)
19
logger = logging.getLogger(__name__)
20
21
class AskQuestionInput(BaseModel):
22
    """
23
    Input schema for the AskQuestion tool.
24
    """
25
    question: str = Field(description="question about the simulation and steady state results")
26
    experiment_name: str = Field(description="""Name assigned to the simulation
27
                                            or steady state analysis when the tool 
28
                                            simulate_model or steady_state is invoked.""")
29
    question_context: Literal["simulation", "steady_state"] = Field(
30
        description="Context of the question")
31
    state: Annotated[dict, InjectedState]
32
33
# Note: It's important that every field has type hints.
34
# BaseTool is a Pydantic class and not having type hints
35
# can lead to unexpected behavior.
36
class AskQuestionTool(BaseTool):
37
    """
38
    Tool for asking a question about the simulation or steady state results.
39
    """
40
    name: str = "ask_question"
41
    description: str = """A tool to ask question about the
42
                        simulation or steady state results."""
43
    args_schema: Type[BaseModel] = AskQuestionInput
44
    return_direct: bool = False
45
46
    def _run(self,
47
             question: str,
48
             experiment_name: str,
49
             question_context: Literal["simulation", "steady_state"],
50
             state: Annotated[dict, InjectedState]) -> str:
51
        """
52
        Run the tool.
53
54
        Args:
55
            question (str): The question to ask about the simulation or steady state results.
56
            state (dict): The state of the graph.
57
            experiment_name (str): The name assigned to the simulation or steady state analysis.
58
59
        Returns:
60
            str: The answer to the question.
61
        """
62
        logger.log(logging.INFO,
63
                   "Calling ask_question tool %s, %s, %s",
64
                   question,
65
                   question_context,
66
                   experiment_name)
67
        # Load hydra configuration
68
        with hydra.initialize(version_base=None, config_path="../configs"):
69
            cfg = hydra.compose(config_name='config',
70
                                overrides=['tools/ask_question=default'])
71
            cfg = cfg.tools.ask_question
72
        # Get the context of the question
73
        # and based on the context, get the data
74
        # and prompt content to ask the question
75
        if question_context == "steady_state":
76
            dic_context = state["dic_steady_state_data"]
77
            prompt_content = cfg.steady_state_prompt
78
        else:
79
            dic_context = state["dic_simulated_data"]
80
            prompt_content = cfg.simulation_prompt
81
        # Extract the
82
        dic_data = {}
83
        for data in dic_context:
84
            for key in data:
85
                if key not in dic_data:
86
                    dic_data[key] = []
87
                dic_data[key] += [data[key]]
88
        # Create a pandas dataframe of the data
89
        df_data = pd.DataFrame.from_dict(dic_data)
90
        # Extract the data for the experiment
91
        # matching the experiment name
92
        df = pd.DataFrame(
93
            df_data[df_data['name'] == experiment_name]['data'].iloc[0]
94
        )
95
        logger.log(logging.INFO, "Shape of the dataframe: %s", df.shape)
96
        # # Extract the model units
97
        # model_units = basico.model_info.get_model_units()
98
        # Update the prompt content with the model units
99
        prompt_content += "Following are the model units:\n"
100
        prompt_content += f"{basico.model_info.get_model_units()}\n\n"
101
        # Create a pandas dataframe agent
102
        df_agent = create_pandas_dataframe_agent(
103
                        state['llm_model'],
104
                        allow_dangerous_code=True,
105
                        agent_type='tool-calling',
106
                        df=df,
107
                        max_iterations=5,
108
                        include_df_in_prompt=True,
109
                        number_of_head_rows=df.shape[0],
110
                        verbose=True,
111
                        prefix=prompt_content)
112
        # Invoke the agent with the question
113
        llm_result = df_agent.invoke(question, stream_mode=None)
114
        # print (llm_result)
115
        return llm_result["output"]