|
a |
|
b/aitrika/llm/anthropic.py |
|
|
1 |
from llama_index.llms.anthropic import Anthropic |
|
|
2 |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
|
|
3 |
from llama_index.core.node_parser import SimpleNodeParser |
|
|
4 |
from llama_index.core import ( |
|
|
5 |
VectorStoreIndex, |
|
|
6 |
Settings, |
|
|
7 |
StorageContext, |
|
|
8 |
load_index_from_storage, |
|
|
9 |
Document, |
|
|
10 |
) |
|
|
11 |
from llama_index.vector_stores.lancedb import LanceDBVectorStore |
|
|
12 |
import os |
|
|
13 |
from aitrika.llm.base_llm import BaseLLM |
|
|
14 |
from aitrika.config.config import LLMConfig |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
class AnthropicLLM(BaseLLM): |
|
|
18 |
config = LLMConfig() |
|
|
19 |
|
|
|
20 |
def __init__( |
|
|
21 |
self, |
|
|
22 |
documents: Document, |
|
|
23 |
api_key: str, |
|
|
24 |
model_name: str = "claude-3-5-sonnet-20240620", |
|
|
25 |
): |
|
|
26 |
self.documents = documents |
|
|
27 |
self.model_name = model_name |
|
|
28 |
if not api_key: |
|
|
29 |
raise ValueError("API key is required for Anthropic.") |
|
|
30 |
self.api_key = api_key |
|
|
31 |
|
|
|
32 |
def _build_index(self): |
|
|
33 |
llm = Anthropic(model=self.model_name, token=self.api_key) |
|
|
34 |
embed_model = HuggingFaceEmbedding( |
|
|
35 |
model_name=self.config.DEFAULT_EMBEDDINGS, |
|
|
36 |
cache_folder=f"aitrika/rag/embeddings/{self.config.DEFAULT_EMBEDDINGS.replace('/','_')}", |
|
|
37 |
) |
|
|
38 |
Settings.llm = llm |
|
|
39 |
Settings.embed_model = embed_model |
|
|
40 |
Settings.chunk_size = self.config.CHUNK_SIZE |
|
|
41 |
Settings.chunk_overlap = self.config.CHUNK_OVERLAP |
|
|
42 |
Settings.context_window = self.config.CONTEXT_WINDOW |
|
|
43 |
Settings.num_output = self.config.NUM_OUTPUT |
|
|
44 |
|
|
|
45 |
if os.path.exists("aitrika/rag/vectorstores/anthropic"): |
|
|
46 |
vector_store = LanceDBVectorStore(uri="aitrika/rag/vectorstores/anthropic") |
|
|
47 |
storage_context = StorageContext.from_defaults( |
|
|
48 |
vector_store=vector_store, |
|
|
49 |
persist_dir="aitrika/rag/vectorstores/anthropic", |
|
|
50 |
) |
|
|
51 |
index = load_index_from_storage(storage_context=storage_context) |
|
|
52 |
parser = SimpleNodeParser() |
|
|
53 |
new_nodes = parser.get_nodes_from_documents(self.documents) |
|
|
54 |
index.insert_nodes(new_nodes) |
|
|
55 |
index = load_index_from_storage(storage_context=storage_context) |
|
|
56 |
else: |
|
|
57 |
vector_store = LanceDBVectorStore(uri="aitrika/rag/vectorstores/anthropic") |
|
|
58 |
storage_context = StorageContext.from_defaults(vector_store=vector_store) |
|
|
59 |
index = VectorStoreIndex( |
|
|
60 |
nodes=self.documents, storage_context=storage_context |
|
|
61 |
) |
|
|
62 |
index.storage_context.persist( |
|
|
63 |
persist_dir="aitrika/rag/vectorstores/anthropic" |
|
|
64 |
) |
|
|
65 |
self.index = index |
|
|
66 |
|
|
|
67 |
def query(self, query: str): |
|
|
68 |
self._build_index() |
|
|
69 |
query_engine = self.index.as_query_engine() |
|
|
70 |
response = query_engine.query(query) |
|
|
71 |
return str(response).strip() |