Diff of /src/build_RAG_private.py [000000] .. [014e6e]

Switch to side-by-side view

--- a
+++ b/src/build_RAG_private.py
@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+"""
+@Time : 2024/2/29 11:19
+@Auth : Juexiao Zhou
+@File :build_RAG.py
+@IDE :PyCharm
+@Page: www.joshuachou.ink
+"""
+
+import os.path
+import os
+from llama_index.core import (
+    VectorStoreIndex,
+    SimpleDirectoryReader,
+    StorageContext,
+    load_index_from_storage,
+    Settings
+)
+from llama_index.embeddings.openai import OpenAIEmbedding
+from llama_index.embeddings.huggingface import HuggingFaceEmbedding
+
+def preload_retriever(local_engine = True, openai = None, PERSIST_DIR = "../softwares_database_RAG", SOURCE_DIR = "../softwares_database"):
+    if not local_engine:
+        os.environ['OPENAI_API_KEY'] = openai
+        Settings.embed_model = OpenAIEmbedding()
+    else:
+        Settings.embed_model = HuggingFaceEmbedding(
+            model_name="BAAI/bge-small-en-v1.5"
+        )
+
+    if not os.path.exists(PERSIST_DIR):
+        # load the documents and create the index
+        documents = SimpleDirectoryReader(SOURCE_DIR).load_data()
+        index = VectorStoreIndex.from_documents(documents, embeddings=Settings.embed_model)
+        # store it for later
+        index.storage_context.persist(persist_dir=PERSIST_DIR)
+    else:
+        # load the existing index
+        storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
+        index = load_index_from_storage(storage_context)
+
+    # Either way we can now query the index
+    retriever = index.as_retriever(similarity_top_k=1)
+    return retriever
+
+def retrive(retriever,
+            retriever_prompt="To perform RNA-seq alignment, what tools and parameters should I use?"
+):
+    response = retriever.retrieve(retriever_prompt)
+    response = response[0].get_text()
+    return response
+
+if __name__ == '__main__':
+    local_engine = True
+    openai = 'sk-xxx'
+    retriever = preload_retriever(local_engine, openai)
+    response = retrive(retriever)
+    print(response)
\ No newline at end of file