|
a |
|
b/aiagents4pharma/talk2biomodels/tools/ask_question.py |
|
|
1 |
#!/usr/bin/env python3 |
|
|
2 |
|
|
|
3 |
""" |
|
|
4 |
Tool for asking a question about the simulation results. |
|
|
5 |
""" |
|
|
6 |
|
|
|
7 |
import logging |
|
|
8 |
from typing import Type, Annotated, Literal |
|
|
9 |
import hydra |
|
|
10 |
import basico |
|
|
11 |
import pandas as pd |
|
|
12 |
from pydantic import BaseModel, Field |
|
|
13 |
from langchain_core.tools.base import BaseTool |
|
|
14 |
from langchain_experimental.agents import create_pandas_dataframe_agent |
|
|
15 |
from langgraph.prebuilt import InjectedState |
|
|
16 |
|
|
|
17 |
# Initialize logger |
|
|
18 |
logging.basicConfig(level=logging.INFO) |
|
|
19 |
logger = logging.getLogger(__name__) |
|
|
20 |
|
|
|
21 |
class AskQuestionInput(BaseModel): |
|
|
22 |
""" |
|
|
23 |
Input schema for the AskQuestion tool. |
|
|
24 |
""" |
|
|
25 |
question: str = Field(description="question about the simulation and steady state results") |
|
|
26 |
experiment_name: str = Field(description="""Name assigned to the simulation |
|
|
27 |
or steady state analysis when the tool |
|
|
28 |
simulate_model or steady_state is invoked.""") |
|
|
29 |
question_context: Literal["simulation", "steady_state"] = Field( |
|
|
30 |
description="Context of the question") |
|
|
31 |
state: Annotated[dict, InjectedState] |
|
|
32 |
|
|
|
33 |
# Note: It's important that every field has type hints. |
|
|
34 |
# BaseTool is a Pydantic class and not having type hints |
|
|
35 |
# can lead to unexpected behavior. |
|
|
36 |
class AskQuestionTool(BaseTool): |
|
|
37 |
""" |
|
|
38 |
Tool for asking a question about the simulation or steady state results. |
|
|
39 |
""" |
|
|
40 |
name: str = "ask_question" |
|
|
41 |
description: str = """A tool to ask question about the |
|
|
42 |
simulation or steady state results.""" |
|
|
43 |
args_schema: Type[BaseModel] = AskQuestionInput |
|
|
44 |
return_direct: bool = False |
|
|
45 |
|
|
|
46 |
def _run(self, |
|
|
47 |
question: str, |
|
|
48 |
experiment_name: str, |
|
|
49 |
question_context: Literal["simulation", "steady_state"], |
|
|
50 |
state: Annotated[dict, InjectedState]) -> str: |
|
|
51 |
""" |
|
|
52 |
Run the tool. |
|
|
53 |
|
|
|
54 |
Args: |
|
|
55 |
question (str): The question to ask about the simulation or steady state results. |
|
|
56 |
state (dict): The state of the graph. |
|
|
57 |
experiment_name (str): The name assigned to the simulation or steady state analysis. |
|
|
58 |
|
|
|
59 |
Returns: |
|
|
60 |
str: The answer to the question. |
|
|
61 |
""" |
|
|
62 |
logger.log(logging.INFO, |
|
|
63 |
"Calling ask_question tool %s, %s, %s", |
|
|
64 |
question, |
|
|
65 |
question_context, |
|
|
66 |
experiment_name) |
|
|
67 |
# Load hydra configuration |
|
|
68 |
with hydra.initialize(version_base=None, config_path="../configs"): |
|
|
69 |
cfg = hydra.compose(config_name='config', |
|
|
70 |
overrides=['tools/ask_question=default']) |
|
|
71 |
cfg = cfg.tools.ask_question |
|
|
72 |
# Get the context of the question |
|
|
73 |
# and based on the context, get the data |
|
|
74 |
# and prompt content to ask the question |
|
|
75 |
if question_context == "steady_state": |
|
|
76 |
dic_context = state["dic_steady_state_data"] |
|
|
77 |
prompt_content = cfg.steady_state_prompt |
|
|
78 |
else: |
|
|
79 |
dic_context = state["dic_simulated_data"] |
|
|
80 |
prompt_content = cfg.simulation_prompt |
|
|
81 |
# Extract the |
|
|
82 |
dic_data = {} |
|
|
83 |
for data in dic_context: |
|
|
84 |
for key in data: |
|
|
85 |
if key not in dic_data: |
|
|
86 |
dic_data[key] = [] |
|
|
87 |
dic_data[key] += [data[key]] |
|
|
88 |
# Create a pandas dataframe of the data |
|
|
89 |
df_data = pd.DataFrame.from_dict(dic_data) |
|
|
90 |
# Extract the data for the experiment |
|
|
91 |
# matching the experiment name |
|
|
92 |
df = pd.DataFrame( |
|
|
93 |
df_data[df_data['name'] == experiment_name]['data'].iloc[0] |
|
|
94 |
) |
|
|
95 |
logger.log(logging.INFO, "Shape of the dataframe: %s", df.shape) |
|
|
96 |
# # Extract the model units |
|
|
97 |
# model_units = basico.model_info.get_model_units() |
|
|
98 |
# Update the prompt content with the model units |
|
|
99 |
prompt_content += "Following are the model units:\n" |
|
|
100 |
prompt_content += f"{basico.model_info.get_model_units()}\n\n" |
|
|
101 |
# Create a pandas dataframe agent |
|
|
102 |
df_agent = create_pandas_dataframe_agent( |
|
|
103 |
state['llm_model'], |
|
|
104 |
allow_dangerous_code=True, |
|
|
105 |
agent_type='tool-calling', |
|
|
106 |
df=df, |
|
|
107 |
max_iterations=5, |
|
|
108 |
include_df_in_prompt=True, |
|
|
109 |
number_of_head_rows=df.shape[0], |
|
|
110 |
verbose=True, |
|
|
111 |
prefix=prompt_content) |
|
|
112 |
# Invoke the agent with the question |
|
|
113 |
llm_result = df_agent.invoke(question, stream_mode=None) |
|
|
114 |
# print (llm_result) |
|
|
115 |
return llm_result["output"] |