[1bdb11]: / aitrika / llm / anthropic.py

Download this file

72 lines (65 with data), 2.8 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from llama_index.llms.anthropic import Anthropic
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.node_parser import SimpleNodeParser
from llama_index.core import (
VectorStoreIndex,
Settings,
StorageContext,
load_index_from_storage,
Document,
)
from llama_index.vector_stores.lancedb import LanceDBVectorStore
import os
from aitrika.llm.base_llm import BaseLLM
from aitrika.config.config import LLMConfig
class AnthropicLLM(BaseLLM):
config = LLMConfig()
def __init__(
self,
documents: Document,
api_key: str,
model_name: str = "claude-3-5-sonnet-20240620",
):
self.documents = documents
self.model_name = model_name
if not api_key:
raise ValueError("API key is required for Anthropic.")
self.api_key = api_key
def _build_index(self):
llm = Anthropic(model=self.model_name, token=self.api_key)
embed_model = HuggingFaceEmbedding(
model_name=self.config.DEFAULT_EMBEDDINGS,
cache_folder=f"aitrika/rag/embeddings/{self.config.DEFAULT_EMBEDDINGS.replace('/','_')}",
)
Settings.llm = llm
Settings.embed_model = embed_model
Settings.chunk_size = self.config.CHUNK_SIZE
Settings.chunk_overlap = self.config.CHUNK_OVERLAP
Settings.context_window = self.config.CONTEXT_WINDOW
Settings.num_output = self.config.NUM_OUTPUT
if os.path.exists("aitrika/rag/vectorstores/anthropic"):
vector_store = LanceDBVectorStore(uri="aitrika/rag/vectorstores/anthropic")
storage_context = StorageContext.from_defaults(
vector_store=vector_store,
persist_dir="aitrika/rag/vectorstores/anthropic",
)
index = load_index_from_storage(storage_context=storage_context)
parser = SimpleNodeParser()
new_nodes = parser.get_nodes_from_documents(self.documents)
index.insert_nodes(new_nodes)
index = load_index_from_storage(storage_context=storage_context)
else:
vector_store = LanceDBVectorStore(uri="aitrika/rag/vectorstores/anthropic")
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex(
nodes=self.documents, storage_context=storage_context
)
index.storage_context.persist(
persist_dir="aitrika/rag/vectorstores/anthropic"
)
self.index = index
def query(self, query: str):
self._build_index()
query_engine = self.index.as_query_engine()
response = query_engine.query(query)
return str(response).strip()