--- a +++ b/aiagents4pharma/talk2biomodels/tools/get_modelinfo.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 + +""" +Tool for get model information. +""" + +import logging +from typing import Type, Optional, Annotated +from dataclasses import dataclass +import basico +from pydantic import BaseModel, Field +from langchain_core.tools import BaseTool +from langchain_core.messages import ToolMessage +from langchain_core.tools.base import InjectedToolCallId +from langgraph.prebuilt import InjectedState +from langgraph.types import Command +from .load_biomodel import ModelData, load_biomodel + +# Initialize logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +@dataclass +class RequestedModelInfo: + """ + Dataclass for storing the requested model information. + """ + species: bool = Field(description="Get species from the model.", default=False) + parameters: bool = Field(description="Get parameters from the model.", default=False) + compartments: bool = Field(description="Get compartments from the model.", default=False) + units: bool = Field(description="Get units from the model.", default=False) + description: bool = Field(description="Get description from the model.", default=False) + name: bool = Field(description="Get name from the model.", default=False) + +class GetModelInfoInput(BaseModel): + """ + Input schema for the GetModelInfo tool. + """ + requested_model_info: RequestedModelInfo = Field(description="requested model information") + sys_bio_model: ModelData = Field(description="model data") + tool_call_id: Annotated[str, InjectedToolCallId] + state: Annotated[dict, InjectedState] + +# Note: It's important that every field has type hints. BaseTool is a +# Pydantic class and not having type hints can lead to unexpected behavior. +class GetModelInfoTool(BaseTool): + """ + This tool ise used extract model information. + """ + name: str = "get_modelinfo" + description: str = """A tool for extracting name, + description, species, parameters, + compartments, and units from a model.""" + args_schema: Type[BaseModel] = GetModelInfoInput + + def _run(self, + requested_model_info: RequestedModelInfo, + tool_call_id: Annotated[str, InjectedToolCallId], + state: Annotated[dict, InjectedState], + sys_bio_model: Optional[ModelData] = None, + ) -> Command: + """ + Run the tool. + + Args: + requested_model_info (RequestedModelInfo): The requested model information. + tool_call_id (str): The tool call ID. This is injected by the system. + state (dict): The state of the tool. + sys_bio_model (ModelData): The model data. + + Returns: + Command: The updated state of the tool. + """ + logger.log(logging.INFO, + "Calling get_modelinfo tool %s, %s", + sys_bio_model, + requested_model_info) + # print (state, 'state') + sbml_file_path = state['sbml_file_path'][-1] if len(state['sbml_file_path']) > 0 else None + model_obj = load_biomodel(sys_bio_model, + sbml_file_path=sbml_file_path) + dic_results = {} + # Extract species from the model + if requested_model_info.species: + df_species = basico.model_info.get_species(model=model_obj.copasi_model) + if df_species is None: + raise ValueError("Unable to extract species from the model.") + # Convert index into a column + df_species.reset_index(inplace=True) + dic_results['Species'] = df_species[ + ['name', + 'compartment', + 'type', + 'unit', + 'initial_concentration', + 'display_name']] + # Convert this into a dictionary + dic_results['Species'] = dic_results['Species'].to_dict(orient='records') + + # Extract parameters from the model + if requested_model_info.parameters: + df_parameters = basico.model_info.get_parameters(model=model_obj.copasi_model) + if df_parameters is None: + raise ValueError("Unable to extract parameters from the model.") + # Convert index into a column + df_parameters.reset_index(inplace=True) + dic_results['Parameters'] = df_parameters[ + ['name', + 'type', + 'unit', + 'initial_value', + 'display_name']] + # Convert this into a dictionary + dic_results['Parameters'] = dic_results['Parameters'].to_dict(orient='records') + + # Extract compartments from the model + if requested_model_info.compartments: + df_compartments = basico.model_info.get_compartments(model=model_obj.copasi_model) + dic_results['Compartments'] = df_compartments.index.tolist() + dic_results['Compartments'] = ','.join(dic_results['Compartments']) + + # Extract description from the model + if requested_model_info.description: + dic_results['Description'] = model_obj.description + + # Extract description from the model + if requested_model_info.name: + dic_results['Name'] = model_obj.name + + # Extract time unit from the model + if requested_model_info.units: + dic_results['Units'] = basico.model_info.get_model_units(model=model_obj.copasi_model) + + # Prepare the dictionary of updated state for the model + dic_updated_state_for_model = {} + for key, value in { + "model_id": [sys_bio_model.biomodel_id], + "sbml_file_path": [sbml_file_path], + }.items(): + if value: + dic_updated_state_for_model[key] = value + + return Command( + update=dic_updated_state_for_model|{ + # update the message history + "messages": [ + ToolMessage( + content=dic_results, + tool_call_id=tool_call_id + ) + ], + } + )