a b/aiagents4pharma/talk2biomodels/tools/simulate_model.py
1
#!/usr/bin/env python3
2
3
"""
4
Tool for simulating a model.
5
"""
6
7
import logging
8
from typing import Type, Annotated
9
from pydantic import BaseModel, Field
10
from langgraph.types import Command
11
from langgraph.prebuilt import InjectedState
12
from langchain_core.tools import BaseTool
13
from langchain_core.messages import ToolMessage
14
from langchain_core.tools.base import InjectedToolCallId
15
from .load_biomodel import ModelData, load_biomodel
16
from .load_arguments import ArgumentData, add_rec_events
17
from .utils import get_model_units
18
19
# Initialize logger
20
logging.basicConfig(level=logging.INFO)
21
logger = logging.getLogger(__name__)
22
23
class SimulateModelInput(BaseModel):
24
    """
25
    Input schema for the SimulateModel tool.
26
    """
27
    sys_bio_model: ModelData = Field(description="model data",
28
                                     default=None)
29
    arg_data: ArgumentData = Field(description=
30
                                   """time, species, and reocurring data
31
                                   as well as the simulation name""",
32
                                   default=None)
33
    tool_call_id: Annotated[str, InjectedToolCallId]
34
    state: Annotated[dict, InjectedState]
35
36
# Note: It's important that every field has type hints. BaseTool is a
37
# Pydantic class and not having type hints can lead to unexpected behavior.
38
class SimulateModelTool(BaseTool):
39
    """
40
    Tool for simulating a model.
41
    """
42
    name: str = "simulate_model"
43
    description: str = "A tool to simulate a biomodel"
44
    args_schema: Type[BaseModel] = SimulateModelInput
45
46
    def _run(self,
47
        tool_call_id: Annotated[str, InjectedToolCallId],
48
        state: Annotated[dict, InjectedState],
49
        sys_bio_model: ModelData = None,
50
        arg_data: ArgumentData = None
51
    ) -> Command:
52
        """
53
        Run the tool.
54
55
        Args:
56
            tool_call_id (str): The tool call ID. This is injected by the system.
57
            state (dict): The state of the tool.
58
            sys_bio_model (ModelData): The model data.
59
            arg_data (ArgumentData): The argument data.
60
61
        Returns:
62
            str: The result of the simulation.
63
        """
64
        logger.log(logging.INFO,
65
                   "Calling simulate_model tool %s, %s",
66
                   sys_bio_model,
67
                   arg_data)
68
        sbml_file_path = state['sbml_file_path'][-1] if len(state['sbml_file_path']) > 0 else None
69
        model_object = load_biomodel(sys_bio_model,
70
                                  sbml_file_path=sbml_file_path)
71
        # Prepare the dictionary of species data
72
        # that will be passed to the simulate method
73
        # of the BasicoModel class
74
        duration = 100.0
75
        interval = 10
76
        dic_species_to_be_analyzed_before_experiment = {}
77
        if arg_data:
78
            # Prepare the dictionary of species data
79
            if arg_data.species_to_be_analyzed_before_experiment is not None:
80
                dic_species_to_be_analyzed_before_experiment = dict(
81
                    zip(arg_data.species_to_be_analyzed_before_experiment.species_name,
82
                        arg_data.species_to_be_analyzed_before_experiment.species_concentration))
83
            # Add reocurring events (if any) to the model
84
            if arg_data.reocurring_data is not None:
85
                add_rec_events(model_object, arg_data.reocurring_data)
86
            # Set the duration and interval
87
            if arg_data.time_data is not None:
88
                duration = arg_data.time_data.duration
89
                interval = arg_data.time_data.interval
90
        # Update the model parameters
91
        model_object.update_parameters(dic_species_to_be_analyzed_before_experiment)
92
        logger.log(logging.INFO,
93
                   "Following species/parameters updated in the model %s",
94
                   dic_species_to_be_analyzed_before_experiment)
95
        # Simulate the model
96
        df = model_object.simulate(duration=duration, interval=interval)
97
        logger.log(logging.INFO, "Simulation results ready with shape %s", df.shape)
98
        dic_simulated_data = {
99
            'name': arg_data.experiment_name,
100
            'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload',
101
            'tool_call_id': tool_call_id,
102
            'data': df.to_dict()
103
        }
104
        # Prepare the dictionary of updated state
105
        dic_updated_state_for_model = {}
106
        for key, value in {
107
            "model_id": [sys_bio_model.biomodel_id],
108
            "sbml_file_path": [sbml_file_path],
109
            "dic_simulated_data": [dic_simulated_data],
110
            }.items():
111
            if value:
112
                dic_updated_state_for_model[key] = value
113
        # Return the updated state of the tool
114
        return Command(
115
                update=dic_updated_state_for_model|{
116
                # update the message history
117
                "messages": [
118
                    ToolMessage(
119
                        content=f"Simulation results of {arg_data.experiment_name}",
120
                        tool_call_id=tool_call_id,
121
                        artifact=get_model_units(model_object)
122
                        )
123
                    ],
124
                }
125
            )