Switch to side-by-side view

--- a
+++ b/aiagents4pharma/talk2biomodels/tools/get_annotation.py
@@ -0,0 +1,327 @@
+#!/usr/bin/env python3
+
+"""
+This module contains the `GetAnnotationTool` for fetching species annotations 
+based on the provided model and species names.
+"""
+import math
+from typing import List, Annotated, Type, Union, Literal
+import logging
+from dataclasses import dataclass
+import hydra
+from pydantic import BaseModel, Field
+import basico
+import pandas as pd
+from langgraph.types import Command
+from langgraph.prebuilt import InjectedState
+from langchain_core.tools.base import BaseTool
+from langchain_core.tools.base import InjectedToolCallId
+from langchain_core.messages import ToolMessage
+# from langchain_openai import ChatOpenAI
+from .load_biomodel import ModelData, load_biomodel
+from ..api.uniprot import search_uniprot_labels
+from ..api.ols import search_ols_labels
+from ..api.kegg import fetch_kegg_annotations
+
+# Initialize logger
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+ols_ontology_abbreviations = {'pato', 'chebi', 'sbo', 'fma', 'pr','go'}
+
+def extract_relevant_species_names(model_object, arg_data, state):
+    """
+    Extract relevant species names based on the user question.
+    """
+    # Load hydra configuration
+    with hydra.initialize(version_base=None, config_path="../configs"):
+        cfg = hydra.compose(config_name='config',
+                            overrides=['tools/get_annotation=default'])
+        cfg = cfg.tools.get_annotation
+    logger.info("Loaded the following system prompt for the LLM"
+                " to get a structured output: %s", cfg.prompt)
+
+    # Extract all the species names from the model
+    df_species = basico.model_info.get_species(model=model_object.copasi_model)
+    if df_species is None:
+        raise ValueError("Unable to extract species from the model.")
+    # Get all the species names
+    all_species_names = df_species.index.tolist()
+
+    # Define a structured output for the LLM model
+    class CustomHeader(BaseModel):
+        """
+        A list of species based on user question.
+        """
+        relevant_species: Union[None, List[Literal[*all_species_names]]] = Field(
+                description="""List of species based on user question.
+                If no relevant species are found, it must be None.""")
+
+    # Create an instance of the LLM model
+    llm = state['llm_model']
+    # Get the structured output from the LLM model
+    llm_with_structured_output = llm.with_structured_output(CustomHeader)
+    # Define the question for the LLM model using the prompt
+    question = cfg.prompt
+    question += f'Here is the user question: {arg_data.user_question}'
+    # Invoke the LLM model with the user question
+    results = llm_with_structured_output.invoke(question)
+    logging.info("Results from the LLM model: %s", results)
+    # Check if the returned species names are empty
+    if not results.relevant_species:
+        raise ValueError("Model does not contain the requested species.")
+    extracted_species = []
+    # Extract all the species names from the model
+    for species in results.relevant_species:
+        if species in all_species_names:
+            extracted_species.append(species)
+    logger.info("Extracted species: %s", extracted_species)
+    return extracted_species
+
+def prepare_content_msg(species_without_description: List[str]):
+    """
+    Prepare the content message.
+    """
+    content = 'Successfully extracted annotations for the species.'
+    if species_without_description:
+        content += f'''The descriptions for the following species
+                        were not found:
+                        {", ".join(species_without_description)}.'''
+    return content
+
+@dataclass
+class ArgumentData:
+    """
+    Dataclass for storing the argument data.
+    """
+    experiment_name: Annotated[str, "An AI assigned _ separated name of"
+                                    " the experiment based on human query"
+                                    " and the context of the experiment."
+                                    " This must be set before the experiment is run."]
+    user_question: Annotated[str, "Description of the user question"]
+
+class GetAnnotationInput(BaseModel):
+    """
+    Input schema for annotation tool.
+    """
+    arg_data: ArgumentData = Field(description="argument data")
+    sys_bio_model: ModelData = Field(description="model data")
+    tool_call_id: Annotated[str, InjectedToolCallId]
+    state: Annotated[dict, InjectedState]
+
+class GetAnnotationTool(BaseTool):
+    """
+    Tool for fetching species annotations based on the provided model and species names.
+    """
+    name: str = "get_annotation"
+    description: str = '''A tool to extract annotations for a list of species names
+                        based on the provided model. Annotations include
+                        the species name, description, database, ID, link,
+                        and qualifier. The tool can handle multiple species
+                        in a single invoke.'''
+    args_schema: Type[BaseModel] = GetAnnotationInput
+    return_direct: bool = False
+
+    def _run(self,
+             arg_data: ArgumentData,
+             tool_call_id: Annotated[str, InjectedToolCallId],
+             state: Annotated[dict, InjectedState],
+             sys_bio_model: ModelData = None) -> str:
+        """
+        Run the tool.
+        """
+        logger.info("Running the GetAnnotationTool tool for species %s, %s",
+                    arg_data.user_question,
+                    arg_data.experiment_name)
+
+        # Prepare the model object
+        sbml_file_path = state['sbml_file_path'][-1] if state['sbml_file_path'] else None
+        model_object = load_biomodel(sys_bio_model, sbml_file_path=sbml_file_path)
+
+        # Extract relevant species names based on the user question
+        list_species_names = extract_relevant_species_names(model_object, arg_data, state)
+        print (list_species_names)
+
+        (annotations_df,
+         species_without_description) = self._fetch_annotations(list_species_names)
+
+        # Process annotations
+        annotations_df = self._process_annotations(annotations_df)
+
+        # Prepare the simulated data
+        dic_annotations_data = {
+            'name': arg_data.experiment_name,
+            'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload',
+            'tool_call_id': tool_call_id,
+            'data': annotations_df.to_dict()
+        }
+
+        # Update the state with the annotations data
+        dic_updated_state_for_model = {}
+        for key, value in {
+            "model_id": [sys_bio_model.biomodel_id],
+            "sbml_file_path": [sbml_file_path],
+            "dic_annotations_data": [dic_annotations_data]
+        }.items():
+            if value:
+                dic_updated_state_for_model[key] = value
+
+        return Command(
+            update=dic_updated_state_for_model | {
+                "messages": [
+                    ToolMessage(
+                        content=prepare_content_msg(species_without_description),
+                        artifact=True,
+                        tool_call_id=tool_call_id
+                    )
+                ]
+            }
+        )
+
+    def _fetch_annotations(self, list_species_names: List[str]) -> tuple:
+        """
+        Fetch annotations for the given species names from the model.
+        In this method, we fetch the MIRIAM annotations for the species names.
+        If the annotation is not found, we add the species to the list of
+        species not found. If the annotation is found, we extract the descriptions
+        from the annotation and add them to the data list.
+
+        Args:
+            list_species_names (List[str]): List of species names to fetch annotations for.
+
+        Returns:
+            tuple: A tuple containing the annotations dataframe, species not found list,
+                   and description not found list.
+        """
+        description_not_found = []
+        data = []
+
+        # Loop through the species names
+        for species in list_species_names:
+            # Get the MIRIAM annotation for the species
+            annotation = basico.get_miriam_annotation(name=species)
+
+            # Extract the descriptions from the annotation
+            descriptions = annotation.get("descriptions", [])
+
+            if descriptions == []:
+                description_not_found.append(species)
+                continue
+
+            # Loop through the descriptions and add them to the data list
+            for desc in descriptions:
+                data.append({
+                    "Species Name": species,
+                    "Link": desc["id"],
+                    "Qualifier": desc["qualifier"]
+                })
+
+        # Create a dataframe from the data list
+        annotations_df = pd.DataFrame(data)
+
+        # Return the annotations dataframe and the species not found list
+        return annotations_df, description_not_found
+
+    def _process_annotations(self, annotations_df: pd.DataFrame) -> pd.DataFrame:
+        """
+        Process annotations dataframe to add additional information.
+        In this method, we add a new column for the ID, a new column for the database,
+        and a new column for the description. We then reorder the columns and process
+        the link to format it correctly.
+
+        Args:
+            annotations_df (pd.DataFrame): Annotations dataframe to process.
+
+        Returns:
+            pd.DataFrame: Processed annotations dataframe
+        """
+        logger.info("Processing annotations.")
+        # Add a new column for the ID
+        # Get the ID from the link key
+        annotations_df['Id'] = annotations_df['Link'].str.split('/').str[-1]
+
+        # Add a new column for the database
+        # Get the database from the link key
+        annotations_df['Database'] = annotations_df['Link'].str.split('/').str[-2]
+
+        # Fetch descriptions for the IDs based on the database type
+        # by qyerying the respective APIs
+        identifiers = annotations_df[['Id', 'Database']].to_dict(orient='records')
+        descriptions = self._fetch_descriptions(identifiers)
+
+        # Add a new column for the description
+        # Get the description from the descriptions dictionary
+        # based on the ID. If the description is not found, use '-'
+        annotations_df['Description'] = annotations_df['Id'].apply(lambda x:
+                                                                   descriptions.get(x, '-'))
+        # annotations_df.index = annotations_df.index + 1
+
+        # Reorder the columns
+        annotations_df = annotations_df[
+            ["Species Name", "Description", "Database", "Id", "Link", "Qualifier"]
+        ]
+
+        # Process the link to format it correctly
+        annotations_df["Link"] = annotations_df["Link"].apply(self._process_link)
+
+        # Return the processed annotations dataframe
+        return annotations_df
+
+    def _process_link(self, link: str) -> str:
+        """
+        Process link to format it correctly.
+        """
+        for ols_ontology_abbreviation in ols_ontology_abbreviations:
+            if ols_ontology_abbreviation +'/' in link:
+                link = link.replace(f"{ols_ontology_abbreviation}/", "")
+        if "kegg.compound" in link:
+            link = link.replace("kegg.compound/", "kegg.compound:")
+        return link
+
+    def _fetch_descriptions(self, data: List[dict[str, str]]) -> dict[str, str]:
+        """
+        Fetch protein names or labels based on the database type.
+        """
+        logger.info("Fetching descriptions for the IDs.")
+        results = {}
+        grouped_data = {}
+
+        # In the following loop, we create a dictionary with database as the key
+        # and a list of identifiers as the value. If either the database or the
+        # identifier is NaN, we set it to None.
+        for entry in data:
+            identifier = entry.get('Id')
+            database = entry.get('Database')
+            # Check if database is NaN
+            if isinstance(database, float):
+                if math.isnan(database):
+                    database = None
+                    results[identifier or "unknown"] = "-"
+            else:
+                database = database.lower()
+                grouped_data.setdefault(database, []).append(identifier)
+
+        # In the following loop, we fetch the descriptions for the identifiers
+        # based on the database type.
+        # Constants
+
+        for database, identifiers in grouped_data.items():
+            if database == 'uniprot':
+                results.update(search_uniprot_labels(identifiers))
+            elif database in ols_ontology_abbreviations:
+                annotations = search_ols_labels([
+                        {"Id": id_, "Database": database}
+                        for id_ in identifiers
+                    ])
+                for identifier in identifiers:
+                    results[identifier] = annotations.get(database, {}).get(identifier, "-")
+            elif database == 'kegg.compound':
+                data = [{"Id": identifier, "Database": "kegg.compound"}
+                        for identifier in identifiers]
+                annotations = fetch_kegg_annotations(data)
+                for identifier in identifiers:
+                    results[identifier] = annotations.get(database, {}).get(identifier, "-")
+            else:
+                for identifier in identifiers:
+                    results[identifier] = "-"
+        return results