|
a |
|
b/modules/chatbot/dataloader.py |
|
|
1 |
import faiss |
|
|
2 |
import numpy as np |
|
|
3 |
import pandas as pd |
|
|
4 |
from datasets import load_dataset |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
def convert(item: str) -> np.ndarray: |
|
|
8 |
""" |
|
|
9 |
Convert a string representation of an array to a numpy array. |
|
|
10 |
|
|
|
11 |
Args: |
|
|
12 |
item (str): String representation of an array. |
|
|
13 |
|
|
|
14 |
Returns: |
|
|
15 |
np.ndarray: Numpy array converted from the string representation. |
|
|
16 |
""" |
|
|
17 |
item = item.strip() |
|
|
18 |
item = item[1:-1] |
|
|
19 |
item = np.fromstring(item, sep=" ") |
|
|
20 |
return item |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
def get_dataset(huggingface_repo: str) -> pd.DataFrame: |
|
|
24 |
""" |
|
|
25 |
Load dataset from Hugging Face repository and convert to pandas DataFrame. |
|
|
26 |
|
|
|
27 |
Args: |
|
|
28 |
huggingface_repo (str): Name of the Hugging Face repository. |
|
|
29 |
|
|
|
30 |
Returns: |
|
|
31 |
pd.DataFrame: Pandas DataFrame containing the loaded dataset. |
|
|
32 |
""" |
|
|
33 |
df = load_dataset(huggingface_repo, "csv") |
|
|
34 |
df = pd.DataFrame(df["train"]) |
|
|
35 |
df["Q_FFNN_embeds"] = df["Q_FFNN_embeds"].apply(convert) |
|
|
36 |
df["A_FFNN_embeds"] = df["A_FFNN_embeds"].apply(convert) |
|
|
37 |
|
|
|
38 |
return df |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
def get_bert_index( |
|
|
42 |
df: pd.DataFrame, target_columns: Union[str, List[str]] |
|
|
43 |
) -> faiss.IndexFlatIP: |
|
|
44 |
""" |
|
|
45 |
Build and return the FAISS index for BERT embeddings. |
|
|
46 |
|
|
|
47 |
Args: |
|
|
48 |
df (pd.DataFrame): DataFrame containing the BERT embeddings. |
|
|
49 |
target_columns (Union[str, List[str]]): Name or list of names of the columns containing BERT embeddings. |
|
|
50 |
|
|
|
51 |
Returns: |
|
|
52 |
faiss.IndexFlatIP: FAISS index for BERT embeddings. |
|
|
53 |
""" |
|
|
54 |
embedded_bert = df[target_columns].tolist() |
|
|
55 |
embedded_bert = np.array(embedded_bert, dtype="float32") |
|
|
56 |
index = faiss.IndexFlatIP(embedded_bert.shape[-1]) |
|
|
57 |
index.add(embedded_bert) |
|
|
58 |
|
|
|
59 |
return index |