Download this file

124 lines (103 with data), 4.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python3
"""
Enrichment class using Ollama model based on LangChain Enrichment class.
"""
import time
from typing import List
import subprocess
import ast
import ollama
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from .enrichments import Enrichments
class EnrichmentWithOllama(Enrichments):
"""
Enrichment class using Ollama model based on the Enrichment abstract class.
"""
def __init__(
self,
model_name: str,
prompt_enrichment: str,
temperature: float,
streaming: bool,
):
"""
Initialize the EnrichmentWithOllama class.
Args:
model_name: The name of the Ollama model to be used.
prompt_enrichment: The prompt enrichment template.
temperature: The temperature for the Ollama model.
streaming: The streaming flag for the Ollama model.
"""
# Setup the Ollama server
self.__setup(model_name)
# Set parameters
self.model_name = model_name
self.prompt_enrichment = prompt_enrichment
self.temperature = temperature
self.streaming = streaming
# Prepare prompt template
self.prompt_template = ChatPromptTemplate.from_messages(
[
("system", self.prompt_enrichment),
("human", "{input}"),
]
)
# Prepare model
self.model = ChatOllama(
model=self.model_name,
temperature=self.temperature,
streaming=self.streaming,
)
def __setup(self, model_name: str) -> None:
"""
Check if the Ollama model is available and run the Ollama server if needed.
Args:
model_name: The name of the Ollama model to be used.
"""
try:
models_list = ollama.list()["models"]
if model_name not in [m['model'].replace(":latest", "") for m in models_list]:
ollama.pull(model_name)
time.sleep(30)
raise ValueError(f"Pulled {model_name} model")
except Exception as e:
with subprocess.Popen(
"ollama serve", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
):
time.sleep(10)
raise ValueError(f"Error: {e} and restarted Ollama server.") from e
def enrich_documents(self, texts: List[str]) -> List[str]:
"""
Enrich a list of input texts with additional textual features using OLLAMA model.
Important: Make sure the input is a list of texts based on the defined prompt template
with 'input' as the variable name.
Args:
texts: The list of texts to be enriched.
Returns:
The list of enriched texts.
"""
# Perform enrichment
chain = self.prompt_template | self.model | StrOutputParser()
# Generate the enriched node
# Important: Make sure the input is a list of texts based on the defined prompt template
# with 'input' as the variable name
enriched_texts = chain.invoke({"input": "[" + ", ".join(texts) + "]"})
# Convert the enriched nodes to a list of dictionary
enriched_texts = ast.literal_eval(enriched_texts.replace("```", ""))
# Final check for the enriched texts
assert len(enriched_texts) == len(texts)
return enriched_texts
def enrich_documents_with_rag(self, texts, docs):
"""
Enrich a list of input texts with additional textual features using OLLAMA model with RAG.
As of now, we don't have a RAG model to test this method yet.
Thus, we will just call the enrich_documents method instead.
Args:
texts: The list of texts to be enriched.
docs: The list of reference documents to enrich the input texts.
Returns:
The list of enriched texts
"""
return self.enrich_documents(texts)