--- a +++ b/aiagents4pharma/talk2biomodels/tools/custom_plotter.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 + +""" +Tool for plotting a custom y-axis of a simulation plot. +""" + +import logging +from typing import Type, Annotated, List, Tuple, Union, Literal +from pydantic import BaseModel, Field +import hydra +import pandas as pd +from langchain_core.tools import BaseTool +from langchain_core.prompts import ChatPromptTemplate +from langgraph.prebuilt import InjectedState +from .load_biomodel import ModelData, load_biomodel +from .utils import get_model_units + +# Initialize logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def extract_relevant_species(question, species_names, state): + """ + Extract the relevant species from the user question. + + Args: + question (str): The user question. + species_names (list): The species names available in the simulation results. + state (dict): The state of the graph. + + Returns: + CustomHeader: The relevant species + """ + # In the following code, we extract the species + # from the user question. We use Literal to restrict + # the species names to the ones available in the + # simulation results. + class CustomHeader(BaseModel): + """ + A list of species based on user question. + + This is a Pydantic model that restricts the species + names to the ones available in the simulation results. + + If no species is relevant, set the attribute + `relevant_species` to None. + """ + relevant_species: Union[None, List[Literal[*species_names]]] = Field( + description="This is a list of species based on the user question." + "It is restricted to the species available in the simulation results." + "If no species is relevant, set this attribute to None." + "If the user asks for very specific species (for example, using the" + "keyword `only` in the question), set this attribute to correspond " + "to the species available in the simulation results, otherwise set it to None." + ) + # Load hydra configuration + with hydra.initialize(version_base=None, config_path="../configs"): + cfg = hydra.compose(config_name='config', + overrides=['tools/custom_plotter=default']) + cfg = cfg.tools.custom_plotter + # Get the system prompt + system_prompt = cfg.system_prompt_custom_header + # Create an instance of the LLM model + logging.log(logging.INFO, "LLM model: %s", state['llm_model']) + llm = state['llm_model'] + llm_with_structured_output = llm.with_structured_output(CustomHeader) + prompt = ChatPromptTemplate.from_messages([("system", system_prompt), + ("human", "{input}")]) + few_shot_structured_llm = prompt | llm_with_structured_output + return few_shot_structured_llm.invoke(question) + +class CustomPlotterInput(BaseModel): + """ + Input schema for the custom plotter tool. + """ + question: str = Field(description="Description of the plot") + sys_bio_model: ModelData = Field(description="model data", + default=None) + simulation_name: str = Field(description="Name assigned to the simulation") + state: Annotated[dict, InjectedState] + +# Note: It's important that every field has type hints. +# BaseTool is a Pydantic class and not having type hints +# can lead to unexpected behavior. +# Note: It's important that every field has type hints. +# BaseTool is a Pydantic class and not having type hints +# can lead to unexpected behavior. +class CustomPlotterTool(BaseTool): + """ + Tool for custom plotting the y-axis of a plot. + """ + name: str = "custom_plotter" + description: str = '''A visualization tool designed to extract and display a subset + of the larger simulation plot generated by the simulate_model tool. + It allows users to specify particular species for the y-axis, + providing a more targeted view of key species without the clutter + of the full plot.''' + args_schema: Type[BaseModel] = CustomPlotterInput + response_format: str = "content_and_artifact" + + def _run(self, + question: str, + sys_bio_model: ModelData, + simulation_name: str, + state: Annotated[dict, InjectedState] + ) -> Tuple[str, Union[None, List[str]]]: + """ + Run the tool. + + Args: + question (str): The question about the custom plot. + sys_bio_model (ModelData): The model data. + simulation_name (str): The name assigned to the simulation. + state (dict): The state of the graph. + + Returns: + str: The answer to the question + """ + logger.log(logging.INFO, "Calling custom_plotter tool %s, %s", question, sys_bio_model) + # Load the model + sbml_file_path = state['sbml_file_path'][-1] if len(state['sbml_file_path']) > 0 else None + model_object = load_biomodel(sys_bio_model, sbml_file_path=sbml_file_path) + dic_simulated_data = {} + for data in state["dic_simulated_data"]: + for key in data: + if key not in dic_simulated_data: + dic_simulated_data[key] = [] + dic_simulated_data[key] += [data[key]] + # Create a pandas dataframe from the dictionary + df = pd.DataFrame.from_dict(dic_simulated_data) + # Get the simulated data for the current tool call + df = pd.DataFrame( + df[df['name'] == simulation_name]['data'].iloc[0] + ) + # df = pd.DataFrame.from_dict(state['dic_simulated_data']) + species_names = df.columns.tolist() + # Exclude the time column + species_names.remove('Time') + logging.log(logging.INFO, "Species names: %s", species_names) + # Extract the relevant species from the user question + results = extract_relevant_species(question, species_names, state) + print (results) + if results.relevant_species is None: + raise ValueError("No species found in the simulation results \ + that matches the user prompt.") + extracted_species = [] + # Extract the species from the results + # that are available in the simulation results + for species in results.relevant_species: + if species in species_names: + extracted_species.append(species) + logging.info("Extracted species: %s", extracted_species) + # Include the time column + extracted_species.insert(0, 'Time') + return f"Custom plot {simulation_name}",{ + 'dic_data': df[extracted_species].to_dict(orient='records') + }| get_model_units(model_object)