Switch to side-by-side view

--- a
+++ b/aiagents4pharma/talk2biomodels/tools/custom_plotter.py
@@ -0,0 +1,157 @@
+#!/usr/bin/env python3
+
+"""
+Tool for plotting a custom y-axis of a simulation plot.
+"""
+
+import logging
+from typing import Type, Annotated, List, Tuple, Union, Literal
+from pydantic import BaseModel, Field
+import hydra
+import pandas as pd
+from langchain_core.tools import BaseTool
+from langchain_core.prompts import ChatPromptTemplate
+from langgraph.prebuilt import InjectedState
+from .load_biomodel import ModelData, load_biomodel
+from .utils import get_model_units
+
+# Initialize logger
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+def extract_relevant_species(question, species_names, state):
+    """
+    Extract the relevant species from the user question.
+
+    Args:
+        question (str): The user question.
+        species_names (list): The species names available in the simulation results.
+        state (dict): The state of the graph.
+
+    Returns:
+        CustomHeader: The relevant species
+    """
+    # In the following code, we extract the species
+    # from the user question. We use Literal to restrict
+    # the species names to the ones available in the
+    # simulation results.
+    class CustomHeader(BaseModel):
+        """
+        A list of species based on user question.
+
+        This is a Pydantic model that restricts the species
+        names to the ones available in the simulation results.
+        
+        If no species is relevant, set the attribute
+        `relevant_species` to None.
+        """
+        relevant_species: Union[None, List[Literal[*species_names]]] = Field(
+                description="This is a list of species based on the user question."
+                "It is restricted to the species available in the simulation results."
+                "If no species is relevant, set this attribute to None."
+                "If the user asks for very specific species (for example, using the"
+                "keyword `only` in the question), set this attribute to correspond "
+                "to the species available in the simulation results, otherwise set it to None."
+                )
+    # Load hydra configuration
+    with hydra.initialize(version_base=None, config_path="../configs"):
+        cfg = hydra.compose(config_name='config',
+                            overrides=['tools/custom_plotter=default'])
+        cfg = cfg.tools.custom_plotter
+    # Get the system prompt
+    system_prompt = cfg.system_prompt_custom_header
+    # Create an instance of the LLM model
+    logging.log(logging.INFO, "LLM model: %s", state['llm_model'])
+    llm = state['llm_model']
+    llm_with_structured_output = llm.with_structured_output(CustomHeader)
+    prompt = ChatPromptTemplate.from_messages([("system", system_prompt),
+                                               ("human", "{input}")])
+    few_shot_structured_llm = prompt | llm_with_structured_output
+    return few_shot_structured_llm.invoke(question)
+
+class CustomPlotterInput(BaseModel):
+    """
+    Input schema for the custom plotter tool.
+    """
+    question: str = Field(description="Description of the plot")
+    sys_bio_model: ModelData = Field(description="model data",
+                                     default=None)
+    simulation_name: str = Field(description="Name assigned to the simulation")
+    state: Annotated[dict, InjectedState]
+
+# Note: It's important that every field has type hints.
+# BaseTool is a Pydantic class and not having type hints
+# can lead to unexpected behavior.
+# Note: It's important that every field has type hints.
+# BaseTool is a Pydantic class and not having type hints
+# can lead to unexpected behavior.
+class CustomPlotterTool(BaseTool):
+    """
+    Tool for custom plotting the y-axis of a plot.
+    """
+    name: str = "custom_plotter"
+    description: str = '''A visualization tool designed to extract and display a subset
+                        of the larger simulation plot generated by the simulate_model tool.
+                        It allows users to specify particular species for the y-axis, 
+                        providing a more targeted view of key species without the clutter 
+                        of the full plot.'''
+    args_schema: Type[BaseModel] = CustomPlotterInput
+    response_format: str = "content_and_artifact"
+
+    def _run(self,
+             question: str,
+             sys_bio_model: ModelData,
+             simulation_name: str,
+             state: Annotated[dict, InjectedState]
+             ) -> Tuple[str, Union[None, List[str]]]:
+        """
+        Run the tool.
+
+        Args:
+            question (str): The question about the custom plot.
+            sys_bio_model (ModelData): The model data.
+            simulation_name (str): The name assigned to the simulation.
+            state (dict): The state of the graph.
+
+        Returns:
+            str: The answer to the question
+        """
+        logger.log(logging.INFO, "Calling custom_plotter tool %s, %s", question, sys_bio_model)
+        # Load the model
+        sbml_file_path = state['sbml_file_path'][-1] if len(state['sbml_file_path']) > 0 else None
+        model_object = load_biomodel(sys_bio_model, sbml_file_path=sbml_file_path)
+        dic_simulated_data = {}
+        for data in state["dic_simulated_data"]:
+            for key in data:
+                if key not in dic_simulated_data:
+                    dic_simulated_data[key] = []
+                dic_simulated_data[key] += [data[key]]
+        # Create a pandas dataframe from the dictionary
+        df = pd.DataFrame.from_dict(dic_simulated_data)
+        # Get the simulated data for the current tool call
+        df = pd.DataFrame(
+                df[df['name'] == simulation_name]['data'].iloc[0]
+                )
+        # df = pd.DataFrame.from_dict(state['dic_simulated_data'])
+        species_names = df.columns.tolist()
+        # Exclude the time column
+        species_names.remove('Time')
+        logging.log(logging.INFO, "Species names: %s", species_names)
+        # Extract the relevant species from the user question
+        results = extract_relevant_species(question, species_names, state)
+        print (results)
+        if results.relevant_species is None:
+            raise ValueError("No species found in the simulation results \
+                             that matches the user prompt.")
+        extracted_species = []
+        # Extract the species from the results
+        # that are available in the simulation results
+        for species in results.relevant_species:
+            if species in species_names:
+                extracted_species.append(species)
+        logging.info("Extracted species: %s", extracted_species)
+        # Include the time column
+        extracted_species.insert(0, 'Time')
+        return f"Custom plot {simulation_name}",{
+                            'dic_data': df[extracted_species].to_dict(orient='records')
+                            }| get_model_units(model_object)