|
a |
|
b/aiagents4pharma/talk2biomodels/tools/parameter_scan.py |
|
|
1 |
#!/usr/bin/env python3 |
|
|
2 |
|
|
|
3 |
""" |
|
|
4 |
Tool for parameter scan. |
|
|
5 |
""" |
|
|
6 |
|
|
|
7 |
import logging |
|
|
8 |
from dataclasses import dataclass |
|
|
9 |
from typing import Type, Union, List, Annotated, Optional |
|
|
10 |
import pandas as pd |
|
|
11 |
import basico |
|
|
12 |
from pydantic import BaseModel, Field |
|
|
13 |
from langgraph.types import Command |
|
|
14 |
from langgraph.prebuilt import InjectedState |
|
|
15 |
from langchain_core.tools import BaseTool |
|
|
16 |
from langchain_core.messages import ToolMessage |
|
|
17 |
from langchain_core.tools.base import InjectedToolCallId |
|
|
18 |
from .load_biomodel import ModelData, load_biomodel |
|
|
19 |
from .load_arguments import TimeData, SpeciesInitialData |
|
|
20 |
from .utils import get_model_units |
|
|
21 |
|
|
|
22 |
# Initialize logger |
|
|
23 |
logging.basicConfig(level=logging.INFO) |
|
|
24 |
logger = logging.getLogger(__name__) |
|
|
25 |
|
|
|
26 |
@dataclass |
|
|
27 |
class ParameterScanData(BaseModel): |
|
|
28 |
""" |
|
|
29 |
Dataclass for storing the parameter scan data. |
|
|
30 |
""" |
|
|
31 |
species_names: List[str] = Field( |
|
|
32 |
description="species to be observed after each scan." |
|
|
33 |
" These are the species whose concentration" |
|
|
34 |
" will be observed after the parameter scan." |
|
|
35 |
" Do not make up this data.", |
|
|
36 |
default=[]) |
|
|
37 |
species_parameter_name: str = Field( |
|
|
38 |
description="Species or parameter name to be scanned." |
|
|
39 |
" This is the species or parameter whose value will be scanned" |
|
|
40 |
" over a range of values. This does not include the species" |
|
|
41 |
" that are to be observed after the scan." |
|
|
42 |
"Do not make up this data.", |
|
|
43 |
default=None) |
|
|
44 |
species_parameter_values: List[Union[int, float]] = Field( |
|
|
45 |
description="Species or parameter values to be scanned." |
|
|
46 |
" These are the values of the species or parameters that will be" |
|
|
47 |
" scanned over a range of values. This does not include the " |
|
|
48 |
"species that are to be observed after the scan." |
|
|
49 |
"Do not make up this data.", |
|
|
50 |
default=None) |
|
|
51 |
|
|
|
52 |
@dataclass |
|
|
53 |
class ArgumentData: |
|
|
54 |
""" |
|
|
55 |
Dataclass for storing the argument data. |
|
|
56 |
""" |
|
|
57 |
time_data: TimeData = Field(description="time data", default=None) |
|
|
58 |
species_to_be_analyzed_before_experiment: Optional[SpeciesInitialData] = Field( |
|
|
59 |
description=" This is the initial condition of the model." |
|
|
60 |
" This does not include species that reoccur or the species" |
|
|
61 |
" whose concentration is to be determined/observed at the end" |
|
|
62 |
" of the experiment. This also does not include the species" |
|
|
63 |
" or the parameter that is to be scanned. Do not make up this data.", |
|
|
64 |
default=None) |
|
|
65 |
parameter_scan_data: ParameterScanData = Field( |
|
|
66 |
description="parameter scan data", |
|
|
67 |
default=None) |
|
|
68 |
experiment_name: str = Field( |
|
|
69 |
description="An AI assigned `_` separated unique name of" |
|
|
70 |
" the parameter scan experiment based on human query." |
|
|
71 |
" This must be unique for each experiment.") |
|
|
72 |
|
|
|
73 |
def make_list_dic_scanned_data(dic_param_scan, arg_data, sys_bio_model, tool_call_id): |
|
|
74 |
""" |
|
|
75 |
Prepare the list dictionary of scanned data |
|
|
76 |
that will be passed to the state of the graph. |
|
|
77 |
|
|
|
78 |
Args: |
|
|
79 |
dic_param_scan: Dictionary of parameter scan results. |
|
|
80 |
arg_data: The argument data. |
|
|
81 |
sys_bio_model: The model data. |
|
|
82 |
tool_call_id: The tool call ID. |
|
|
83 |
|
|
|
84 |
Returns: |
|
|
85 |
list: List of dictionary of scanned data. |
|
|
86 |
""" |
|
|
87 |
list_dic_scanned_data = [] |
|
|
88 |
for species_name, df_param_scan in dic_param_scan.items(): |
|
|
89 |
logger.log(logging.INFO, "Parameter scan results for %s with shape %s", |
|
|
90 |
species_name, |
|
|
91 |
df_param_scan.shape) |
|
|
92 |
# Prepare the list dictionary of scanned data |
|
|
93 |
# that will be passed to the state of the graph |
|
|
94 |
list_dic_scanned_data.append({ |
|
|
95 |
'name': arg_data.experiment_name+':'+species_name, |
|
|
96 |
'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload', |
|
|
97 |
'tool_call_id': tool_call_id, |
|
|
98 |
'data': df_param_scan.to_dict() |
|
|
99 |
}) |
|
|
100 |
return list_dic_scanned_data |
|
|
101 |
|
|
|
102 |
def run_parameter_scan(model_object, |
|
|
103 |
arg_data, |
|
|
104 |
dic_species_data, |
|
|
105 |
duration, |
|
|
106 |
interval) -> dict: |
|
|
107 |
""" |
|
|
108 |
Run parameter scan on the model. |
|
|
109 |
|
|
|
110 |
Args: |
|
|
111 |
model_object: The model object. |
|
|
112 |
arg_data: The argument data. |
|
|
113 |
dic_species_data: Dictionary of species data. |
|
|
114 |
duration: Duration of the simulation. |
|
|
115 |
interval: Interval between time points in the simulation. |
|
|
116 |
|
|
|
117 |
Returns: |
|
|
118 |
dict: Dictionary of parameter scan results. Each key is a species name |
|
|
119 |
and each value is a DataFrame containing the results of the parameter scan. |
|
|
120 |
""" |
|
|
121 |
# Extract all parameter names from the model |
|
|
122 |
df_all_parameters = basico.model_info.get_parameters(model=model_object.copasi_model) |
|
|
123 |
all_parameters = [] |
|
|
124 |
if df_all_parameters is not None: |
|
|
125 |
# For example model 10 in the BioModels database |
|
|
126 |
# has no parameters |
|
|
127 |
all_parameters = df_all_parameters.index.tolist() |
|
|
128 |
|
|
|
129 |
# Extract all species name from the model |
|
|
130 |
df_all_species = basico.model_info.get_species(model=model_object.copasi_model) |
|
|
131 |
all_species = df_all_species['display_name'].tolist() |
|
|
132 |
|
|
|
133 |
# Verify if the given species or parameter names to be scanned are valid |
|
|
134 |
if arg_data.parameter_scan_data.species_parameter_name not in all_parameters + all_species: |
|
|
135 |
logger.error( |
|
|
136 |
"Invalid species or parameter name: %s", |
|
|
137 |
arg_data.parameter_scan_data.species_parameter_name) |
|
|
138 |
raise ValueError( |
|
|
139 |
"Invalid species or parameter name: " |
|
|
140 |
f"{arg_data.parameter_scan_data.species_parameter_name}.") |
|
|
141 |
|
|
|
142 |
# Dictionary to store the parameter scan results |
|
|
143 |
dic_param_scan_results = {} |
|
|
144 |
|
|
|
145 |
# Loop through the species names that are to be observed |
|
|
146 |
for species_name in arg_data.parameter_scan_data.species_names: |
|
|
147 |
# Verify if the given species name to be observed is valid |
|
|
148 |
if species_name not in all_species: |
|
|
149 |
logger.error("Invalid species name: %s", species_name) |
|
|
150 |
raise ValueError(f"Invalid species name: {species_name}.") |
|
|
151 |
|
|
|
152 |
# Copy the model object to avoid modifying the original model |
|
|
153 |
model_object_copy = model_object.model_copy() |
|
|
154 |
|
|
|
155 |
# Update the fixed model species and parameters |
|
|
156 |
# These are the initial conditions of the model |
|
|
157 |
# set by the user |
|
|
158 |
model_object_copy.update_parameters(dic_species_data) |
|
|
159 |
|
|
|
160 |
# Initialize empty DataFrame to store results |
|
|
161 |
# of the parameter scan |
|
|
162 |
df_param_scan = pd.DataFrame() |
|
|
163 |
|
|
|
164 |
# Loop through the parameter that are to be scanned |
|
|
165 |
for param_value in arg_data.parameter_scan_data.species_parameter_values: |
|
|
166 |
# Update the parameter value in the model |
|
|
167 |
model_object_copy.update_parameters( |
|
|
168 |
{arg_data.parameter_scan_data.species_parameter_name: param_value}) |
|
|
169 |
# Simulate the model |
|
|
170 |
model_object_copy.simulate(duration=duration, interval=interval) |
|
|
171 |
# If the column name 'Time' is not present in the results DataFrame |
|
|
172 |
if 'Time' not in df_param_scan.columns: |
|
|
173 |
df_param_scan['Time'] = model_object_copy.simulation_results['Time'] |
|
|
174 |
# Add the simulation results to the results DataFrame |
|
|
175 |
col_name = f"{arg_data.parameter_scan_data.species_parameter_name}_{param_value}" |
|
|
176 |
df_param_scan[col_name] = model_object_copy.simulation_results[species_name] |
|
|
177 |
|
|
|
178 |
logger.log(logging.INFO, "Parameter scan results with shape %s", df_param_scan.shape) |
|
|
179 |
|
|
|
180 |
# Add the results of the parameter scan to the dictionary |
|
|
181 |
dic_param_scan_results[species_name] = df_param_scan |
|
|
182 |
# return df_param_scan |
|
|
183 |
return dic_param_scan_results |
|
|
184 |
|
|
|
185 |
class ParameterScanInput(BaseModel): |
|
|
186 |
""" |
|
|
187 |
Input schema for the ParameterScan tool. |
|
|
188 |
""" |
|
|
189 |
sys_bio_model: ModelData = Field(description="model data", |
|
|
190 |
default=None) |
|
|
191 |
arg_data: ArgumentData = Field(description= |
|
|
192 |
"""time, species, and reocurring data |
|
|
193 |
as well as the parameter scan name and |
|
|
194 |
data""", |
|
|
195 |
default=None) |
|
|
196 |
tool_call_id: Annotated[str, InjectedToolCallId] |
|
|
197 |
state: Annotated[dict, InjectedState] |
|
|
198 |
|
|
|
199 |
# Note: It's important that every field has type hints. BaseTool is a |
|
|
200 |
# Pydantic class and not having type hints can lead to unexpected behavior. |
|
|
201 |
class ParameterScanTool(BaseTool): |
|
|
202 |
""" |
|
|
203 |
Tool for parameter scan. |
|
|
204 |
""" |
|
|
205 |
name: str = "parameter_scan" |
|
|
206 |
description: str = """A tool to perform scanning of a given |
|
|
207 |
parameter over a range of values and observe the effect on |
|
|
208 |
the concentration of a given species""" |
|
|
209 |
args_schema: Type[BaseModel] = ParameterScanInput |
|
|
210 |
|
|
|
211 |
def _run(self, |
|
|
212 |
tool_call_id: Annotated[str, InjectedToolCallId], |
|
|
213 |
state: Annotated[dict, InjectedState], |
|
|
214 |
sys_bio_model: ModelData = None, |
|
|
215 |
arg_data: ArgumentData = None |
|
|
216 |
) -> Command: |
|
|
217 |
""" |
|
|
218 |
Run the tool. |
|
|
219 |
|
|
|
220 |
Args: |
|
|
221 |
tool_call_id (str): The tool call ID. This is injected by the system. |
|
|
222 |
state (dict): The state of the tool. |
|
|
223 |
sys_bio_model (ModelData): The model data. |
|
|
224 |
arg_data (ArgumentData): The argument data. |
|
|
225 |
|
|
|
226 |
Returns: |
|
|
227 |
Command: The updated state of the tool. |
|
|
228 |
""" |
|
|
229 |
logger.log(logging.INFO, "Calling parameter_scan tool %s, %s", |
|
|
230 |
sys_bio_model, arg_data) |
|
|
231 |
sbml_file_path = state['sbml_file_path'][-1] if len(state['sbml_file_path']) > 0 else None |
|
|
232 |
model_object = load_biomodel(sys_bio_model, |
|
|
233 |
sbml_file_path=sbml_file_path) |
|
|
234 |
# Prepare the dictionary of species data |
|
|
235 |
# that will be passed to the simulate method |
|
|
236 |
# of the BasicoModel class |
|
|
237 |
duration = 100.0 |
|
|
238 |
interval = 10 |
|
|
239 |
dic_species_data = {} |
|
|
240 |
if arg_data: |
|
|
241 |
# Prepare the dictionary of species data |
|
|
242 |
if arg_data.species_to_be_analyzed_before_experiment is not None: |
|
|
243 |
dic_species_data = dict( |
|
|
244 |
zip( |
|
|
245 |
arg_data.species_to_be_analyzed_before_experiment.species_name, |
|
|
246 |
arg_data.species_to_be_analyzed_before_experiment.species_concentration |
|
|
247 |
) |
|
|
248 |
) |
|
|
249 |
|
|
|
250 |
# # Add reocurring events (if any) to the model |
|
|
251 |
# if arg_data.reocurring_data is not None: |
|
|
252 |
# add_rec_events(model_object, arg_data.reocurring_data) |
|
|
253 |
|
|
|
254 |
# Set the duration and interval |
|
|
255 |
if arg_data.time_data is not None: |
|
|
256 |
duration = arg_data.time_data.duration |
|
|
257 |
interval = arg_data.time_data.interval |
|
|
258 |
|
|
|
259 |
# Run the parameter scan |
|
|
260 |
dic_param_scan = run_parameter_scan(model_object, |
|
|
261 |
arg_data, |
|
|
262 |
dic_species_data, |
|
|
263 |
duration, |
|
|
264 |
interval) |
|
|
265 |
|
|
|
266 |
logger.log(logging.INFO, "Parameter scan results ready") |
|
|
267 |
# Prepare the list dictionary of scanned data |
|
|
268 |
list_dic_scanned_data = make_list_dic_scanned_data(dic_param_scan, |
|
|
269 |
arg_data, |
|
|
270 |
sys_bio_model, |
|
|
271 |
tool_call_id) |
|
|
272 |
# Prepare the dictionary of updated state for the model |
|
|
273 |
dic_updated_state_for_model = {} |
|
|
274 |
for key, value in { |
|
|
275 |
"model_id": [sys_bio_model.biomodel_id], |
|
|
276 |
"sbml_file_path": [sbml_file_path], |
|
|
277 |
"dic_scanned_data": list_dic_scanned_data, |
|
|
278 |
}.items(): |
|
|
279 |
if value: |
|
|
280 |
dic_updated_state_for_model[key] = value |
|
|
281 |
# Return the updated state |
|
|
282 |
return Command( |
|
|
283 |
update=dic_updated_state_for_model|{ |
|
|
284 |
# update the message history |
|
|
285 |
"messages": [ |
|
|
286 |
ToolMessage( |
|
|
287 |
content=f"Parameter scan results of {arg_data.experiment_name}", |
|
|
288 |
tool_call_id=tool_call_id, |
|
|
289 |
artifact=get_model_units(model_object) |
|
|
290 |
) |
|
|
291 |
], |
|
|
292 |
} |
|
|
293 |
) |