|
a |
|
b/aiagents4pharma/talk2biomodels/tools/steady_state.py |
|
|
1 |
#!/usr/bin/env python3 |
|
|
2 |
|
|
|
3 |
""" |
|
|
4 |
Tool for parameter scan. |
|
|
5 |
""" |
|
|
6 |
|
|
|
7 |
import logging |
|
|
8 |
from typing import Type, Annotated |
|
|
9 |
import basico |
|
|
10 |
from pydantic import BaseModel, Field |
|
|
11 |
from langgraph.types import Command |
|
|
12 |
from langgraph.prebuilt import InjectedState |
|
|
13 |
from langchain_core.tools import BaseTool |
|
|
14 |
from langchain_core.messages import ToolMessage |
|
|
15 |
from langchain_core.tools.base import InjectedToolCallId |
|
|
16 |
from .load_biomodel import ModelData, load_biomodel |
|
|
17 |
from .load_arguments import ArgumentData, add_rec_events |
|
|
18 |
|
|
|
19 |
# Initialize logger |
|
|
20 |
logging.basicConfig(level=logging.INFO) |
|
|
21 |
logger = logging.getLogger(__name__) |
|
|
22 |
|
|
|
23 |
def run_steady_state(model_object, |
|
|
24 |
dic_species_to_be_analyzed_before_experiment): |
|
|
25 |
""" |
|
|
26 |
Run the steady state analysis. |
|
|
27 |
|
|
|
28 |
Args: |
|
|
29 |
model_object: The model object. |
|
|
30 |
dic_species_to_be_analyzed_before_experiment: Dictionary of species data. |
|
|
31 |
|
|
|
32 |
Returns: |
|
|
33 |
DataFrame: The results of the steady state analysis. |
|
|
34 |
""" |
|
|
35 |
# Update the fixed model species and parameters |
|
|
36 |
# These are the initial conditions of the model |
|
|
37 |
# set by the user |
|
|
38 |
model_object.update_parameters(dic_species_to_be_analyzed_before_experiment) |
|
|
39 |
logger.log(logging.INFO, "Running steady state analysis") |
|
|
40 |
# Run the steady state analysis |
|
|
41 |
output = basico.task_steadystate.run_steadystate(model=model_object.copasi_model) |
|
|
42 |
if output == 0: |
|
|
43 |
logger.error("Steady state analysis failed") |
|
|
44 |
raise ValueError("A steady state was not found") |
|
|
45 |
logger.log(logging.INFO, "Steady state analysis successful") |
|
|
46 |
# Store the steady state results in a DataFrame |
|
|
47 |
df_steady_state = basico.model_info.get_species(model=model_object.copasi_model).reset_index() |
|
|
48 |
# print (df_steady_state) |
|
|
49 |
# Rename the column name to species_name |
|
|
50 |
df_steady_state.rename(columns={'name': 'species_name'}, |
|
|
51 |
inplace=True) |
|
|
52 |
# Rename the column concentration to steady_state_concentration |
|
|
53 |
df_steady_state.rename(columns={'concentration': 'steady_state_concentration'}, |
|
|
54 |
inplace=True) |
|
|
55 |
# Rename the column transition_time to steady_state_transition_time |
|
|
56 |
df_steady_state.rename(columns={'transition_time': 'steady_state_transition_time'}, |
|
|
57 |
inplace=True) |
|
|
58 |
# Drop some columns |
|
|
59 |
df_steady_state.drop(columns= |
|
|
60 |
[ |
|
|
61 |
'initial_particle_number', |
|
|
62 |
'initial_expression', |
|
|
63 |
'expression', |
|
|
64 |
'particle_number', |
|
|
65 |
'type', |
|
|
66 |
'particle_number_rate', |
|
|
67 |
'key', |
|
|
68 |
'sbml_id', |
|
|
69 |
'display_name'], |
|
|
70 |
inplace=True) |
|
|
71 |
logger.log(logging.INFO, "Steady state results with shape %s", df_steady_state.shape) |
|
|
72 |
return df_steady_state |
|
|
73 |
|
|
|
74 |
class SteadyStateInput(BaseModel): |
|
|
75 |
""" |
|
|
76 |
Input schema for the steady state tool. |
|
|
77 |
""" |
|
|
78 |
sys_bio_model: ModelData = Field(description="model data", |
|
|
79 |
default=None) |
|
|
80 |
arg_data: ArgumentData = Field( |
|
|
81 |
description="time, species, and reocurring data" |
|
|
82 |
" that must be set before the steady state analysis" |
|
|
83 |
" as well as the experiment name", default=None) |
|
|
84 |
tool_call_id: Annotated[str, InjectedToolCallId] |
|
|
85 |
state: Annotated[dict, InjectedState] |
|
|
86 |
|
|
|
87 |
# Note: It's important that every field has type hints. BaseTool is a |
|
|
88 |
# Pydantic class and not having type hints can lead to unexpected behavior. |
|
|
89 |
class SteadyStateTool(BaseTool): |
|
|
90 |
""" |
|
|
91 |
Tool to bring a model to steady state. |
|
|
92 |
""" |
|
|
93 |
name: str = "steady_state" |
|
|
94 |
description: str = "A tool to bring a model to steady state." |
|
|
95 |
args_schema: Type[BaseModel] = SteadyStateInput |
|
|
96 |
|
|
|
97 |
def _run(self, |
|
|
98 |
tool_call_id: Annotated[str, InjectedToolCallId], |
|
|
99 |
state: Annotated[dict, InjectedState], |
|
|
100 |
sys_bio_model: ModelData = None, |
|
|
101 |
arg_data: ArgumentData = None |
|
|
102 |
) -> Command: |
|
|
103 |
""" |
|
|
104 |
Run the tool. |
|
|
105 |
|
|
|
106 |
Args: |
|
|
107 |
tool_call_id (str): The tool call ID. This is injected by the system. |
|
|
108 |
state (dict): The state of the tool. |
|
|
109 |
sys_bio_model (ModelData): The model data. |
|
|
110 |
arg_data (ArgumentData): The argument data. |
|
|
111 |
|
|
|
112 |
Returns: |
|
|
113 |
Command: The updated state of the tool. |
|
|
114 |
""" |
|
|
115 |
logger.log(logging.INFO, "Calling the steady_state tool %s, %s", |
|
|
116 |
sys_bio_model, arg_data) |
|
|
117 |
# print (f'Calling steady_state tool {sys_bio_model}, {arg_data}, {tool_call_id}') |
|
|
118 |
sbml_file_path = state['sbml_file_path'][-1] if len(state['sbml_file_path']) > 0 else None |
|
|
119 |
model_object = load_biomodel(sys_bio_model, |
|
|
120 |
sbml_file_path=sbml_file_path) |
|
|
121 |
# Prepare the dictionary of species data |
|
|
122 |
# that will be passed to the simulate method |
|
|
123 |
# of the BasicoModel class |
|
|
124 |
dic_species_to_be_analyzed_before_experiment = {} |
|
|
125 |
if arg_data: |
|
|
126 |
# Prepare the dictionary of species data |
|
|
127 |
if arg_data.species_to_be_analyzed_before_experiment is not None: |
|
|
128 |
dic_species_to_be_analyzed_before_experiment = dict( |
|
|
129 |
zip(arg_data.species_to_be_analyzed_before_experiment.species_name, |
|
|
130 |
arg_data.species_to_be_analyzed_before_experiment.species_concentration)) |
|
|
131 |
# Add reocurring events (if any) to the model |
|
|
132 |
if arg_data.reocurring_data is not None: |
|
|
133 |
add_rec_events(model_object, arg_data.reocurring_data) |
|
|
134 |
# Run the parameter scan |
|
|
135 |
df_steady_state = run_steady_state(model_object, |
|
|
136 |
dic_species_to_be_analyzed_before_experiment) |
|
|
137 |
print (df_steady_state) |
|
|
138 |
# Prepare the dictionary of scanned data |
|
|
139 |
# that will be passed to the state of the graph |
|
|
140 |
dic_steady_state_data = { |
|
|
141 |
'name': arg_data.experiment_name, |
|
|
142 |
'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload', |
|
|
143 |
'tool_call_id': tool_call_id, |
|
|
144 |
'data': df_steady_state.to_dict(orient='records') |
|
|
145 |
} |
|
|
146 |
# Prepare the dictionary of updated state for the model |
|
|
147 |
dic_updated_state_for_model = {} |
|
|
148 |
for key, value in { |
|
|
149 |
"model_id": [sys_bio_model.biomodel_id], |
|
|
150 |
"sbml_file_path": [sbml_file_path], |
|
|
151 |
"dic_steady_state_data": [dic_steady_state_data] |
|
|
152 |
}.items(): |
|
|
153 |
if value: |
|
|
154 |
dic_updated_state_for_model[key] = value |
|
|
155 |
# Return the updated state |
|
|
156 |
return Command( |
|
|
157 |
update=dic_updated_state_for_model|{ |
|
|
158 |
# Update the message history |
|
|
159 |
"messages": [ |
|
|
160 |
ToolMessage( |
|
|
161 |
content=f"Steady state analysis of" |
|
|
162 |
f" {arg_data.experiment_name}" |
|
|
163 |
" was successful.", |
|
|
164 |
tool_call_id=tool_call_id, |
|
|
165 |
artifact={'dic_data': df_steady_state.to_dict(orient='records')} |
|
|
166 |
) |
|
|
167 |
], |
|
|
168 |
} |
|
|
169 |
) |