[3af7d7]: / aiagents4pharma / talk2biomodels / tools / parameter_scan.py

Download this file

294 lines (263 with data), 12.7 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
#!/usr/bin/env python3
"""
Tool for parameter scan.
"""
import logging
from dataclasses import dataclass
from typing import Type, Union, List, Annotated, Optional
import pandas as pd
import basico
from pydantic import BaseModel, Field
from langgraph.types import Command
from langgraph.prebuilt import InjectedState
from langchain_core.tools import BaseTool
from langchain_core.messages import ToolMessage
from langchain_core.tools.base import InjectedToolCallId
from .load_biomodel import ModelData, load_biomodel
from .load_arguments import TimeData, SpeciesInitialData
from .utils import get_model_units
# Initialize logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class ParameterScanData(BaseModel):
"""
Dataclass for storing the parameter scan data.
"""
species_names: List[str] = Field(
description="species to be observed after each scan."
" These are the species whose concentration"
" will be observed after the parameter scan."
" Do not make up this data.",
default=[])
species_parameter_name: str = Field(
description="Species or parameter name to be scanned."
" This is the species or parameter whose value will be scanned"
" over a range of values. This does not include the species"
" that are to be observed after the scan."
"Do not make up this data.",
default=None)
species_parameter_values: List[Union[int, float]] = Field(
description="Species or parameter values to be scanned."
" These are the values of the species or parameters that will be"
" scanned over a range of values. This does not include the "
"species that are to be observed after the scan."
"Do not make up this data.",
default=None)
@dataclass
class ArgumentData:
"""
Dataclass for storing the argument data.
"""
time_data: TimeData = Field(description="time data", default=None)
species_to_be_analyzed_before_experiment: Optional[SpeciesInitialData] = Field(
description=" This is the initial condition of the model."
" This does not include species that reoccur or the species"
" whose concentration is to be determined/observed at the end"
" of the experiment. This also does not include the species"
" or the parameter that is to be scanned. Do not make up this data.",
default=None)
parameter_scan_data: ParameterScanData = Field(
description="parameter scan data",
default=None)
experiment_name: str = Field(
description="An AI assigned `_` separated unique name of"
" the parameter scan experiment based on human query."
" This must be unique for each experiment.")
def make_list_dic_scanned_data(dic_param_scan, arg_data, sys_bio_model, tool_call_id):
"""
Prepare the list dictionary of scanned data
that will be passed to the state of the graph.
Args:
dic_param_scan: Dictionary of parameter scan results.
arg_data: The argument data.
sys_bio_model: The model data.
tool_call_id: The tool call ID.
Returns:
list: List of dictionary of scanned data.
"""
list_dic_scanned_data = []
for species_name, df_param_scan in dic_param_scan.items():
logger.log(logging.INFO, "Parameter scan results for %s with shape %s",
species_name,
df_param_scan.shape)
# Prepare the list dictionary of scanned data
# that will be passed to the state of the graph
list_dic_scanned_data.append({
'name': arg_data.experiment_name+':'+species_name,
'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload',
'tool_call_id': tool_call_id,
'data': df_param_scan.to_dict()
})
return list_dic_scanned_data
def run_parameter_scan(model_object,
arg_data,
dic_species_data,
duration,
interval) -> dict:
"""
Run parameter scan on the model.
Args:
model_object: The model object.
arg_data: The argument data.
dic_species_data: Dictionary of species data.
duration: Duration of the simulation.
interval: Interval between time points in the simulation.
Returns:
dict: Dictionary of parameter scan results. Each key is a species name
and each value is a DataFrame containing the results of the parameter scan.
"""
# Extract all parameter names from the model
df_all_parameters = basico.model_info.get_parameters(model=model_object.copasi_model)
all_parameters = []
if df_all_parameters is not None:
# For example model 10 in the BioModels database
# has no parameters
all_parameters = df_all_parameters.index.tolist()
# Extract all species name from the model
df_all_species = basico.model_info.get_species(model=model_object.copasi_model)
all_species = df_all_species['display_name'].tolist()
# Verify if the given species or parameter names to be scanned are valid
if arg_data.parameter_scan_data.species_parameter_name not in all_parameters + all_species:
logger.error(
"Invalid species or parameter name: %s",
arg_data.parameter_scan_data.species_parameter_name)
raise ValueError(
"Invalid species or parameter name: "
f"{arg_data.parameter_scan_data.species_parameter_name}.")
# Dictionary to store the parameter scan results
dic_param_scan_results = {}
# Loop through the species names that are to be observed
for species_name in arg_data.parameter_scan_data.species_names:
# Verify if the given species name to be observed is valid
if species_name not in all_species:
logger.error("Invalid species name: %s", species_name)
raise ValueError(f"Invalid species name: {species_name}.")
# Copy the model object to avoid modifying the original model
model_object_copy = model_object.model_copy()
# Update the fixed model species and parameters
# These are the initial conditions of the model
# set by the user
model_object_copy.update_parameters(dic_species_data)
# Initialize empty DataFrame to store results
# of the parameter scan
df_param_scan = pd.DataFrame()
# Loop through the parameter that are to be scanned
for param_value in arg_data.parameter_scan_data.species_parameter_values:
# Update the parameter value in the model
model_object_copy.update_parameters(
{arg_data.parameter_scan_data.species_parameter_name: param_value})
# Simulate the model
model_object_copy.simulate(duration=duration, interval=interval)
# If the column name 'Time' is not present in the results DataFrame
if 'Time' not in df_param_scan.columns:
df_param_scan['Time'] = model_object_copy.simulation_results['Time']
# Add the simulation results to the results DataFrame
col_name = f"{arg_data.parameter_scan_data.species_parameter_name}_{param_value}"
df_param_scan[col_name] = model_object_copy.simulation_results[species_name]
logger.log(logging.INFO, "Parameter scan results with shape %s", df_param_scan.shape)
# Add the results of the parameter scan to the dictionary
dic_param_scan_results[species_name] = df_param_scan
# return df_param_scan
return dic_param_scan_results
class ParameterScanInput(BaseModel):
"""
Input schema for the ParameterScan tool.
"""
sys_bio_model: ModelData = Field(description="model data",
default=None)
arg_data: ArgumentData = Field(description=
"""time, species, and reocurring data
as well as the parameter scan name and
data""",
default=None)
tool_call_id: Annotated[str, InjectedToolCallId]
state: Annotated[dict, InjectedState]
# Note: It's important that every field has type hints. BaseTool is a
# Pydantic class and not having type hints can lead to unexpected behavior.
class ParameterScanTool(BaseTool):
"""
Tool for parameter scan.
"""
name: str = "parameter_scan"
description: str = """A tool to perform scanning of a given
parameter over a range of values and observe the effect on
the concentration of a given species"""
args_schema: Type[BaseModel] = ParameterScanInput
def _run(self,
tool_call_id: Annotated[str, InjectedToolCallId],
state: Annotated[dict, InjectedState],
sys_bio_model: ModelData = None,
arg_data: ArgumentData = None
) -> Command:
"""
Run the tool.
Args:
tool_call_id (str): The tool call ID. This is injected by the system.
state (dict): The state of the tool.
sys_bio_model (ModelData): The model data.
arg_data (ArgumentData): The argument data.
Returns:
Command: The updated state of the tool.
"""
logger.log(logging.INFO, "Calling parameter_scan tool %s, %s",
sys_bio_model, arg_data)
sbml_file_path = state['sbml_file_path'][-1] if len(state['sbml_file_path']) > 0 else None
model_object = load_biomodel(sys_bio_model,
sbml_file_path=sbml_file_path)
# Prepare the dictionary of species data
# that will be passed to the simulate method
# of the BasicoModel class
duration = 100.0
interval = 10
dic_species_data = {}
if arg_data:
# Prepare the dictionary of species data
if arg_data.species_to_be_analyzed_before_experiment is not None:
dic_species_data = dict(
zip(
arg_data.species_to_be_analyzed_before_experiment.species_name,
arg_data.species_to_be_analyzed_before_experiment.species_concentration
)
)
# # Add reocurring events (if any) to the model
# if arg_data.reocurring_data is not None:
# add_rec_events(model_object, arg_data.reocurring_data)
# Set the duration and interval
if arg_data.time_data is not None:
duration = arg_data.time_data.duration
interval = arg_data.time_data.interval
# Run the parameter scan
dic_param_scan = run_parameter_scan(model_object,
arg_data,
dic_species_data,
duration,
interval)
logger.log(logging.INFO, "Parameter scan results ready")
# Prepare the list dictionary of scanned data
list_dic_scanned_data = make_list_dic_scanned_data(dic_param_scan,
arg_data,
sys_bio_model,
tool_call_id)
# Prepare the dictionary of updated state for the model
dic_updated_state_for_model = {}
for key, value in {
"model_id": [sys_bio_model.biomodel_id],
"sbml_file_path": [sbml_file_path],
"dic_scanned_data": list_dic_scanned_data,
}.items():
if value:
dic_updated_state_for_model[key] = value
# Return the updated state
return Command(
update=dic_updated_state_for_model|{
# update the message history
"messages": [
ToolMessage(
content=f"Parameter scan results of {arg_data.experiment_name}",
tool_call_id=tool_call_id,
artifact=get_model_units(model_object)
)
],
}
)