a b/aiagents4pharma/talk2biomodels/tools/steady_state.py
1
#!/usr/bin/env python3
2
3
"""
4
Tool for parameter scan.
5
"""
6
7
import logging
8
from typing import Type, Annotated
9
import basico
10
from pydantic import BaseModel, Field
11
from langgraph.types import Command
12
from langgraph.prebuilt import InjectedState
13
from langchain_core.tools import BaseTool
14
from langchain_core.messages import ToolMessage
15
from langchain_core.tools.base import InjectedToolCallId
16
from .load_biomodel import ModelData, load_biomodel
17
from .load_arguments import ArgumentData, add_rec_events
18
19
# Initialize logger
20
logging.basicConfig(level=logging.INFO)
21
logger = logging.getLogger(__name__)
22
23
def run_steady_state(model_object,
24
                     dic_species_to_be_analyzed_before_experiment):
25
    """
26
    Run the steady state analysis.
27
28
    Args:
29
        model_object: The model object.
30
        dic_species_to_be_analyzed_before_experiment: Dictionary of species data.
31
32
    Returns:
33
        DataFrame: The results of the steady state analysis.
34
    """
35
    # Update the fixed model species and parameters
36
    # These are the initial conditions of the model
37
    # set by the user
38
    model_object.update_parameters(dic_species_to_be_analyzed_before_experiment)
39
    logger.log(logging.INFO, "Running steady state analysis")
40
    # Run the steady state analysis
41
    output = basico.task_steadystate.run_steadystate(model=model_object.copasi_model)
42
    if output == 0:
43
        logger.error("Steady state analysis failed")
44
        raise ValueError("A steady state was not found")
45
    logger.log(logging.INFO, "Steady state analysis successful")
46
    # Store the steady state results in a DataFrame
47
    df_steady_state = basico.model_info.get_species(model=model_object.copasi_model).reset_index()
48
    # print (df_steady_state)
49
    # Rename the column name to species_name
50
    df_steady_state.rename(columns={'name': 'species_name'},
51
                           inplace=True)
52
    # Rename the column concentration to steady_state_concentration
53
    df_steady_state.rename(columns={'concentration': 'steady_state_concentration'},
54
                           inplace=True)
55
    # Rename the column transition_time to steady_state_transition_time
56
    df_steady_state.rename(columns={'transition_time': 'steady_state_transition_time'},
57
                           inplace=True)
58
    # Drop some columns
59
    df_steady_state.drop(columns=
60
                         [
61
                            'initial_particle_number',
62
                            'initial_expression',
63
                            'expression',
64
                            'particle_number',
65
                            'type',
66
                            'particle_number_rate',
67
                            'key',
68
                            'sbml_id',
69
                            'display_name'],
70
                            inplace=True)
71
    logger.log(logging.INFO, "Steady state results with shape %s", df_steady_state.shape)
72
    return df_steady_state
73
74
class SteadyStateInput(BaseModel):
75
    """
76
    Input schema for the steady state tool.
77
    """
78
    sys_bio_model: ModelData = Field(description="model data",
79
                                     default=None)
80
    arg_data: ArgumentData = Field(
81
        description="time, species, and reocurring data"
82
                " that must be set before the steady state analysis"
83
                " as well as the experiment name", default=None)
84
    tool_call_id: Annotated[str, InjectedToolCallId]
85
    state: Annotated[dict, InjectedState]
86
87
# Note: It's important that every field has type hints. BaseTool is a
88
# Pydantic class and not having type hints can lead to unexpected behavior.
89
class SteadyStateTool(BaseTool):
90
    """
91
    Tool to bring a model to steady state.
92
    """
93
    name: str = "steady_state"
94
    description: str = "A tool to bring a model to steady state."
95
    args_schema: Type[BaseModel] = SteadyStateInput
96
97
    def _run(self,
98
        tool_call_id: Annotated[str, InjectedToolCallId],
99
        state: Annotated[dict, InjectedState],
100
        sys_bio_model: ModelData = None,
101
        arg_data: ArgumentData = None
102
    ) -> Command:
103
        """
104
        Run the tool.
105
106
        Args:
107
            tool_call_id (str): The tool call ID. This is injected by the system.
108
            state (dict): The state of the tool.
109
            sys_bio_model (ModelData): The model data.
110
            arg_data (ArgumentData): The argument data.
111
112
        Returns:
113
            Command: The updated state of the tool.
114
        """
115
        logger.log(logging.INFO, "Calling the steady_state tool %s, %s",
116
                   sys_bio_model, arg_data)
117
        # print (f'Calling steady_state tool {sys_bio_model}, {arg_data}, {tool_call_id}')
118
        sbml_file_path = state['sbml_file_path'][-1] if len(state['sbml_file_path']) > 0 else None
119
        model_object = load_biomodel(sys_bio_model,
120
                                  sbml_file_path=sbml_file_path)
121
        # Prepare the dictionary of species data
122
        # that will be passed to the simulate method
123
        # of the BasicoModel class
124
        dic_species_to_be_analyzed_before_experiment = {}
125
        if arg_data:
126
            # Prepare the dictionary of species data
127
            if arg_data.species_to_be_analyzed_before_experiment is not None:
128
                dic_species_to_be_analyzed_before_experiment = dict(
129
                    zip(arg_data.species_to_be_analyzed_before_experiment.species_name,
130
                        arg_data.species_to_be_analyzed_before_experiment.species_concentration))
131
            # Add reocurring events (if any) to the model
132
            if arg_data.reocurring_data is not None:
133
                add_rec_events(model_object, arg_data.reocurring_data)
134
        # Run the parameter scan
135
        df_steady_state = run_steady_state(model_object,
136
                                           dic_species_to_be_analyzed_before_experiment)
137
        print (df_steady_state)
138
        # Prepare the dictionary of scanned data
139
        # that will be passed to the state of the graph
140
        dic_steady_state_data = {
141
            'name': arg_data.experiment_name,
142
            'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload',
143
            'tool_call_id': tool_call_id,
144
            'data': df_steady_state.to_dict(orient='records')
145
        }
146
        # Prepare the dictionary of updated state for the model
147
        dic_updated_state_for_model = {}
148
        for key, value in {
149
            "model_id": [sys_bio_model.biomodel_id],
150
            "sbml_file_path": [sbml_file_path],
151
            "dic_steady_state_data": [dic_steady_state_data]
152
            }.items():
153
            if value:
154
                dic_updated_state_for_model[key] = value
155
        # Return the updated state
156
        return Command(
157
                update=dic_updated_state_for_model|{
158
                # Update the message history
159
                "messages": [
160
                ToolMessage(
161
                        content=f"Steady state analysis of"
162
                                f" {arg_data.experiment_name}"
163
                                " was successful.",
164
                        tool_call_id=tool_call_id,
165
                        artifact={'dic_data': df_steady_state.to_dict(orient='records')}
166
                        )
167
                    ],
168
                }
169
            )