|
a |
|
b/aitrika/llm/ollama.py |
|
|
1 |
from llama_index.llms.ollama import Ollama |
|
|
2 |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
|
|
3 |
from llama_index.core.node_parser import SimpleNodeParser |
|
|
4 |
from llama_index.core import ( |
|
|
5 |
VectorStoreIndex, |
|
|
6 |
Settings, |
|
|
7 |
StorageContext, |
|
|
8 |
load_index_from_storage, |
|
|
9 |
Document, |
|
|
10 |
) |
|
|
11 |
from llama_index.vector_stores.lancedb import LanceDBVectorStore |
|
|
12 |
import os |
|
|
13 |
from aitrika.llm.base_llm import BaseLLM |
|
|
14 |
from aitrika.config.config import LLMConfig |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
class OllamaLLM(BaseLLM): |
|
|
18 |
config = LLMConfig() |
|
|
19 |
|
|
|
20 |
def __init__(self, documents: Document): |
|
|
21 |
self.documents = documents |
|
|
22 |
|
|
|
23 |
def _build_index(self): |
|
|
24 |
llm = Ollama(model=self.model_name, request_timeout=120.0) |
|
|
25 |
embed_model = HuggingFaceEmbedding( |
|
|
26 |
model_name=self.config.DEFAULT_EMBEDDINGS, |
|
|
27 |
cache_folder=f"aitrika/rag/embeddings/{self.config.DEFAULT_EMBEDDINGS.replace('/','_')}", |
|
|
28 |
) |
|
|
29 |
Settings.llm = llm |
|
|
30 |
Settings.embed_model = embed_model |
|
|
31 |
Settings.chunk_size = self.config.CHUNK_SIZE |
|
|
32 |
Settings.chunk_overlap = self.config.CHUNK_OVERLAP |
|
|
33 |
Settings.context_window = self.config.CONTEXT_WINDOW |
|
|
34 |
Settings.num_output = self.config.NUM_OUTPUT |
|
|
35 |
|
|
|
36 |
if os.path.exists("aitrika/rag/vectorstores/ollama"): |
|
|
37 |
vector_store = LanceDBVectorStore(uri="aitrika/rag/vectorstores/ollama") |
|
|
38 |
storage_context = StorageContext.from_defaults( |
|
|
39 |
vector_store=vector_store, persist_dir="aitrika/rag/vectorstores/ollama" |
|
|
40 |
) |
|
|
41 |
index = load_index_from_storage(storage_context=storage_context) |
|
|
42 |
parser = SimpleNodeParser() |
|
|
43 |
new_nodes = parser.get_nodes_from_documents(self.documents) |
|
|
44 |
index.insert_nodes(new_nodes) |
|
|
45 |
index = load_index_from_storage(storage_context=storage_context) |
|
|
46 |
else: |
|
|
47 |
vector_store = LanceDBVectorStore(uri="aitrika/rag/vectorstores/ollama") |
|
|
48 |
storage_context = StorageContext.from_defaults(vector_store=vector_store) |
|
|
49 |
index = VectorStoreIndex( |
|
|
50 |
nodes=self.documents, storage_context=storage_context |
|
|
51 |
) |
|
|
52 |
index.storage_context.persist(persist_dir="aitrika/rag/vectorstores/ollama") |
|
|
53 |
self.index = index |
|
|
54 |
|
|
|
55 |
def query(self, query: str): |
|
|
56 |
self._build_index() |
|
|
57 |
query_engine = self.index.as_query_engine() |
|
|
58 |
response = query_engine.query(query) |
|
|
59 |
return str(response).strip() |