--- a +++ b/aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py @@ -0,0 +1,81 @@ +""" +Embedding class using Ollama model based on LangChain Embeddings class. +""" + +import time +from typing import List +import subprocess +import ollama +from langchain_ollama import OllamaEmbeddings +from .embeddings import Embeddings + +class EmbeddingWithOllama(Embeddings): + """ + Embedding class using Ollama model based on LangChain Embeddings class. + """ + def __init__(self, model_name: str): + """ + Initialize the EmbeddingWithOllama class. + + Args: + model_name: The name of the Ollama model to be used. + """ + # Setup the Ollama server + self.__setup(model_name) + + # Set parameters + self.model_name = model_name + + # Prepare model + self.model = OllamaEmbeddings(model=self.model_name) + + def __setup(self, model_name: str) -> None: + """ + Check if the Ollama model is available and run the Ollama server if needed. + + Args: + model_name: The name of the Ollama model to be used. + """ + try: + models_list = ollama.list()["models"] + if model_name not in [m['model'].replace(":latest", "") for m in models_list]: + ollama.pull(model_name) + time.sleep(30) + raise ValueError(f"Pulled {model_name} model") + except Exception as e: + with subprocess.Popen( + "ollama serve", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ): + time.sleep(10) + raise ValueError(f"Error: {e} and restarted Ollama server.") from e + + def embed_documents(self, texts: List[str]) -> List[float]: + """ + Generate embedding for a list of input texts using Ollama model. + + Args: + texts: The list of texts to be embedded. + + Returns: + The list of embeddings for the given texts. + """ + + # Generate the embedding + embeddings = self.model.embed_documents(texts) + + return embeddings + + def embed_query(self, text: str) -> List[float]: + """ + Generate embeddings for an input text using Ollama model. + + Args: + text: A query to be embedded. + Returns: + The embeddings for the given query. + """ + + # Generate the embedding + embeddings = self.model.embed_query(text) + + return embeddings