Switch to side-by-side view

--- a
+++ b/aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py
@@ -0,0 +1,201 @@
+"""
+Class for loading StarkQAPrimeKG dataset.
+"""
+
+import os
+import shutil
+import pickle
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+import torch
+from huggingface_hub import hf_hub_download, list_repo_files
+import gdown
+from .dataset import Dataset
+
+class StarkQAPrimeKG(Dataset):
+    """
+    Class for loading StarkQAPrimeKG dataset.
+    It downloads the data from the HuggingFace repo and stores it in the local directory.
+    The data is then loaded into pandas DataFrame of QA pairs, dictionary of split indices,
+    and node information.
+    """
+
+    def __init__(self, local_dir: str = "../../../data/starkqa_primekg/"):
+        """
+        Constructor for StarkQAPrimeKG class.
+
+        Args:
+            local_dir (str): The local directory to store the dataset files.
+        """
+        self.name: str = "starkqa_primekg"
+        self.hf_repo_id: str = "snap-stanford/stark"
+        self.local_dir: str = local_dir
+        # Attributes to store the data
+        self.starkqa: pd.DataFrame = None
+        self.starkqa_split_idx: dict = None
+        self.starkqa_node_info: dict = None
+        self.query_emb_dict: dict = None
+        self.node_emb_dict: dict = None
+
+        # Set up the dataset
+        self.setup()
+
+    def setup(self):
+        """
+        A method to set up the dataset.
+        """
+        # Make the directory if it doesn't exist
+        os.makedirs(os.path.dirname(self.local_dir), exist_ok=True)
+
+    def _load_stark_repo(self) -> tuple[pd.DataFrame, dict, dict]:
+        """
+        Private method to load related files of StarkQAPrimeKG dataset.
+
+        Returns:
+            The nodes dataframe of StarkQAPrimeKG dataset.
+            The split indices of StarkQAPrimeKG dataset.
+            The node information of StarkQAPrimeKG dataset.
+        """
+        # Download the file if it does not exist in the local directory
+        # Otherwise, load the data from the local directory
+        local_file = os.path.join(self.local_dir, "qa/prime/stark_qa/stark_qa.csv")
+        if os.path.exists(local_file):
+            print(f"{local_file} already exists. Loading the data from the local directory.")
+        else:
+            print(f"Downloading files from {self.hf_repo_id}")
+
+            # List all related files in the HuggingFace Hub repository
+            files = list_repo_files(self.hf_repo_id, repo_type="dataset")
+            files = [f for f in files if ((f.startswith("qa/prime/") or
+                                           f.startswith("skb/prime/")) and f.find("raw") == -1)]
+
+            # Download and save each file in the specified folder
+            for file in tqdm(files):
+                _ = hf_hub_download(self.hf_repo_id,
+                                    file,
+                                    repo_type="dataset",
+                                    local_dir=self.local_dir)
+
+            # Unzip the processed files
+            shutil.unpack_archive(
+                os.path.join(self.local_dir, "skb/prime/processed.zip"),
+                os.path.join(self.local_dir, "skb/prime/")
+            )
+
+        # Load StarkQA dataframe
+        starkqa = pd.read_csv(
+            os.path.join(self.local_dir, "qa/prime/stark_qa/stark_qa.csv"),
+            low_memory=False)
+
+        # Read split indices
+        qa_indices = sorted(starkqa['id'].tolist())
+        starkqa_split_idx = {}
+        for split in ['train', 'val', 'test', 'test-0.1']:
+            indices_file = os.path.join(self.local_dir, "qa/prime/split", f'{split}.index')
+            with open(indices_file, 'r', encoding='utf-8') as f:
+                indices = f.read().strip().split('\n')
+            query_ids = [int(idx) for idx in indices]
+            starkqa_split_idx[split] = np.array(
+                [qa_indices.index(query_id) for query_id in query_ids]
+            )
+
+        # Load the node info of PrimeKG preprocessed for StarkQA
+        with open(os.path.join(self.local_dir, 'skb/prime/processed/node_info.pkl'), 'rb') as f:
+            starkqa_node_info = pickle.load(f)
+
+        return starkqa, starkqa_split_idx, starkqa_node_info
+
+    def _load_stark_embeddings(self) -> tuple[dict, dict]:
+        """
+        Private method to load the embeddings of StarkQAPrimeKG dataset.
+
+        Returns:
+            The query embeddings of StarkQAPrimeKG dataset.
+            The node embeddings of StarkQAPrimeKG dataset.
+        """
+        # Load the provided embeddings of query and nodes
+        # Note that they utilized 'text-embedding-ada-002' for embeddings
+        emb_model = 'text-embedding-ada-002'
+        query_emb_url = 'https://drive.google.com/uc?id=1MshwJttPZsHEM2cKA5T13SIrsLeBEdyU'
+        node_emb_url = 'https://drive.google.com/uc?id=16EJvCMbgkVrQ0BuIBvLBp-BYPaye-Edy'
+
+        # Prepare respective directories to store the embeddings
+        emb_dir = os.path.join(self.local_dir, emb_model)
+        query_emb_dir = os.path.join(emb_dir, "query")
+        node_emb_dir = os.path.join(emb_dir, "doc")
+        os.makedirs(query_emb_dir, exist_ok=True)
+        os.makedirs(node_emb_dir, exist_ok=True)
+        query_emb_path = os.path.join(query_emb_dir, "query_emb_dict.pt")
+        node_emb_path = os.path.join(node_emb_dir, "candidate_emb_dict.pt")
+
+        # Download the embeddings if they do not exist in the local directory
+        if not os.path.exists(query_emb_path) or not os.path.exists(node_emb_path):
+            # Download the query embeddings
+            gdown.download(query_emb_url, query_emb_path, quiet=False)
+
+            # Download the node embeddings
+            gdown.download(node_emb_url, node_emb_path, quiet=False)
+
+        # Load the embeddings
+        query_emb_dict = torch.load(query_emb_path)
+        node_emb_dict = torch.load(node_emb_path)
+
+        return query_emb_dict, node_emb_dict
+
+    def load_data(self):
+        """
+        Load the StarkQAPrimeKG dataset into pandas DataFrame of QA pairs,
+        dictionary of split indices, and node information.
+        """
+        print("Loading StarkQAPrimeKG dataset...")
+        self.starkqa, self.starkqa_split_idx, self.starkqa_node_info = self._load_stark_repo()
+
+        print("Loading StarkQAPrimeKG embeddings...")
+        self.query_emb_dict, self.node_emb_dict = self._load_stark_embeddings()
+
+
+    def get_starkqa(self) -> pd.DataFrame:
+        """
+        Get the dataframe of StarkQAPrimeKG dataset, containing the QA pairs.
+
+        Returns:
+            The nodes dataframe of PrimeKG dataset.
+        """
+        return self.starkqa
+
+    def get_starkqa_split_indicies(self) -> dict:
+        """
+        Get the split indices of StarkQAPrimeKG dataset.
+
+        Returns:
+            The split indices of StarkQAPrimeKG dataset.
+        """
+        return self.starkqa_split_idx
+
+    def get_starkqa_node_info(self) -> dict:
+        """
+        Get the node information of StarkQAPrimeKG dataset.
+
+        Returns:
+            The node information of StarkQAPrimeKG dataset.
+        """
+        return self.starkqa_node_info
+
+    def get_query_embeddings(self) -> dict:
+        """
+        Get the query embeddings of StarkQAPrimeKG dataset.
+
+        Returns:
+            The query embeddings of StarkQAPrimeKG dataset.
+        """
+        return self.query_emb_dict
+
+    def get_node_embeddings(self) -> dict:
+        """
+        Get the node embeddings of StarkQAPrimeKG dataset.
+
+        Returns:
+            The node embeddings of StarkQAPrimeKG dataset.
+        """
+        return self.node_emb_dict