--- a +++ b/aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ollama.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 + +""" +Enrichment class using Ollama model based on LangChain Enrichment class. +""" + +import time +from typing import List +import subprocess +import ast +import ollama +from langchain_ollama import ChatOllama +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.output_parsers import StrOutputParser +from .enrichments import Enrichments + +class EnrichmentWithOllama(Enrichments): + """ + Enrichment class using Ollama model based on the Enrichment abstract class. + """ + def __init__( + self, + model_name: str, + prompt_enrichment: str, + temperature: float, + streaming: bool, + ): + """ + Initialize the EnrichmentWithOllama class. + + Args: + model_name: The name of the Ollama model to be used. + prompt_enrichment: The prompt enrichment template. + temperature: The temperature for the Ollama model. + streaming: The streaming flag for the Ollama model. + """ + # Setup the Ollama server + self.__setup(model_name) + + # Set parameters + self.model_name = model_name + self.prompt_enrichment = prompt_enrichment + self.temperature = temperature + self.streaming = streaming + + # Prepare prompt template + self.prompt_template = ChatPromptTemplate.from_messages( + [ + ("system", self.prompt_enrichment), + ("human", "{input}"), + ] + ) + + # Prepare model + self.model = ChatOllama( + model=self.model_name, + temperature=self.temperature, + streaming=self.streaming, + ) + + def __setup(self, model_name: str) -> None: + """ + Check if the Ollama model is available and run the Ollama server if needed. + + Args: + model_name: The name of the Ollama model to be used. + """ + try: + models_list = ollama.list()["models"] + if model_name not in [m['model'].replace(":latest", "") for m in models_list]: + ollama.pull(model_name) + time.sleep(30) + raise ValueError(f"Pulled {model_name} model") + except Exception as e: + with subprocess.Popen( + "ollama serve", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ): + time.sleep(10) + raise ValueError(f"Error: {e} and restarted Ollama server.") from e + + def enrich_documents(self, texts: List[str]) -> List[str]: + """ + Enrich a list of input texts with additional textual features using OLLAMA model. + Important: Make sure the input is a list of texts based on the defined prompt template + with 'input' as the variable name. + + Args: + texts: The list of texts to be enriched. + + Returns: + The list of enriched texts. + """ + + # Perform enrichment + chain = self.prompt_template | self.model | StrOutputParser() + + # Generate the enriched node + # Important: Make sure the input is a list of texts based on the defined prompt template + # with 'input' as the variable name + enriched_texts = chain.invoke({"input": "[" + ", ".join(texts) + "]"}) + + # Convert the enriched nodes to a list of dictionary + enriched_texts = ast.literal_eval(enriched_texts.replace("```", "")) + + # Final check for the enriched texts + assert len(enriched_texts) == len(texts) + + return enriched_texts + + def enrich_documents_with_rag(self, texts, docs): + """ + Enrich a list of input texts with additional textual features using OLLAMA model with RAG. + As of now, we don't have a RAG model to test this method yet. + Thus, we will just call the enrich_documents method instead. + + Args: + texts: The list of texts to be enriched. + docs: The list of reference documents to enrich the input texts. + + Returns: + The list of enriched texts + """ + return self.enrich_documents(texts)