a b/aiagents4pharma/talk2biomodels/tools/get_modelinfo.py
1
#!/usr/bin/env python3
2
3
"""
4
Tool for get model information.
5
"""
6
7
import logging
8
from typing import Type, Optional, Annotated
9
from dataclasses import dataclass
10
import basico
11
from pydantic import BaseModel, Field
12
from langchain_core.tools import BaseTool
13
from langchain_core.messages import ToolMessage
14
from langchain_core.tools.base import InjectedToolCallId
15
from langgraph.prebuilt import InjectedState
16
from langgraph.types import Command
17
from .load_biomodel import ModelData, load_biomodel
18
19
# Initialize logger
20
logging.basicConfig(level=logging.INFO)
21
logger = logging.getLogger(__name__)
22
23
@dataclass
24
class RequestedModelInfo:
25
    """
26
    Dataclass for storing the requested model information.
27
    """
28
    species: bool = Field(description="Get species from the model.", default=False)
29
    parameters: bool = Field(description="Get parameters from the model.", default=False)
30
    compartments: bool = Field(description="Get compartments from the model.", default=False)
31
    units: bool = Field(description="Get units from the model.", default=False)
32
    description: bool = Field(description="Get description from the model.", default=False)
33
    name: bool = Field(description="Get name from the model.", default=False)
34
35
class GetModelInfoInput(BaseModel):
36
    """
37
    Input schema for the GetModelInfo tool.
38
    """
39
    requested_model_info: RequestedModelInfo = Field(description="requested model information")
40
    sys_bio_model: ModelData = Field(description="model data")
41
    tool_call_id: Annotated[str, InjectedToolCallId]
42
    state: Annotated[dict, InjectedState]
43
44
# Note: It's important that every field has type hints. BaseTool is a
45
# Pydantic class and not having type hints can lead to unexpected behavior.
46
class GetModelInfoTool(BaseTool):
47
    """
48
    This tool ise used extract model information.
49
    """
50
    name: str = "get_modelinfo"
51
    description: str = """A tool for extracting name,
52
                    description, species, parameters,
53
                    compartments, and units from a model."""
54
    args_schema: Type[BaseModel] = GetModelInfoInput
55
56
    def _run(self,
57
            requested_model_info: RequestedModelInfo,
58
            tool_call_id: Annotated[str, InjectedToolCallId],
59
            state: Annotated[dict, InjectedState],
60
            sys_bio_model: Optional[ModelData] = None,
61
             ) -> Command:
62
        """
63
        Run the tool.
64
65
        Args:
66
            requested_model_info (RequestedModelInfo): The requested model information.
67
            tool_call_id (str): The tool call ID. This is injected by the system.
68
            state (dict): The state of the tool.
69
            sys_bio_model (ModelData): The model data.
70
71
        Returns:
72
            Command: The updated state of the tool.
73
        """
74
        logger.log(logging.INFO,
75
                   "Calling get_modelinfo tool %s, %s",
76
                     sys_bio_model,
77
                   requested_model_info)
78
        # print (state, 'state')
79
        sbml_file_path = state['sbml_file_path'][-1] if len(state['sbml_file_path']) > 0 else None
80
        model_obj = load_biomodel(sys_bio_model,
81
                                  sbml_file_path=sbml_file_path)
82
        dic_results = {}
83
        # Extract species from the model
84
        if requested_model_info.species:
85
            df_species = basico.model_info.get_species(model=model_obj.copasi_model)
86
            if df_species is None:
87
                raise ValueError("Unable to extract species from the model.")
88
            # Convert index into a column
89
            df_species.reset_index(inplace=True)
90
            dic_results['Species'] = df_species[
91
                                        ['name',
92
                                         'compartment',
93
                                         'type',
94
                                         'unit',
95
                                         'initial_concentration',
96
                                         'display_name']]
97
            # Convert this into a dictionary
98
            dic_results['Species'] = dic_results['Species'].to_dict(orient='records')
99
100
        # Extract parameters from the model
101
        if requested_model_info.parameters:
102
            df_parameters = basico.model_info.get_parameters(model=model_obj.copasi_model)
103
            if df_parameters is None:
104
                raise ValueError("Unable to extract parameters from the model.")
105
            # Convert index into a column
106
            df_parameters.reset_index(inplace=True)
107
            dic_results['Parameters'] = df_parameters[
108
                                        ['name',
109
                                         'type',
110
                                         'unit',
111
                                         'initial_value',
112
                                         'display_name']]
113
            # Convert this into a dictionary
114
            dic_results['Parameters'] = dic_results['Parameters'].to_dict(orient='records')
115
116
        # Extract compartments from the model
117
        if requested_model_info.compartments:
118
            df_compartments = basico.model_info.get_compartments(model=model_obj.copasi_model)
119
            dic_results['Compartments'] = df_compartments.index.tolist()
120
            dic_results['Compartments'] = ','.join(dic_results['Compartments'])
121
122
        # Extract description from the model
123
        if requested_model_info.description:
124
            dic_results['Description'] = model_obj.description
125
126
        # Extract description from the model
127
        if requested_model_info.name:
128
            dic_results['Name'] = model_obj.name
129
130
        # Extract time unit from the model
131
        if requested_model_info.units:
132
            dic_results['Units'] = basico.model_info.get_model_units(model=model_obj.copasi_model)
133
134
        # Prepare the dictionary of updated state for the model
135
        dic_updated_state_for_model = {}
136
        for key, value in {
137
                        "model_id": [sys_bio_model.biomodel_id],
138
                        "sbml_file_path": [sbml_file_path],
139
                        }.items():
140
            if value:
141
                dic_updated_state_for_model[key] = value
142
143
        return Command(
144
            update=dic_updated_state_for_model|{
145
                    # update the message history
146
                    "messages": [
147
                        ToolMessage(
148
                            content=dic_results,
149
                            tool_call_id=tool_call_id
150
                            )
151
                        ],
152
                    }
153
            )