Switch to side-by-side view

--- a
+++ b/aiagents4pharma/talk2knowledgegraphs/datasets/biobridge_primekg.py
@@ -0,0 +1,562 @@
+"""
+Class for loading BioBridgePrimeKG dataset.
+"""
+
+import os
+import pickle
+import json
+import requests
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+from .dataset import Dataset
+from .primekg import PrimeKG
+
+class BioBridgePrimeKG(Dataset):
+    """
+    Class for loading BioBridgePrimeKG dataset.
+    It downloads the data from the BioBridge repo and stores it in the local directory.
+    The data is then loaded into pandas DataFrame of nodes and edges.
+    This class was adapted from the BioBridge repo:
+    https://github.com/RyanWangZf/BioBridge
+    """
+
+    def __init__(self,
+                 primekg_dir: str = "../../../data/primekg/",
+                 local_dir: str = "../../../data/biobridge_primekg/",
+                 random_seed: int=0,
+                 n_neg_samples: int=5):
+        """
+        Constructor for BioBridgePrimeKG class.
+
+        Args:
+            primekg_dir (str): The directory of PrimeKG dataset.
+            local_dir (str): The directory to store the downloaded data.
+            random_seed (int): The random seed value.
+        """
+        self.name: str = "biobridge_primekg"
+        self.primekg_dir: str = primekg_dir
+        self.local_dir: str = local_dir
+        self.random_seed = random_seed
+        self.n_neg_samples = n_neg_samples
+        # Preselected node types:
+        # protein, molecular function, cellular component, biological process, drug, disease
+        self.preselected_node_types = ["protein", "mf", "cc", "bp", "drug", "disease"]
+        self.node_type_map = {
+            "protein": "gene/protein",
+            "mf": "molecular_function",
+            "cc": "cellular_component",
+            "bp": "biological_process",
+            "drug": "drug",
+            "disease": "disease",
+        }
+
+        # Attributes to store the data
+        self.primekg = None
+        self.primekg_triplets = None
+        self.primekg_triplets_negative = None
+        self.data_config = None
+        self.emb_dict = None
+        self.df_train = None
+        self.df_node_train = None
+        self.df_test = None
+        self.df_node_test = None
+        self.node_info_dict = None
+
+        # Set up the dataset
+        self.setup()
+
+    def setup(self):
+        """
+        A method to set up the dataset.
+        """
+        # Make the directories if it doesn't exist
+        os.makedirs(os.path.dirname(self.primekg_dir), exist_ok=True)
+        os.makedirs(os.path.dirname(self.local_dir), exist_ok=True)
+
+        # Set the random seed
+        self.set_random_seed(self.random_seed)
+
+        # Set SettingWithCopyWarning  warnings to none
+        pd.options.mode.chained_assignment = None
+
+    def _load_primekg(self) -> PrimeKG:
+        """
+        Private method to load related files of PrimeKG dataset.
+
+        Returns:
+            The PrimeKG dataset.
+        """
+        primekg_data = PrimeKG(local_dir=self.primekg_dir)
+        primekg_data.load_data()
+
+        return primekg_data
+
+    def _download_file(self,
+                       remote_url:str,
+                       local_dir: str,
+                       local_filename: str):
+        """
+        A helper function to download a file from remote URL to the local directory.
+
+        Args:
+            remote_url (str): The remote URL of the file to be downloaded.
+            local_dir (str): The local directory to store the downloaded file.
+            local_filename (str): The local filename to store the downloaded file.
+        """
+        # Make the local directory if it does not exist
+        if not os.path.exists(local_dir):
+            os.makedirs(local_dir)
+        # Download the file from remote URL to local directory
+        local_path = os.path.join(local_dir, local_filename)
+        if os.path.exists(local_path):
+            print(f"File {local_filename} already exists in {local_dir}.")
+        else:
+            print(f"Downloading {local_filename} from {remote_url} to {local_dir}...")
+            response = requests.get(remote_url, stream=True, timeout=300)
+            response.raise_for_status()
+            progress_bar = tqdm(
+                total=int(response.headers.get("content-length", 0)),
+                unit="iB",
+                unit_scale=True,
+            )
+            with open(os.path.join(local_dir, local_filename), "wb") as file:
+                for data in response.iter_content(1024):
+                    progress_bar.update(len(data))
+                    file.write(data)
+            progress_bar.close()
+
+    def _load_data_config(self) -> dict:
+        """
+        Load the data config file of BioBridgePrimeKG dataset.
+
+        Returns:
+            The data config file of BioBridgePrimeKG dataset.
+        """
+        # Download the data config file of BioBridgePrimeKG
+        self._download_file(
+            remote_url= ('https://raw.githubusercontent.com/RyanWangZf/BioBridge/'
+                         'refs/heads/main/data/BindData/data_config.json'),
+            local_dir=self.local_dir,
+            local_filename='data_config.json')
+
+        # Load the downloaded data config file
+        with open(os.path.join(self.local_dir, 'data_config.json'), 'r', encoding='utf-8') as f:
+            data_config = json.load(f)
+
+        return data_config
+
+    def _build_node_embeddings(self) -> dict:
+        """
+        Build the node embeddings for BioBridgePrimeKG dataset.
+
+        Returns:
+            The dictionary of node embeddings.
+        """
+        processed_file_path = os.path.join(self.local_dir, "embeddings", "embedding_dict.pkl")
+        if os.path.exists(processed_file_path):
+            # Load the embeddings from the local directory
+            with open(processed_file_path, "rb") as f:
+                emb_dict_all = pickle.load(f)
+        else:
+            # Download the embeddings from the BioBridge repo and further process them
+            # List of embedding source files
+            url = ('https://media.githubusercontent.com/media/RyanWangZf/BioBridge/'
+                   'refs/heads/main/data/embeddings/esm2b_unimo_pubmedbert/')
+            file_list = [f"{n}.pkl" for n in self.preselected_node_types]
+
+            # Download the embeddings
+            for file in file_list:
+                self._download_file(remote_url=os.path.join(url, file),
+                                    local_dir=os.path.join(self.local_dir, "embeddings"),
+                                    local_filename=file)
+
+            # Unified embeddings
+            emb_dict_all = {}
+            for file in file_list:
+                with open(os.path.join(self.local_dir, "embeddings", file), "rb") as f:
+                    emb = pickle.load(f)
+                emb_ar = emb["embedding"]
+                if not isinstance(emb_ar, list):
+                    emb_ar = emb_ar.tolist()
+                emb_dict_all.update(dict(zip(emb["node_index"], emb_ar)))
+
+            # Store embeddings
+            with open(processed_file_path, "wb") as f:
+                pickle.dump(emb_dict_all, f)
+
+        return emb_dict_all
+
+    def _build_full_triplets(self) -> tuple[pd.DataFrame, dict]:
+        """
+        Build the full triplets for BioBridgePrimeKG dataset.
+
+        Returns:
+            The full triplets for BioBridgePrimeKG dataset.
+            The dictionary of node information.
+        """
+        processed_file_path = os.path.join(self.local_dir, "processed", "triplet_full.tsv.gz")
+        if os.path.exists(processed_file_path):
+            # Load the file from the local directory
+            with open(processed_file_path, "rb") as f:
+                primekg_triplets = pd.read_csv(f, sep="\t", compression="gzip", low_memory=False)
+
+            # Load each dataframe in the local directory
+            node_info_dict = {}
+            for i, node_type in enumerate(self.preselected_node_types):
+                with open(os.path.join(self.local_dir, "processed",
+                                       f"{node_type}.csv"), "rb") as f:
+                    df_node = pd.read_csv(f)
+                node_info_dict[self.node_type_map[node_type]] = df_node
+        else:
+            # Download the related files from the BioBridge repo and further process them
+            # List of processed files
+            url = ('https://media.githubusercontent.com/media/RyanWangZf/BioBridge/'
+                   'refs/heads/main/data/Processed/')
+            file_list = ["protein", "molecular", "cellular", "biological", "drug", "disease"]
+
+            # Download the processed files
+            for i, file in enumerate(file_list):
+                self._download_file(remote_url=os.path.join(url, f"{file}.csv"),
+                                    local_dir=os.path.join(self.local_dir, "processed"),
+                                    local_filename=f"{self.preselected_node_types[i]}.csv")
+
+            # Build the node index list
+            node_info_dict = {}
+            node_index_list = []
+            for i, file in enumerate(file_list):
+                df_node = pd.read_csv(os.path.join(self.local_dir, "processed",
+                                                   f"{self.preselected_node_types[i]}.csv"))
+                node_info_dict[self.node_type_map[self.preselected_node_types[i]]] = df_node
+                node_index_list.extend(df_node["node_index"].tolist())
+
+            # Filter the PrimeKG dataset to take into account only the selected node types
+            primekg_triplets = self.primekg.get_edges().copy()
+            primekg_triplets = primekg_triplets[
+                primekg_triplets["head_index"].isin(node_index_list) &\
+                primekg_triplets["tail_index"].isin(node_index_list)
+            ]
+            primekg_triplets = primekg_triplets.reset_index(drop=True)
+
+            # Perform mapping of node types
+            primekg_triplets["head_type"] = primekg_triplets["head_type"].apply(
+                lambda x: self.data_config["node_type"][x]
+            )
+            primekg_triplets["tail_type"] = primekg_triplets["tail_type"].apply(
+                lambda x: self.data_config["node_type"][x]
+            )
+
+            # Perform mapping of relation types
+            primekg_triplets["display_relation"] = primekg_triplets["display_relation"].apply(
+                lambda x: self.data_config["relation_type"][x]
+            )
+
+            # Store the processed triplets
+            primekg_triplets.to_csv(processed_file_path, sep="\t", compression="gzip", index=False)
+
+        return primekg_triplets, node_info_dict
+
+    def _build_train_test_split(self) -> tuple[pd.DataFrame, pd.DataFrame,
+                                               pd.DataFrame, pd.DataFrame, pd.DataFrame]:
+        """
+        Build the train-test split for BioBridgePrimeKG dataset.
+
+        Returns:
+            The train triplets for BioBridgePrimeKG dataset.
+            The train nodes for BioBridgePrimeKG dataset.
+            The test triplets for BioBridgePrimeKG dataset.
+            The test nodes for BioBridgePrimeKG dataset.
+            The full triplets for BioBridgePrimeKG dataset.
+        """
+        if os.path.exists(os.path.join(self.local_dir, "processed",
+                                       "triplet_full_altered.tsv.gz")):
+            # Load each dataframe in the local directory
+            with open(os.path.join(self.local_dir, "processed",
+                                   "triplet_train.tsv.gz"), "rb") as f:
+                df_train = pd.read_csv(f, sep="\t", compression="gzip", low_memory=False)
+
+            with open(os.path.join(self.local_dir, "processed",
+                                   "node_train.tsv.gz"), "rb") as f:
+                df_node_train = pd.read_csv(f, sep="\t", compression="gzip", low_memory=False)
+
+            with open(os.path.join(self.local_dir, "processed",
+                                   "triplet_test.tsv.gz"), "rb") as f:
+                df_test = pd.read_csv(f, sep="\t", compression="gzip", low_memory=False)
+
+            with open(os.path.join(self.local_dir, "processed",
+                                   "node_test.tsv.gz"), "rb") as f:
+                df_node_test = pd.read_csv(f, sep="\t", compression="gzip", low_memory=False)
+
+            with open(os.path.join(self.local_dir, "processed",
+                                   "triplet_full_altered.tsv.gz"), "rb") as f:
+                triplets = pd.read_csv(f, sep="\t", compression="gzip", low_memory=False)
+        else:
+            # Filtering out some nodes in the embedding dictionary
+            triplets = self.primekg_triplets.copy()
+            triplets = triplets[
+                triplets["head_index"].isin(list(self.emb_dict.keys())) &\
+                triplets["tail_index"].isin(list(self.emb_dict.keys()))
+            ].reset_index(drop=True)
+
+            # Perform splitting of the triplets
+            list_split = {
+                "train": [],
+                "test": [],
+            }
+            node_split = {
+                "train": {
+                    "node_index": [],
+                    "node_type": [],
+                },
+                "test": {
+                    "node_index": [],
+                    "node_type": [],
+                }
+            }
+            # Loop over the node types
+            for node_type in triplets["head_type"].unique():
+                df_sub = triplets[triplets["head_type"] == node_type]
+                all_x_indexes = df_sub["head_index"].unique()
+                # By default, we use 90% of the nodes for training and 10% for testing
+                te_x_indexes = np.random.choice(
+                    all_x_indexes, size=int(0.1*len(all_x_indexes)), replace=False
+                )
+                df_subs = {}
+                df_subs["test"] = df_sub[df_sub["head_index"].isin(te_x_indexes)]
+                df_subs["train"] = df_sub[~df_sub["head_index"].isin(te_x_indexes)]
+                list_split["train"].append(df_subs["train"])
+                list_split["test"].append(df_subs["test"])
+
+                # record the split
+                node_index = {}
+                node_index["train"] = df_subs["train"]["head_index"].unique()
+                node_split["train"]["node_index"].extend(node_index["train"].tolist())
+                node_split["train"]["node_type"].extend([node_type]*len(node_index["train"]))
+                node_index["test"] = df_subs["test"]["head_index"].unique()
+                node_split["test"]["node_index"].extend(node_index["test"].tolist())
+                node_split["test"]["node_type"].extend([node_type]*len(node_index["test"]))
+
+                print(f"Number of {node_type} nodes in train: {len(node_index['train'])}")
+                print(f"Number of {node_type} nodes in test: {len(node_index['test'])}")
+
+            # Prepare train and test DataFrames
+            df_train = pd.concat(list_split["train"])
+            df_node_train = pd.DataFrame(node_split["train"])
+            df_test = pd.concat(list_split["test"])
+            df_node_test = pd.DataFrame(node_split["test"])
+
+            # Store each dataframe in the local directory
+            df_train.to_csv(os.path.join(self.local_dir, "processed", "triplet_train.tsv.gz"),
+                            sep="\t", compression="gzip", index=False)
+            df_node_train.to_csv(os.path.join(self.local_dir, "processed", "node_train.tsv.gz"),
+                                sep="\t", compression="gzip", index=False)
+            df_test.to_csv(os.path.join(self.local_dir, "processed", "triplet_test.tsv.gz"),
+                           sep="\t", compression="gzip", index=False)
+            df_node_test.to_csv(os.path.join(self.local_dir, "processed", "node_test.tsv.gz"),
+                                sep="\t", compression="gzip", index=False)
+            # Store altered full triplets as well
+            triplets.to_csv(os.path.join(self.local_dir, "processed",
+                                         "triplet_full_altered.tsv.gz"),
+                            sep="\t", compression="gzip", index=False)
+
+        return df_train, df_node_train, df_test, df_node_test, triplets
+
+    # def _negative_sampling(self,
+    #                        batch_df: pd.DataFrame,
+    #                        process_index: int,
+    #                        index_map: dict,
+    #                        node_train_dict: dict) -> pd.DataFrame:
+    #     """
+    #     A helper function to perform negative sampling for a batch of triplets.
+    #     """
+    #     negative_y_index_list = []
+    #     for _, row in tqdm(batch_df.iterrows(),
+    #                        total=batch_df.shape[0],
+    #                        desc=f"Process {process_index}"):
+    #         x_index = row['head_index']
+    #         # y_index = row['y_index']
+    #         y_index_type = row['tail_type']
+    #         paired_y_index_list = index_map[x_index]
+
+    #         # sample a list of negative y_index
+    #         node_train_sub = node_train_dict[y_index_type]
+    #         negative_y_index = node_train_sub[
+    #             ~node_train_sub['node_index'].isin(paired_y_index_list)
+    #         ]['node_index'].sample(self.n_neg_samples).tolist()
+    #         negative_y_index_list.append(negative_y_index)
+
+    #     batch_df.loc[:, 'negative_tail_index'] = negative_y_index_list
+    #     return batch_df
+
+    # def _build_negative_triplets(self,
+    #                              chunk_size: int=100000,
+    #                              n_neg_samples: int=10):
+    #     """
+    #     Build the negative triplets for BioBridgePrimeKG dataset.
+    #     """
+    #     processed_file_path = os.path.join(self.local_dir,
+    #                                        "processed",
+    #                                        "triplet_train_negative.tsv.gz")
+    #     if os.path.exists(processed_file_path):
+    #         # Load the negative triplets from the local directory
+    #         with open(processed_file_path, "rb") as f:
+    #             triplets_negative = pd.read_csv(f, sep="\t", compression="gzip", low_memory=False)
+    #     else:
+    #         # Set the number samples for negative sampling
+    #         self.n_neg_samples = n_neg_samples
+
+    #         # Split node list by type
+    #         node_train_dict = {}
+    #         type_list = self.df_node_train['node_type'].unique()
+    #         for node_type in type_list:
+    #             node_train_dict[node_type] = self.df_node_train[
+    #                 self.df_node_train['node_type'] == node_type
+    #             ].reset_index(drop=True)
+
+    #         # create an index mapping from x_index to y_index
+    #         index_map = self.df_train[
+    #             ['head_index', 'tail_index']
+    #         ].drop_duplicates().groupby('head_index').agg(list).to_dict()['tail_index']
+
+    #         # Negative sampling
+    #         batch_df_list = []
+    #         for i in tqdm(range(0, self.df_train.shape[0], chunk_size)):
+    #             batch_df_list.append(self.df_train.iloc[i:i+chunk_size])
+    #         # Process negative sampling
+    #         results = [
+    #             self._negative_sampling(batch_df,
+    #                                     num_piece,
+    #                                     index_map,
+    #                                     node_train_dict)
+    #                                     for num_piece, batch_df in enumerate(batch_df_list)
+    #         ]
+
+    #         # Store the negative triplets
+    #         triplets_negative = pd.concat(results, axis=0)
+    #         triplets_negative.to_csv(processed_file_path,
+    #                                  sep="\t", compression="gzip", index=False)
+
+    #     # Set attribute
+    #     self.primekg_triplets_negative = triplets_negative
+
+    #     return triplets_negative
+
+    # def load_data(self,
+    #               build_neg_triplest: bool= False,
+    #               chunk_size: int=100000,
+    #               n_neg_samples: int=10):
+
+    def load_data(self):
+        """
+        Load the BioBridgePrimeKG dataset into pandas DataFrame of nodes and edges.
+
+        Args:
+            build_neg_triplest (bool): Whether to build negative triplets.
+            chunk_size (int): The chunk size for negative sampling.
+            n_neg_samples (int): The number of negative samples for negative sampling.
+        """
+        # Load PrimeKG dataset
+        print("Loading PrimeKG dataset...")
+        self.primekg = self._load_primekg()
+
+        # Load data config file of BioBridgePrimeKG
+        print("Loading data config file of BioBridgePrimeKG...")
+        self.data_config = self._load_data_config()
+
+        # Build node embeddings
+        print("Building node embeddings...")
+        self.emb_dict = self._build_node_embeddings()
+
+        # Build full triplets
+        print("Building full triplets...")
+        self.primekg_triplets, self.node_info_dict = self._build_full_triplets()
+
+        # Build train-test split
+        print("Building train-test split...")
+        self.df_train, self.df_node_train, self.df_test, self.df_node_test, self.primekg_triplets =\
+        self._build_train_test_split()
+
+        # if build_neg_triplest:
+        #     # Build negative triplets
+        #     print("Building negative triplets...")
+        #     self.primekg_triplets_negative = self._build_negative_triplets(
+        #         chunk_size=chunk_size,
+        #         n_neg_samples=n_neg_samples
+        #     )
+
+    def set_random_seed(self, seed: int):
+        """
+        Set the random seed for reproducibility.
+
+        Args:
+            seed (int): The random seed value.
+        """
+        np.random.seed(seed)
+
+    def get_primekg(self) -> PrimeKG:
+        """
+        Get the PrimeKG dataset.
+
+        Returns:
+            The PrimeKG dataset.
+        """
+        return self.primekg
+
+    def get_data_config(self) -> dict:
+        """
+        Get the data config file of BioBridgePrimeKG dataset.
+
+        Returns:
+            The data config file of BioBridgePrimeKG dataset.
+        """
+        return self.data_config
+
+    def get_node_embeddings(self) -> dict:
+        """
+        Get the node embeddings for BioBridgePrimeKG dataset.
+
+        Returns:
+            The dictionary of node embeddings.
+        """
+        return self.emb_dict
+
+    def get_primekg_triplets(self) -> pd.DataFrame:
+        """
+        Get the full triplets for BioBridgePrimeKG dataset.
+
+        Returns:
+            The full triplets for BioBridgePrimeKG dataset.
+        """
+        return self.primekg_triplets
+
+    # def get_primekg_triplets_negative(self) -> pd.DataFrame:
+    #     """
+    #     Get the negative triplets for BioBridgePrimeKG dataset.
+
+    #     Returns:
+    #         The negative triplets for BioBridgePrimeKG dataset.
+    #     """
+    #     return self.primekg_triplets_negative
+
+    def get_train_test_split(self) -> dict:
+        """
+        Get the train-test split for BioBridgePrimeKG dataset.
+
+        Returns:
+            The train-test split for BioBridgePrimeKG dataset.
+        """
+        return {
+            "train": self.df_train,
+            "node_train": self.df_node_train,
+            "test": self.df_test,
+            "node_test": self.df_node_test
+        }
+
+    def get_node_info_dict(self) -> dict:
+        """
+        Get the node information dictionary for BioBridgePrimeKG dataset.
+
+        Returns:
+            The node information dictionary for BioBridgePrimeKG dataset.
+        """
+        return self.node_info_dict