[3af7d7]: / aiagents4pharma / talk2biomodels / tools / custom_plotter.py

Download this file

158 lines (144 with data), 6.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#!/usr/bin/env python3
"""
Tool for plotting a custom y-axis of a simulation plot.
"""
import logging
from typing import Type, Annotated, List, Tuple, Union, Literal
from pydantic import BaseModel, Field
import hydra
import pandas as pd
from langchain_core.tools import BaseTool
from langchain_core.prompts import ChatPromptTemplate
from langgraph.prebuilt import InjectedState
from .load_biomodel import ModelData, load_biomodel
from .utils import get_model_units
# Initialize logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def extract_relevant_species(question, species_names, state):
"""
Extract the relevant species from the user question.
Args:
question (str): The user question.
species_names (list): The species names available in the simulation results.
state (dict): The state of the graph.
Returns:
CustomHeader: The relevant species
"""
# In the following code, we extract the species
# from the user question. We use Literal to restrict
# the species names to the ones available in the
# simulation results.
class CustomHeader(BaseModel):
"""
A list of species based on user question.
This is a Pydantic model that restricts the species
names to the ones available in the simulation results.
If no species is relevant, set the attribute
`relevant_species` to None.
"""
relevant_species: Union[None, List[Literal[*species_names]]] = Field(
description="This is a list of species based on the user question."
"It is restricted to the species available in the simulation results."
"If no species is relevant, set this attribute to None."
"If the user asks for very specific species (for example, using the"
"keyword `only` in the question), set this attribute to correspond "
"to the species available in the simulation results, otherwise set it to None."
)
# Load hydra configuration
with hydra.initialize(version_base=None, config_path="../configs"):
cfg = hydra.compose(config_name='config',
overrides=['tools/custom_plotter=default'])
cfg = cfg.tools.custom_plotter
# Get the system prompt
system_prompt = cfg.system_prompt_custom_header
# Create an instance of the LLM model
logging.log(logging.INFO, "LLM model: %s", state['llm_model'])
llm = state['llm_model']
llm_with_structured_output = llm.with_structured_output(CustomHeader)
prompt = ChatPromptTemplate.from_messages([("system", system_prompt),
("human", "{input}")])
few_shot_structured_llm = prompt | llm_with_structured_output
return few_shot_structured_llm.invoke(question)
class CustomPlotterInput(BaseModel):
"""
Input schema for the custom plotter tool.
"""
question: str = Field(description="Description of the plot")
sys_bio_model: ModelData = Field(description="model data",
default=None)
simulation_name: str = Field(description="Name assigned to the simulation")
state: Annotated[dict, InjectedState]
# Note: It's important that every field has type hints.
# BaseTool is a Pydantic class and not having type hints
# can lead to unexpected behavior.
# Note: It's important that every field has type hints.
# BaseTool is a Pydantic class and not having type hints
# can lead to unexpected behavior.
class CustomPlotterTool(BaseTool):
"""
Tool for custom plotting the y-axis of a plot.
"""
name: str = "custom_plotter"
description: str = '''A visualization tool designed to extract and display a subset
of the larger simulation plot generated by the simulate_model tool.
It allows users to specify particular species for the y-axis,
providing a more targeted view of key species without the clutter
of the full plot.'''
args_schema: Type[BaseModel] = CustomPlotterInput
response_format: str = "content_and_artifact"
def _run(self,
question: str,
sys_bio_model: ModelData,
simulation_name: str,
state: Annotated[dict, InjectedState]
) -> Tuple[str, Union[None, List[str]]]:
"""
Run the tool.
Args:
question (str): The question about the custom plot.
sys_bio_model (ModelData): The model data.
simulation_name (str): The name assigned to the simulation.
state (dict): The state of the graph.
Returns:
str: The answer to the question
"""
logger.log(logging.INFO, "Calling custom_plotter tool %s, %s", question, sys_bio_model)
# Load the model
sbml_file_path = state['sbml_file_path'][-1] if len(state['sbml_file_path']) > 0 else None
model_object = load_biomodel(sys_bio_model, sbml_file_path=sbml_file_path)
dic_simulated_data = {}
for data in state["dic_simulated_data"]:
for key in data:
if key not in dic_simulated_data:
dic_simulated_data[key] = []
dic_simulated_data[key] += [data[key]]
# Create a pandas dataframe from the dictionary
df = pd.DataFrame.from_dict(dic_simulated_data)
# Get the simulated data for the current tool call
df = pd.DataFrame(
df[df['name'] == simulation_name]['data'].iloc[0]
)
# df = pd.DataFrame.from_dict(state['dic_simulated_data'])
species_names = df.columns.tolist()
# Exclude the time column
species_names.remove('Time')
logging.log(logging.INFO, "Species names: %s", species_names)
# Extract the relevant species from the user question
results = extract_relevant_species(question, species_names, state)
print (results)
if results.relevant_species is None:
raise ValueError("No species found in the simulation results \
that matches the user prompt.")
extracted_species = []
# Extract the species from the results
# that are available in the simulation results
for species in results.relevant_species:
if species in species_names:
extracted_species.append(species)
logging.info("Extracted species: %s", extracted_species)
# Include the time column
extracted_species.insert(0, 'Time')
return f"Custom plot {simulation_name}",{
'dic_data': df[extracted_species].to_dict(orient='records')
}| get_model_units(model_object)