a b/aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py
1
"""
2
Class for loading StarkQAPrimeKG dataset.
3
"""
4
5
import os
6
import shutil
7
import pickle
8
import numpy as np
9
import pandas as pd
10
from tqdm import tqdm
11
import torch
12
from huggingface_hub import hf_hub_download, list_repo_files
13
import gdown
14
from .dataset import Dataset
15
16
class StarkQAPrimeKG(Dataset):
17
    """
18
    Class for loading StarkQAPrimeKG dataset.
19
    It downloads the data from the HuggingFace repo and stores it in the local directory.
20
    The data is then loaded into pandas DataFrame of QA pairs, dictionary of split indices,
21
    and node information.
22
    """
23
24
    def __init__(self, local_dir: str = "../../../data/starkqa_primekg/"):
25
        """
26
        Constructor for StarkQAPrimeKG class.
27
28
        Args:
29
            local_dir (str): The local directory to store the dataset files.
30
        """
31
        self.name: str = "starkqa_primekg"
32
        self.hf_repo_id: str = "snap-stanford/stark"
33
        self.local_dir: str = local_dir
34
        # Attributes to store the data
35
        self.starkqa: pd.DataFrame = None
36
        self.starkqa_split_idx: dict = None
37
        self.starkqa_node_info: dict = None
38
        self.query_emb_dict: dict = None
39
        self.node_emb_dict: dict = None
40
41
        # Set up the dataset
42
        self.setup()
43
44
    def setup(self):
45
        """
46
        A method to set up the dataset.
47
        """
48
        # Make the directory if it doesn't exist
49
        os.makedirs(os.path.dirname(self.local_dir), exist_ok=True)
50
51
    def _load_stark_repo(self) -> tuple[pd.DataFrame, dict, dict]:
52
        """
53
        Private method to load related files of StarkQAPrimeKG dataset.
54
55
        Returns:
56
            The nodes dataframe of StarkQAPrimeKG dataset.
57
            The split indices of StarkQAPrimeKG dataset.
58
            The node information of StarkQAPrimeKG dataset.
59
        """
60
        # Download the file if it does not exist in the local directory
61
        # Otherwise, load the data from the local directory
62
        local_file = os.path.join(self.local_dir, "qa/prime/stark_qa/stark_qa.csv")
63
        if os.path.exists(local_file):
64
            print(f"{local_file} already exists. Loading the data from the local directory.")
65
        else:
66
            print(f"Downloading files from {self.hf_repo_id}")
67
68
            # List all related files in the HuggingFace Hub repository
69
            files = list_repo_files(self.hf_repo_id, repo_type="dataset")
70
            files = [f for f in files if ((f.startswith("qa/prime/") or
71
                                           f.startswith("skb/prime/")) and f.find("raw") == -1)]
72
73
            # Download and save each file in the specified folder
74
            for file in tqdm(files):
75
                _ = hf_hub_download(self.hf_repo_id,
76
                                    file,
77
                                    repo_type="dataset",
78
                                    local_dir=self.local_dir)
79
80
            # Unzip the processed files
81
            shutil.unpack_archive(
82
                os.path.join(self.local_dir, "skb/prime/processed.zip"),
83
                os.path.join(self.local_dir, "skb/prime/")
84
            )
85
86
        # Load StarkQA dataframe
87
        starkqa = pd.read_csv(
88
            os.path.join(self.local_dir, "qa/prime/stark_qa/stark_qa.csv"),
89
            low_memory=False)
90
91
        # Read split indices
92
        qa_indices = sorted(starkqa['id'].tolist())
93
        starkqa_split_idx = {}
94
        for split in ['train', 'val', 'test', 'test-0.1']:
95
            indices_file = os.path.join(self.local_dir, "qa/prime/split", f'{split}.index')
96
            with open(indices_file, 'r', encoding='utf-8') as f:
97
                indices = f.read().strip().split('\n')
98
            query_ids = [int(idx) for idx in indices]
99
            starkqa_split_idx[split] = np.array(
100
                [qa_indices.index(query_id) for query_id in query_ids]
101
            )
102
103
        # Load the node info of PrimeKG preprocessed for StarkQA
104
        with open(os.path.join(self.local_dir, 'skb/prime/processed/node_info.pkl'), 'rb') as f:
105
            starkqa_node_info = pickle.load(f)
106
107
        return starkqa, starkqa_split_idx, starkqa_node_info
108
109
    def _load_stark_embeddings(self) -> tuple[dict, dict]:
110
        """
111
        Private method to load the embeddings of StarkQAPrimeKG dataset.
112
113
        Returns:
114
            The query embeddings of StarkQAPrimeKG dataset.
115
            The node embeddings of StarkQAPrimeKG dataset.
116
        """
117
        # Load the provided embeddings of query and nodes
118
        # Note that they utilized 'text-embedding-ada-002' for embeddings
119
        emb_model = 'text-embedding-ada-002'
120
        query_emb_url = 'https://drive.google.com/uc?id=1MshwJttPZsHEM2cKA5T13SIrsLeBEdyU'
121
        node_emb_url = 'https://drive.google.com/uc?id=16EJvCMbgkVrQ0BuIBvLBp-BYPaye-Edy'
122
123
        # Prepare respective directories to store the embeddings
124
        emb_dir = os.path.join(self.local_dir, emb_model)
125
        query_emb_dir = os.path.join(emb_dir, "query")
126
        node_emb_dir = os.path.join(emb_dir, "doc")
127
        os.makedirs(query_emb_dir, exist_ok=True)
128
        os.makedirs(node_emb_dir, exist_ok=True)
129
        query_emb_path = os.path.join(query_emb_dir, "query_emb_dict.pt")
130
        node_emb_path = os.path.join(node_emb_dir, "candidate_emb_dict.pt")
131
132
        # Download the embeddings if they do not exist in the local directory
133
        if not os.path.exists(query_emb_path) or not os.path.exists(node_emb_path):
134
            # Download the query embeddings
135
            gdown.download(query_emb_url, query_emb_path, quiet=False)
136
137
            # Download the node embeddings
138
            gdown.download(node_emb_url, node_emb_path, quiet=False)
139
140
        # Load the embeddings
141
        query_emb_dict = torch.load(query_emb_path)
142
        node_emb_dict = torch.load(node_emb_path)
143
144
        return query_emb_dict, node_emb_dict
145
146
    def load_data(self):
147
        """
148
        Load the StarkQAPrimeKG dataset into pandas DataFrame of QA pairs,
149
        dictionary of split indices, and node information.
150
        """
151
        print("Loading StarkQAPrimeKG dataset...")
152
        self.starkqa, self.starkqa_split_idx, self.starkqa_node_info = self._load_stark_repo()
153
154
        print("Loading StarkQAPrimeKG embeddings...")
155
        self.query_emb_dict, self.node_emb_dict = self._load_stark_embeddings()
156
157
158
    def get_starkqa(self) -> pd.DataFrame:
159
        """
160
        Get the dataframe of StarkQAPrimeKG dataset, containing the QA pairs.
161
162
        Returns:
163
            The nodes dataframe of PrimeKG dataset.
164
        """
165
        return self.starkqa
166
167
    def get_starkqa_split_indicies(self) -> dict:
168
        """
169
        Get the split indices of StarkQAPrimeKG dataset.
170
171
        Returns:
172
            The split indices of StarkQAPrimeKG dataset.
173
        """
174
        return self.starkqa_split_idx
175
176
    def get_starkqa_node_info(self) -> dict:
177
        """
178
        Get the node information of StarkQAPrimeKG dataset.
179
180
        Returns:
181
            The node information of StarkQAPrimeKG dataset.
182
        """
183
        return self.starkqa_node_info
184
185
    def get_query_embeddings(self) -> dict:
186
        """
187
        Get the query embeddings of StarkQAPrimeKG dataset.
188
189
        Returns:
190
            The query embeddings of StarkQAPrimeKG dataset.
191
        """
192
        return self.query_emb_dict
193
194
    def get_node_embeddings(self) -> dict:
195
        """
196
        Get the node embeddings of StarkQAPrimeKG dataset.
197
198
        Returns:
199
            The node embeddings of StarkQAPrimeKG dataset.
200
        """
201
        return self.node_emb_dict