Diff of /aitrika/llm/neutrino.py [000000] .. [1bdb11]

Switch to unified view

a b/aitrika/llm/neutrino.py
1
from llama_index.llms.neutrino import Neutrino
2
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
3
from llama_index.core.node_parser import SimpleNodeParser
4
from llama_index.core import (
5
    VectorStoreIndex,
6
    Settings,
7
    StorageContext,
8
    load_index_from_storage,
9
    Document,
10
)
11
from llama_index.vector_stores.lancedb import LanceDBVectorStore
12
import os
13
from aitrika.llm.base_llm import BaseLLM
14
from aitrika.config.config import LLMConfig
15
16
17
class NeutrinoLLM(BaseLLM):
18
    config = LLMConfig()
19
20
    def __init__(self, documents: Document, api_key: str):
21
        self.documents = documents
22
        if not api_key:
23
            raise ValueError("API key is required for Neutrino.")
24
        self.api_key = api_key
25
26
    def _build_index(self):
27
        llm = Neutrino(token=self.api_key)
28
        embed_model = HuggingFaceEmbedding(
29
            model_name=self.config.DEFAULT_EMBEDDINGS,
30
            cache_folder=f"aitrika/rag/embeddings/{self.config.DEFAULT_EMBEDDINGS.replace('/','_')}",
31
        )
32
        Settings.llm = llm
33
        Settings.embed_model = embed_model
34
        Settings.chunk_size = self.config.CHUNK_SIZE
35
        Settings.chunk_overlap = self.config.CHUNK_OVERLAP
36
        Settings.context_window = self.config.CONTEXT_WINDOW
37
        Settings.num_output = self.config.NUM_OUTPUT
38
39
        if os.path.exists("aitrika/rag/vectorstores/neutrino"):
40
            vector_store = LanceDBVectorStore(uri="aitrika/rag/vectorstores/neutrino")
41
            storage_context = StorageContext.from_defaults(
42
                vector_store=vector_store,
43
                persist_dir="aitrika/rag/vectorstores/neutrino",
44
            )
45
            index = load_index_from_storage(storage_context=storage_context)
46
            parser = SimpleNodeParser()
47
            new_nodes = parser.get_nodes_from_documents(self.documents)
48
            index.insert_nodes(new_nodes)
49
            index = load_index_from_storage(storage_context=storage_context)
50
        else:
51
            vector_store = LanceDBVectorStore(uri="aitrika/rag/vectorstores/neutrino")
52
            storage_context = StorageContext.from_defaults(vector_store=vector_store)
53
            index = VectorStoreIndex(
54
                nodes=self.documents, storage_context=storage_context
55
            )
56
            index.storage_context.persist(
57
                persist_dir="aitrika/rag/vectorstores/neutrino"
58
            )
59
        self.index = index
60
61
    def query(self, query: str):
62
        self._build_index()
63
        query_engine = self.index.as_query_engine()
64
        response = query_engine.query(query)
65
        return str(response).strip()