Switch to side-by-side view

--- a
+++ b/aiagents4pharma/talk2biomodels/tools/parameter_scan.py
@@ -0,0 +1,293 @@
+#!/usr/bin/env python3
+
+"""
+Tool for parameter scan.
+"""
+
+import logging
+from dataclasses import dataclass
+from typing import Type, Union, List, Annotated, Optional
+import pandas as pd
+import basico
+from pydantic import BaseModel, Field
+from langgraph.types import Command
+from langgraph.prebuilt import InjectedState
+from langchain_core.tools import BaseTool
+from langchain_core.messages import ToolMessage
+from langchain_core.tools.base import InjectedToolCallId
+from .load_biomodel import ModelData, load_biomodel
+from .load_arguments import TimeData, SpeciesInitialData
+from .utils import get_model_units
+
+# Initialize logger
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+@dataclass
+class ParameterScanData(BaseModel):
+    """
+    Dataclass for storing the parameter scan data.
+    """
+    species_names: List[str] = Field(
+                    description="species to be observed after each scan."
+                    " These are the species whose concentration"
+                    " will be observed after the parameter scan."
+                    " Do not make up this data.",
+                    default=[])
+    species_parameter_name: str = Field(
+                    description="Species or parameter name to be scanned."
+                    " This is the species or parameter whose value will be scanned"
+                    " over a range of values. This does not include the species"
+                    " that are to be observed after the scan."
+                    "Do not make up this data.",
+                    default=None)
+    species_parameter_values: List[Union[int, float]] = Field(
+                    description="Species or parameter values to be scanned."
+                    " These are the values of the species or parameters that will be"
+                    " scanned over a range of values. This does not include the "
+                    "species that are to be observed after the scan."
+                    "Do not make up this data.",
+                    default=None)
+
+@dataclass
+class ArgumentData:
+    """
+    Dataclass for storing the argument data.
+    """
+    time_data: TimeData = Field(description="time data", default=None)
+    species_to_be_analyzed_before_experiment: Optional[SpeciesInitialData] = Field(
+                    description=" This is the initial condition of the model."
+                    " This does not include species that reoccur or the species"
+                    " whose concentration is to be determined/observed at the end"
+                    " of the experiment. This also does not include the species"
+                    " or the parameter that is to be scanned. Do not make up this data.",
+                    default=None)
+    parameter_scan_data: ParameterScanData = Field(
+                    description="parameter scan data",
+                    default=None)
+    experiment_name: str = Field(
+                    description="An AI assigned `_` separated unique name of"
+                    " the parameter scan experiment based on human query."
+                    " This must be unique for each experiment.")
+
+def make_list_dic_scanned_data(dic_param_scan, arg_data, sys_bio_model, tool_call_id):
+    """
+    Prepare the list dictionary of scanned data
+    that will be passed to the state of the graph.
+
+    Args:
+        dic_param_scan: Dictionary of parameter scan results.
+        arg_data: The argument data.
+        sys_bio_model: The model data.
+        tool_call_id: The tool call ID.
+
+    Returns:
+        list: List of dictionary of scanned data.
+    """
+    list_dic_scanned_data = []
+    for species_name, df_param_scan in dic_param_scan.items():
+        logger.log(logging.INFO, "Parameter scan results for %s with shape %s",
+                    species_name,
+                    df_param_scan.shape)
+        # Prepare the list dictionary of scanned data
+        # that will be passed to the state of the graph
+        list_dic_scanned_data.append({
+            'name': arg_data.experiment_name+':'+species_name,
+            'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload',
+            'tool_call_id': tool_call_id,
+            'data': df_param_scan.to_dict()
+        })
+    return list_dic_scanned_data
+
+def run_parameter_scan(model_object,
+                       arg_data,
+                       dic_species_data,
+                       duration,
+                       interval) -> dict:
+    """
+    Run parameter scan on the model.
+
+    Args:
+        model_object: The model object.
+        arg_data: The argument data.
+        dic_species_data: Dictionary of species data.
+        duration: Duration of the simulation.
+        interval: Interval between time points in the simulation.
+
+    Returns:
+        dict: Dictionary of parameter scan results. Each key is a species name
+        and each value is a DataFrame containing the results of the parameter scan.
+    """
+    # Extract all parameter names from the model
+    df_all_parameters = basico.model_info.get_parameters(model=model_object.copasi_model)
+    all_parameters = []
+    if df_all_parameters is not None:
+        # For example model 10 in the BioModels database
+        # has no parameters
+        all_parameters = df_all_parameters.index.tolist()
+
+    # Extract all species name from the model
+    df_all_species = basico.model_info.get_species(model=model_object.copasi_model)
+    all_species = df_all_species['display_name'].tolist()
+
+    # Verify if the given species or parameter names to be scanned are valid
+    if arg_data.parameter_scan_data.species_parameter_name not in all_parameters + all_species:
+        logger.error(
+            "Invalid species or parameter name: %s",
+            arg_data.parameter_scan_data.species_parameter_name)
+        raise ValueError(
+            "Invalid species or parameter name: "
+            f"{arg_data.parameter_scan_data.species_parameter_name}.")
+
+    # Dictionary to store the parameter scan results
+    dic_param_scan_results = {}
+
+    # Loop through the species names that are to be observed
+    for species_name in arg_data.parameter_scan_data.species_names:
+        # Verify if the given species name to be observed is valid
+        if species_name not in all_species:
+            logger.error("Invalid species name: %s", species_name)
+            raise ValueError(f"Invalid species name: {species_name}.")
+
+        # Copy the model object to avoid modifying the original model
+        model_object_copy = model_object.model_copy()
+
+        # Update the fixed model species and parameters
+        # These are the initial conditions of the model
+        # set by the user
+        model_object_copy.update_parameters(dic_species_data)
+
+        # Initialize empty DataFrame to store results
+        # of the parameter scan
+        df_param_scan = pd.DataFrame()
+
+        # Loop through the parameter that are to be scanned
+        for param_value in arg_data.parameter_scan_data.species_parameter_values:
+            # Update the parameter value in the model
+            model_object_copy.update_parameters(
+                {arg_data.parameter_scan_data.species_parameter_name: param_value})
+            # Simulate the model
+            model_object_copy.simulate(duration=duration, interval=interval)
+            # If the column name 'Time' is not present in the results DataFrame
+            if 'Time' not in df_param_scan.columns:
+                df_param_scan['Time'] = model_object_copy.simulation_results['Time']
+            # Add the simulation results to the results DataFrame
+            col_name = f"{arg_data.parameter_scan_data.species_parameter_name}_{param_value}"
+            df_param_scan[col_name] = model_object_copy.simulation_results[species_name]
+
+        logger.log(logging.INFO, "Parameter scan results with shape %s", df_param_scan.shape)
+
+        # Add the results of the parameter scan to the dictionary
+        dic_param_scan_results[species_name] = df_param_scan
+    # return df_param_scan
+    return dic_param_scan_results
+
+class ParameterScanInput(BaseModel):
+    """
+    Input schema for the ParameterScan tool.
+    """
+    sys_bio_model: ModelData = Field(description="model data",
+                                     default=None)
+    arg_data: ArgumentData = Field(description=
+                                   """time, species, and reocurring data
+                                   as well as the parameter scan name and
+                                   data""",
+                                   default=None)
+    tool_call_id: Annotated[str, InjectedToolCallId]
+    state: Annotated[dict, InjectedState]
+
+# Note: It's important that every field has type hints. BaseTool is a
+# Pydantic class and not having type hints can lead to unexpected behavior.
+class ParameterScanTool(BaseTool):
+    """
+    Tool for parameter scan.
+    """
+    name: str = "parameter_scan"
+    description: str = """A tool to perform scanning of a given
+    parameter over a range of values and observe the effect on
+    the concentration of a given species"""
+    args_schema: Type[BaseModel] = ParameterScanInput
+
+    def _run(self,
+        tool_call_id: Annotated[str, InjectedToolCallId],
+        state: Annotated[dict, InjectedState],
+        sys_bio_model: ModelData = None,
+        arg_data: ArgumentData = None
+    ) -> Command:
+        """
+        Run the tool.
+
+        Args:
+            tool_call_id (str): The tool call ID. This is injected by the system.
+            state (dict): The state of the tool.
+            sys_bio_model (ModelData): The model data.
+            arg_data (ArgumentData): The argument data.
+
+        Returns:
+            Command: The updated state of the tool.
+        """
+        logger.log(logging.INFO, "Calling parameter_scan tool %s, %s",
+                   sys_bio_model, arg_data)
+        sbml_file_path = state['sbml_file_path'][-1] if len(state['sbml_file_path']) > 0 else None
+        model_object = load_biomodel(sys_bio_model,
+                                  sbml_file_path=sbml_file_path)
+        # Prepare the dictionary of species data
+        # that will be passed to the simulate method
+        # of the BasicoModel class
+        duration = 100.0
+        interval = 10
+        dic_species_data = {}
+        if arg_data:
+            # Prepare the dictionary of species data
+            if arg_data.species_to_be_analyzed_before_experiment is not None:
+                dic_species_data = dict(
+                    zip(
+                        arg_data.species_to_be_analyzed_before_experiment.species_name,
+                        arg_data.species_to_be_analyzed_before_experiment.species_concentration
+                        )
+                    )
+
+            # # Add reocurring events (if any) to the model
+            # if arg_data.reocurring_data is not None:
+            #     add_rec_events(model_object, arg_data.reocurring_data)
+
+            # Set the duration and interval
+            if arg_data.time_data is not None:
+                duration = arg_data.time_data.duration
+                interval = arg_data.time_data.interval
+
+        # Run the parameter scan
+        dic_param_scan = run_parameter_scan(model_object,
+                                           arg_data,
+                                           dic_species_data,
+                                           duration,
+                                           interval)
+
+        logger.log(logging.INFO, "Parameter scan results ready")
+        # Prepare the list dictionary of scanned data
+        list_dic_scanned_data = make_list_dic_scanned_data(dic_param_scan,
+                                                           arg_data,
+                                                           sys_bio_model,
+                                                           tool_call_id)
+        # Prepare the dictionary of updated state for the model
+        dic_updated_state_for_model = {}
+        for key, value in {
+            "model_id": [sys_bio_model.biomodel_id],
+            "sbml_file_path": [sbml_file_path],
+            "dic_scanned_data": list_dic_scanned_data,
+            }.items():
+            if value:
+                dic_updated_state_for_model[key] = value
+        # Return the updated state
+        return Command(
+                update=dic_updated_state_for_model|{
+                # update the message history
+                "messages": [
+                    ToolMessage(
+                        content=f"Parameter scan results of {arg_data.experiment_name}",
+                        tool_call_id=tool_call_id,
+                        artifact=get_model_units(model_object)
+                        )
+                    ],
+                }
+            )