Switch to unified view

a b/aitrika/llm/huggingface.py
1
from llama_index.llms.huggingface import HuggingFaceInferenceAPI
2
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
3
from llama_index.core.node_parser import SimpleNodeParser
4
from llama_index.core import (
5
    VectorStoreIndex,
6
    Settings,
7
    StorageContext,
8
    load_index_from_storage,
9
    Document,
10
)
11
from llama_index.vector_stores.lancedb import LanceDBVectorStore
12
import os
13
from aitrika.llm.base_llm import BaseLLM
14
from aitrika.config.config import LLMConfig
15
16
17
class HuggingFaceLLM(BaseLLM):
18
    config = LLMConfig()
19
20
    def __init__(
21
        self,
22
        documents: Document,
23
        api_key: str,
24
        model_endpoint: str = "microsoft/Phi-3-mini-4k-instruct",
25
    ):
26
        self.documents = documents
27
        self.model_endpoint = model_endpoint
28
        if not api_key:
29
            raise ValueError("API key is required for HuggingFace.")
30
        self.api_key = api_key
31
32
    def _build_index(self):
33
        llm = HuggingFaceInferenceAPI(
34
            model_name=self.model_endpoint, token=self.api_key
35
        )
36
        embed_model = HuggingFaceEmbedding(
37
            model_name=self.config.DEFAULT_EMBEDDINGS,
38
            cache_folder=f"aitrika/rag/embeddings/{self.config.DEFAULT_EMBEDDINGS.replace('/','_')}",
39
        )
40
        Settings.llm = llm
41
        Settings.embed_model = embed_model
42
        Settings.chunk_size = self.config.CHUNK_SIZE
43
        Settings.chunk_overlap = self.config.CHUNK_OVERLAP
44
        Settings.context_window = self.config.CONTEXT_WINDOW
45
        Settings.num_output = self.config.NUM_OUTPUT
46
47
        if os.path.exists("aitrika/rag/vectorstores/huggingface"):
48
            vector_store = LanceDBVectorStore(
49
                uri="aitrika/rag/vectorstores/huggingface"
50
            )
51
            storage_context = StorageContext.from_defaults(
52
                vector_store=vector_store,
53
                persist_dir="aitrika/rag/vectorstores/huggingface",
54
            )
55
            index = load_index_from_storage(storage_context=storage_context)
56
            parser = SimpleNodeParser()
57
            new_nodes = parser.get_nodes_from_documents(self.documents)
58
            index.insert_nodes(new_nodes)
59
            index = load_index_from_storage(storage_context=storage_context)
60
        else:
61
            vector_store = LanceDBVectorStore(
62
                uri="aitrika/rag/vectorstores/huggingface"
63
            )
64
            storage_context = StorageContext.from_defaults(vector_store=vector_store)
65
            index = VectorStoreIndex(
66
                nodes=self.documents, storage_context=storage_context
67
            )
68
            index.storage_context.persist(
69
                persist_dir="aitrika/rag/vectorstores/huggingface"
70
            )
71
        self.index = index
72
73
    def query(self, query: str):
74
        self._build_index()
75
        query_engine = self.index.as_query_engine()
76
        response = query_engine.query(query)
77
        return str(response).strip()