a b/modules/chatbot/dataloader.py
1
import faiss
2
import numpy as np
3
import pandas as pd
4
from datasets import load_dataset
5
6
7
def convert(item: str) -> np.ndarray:
8
    """
9
    Convert a string representation of an array to a numpy array.
10
11
    Args:
12
        item (str): String representation of an array.
13
14
    Returns:
15
        np.ndarray: Numpy array converted from the string representation.
16
    """
17
    item = item.strip()
18
    item = item[1:-1]
19
    item = np.fromstring(item, sep=" ")
20
    return item
21
22
23
def get_dataset(huggingface_repo: str) -> pd.DataFrame:
24
    """
25
    Load dataset from Hugging Face repository and convert to pandas DataFrame.
26
27
    Args:
28
        huggingface_repo (str): Name of the Hugging Face repository.
29
30
    Returns:
31
        pd.DataFrame: Pandas DataFrame containing the loaded dataset.
32
    """
33
    df = load_dataset(huggingface_repo, "csv")
34
    df = pd.DataFrame(df["train"])
35
    df["Q_FFNN_embeds"] = df["Q_FFNN_embeds"].apply(convert)
36
    df["A_FFNN_embeds"] = df["A_FFNN_embeds"].apply(convert)
37
38
    return df
39
40
41
def get_bert_index(
42
    df: pd.DataFrame, target_columns: Union[str, List[str]]
43
) -> faiss.IndexFlatIP:
44
    """
45
    Build and return the FAISS index for BERT embeddings.
46
47
    Args:
48
        df (pd.DataFrame): DataFrame containing the BERT embeddings.
49
        target_columns (Union[str, List[str]]): Name or list of names of the columns containing BERT embeddings.
50
51
    Returns:
52
        faiss.IndexFlatIP: FAISS index for BERT embeddings.
53
    """
54
    embedded_bert = df[target_columns].tolist()
55
    embedded_bert = np.array(embedded_bert, dtype="float32")
56
    index = faiss.IndexFlatIP(embedded_bert.shape[-1])
57
    index.add(embedded_bert)
58
59
    return index