Switch to side-by-side view

--- a
+++ b/aiagents4pharma/talk2biomodels/tools/steady_state.py
@@ -0,0 +1,169 @@
+#!/usr/bin/env python3
+
+"""
+Tool for parameter scan.
+"""
+
+import logging
+from typing import Type, Annotated
+import basico
+from pydantic import BaseModel, Field
+from langgraph.types import Command
+from langgraph.prebuilt import InjectedState
+from langchain_core.tools import BaseTool
+from langchain_core.messages import ToolMessage
+from langchain_core.tools.base import InjectedToolCallId
+from .load_biomodel import ModelData, load_biomodel
+from .load_arguments import ArgumentData, add_rec_events
+
+# Initialize logger
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+def run_steady_state(model_object,
+                     dic_species_to_be_analyzed_before_experiment):
+    """
+    Run the steady state analysis.
+
+    Args:
+        model_object: The model object.
+        dic_species_to_be_analyzed_before_experiment: Dictionary of species data.
+
+    Returns:
+        DataFrame: The results of the steady state analysis.
+    """
+    # Update the fixed model species and parameters
+    # These are the initial conditions of the model
+    # set by the user
+    model_object.update_parameters(dic_species_to_be_analyzed_before_experiment)
+    logger.log(logging.INFO, "Running steady state analysis")
+    # Run the steady state analysis
+    output = basico.task_steadystate.run_steadystate(model=model_object.copasi_model)
+    if output == 0:
+        logger.error("Steady state analysis failed")
+        raise ValueError("A steady state was not found")
+    logger.log(logging.INFO, "Steady state analysis successful")
+    # Store the steady state results in a DataFrame
+    df_steady_state = basico.model_info.get_species(model=model_object.copasi_model).reset_index()
+    # print (df_steady_state)
+    # Rename the column name to species_name
+    df_steady_state.rename(columns={'name': 'species_name'},
+                           inplace=True)
+    # Rename the column concentration to steady_state_concentration
+    df_steady_state.rename(columns={'concentration': 'steady_state_concentration'},
+                           inplace=True)
+    # Rename the column transition_time to steady_state_transition_time
+    df_steady_state.rename(columns={'transition_time': 'steady_state_transition_time'},
+                           inplace=True)
+    # Drop some columns
+    df_steady_state.drop(columns=
+                         [
+                            'initial_particle_number',
+                            'initial_expression',
+                            'expression',
+                            'particle_number',
+                            'type',
+                            'particle_number_rate',
+                            'key',
+                            'sbml_id',
+                            'display_name'],
+                            inplace=True)
+    logger.log(logging.INFO, "Steady state results with shape %s", df_steady_state.shape)
+    return df_steady_state
+
+class SteadyStateInput(BaseModel):
+    """
+    Input schema for the steady state tool.
+    """
+    sys_bio_model: ModelData = Field(description="model data",
+                                     default=None)
+    arg_data: ArgumentData = Field(
+        description="time, species, and reocurring data"
+                " that must be set before the steady state analysis"
+                " as well as the experiment name", default=None)
+    tool_call_id: Annotated[str, InjectedToolCallId]
+    state: Annotated[dict, InjectedState]
+
+# Note: It's important that every field has type hints. BaseTool is a
+# Pydantic class and not having type hints can lead to unexpected behavior.
+class SteadyStateTool(BaseTool):
+    """
+    Tool to bring a model to steady state.
+    """
+    name: str = "steady_state"
+    description: str = "A tool to bring a model to steady state."
+    args_schema: Type[BaseModel] = SteadyStateInput
+
+    def _run(self,
+        tool_call_id: Annotated[str, InjectedToolCallId],
+        state: Annotated[dict, InjectedState],
+        sys_bio_model: ModelData = None,
+        arg_data: ArgumentData = None
+    ) -> Command:
+        """
+        Run the tool.
+
+        Args:
+            tool_call_id (str): The tool call ID. This is injected by the system.
+            state (dict): The state of the tool.
+            sys_bio_model (ModelData): The model data.
+            arg_data (ArgumentData): The argument data.
+
+        Returns:
+            Command: The updated state of the tool.
+        """
+        logger.log(logging.INFO, "Calling the steady_state tool %s, %s",
+                   sys_bio_model, arg_data)
+        # print (f'Calling steady_state tool {sys_bio_model}, {arg_data}, {tool_call_id}')
+        sbml_file_path = state['sbml_file_path'][-1] if len(state['sbml_file_path']) > 0 else None
+        model_object = load_biomodel(sys_bio_model,
+                                  sbml_file_path=sbml_file_path)
+        # Prepare the dictionary of species data
+        # that will be passed to the simulate method
+        # of the BasicoModel class
+        dic_species_to_be_analyzed_before_experiment = {}
+        if arg_data:
+            # Prepare the dictionary of species data
+            if arg_data.species_to_be_analyzed_before_experiment is not None:
+                dic_species_to_be_analyzed_before_experiment = dict(
+                    zip(arg_data.species_to_be_analyzed_before_experiment.species_name,
+                        arg_data.species_to_be_analyzed_before_experiment.species_concentration))
+            # Add reocurring events (if any) to the model
+            if arg_data.reocurring_data is not None:
+                add_rec_events(model_object, arg_data.reocurring_data)
+        # Run the parameter scan
+        df_steady_state = run_steady_state(model_object,
+                                           dic_species_to_be_analyzed_before_experiment)
+        print (df_steady_state)
+        # Prepare the dictionary of scanned data
+        # that will be passed to the state of the graph
+        dic_steady_state_data = {
+            'name': arg_data.experiment_name,
+            'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload',
+            'tool_call_id': tool_call_id,
+            'data': df_steady_state.to_dict(orient='records')
+        }
+        # Prepare the dictionary of updated state for the model
+        dic_updated_state_for_model = {}
+        for key, value in {
+            "model_id": [sys_bio_model.biomodel_id],
+            "sbml_file_path": [sbml_file_path],
+            "dic_steady_state_data": [dic_steady_state_data]
+            }.items():
+            if value:
+                dic_updated_state_for_model[key] = value
+        # Return the updated state
+        return Command(
+                update=dic_updated_state_for_model|{
+                # Update the message history
+                "messages": [
+                ToolMessage(
+                        content=f"Steady state analysis of"
+                                f" {arg_data.experiment_name}"
+                                " was successful.",
+                        tool_call_id=tool_call_id,
+                        artifact={'dic_data': df_steady_state.to_dict(orient='records')}
+                        )
+                    ],
+                }
+            )