a b/aiagents4pharma/talk2biomodels/tools/parameter_scan.py
1
#!/usr/bin/env python3
2
3
"""
4
Tool for parameter scan.
5
"""
6
7
import logging
8
from dataclasses import dataclass
9
from typing import Type, Union, List, Annotated, Optional
10
import pandas as pd
11
import basico
12
from pydantic import BaseModel, Field
13
from langgraph.types import Command
14
from langgraph.prebuilt import InjectedState
15
from langchain_core.tools import BaseTool
16
from langchain_core.messages import ToolMessage
17
from langchain_core.tools.base import InjectedToolCallId
18
from .load_biomodel import ModelData, load_biomodel
19
from .load_arguments import TimeData, SpeciesInitialData
20
from .utils import get_model_units
21
22
# Initialize logger
23
logging.basicConfig(level=logging.INFO)
24
logger = logging.getLogger(__name__)
25
26
@dataclass
27
class ParameterScanData(BaseModel):
28
    """
29
    Dataclass for storing the parameter scan data.
30
    """
31
    species_names: List[str] = Field(
32
                    description="species to be observed after each scan."
33
                    " These are the species whose concentration"
34
                    " will be observed after the parameter scan."
35
                    " Do not make up this data.",
36
                    default=[])
37
    species_parameter_name: str = Field(
38
                    description="Species or parameter name to be scanned."
39
                    " This is the species or parameter whose value will be scanned"
40
                    " over a range of values. This does not include the species"
41
                    " that are to be observed after the scan."
42
                    "Do not make up this data.",
43
                    default=None)
44
    species_parameter_values: List[Union[int, float]] = Field(
45
                    description="Species or parameter values to be scanned."
46
                    " These are the values of the species or parameters that will be"
47
                    " scanned over a range of values. This does not include the "
48
                    "species that are to be observed after the scan."
49
                    "Do not make up this data.",
50
                    default=None)
51
52
@dataclass
53
class ArgumentData:
54
    """
55
    Dataclass for storing the argument data.
56
    """
57
    time_data: TimeData = Field(description="time data", default=None)
58
    species_to_be_analyzed_before_experiment: Optional[SpeciesInitialData] = Field(
59
                    description=" This is the initial condition of the model."
60
                    " This does not include species that reoccur or the species"
61
                    " whose concentration is to be determined/observed at the end"
62
                    " of the experiment. This also does not include the species"
63
                    " or the parameter that is to be scanned. Do not make up this data.",
64
                    default=None)
65
    parameter_scan_data: ParameterScanData = Field(
66
                    description="parameter scan data",
67
                    default=None)
68
    experiment_name: str = Field(
69
                    description="An AI assigned `_` separated unique name of"
70
                    " the parameter scan experiment based on human query."
71
                    " This must be unique for each experiment.")
72
73
def make_list_dic_scanned_data(dic_param_scan, arg_data, sys_bio_model, tool_call_id):
74
    """
75
    Prepare the list dictionary of scanned data
76
    that will be passed to the state of the graph.
77
78
    Args:
79
        dic_param_scan: Dictionary of parameter scan results.
80
        arg_data: The argument data.
81
        sys_bio_model: The model data.
82
        tool_call_id: The tool call ID.
83
84
    Returns:
85
        list: List of dictionary of scanned data.
86
    """
87
    list_dic_scanned_data = []
88
    for species_name, df_param_scan in dic_param_scan.items():
89
        logger.log(logging.INFO, "Parameter scan results for %s with shape %s",
90
                    species_name,
91
                    df_param_scan.shape)
92
        # Prepare the list dictionary of scanned data
93
        # that will be passed to the state of the graph
94
        list_dic_scanned_data.append({
95
            'name': arg_data.experiment_name+':'+species_name,
96
            'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload',
97
            'tool_call_id': tool_call_id,
98
            'data': df_param_scan.to_dict()
99
        })
100
    return list_dic_scanned_data
101
102
def run_parameter_scan(model_object,
103
                       arg_data,
104
                       dic_species_data,
105
                       duration,
106
                       interval) -> dict:
107
    """
108
    Run parameter scan on the model.
109
110
    Args:
111
        model_object: The model object.
112
        arg_data: The argument data.
113
        dic_species_data: Dictionary of species data.
114
        duration: Duration of the simulation.
115
        interval: Interval between time points in the simulation.
116
117
    Returns:
118
        dict: Dictionary of parameter scan results. Each key is a species name
119
        and each value is a DataFrame containing the results of the parameter scan.
120
    """
121
    # Extract all parameter names from the model
122
    df_all_parameters = basico.model_info.get_parameters(model=model_object.copasi_model)
123
    all_parameters = []
124
    if df_all_parameters is not None:
125
        # For example model 10 in the BioModels database
126
        # has no parameters
127
        all_parameters = df_all_parameters.index.tolist()
128
129
    # Extract all species name from the model
130
    df_all_species = basico.model_info.get_species(model=model_object.copasi_model)
131
    all_species = df_all_species['display_name'].tolist()
132
133
    # Verify if the given species or parameter names to be scanned are valid
134
    if arg_data.parameter_scan_data.species_parameter_name not in all_parameters + all_species:
135
        logger.error(
136
            "Invalid species or parameter name: %s",
137
            arg_data.parameter_scan_data.species_parameter_name)
138
        raise ValueError(
139
            "Invalid species or parameter name: "
140
            f"{arg_data.parameter_scan_data.species_parameter_name}.")
141
142
    # Dictionary to store the parameter scan results
143
    dic_param_scan_results = {}
144
145
    # Loop through the species names that are to be observed
146
    for species_name in arg_data.parameter_scan_data.species_names:
147
        # Verify if the given species name to be observed is valid
148
        if species_name not in all_species:
149
            logger.error("Invalid species name: %s", species_name)
150
            raise ValueError(f"Invalid species name: {species_name}.")
151
152
        # Copy the model object to avoid modifying the original model
153
        model_object_copy = model_object.model_copy()
154
155
        # Update the fixed model species and parameters
156
        # These are the initial conditions of the model
157
        # set by the user
158
        model_object_copy.update_parameters(dic_species_data)
159
160
        # Initialize empty DataFrame to store results
161
        # of the parameter scan
162
        df_param_scan = pd.DataFrame()
163
164
        # Loop through the parameter that are to be scanned
165
        for param_value in arg_data.parameter_scan_data.species_parameter_values:
166
            # Update the parameter value in the model
167
            model_object_copy.update_parameters(
168
                {arg_data.parameter_scan_data.species_parameter_name: param_value})
169
            # Simulate the model
170
            model_object_copy.simulate(duration=duration, interval=interval)
171
            # If the column name 'Time' is not present in the results DataFrame
172
            if 'Time' not in df_param_scan.columns:
173
                df_param_scan['Time'] = model_object_copy.simulation_results['Time']
174
            # Add the simulation results to the results DataFrame
175
            col_name = f"{arg_data.parameter_scan_data.species_parameter_name}_{param_value}"
176
            df_param_scan[col_name] = model_object_copy.simulation_results[species_name]
177
178
        logger.log(logging.INFO, "Parameter scan results with shape %s", df_param_scan.shape)
179
180
        # Add the results of the parameter scan to the dictionary
181
        dic_param_scan_results[species_name] = df_param_scan
182
    # return df_param_scan
183
    return dic_param_scan_results
184
185
class ParameterScanInput(BaseModel):
186
    """
187
    Input schema for the ParameterScan tool.
188
    """
189
    sys_bio_model: ModelData = Field(description="model data",
190
                                     default=None)
191
    arg_data: ArgumentData = Field(description=
192
                                   """time, species, and reocurring data
193
                                   as well as the parameter scan name and
194
                                   data""",
195
                                   default=None)
196
    tool_call_id: Annotated[str, InjectedToolCallId]
197
    state: Annotated[dict, InjectedState]
198
199
# Note: It's important that every field has type hints. BaseTool is a
200
# Pydantic class and not having type hints can lead to unexpected behavior.
201
class ParameterScanTool(BaseTool):
202
    """
203
    Tool for parameter scan.
204
    """
205
    name: str = "parameter_scan"
206
    description: str = """A tool to perform scanning of a given
207
    parameter over a range of values and observe the effect on
208
    the concentration of a given species"""
209
    args_schema: Type[BaseModel] = ParameterScanInput
210
211
    def _run(self,
212
        tool_call_id: Annotated[str, InjectedToolCallId],
213
        state: Annotated[dict, InjectedState],
214
        sys_bio_model: ModelData = None,
215
        arg_data: ArgumentData = None
216
    ) -> Command:
217
        """
218
        Run the tool.
219
220
        Args:
221
            tool_call_id (str): The tool call ID. This is injected by the system.
222
            state (dict): The state of the tool.
223
            sys_bio_model (ModelData): The model data.
224
            arg_data (ArgumentData): The argument data.
225
226
        Returns:
227
            Command: The updated state of the tool.
228
        """
229
        logger.log(logging.INFO, "Calling parameter_scan tool %s, %s",
230
                   sys_bio_model, arg_data)
231
        sbml_file_path = state['sbml_file_path'][-1] if len(state['sbml_file_path']) > 0 else None
232
        model_object = load_biomodel(sys_bio_model,
233
                                  sbml_file_path=sbml_file_path)
234
        # Prepare the dictionary of species data
235
        # that will be passed to the simulate method
236
        # of the BasicoModel class
237
        duration = 100.0
238
        interval = 10
239
        dic_species_data = {}
240
        if arg_data:
241
            # Prepare the dictionary of species data
242
            if arg_data.species_to_be_analyzed_before_experiment is not None:
243
                dic_species_data = dict(
244
                    zip(
245
                        arg_data.species_to_be_analyzed_before_experiment.species_name,
246
                        arg_data.species_to_be_analyzed_before_experiment.species_concentration
247
                        )
248
                    )
249
250
            # # Add reocurring events (if any) to the model
251
            # if arg_data.reocurring_data is not None:
252
            #     add_rec_events(model_object, arg_data.reocurring_data)
253
254
            # Set the duration and interval
255
            if arg_data.time_data is not None:
256
                duration = arg_data.time_data.duration
257
                interval = arg_data.time_data.interval
258
259
        # Run the parameter scan
260
        dic_param_scan = run_parameter_scan(model_object,
261
                                           arg_data,
262
                                           dic_species_data,
263
                                           duration,
264
                                           interval)
265
266
        logger.log(logging.INFO, "Parameter scan results ready")
267
        # Prepare the list dictionary of scanned data
268
        list_dic_scanned_data = make_list_dic_scanned_data(dic_param_scan,
269
                                                           arg_data,
270
                                                           sys_bio_model,
271
                                                           tool_call_id)
272
        # Prepare the dictionary of updated state for the model
273
        dic_updated_state_for_model = {}
274
        for key, value in {
275
            "model_id": [sys_bio_model.biomodel_id],
276
            "sbml_file_path": [sbml_file_path],
277
            "dic_scanned_data": list_dic_scanned_data,
278
            }.items():
279
            if value:
280
                dic_updated_state_for_model[key] = value
281
        # Return the updated state
282
        return Command(
283
                update=dic_updated_state_for_model|{
284
                # update the message history
285
                "messages": [
286
                    ToolMessage(
287
                        content=f"Parameter scan results of {arg_data.experiment_name}",
288
                        tool_call_id=tool_call_id,
289
                        artifact=get_model_units(model_object)
290
                        )
291
                    ],
292
                }
293
            )