--- a
+++ b/shepherd/samplers.py
@@ -0,0 +1,690 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch_sparse import SparseTensor
+from torch_cluster import random_walk
+from torch_geometric.data.sampler import EdgeIndex, Adj
+from torch.nn.utils.rnn import pad_sequence
+from torch.utils.data import Dataset
+from torch_geometric.utils import add_self_loops, add_remaining_self_loops
+from torch_geometric.data import Data, DataLoader, NeighborSampler
+
+from typing import List, Optional, Tuple, NamedTuple, Union, Callable, Dict
+from collections import defaultdict
+import time
+import random
+import pickle
+from collections import Counter
+from operator import itemgetter
+import copy
+import numpy as np
+from utils.pretrain_utils import get_indices_into_edge_index, HeterogeneousEdgeIndex 
+from sklearn.preprocessing import label_binarize
+
+import project_config
+
+
+class NeighborSampler(torch.utils.data.DataLoader):
+    r"""The neighbor sampler from the `"Inductive Representation Learning on
+    Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper, which allows
+    for mini-batch training of GNNs on large-scale graphs where full-batch
+    training is not feasible.
+    Given a GNN with :math:`L` layers and a specific mini-batch of nodes
+    :obj:`node_idx` for which we want to compute embeddings, this module
+    iteratively samples neighbors and constructs bipartite graphs that simulate
+    the actual computation flow of GNNs.
+    More specifically, :obj:`sizes` denotes how much neighbors we want to
+    sample for each node in each layer.
+    This module then takes in these :obj:`sizes` and iteratively samples
+    :obj:`sizes[l]` for each node involved in layer :obj:`l`.
+    In the next layer, sampling is repeated for the union of nodes that were
+    already encountered.
+    The actual computation graphs are then returned in reverse-mode, meaning
+    that we pass messages from a larger set of nodes to a smaller one, until we
+    reach the nodes for which we originally wanted to compute embeddings.
+    Hence, an item returned by :class:`NeighborSampler` holds the current
+    :obj:`batch_size`, the IDs :obj:`n_id` of all nodes involved in the
+    computation, and a list of bipartite graph objects via the tuple
+    :obj:`(edge_index, e_id, size)`, where :obj:`edge_index` represents the
+    bipartite edges between source and target nodes, :obj:`e_id` denotes the
+    IDs of original edges in the full graph, and :obj:`size` holds the shape
+    of the bipartite graph.
+    For each bipartite graph, target nodes are also included at the beginning
+    of the list of source nodes so that one can easily apply skip-connections
+    or add self-loops.
+    .. note::
+        For an example of using :obj:`NeighborSampler`, see
+        `examples/reddit.py
+        <https://github.com/rusty1s/pytorch_geometric/blob/master/examples/
+        reddit.py>`_ or
+        `examples/ogbn_products_sage.py
+        <https://github.com/rusty1s/pytorch_geometric/blob/master/examples/
+        ogbn_products_sage.py>`_.
+    Args:
+        edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a
+            :obj:`torch_sparse.SparseTensor` that defines the underlying graph
+            connectivity/message passing flow.
+            :obj:`edge_index` holds the indices of a (sparse) symmetric
+            adjacency matrix.
+            If :obj:`edge_index` is of type :obj:`torch.LongTensor`, its shape
+            must be defined as :obj:`[2, num_edges]`, where messages from nodes
+            :obj:`edge_index[0]` are sent to nodes in :obj:`edge_index[1]`
+            (in case :obj:`flow="source_to_target"`).
+            If :obj:`edge_index` is of type :obj:`torch_sparse.SparseTensor`,
+            its sparse indices :obj:`(row, col)` should relate to
+            :obj:`row = edge_index[1]` and :obj:`col = edge_index[0]`.
+            The major difference between both formats is that we need to input
+            the *transposed* sparse adjacency matrix.
+        sizes ([int]): The number of neighbors to sample for each node in each
+            layer. If set to :obj:`sizes[l] = -1`, all neighbors are included
+            in layer :obj:`l`.
+        node_idx (LongTensor, optional): The nodes that should be considered
+            for creating mini-batches. If set to :obj:`None`, all nodes will be
+            considered.
+        num_nodes (int, optional): The number of nodes in the graph.
+            (default: :obj:`None`)
+        return_e_id (bool, optional): If set to :obj:`False`, will not return
+            original edge indices of sampled edges. This is only useful in case
+            when operating on graphs without edge features to save memory.
+            (default: :obj:`True`)
+        transform (callable, optional): A function/transform that takes in
+            an a sampled mini-batch and returns a transformed version.
+            (default: :obj:`None`)
+        **kwargs (optional): Additional arguments of
+            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
+            :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
+    """
+    def __init__(self, dataset_type: str, edge_index: Union[Tensor, SparseTensor], 
+                sample_edge_index: Union[Tensor, SparseTensor],
+                 sizes: List[int],
+                 node_idx: Optional[Tensor] = None,
+                 num_nodes: Optional[int] = None, return_e_id: bool = True,
+                 transform: Callable = None,
+                 do_filter_edges: bool = True, 
+                 **kwargs):
+
+        edge_index = edge_index.to('cpu')
+        sample_edge_index = sample_edge_index.to('cpu')
+
+        # add self loops
+        sample_edge_index, _ = add_self_loops(sample_edge_index)
+
+
+        if 'collate_fn' in kwargs:
+            del kwargs['collate_fn']
+
+        # Save for Pytorch Lightning...
+        self.dataset_type = dataset_type
+        self.edge_index = edge_index #always train edge index
+        self.sample_edge_index = sample_edge_index # depends on train/val/test
+        self.node_idx = node_idx
+        self.num_nodes = num_nodes
+
+        self.sizes = sizes
+        self.return_e_id = return_e_id
+        self.transform = transform
+        self.is_sparse_tensor = isinstance(edge_index, SparseTensor)
+        self.__val__ = None
+        self.do_filter_edges = do_filter_edges
+
+        # Obtain a *transposed* `SparseTensor` instance.
+        if not self.is_sparse_tensor:
+            if (num_nodes is None and node_idx is not None
+                    and node_idx.dtype == torch.bool):
+                num_nodes = node_idx.size(0)
+                sample_num_nodes = num_nodes
+            if (num_nodes is None and node_idx is not None
+                    and node_idx.dtype == torch.long):
+                num_nodes = max(int(edge_index.max()), int(node_idx.max())) + 1
+                sample_num_nodes = num_nodes
+            if num_nodes is None:
+                num_nodes = int(edge_index.max()) + 1
+                sample_num_nodes = int(sample_edge_index.max()) + 1
+
+            value = torch.arange(edge_index.size(1)) if return_e_id else None
+            sample_value = torch.arange(sample_edge_index.size(1)) if return_e_id else None
+            self.adj_t = SparseTensor(row=edge_index[0], col=edge_index[1],
+                                      value=value,
+                                      sparse_sizes=(num_nodes, num_nodes)).t()
+            self.adj_t_sample = SparseTensor(row=sample_edge_index[0], col=sample_edge_index[1],
+                                      value=sample_value,
+                                      sparse_sizes=(sample_num_nodes, sample_num_nodes)).t()
+        else:
+            adj_t = edge_index
+            adj_t_sample = sample_edge_index
+            if return_e_id:
+                self.__val__ = adj_t.storage.value()
+                value = torch.arange(adj_t.nnz())
+                adj_t = adj_t.set_value(value, layout='coo')
+                adj_t_sample = adj_t_sample.set_value(torch.arange(adj_t_sample.nnz()), layout='coo')
+            self.adj_t = adj_t
+            self.adj_t_sample = adj_t_sample
+
+        self.adj_t.storage.rowptr()
+        self.adj_t_sample.storage.rowptr()
+
+        if node_idx is None:
+            node_idx = torch.arange(self.adj_t_sample.sparse_size(0)) 
+        elif node_idx.dtype == torch.bool:
+            node_idx = node_idx.nonzero(as_tuple=False).view(-1)
+
+        super(NeighborSampler, self).__init__(
+            node_idx.view(-1).tolist(), collate_fn=self.sample, **kwargs)
+
+    
+
+    def filter_edges(self, edge_index, e_id, source_nodes, target_nodes):
+        '''
+        Filter out the edges we're trying to predict in the current batch from the edge index
+        NOTE: edge_index here is re-indexed
+        '''
+        reindex_source_nodes = torch.arange(source_nodes.size(0))
+        reindex_target_nodes = torch.arange(start = source_nodes.size(0), end = source_nodes.size(0) + target_nodes.size(0))
+
+        # get reverse edges to filter as well
+        all_source_nodes = torch.cat([reindex_source_nodes, reindex_target_nodes])
+        all_target_nodes = torch.cat([reindex_target_nodes, reindex_source_nodes])
+        ind_to_edge_index, ind_to_nodes = get_indices_into_edge_index(edge_index, all_source_nodes, all_target_nodes) #get index into the original edge index (this returns e_ids)
+        mask = torch.ones(edge_index.size(1), dtype=torch.bool)
+        mask[ind_to_edge_index] = False
+
+        return edge_index[:, mask], e_id[mask]
+
+
+    def sample(self, source_batch):
+        
+        #convert to tensor
+        if not isinstance(source_batch, Tensor):
+            source_batch = torch.tensor(source_batch)
+
+        # sample nodes to form positive edges. we will try to predict these edges
+        row, col, e_id = self.adj_t_sample.coo()    
+        target_batch = random_walk(row, col, source_batch, walk_length=1, coalesced=False)[:, 1] #NOTE: only does self loops when no edges in the current partition of the dataset
+        batch = torch.cat([source_batch, target_batch], dim=0) 
+
+        batch_size: int = len(batch)
+        adjs = []
+        n_id = batch
+        for size in self.sizes:
+            adj_t, n_id = self.adj_t.sample_adj(n_id, size, replace=False) 
+            e_id = adj_t.storage.value()
+            size = adj_t.sparse_sizes()[::-1]
+            if self.__val__ is not None:
+                adj_t.set_value_(self.__val__[e_id], layout='coo')
+
+            if self.is_sparse_tensor: #TODO: implement filter_edges if sparse tensor
+                adjs.append(Adj(adj_t, e_id, size))
+            else:
+                row, col, _ = adj_t.coo()
+                edge_index = torch.stack([col, row], dim=0)
+
+                if self.do_filter_edges and self.dataset_type == 'train':
+                    edge_index, e_id = self.filter_edges(edge_index, e_id, source_batch, target_batch)
+                adjs.append(EdgeIndex(edge_index, e_id, size))
+
+        adjs = adjs[0] if len(adjs) == 1 else adjs[::-1]
+        out = (batch_size, n_id, adjs)
+        out = self.transform(*out) if self.transform is not None else out
+        return out
+
+    def __repr__(self):
+        return '{}(sizes={})'.format(self.__class__.__name__, self.sizes)
+
+class PatientNeighborSampler(torch.utils.data.DataLoader):
+   
+    def __init__(self, dataset_type: str, edge_index: Union[Tensor, SparseTensor], 
+                 sample_edge_index: Union[Tensor, SparseTensor],
+                 sizes: List[int], 
+                 patient_dataset,
+                 all_edge_attributes,
+                 n_nodes: int,
+                 relevant_node_idx = None,
+                 do_filter_edges: Optional[bool] = False,
+                 num_nodes: Optional[int] = None, 
+                 return_e_id: bool = True,
+                 sparse_sample: Optional[int] = 0,
+                 train_phenotype_counter: Dict = None,
+                 train_gene_counter: Dict = None,
+                 sample_edges_from_train_patients=False,
+                 upsample_cand: Optional[int] = 0,
+                 n_cand_diseases=-1,
+                 use_diseases=False,
+                 nid_to_spl_dict = None,
+                 gp_spl = None,
+                 spl_indexing_dict=None,
+
+                 gene_similarity_dict=None,
+                 gene_deg_dict = None,
+
+                 hparams=None,
+                 transform: Callable = None, 
+                 **kwargs):
+
+        edge_index = edge_index.to('cpu')
+        sample_edge_index = sample_edge_index.to('cpu')
+
+        # add self loops
+        sample_edge_index = torch.cat((sample_edge_index, torch.stack([edge_index.unique(), edge_index.unique()])),1 )
+        sample_edge_index, _ = add_remaining_self_loops(sample_edge_index)
+
+        if 'collate_fn' in kwargs:
+            del kwargs['collate_fn']
+
+        # Save for Pytorch Lightning...
+        self.do_filter_edges = do_filter_edges
+        self.relevant_node_idx = relevant_node_idx
+        self.n_nodes = n_nodes
+        self.all_edge_attr = all_edge_attributes
+        self.dataset_type = dataset_type
+        self.sparse_sample = sparse_sample
+        self.edge_index = edge_index #always train edge index
+        self.sample_edge_index = sample_edge_index # depends on train/val/test
+        self.patient_dataset = patient_dataset
+        self.num_nodes = num_nodes
+        self.train_phenotype_counter = train_phenotype_counter
+        self.train_gene_counter = train_gene_counter
+        self.sample_edges_from_train_patients = sample_edges_from_train_patients
+        self.sizes = sizes
+        self.return_e_id = return_e_id
+        self.transform = transform
+        self.is_sparse_tensor = isinstance(edge_index, SparseTensor)
+        self.__val__ = None
+
+        # For SPL
+        self.nid_to_spl_dict = nid_to_spl_dict 
+        if hparams["alpha"] < 1: self.gp_spl = gp_spl
+        else: self.gp_spl = None
+        self.spl_indexing_dict = spl_indexing_dict
+
+        # Up-sample candidate genes
+        self.upsample_cand = upsample_cand
+        self.cand_gene_freq = Counter([])
+        with open(str(project_config.KG_DIR  / f'ensembl_to_idx_dict_{project_config.CURR_KG}.pkl'), 'rb') as handle:
+            ensembl_to_idx_dict = pickle.load(handle) # create ensembl to node_idx map
+        idx_to_ensembl_dict = {v: k for k, v in ensembl_to_idx_dict.items()}
+        self.cand_gene_freq = Counter([k for k in nid_to_spl_dict if k in idx_to_ensembl_dict]) # Upsample from all gene nodes in the KG
+        
+        self.n_cand_diseases = n_cand_diseases
+        self.use_diseases = use_diseases
+        self.hparams = hparams
+
+        self.gene_similarity_dict = gene_similarity_dict
+        self.gene_deg_dict = gene_deg_dict
+
+        # Obtain a *transposed* `SparseTensor` instance.
+        if not self.is_sparse_tensor:
+            if num_nodes is None:
+                num_nodes = int(edge_index.max()) + 1
+                sample_num_nodes = int(sample_edge_index.max()) + 1
+
+            value = torch.arange(edge_index.size(1)) if return_e_id else None
+            sample_value = torch.arange(sample_edge_index.size(1)) if return_e_id else None
+            self.adj_t = SparseTensor(row=edge_index[0], col=edge_index[1],
+                                      value=value,
+                                      sparse_sizes=(num_nodes, num_nodes)).t()
+            self.adj_t_sample = SparseTensor(row=sample_edge_index[0], col=sample_edge_index[1],
+                                      value=sample_value,
+                                      sparse_sizes=(sample_num_nodes, sample_num_nodes)).t()
+        else:
+            adj_t = edge_index
+            adj_t_sample = sample_edge_index
+            if return_e_id:
+                self.__val__ = adj_t.storage.value()
+                value = torch.arange(adj_t.nnz())
+                adj_t = adj_t.set_value(value, layout='coo')
+                adj_t_sample = adj_t_sample.set_value(torch.arange(adj_t_sample.nnz()), layout='coo')
+            self.adj_t = adj_t
+            self.adj_t_sample = adj_t_sample
+
+        self.adj_t.storage.rowptr()
+        self.adj_t_sample.storage.rowptr()
+
+
+
+        super(PatientNeighborSampler, self).__init__(
+            self.patient_dataset, collate_fn=self.collate, **kwargs)
+
+    def filter_edges(self, edge_index, e_id, source_nodes, target_nodes):
+        '''
+        Filter out the edges we're trying to predict in the current batch from the edge index
+        NOTE: edge_index here is re-indexed
+        '''
+        reindex_source_nodes = torch.arange(source_nodes.size(0))
+        reindex_target_nodes = torch.arange(start = source_nodes.size(0), end = source_nodes.size(0) + target_nodes.size(0))
+
+        # get reverse edges to filter as well
+        all_source_nodes = torch.cat([reindex_source_nodes, reindex_target_nodes])
+        all_target_nodes = torch.cat([reindex_target_nodes, reindex_source_nodes])
+        ind_to_edge_index, ind_to_nodes = get_indices_into_edge_index(edge_index, all_source_nodes, all_target_nodes) #get index into the original edge index (this returns e_ids)
+        mask = torch.ones(edge_index.size(1), dtype=torch.bool)
+        mask[ind_to_edge_index] = False
+
+        return edge_index[:, mask], e_id[mask]
+
+    def get_source_nodes(self, phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, disease_node_idx, candidate_disease_node_idx, sim_gene_node_idx): 
+        
+        # Get batch node indices based on patient phenotypes and genes
+        if sim_gene_node_idx is not None:
+            source_batch = torch.cat(phenotype_node_idx +  candidate_gene_node_idx +  correct_genes_node_idx + disease_node_idx + candidate_disease_node_idx + sim_gene_node_idx)
+        else:
+            source_batch = torch.cat(phenotype_node_idx +  candidate_gene_node_idx +  correct_genes_node_idx + disease_node_idx + candidate_disease_node_idx)
+
+         # Randomly sample nodes in KG 
+        if self.sparse_sample > 0:
+            if self.relevant_node_idx == None:
+                rand_idx = torch.randint(high=self.n_nodes, size=(self.sparse_sample,)) # NOTE that this can sample duplicates, but has the benefit of randomly sampling new nodes each epoch
+            else:
+                rand_idx = self.relevant_node_idx[torch.randint(high=self.relevant_node_idx.size(0), size=(self.sparse_sample,))]
+            
+            source_batch = torch.cat([source_batch, rand_idx])
+            source_batch = torch.unique(source_batch)
+            sparse_idx = torch.unique(rand_idx)
+        else:
+            source_batch = torch.unique(source_batch)
+            sparse_idx = torch.Tensor([])
+
+        return source_batch, sparse_idx
+
+    def sample_target_nodes(self, source_batch):
+        row, col, e_id = self.adj_t_sample.coo() 
+        
+        if self.sample_edges_from_train_patients:
+            train_patient_nodes = torch.tensor(list(self.train_phenotype_counter.keys()) + list(self.train_gene_counter.keys())) 
+            ind_with_train_patient_nodes = (col == train_patient_nodes.unsqueeze(-1)).nonzero(as_tuple=True)[1]
+            subset_row = row[ind_with_train_patient_nodes]
+            subset_col = col[ind_with_train_patient_nodes]
+            try:
+                # first try to find an edge that connects back to the training set patient data
+                targets = random_walk(subset_row, subset_col, source_batch, walk_length=1, coalesced=False)[:, 1] #NOTE: only does self loops when no edges in the current partition of the dataset
+                source_batch_1 = source_batch[~torch.eq(source_batch, targets)]
+                targets_1 = targets[~torch.eq(source_batch, targets)]
+
+                # if no edges are found, use all available edges in this split of the data
+                source_batch_2 = source_batch[torch.eq(source_batch, targets)]
+                targets_2 = random_walk(row, col, source_batch_2, walk_length=1, coalesced=False)[:, 1] #NOTE: only does self loops when no edges in the current partition of the dataset
+
+                #concat the two together
+                source_batch = torch.cat([source_batch_1, source_batch_2])
+                targets = torch.cat([targets_1, targets_2])
+
+            except:
+                targets = random_walk(row, col, source_batch, walk_length=1, coalesced=False)[:, 1] #NOTE: only does self loops when no edges in the current partition of the dataset
+        else:
+            targets = random_walk(row, col, source_batch, walk_length=1, coalesced=False)[:, 1] #NOTE: only does self loops when no edges in the current partition of the dataset
+        return source_batch, targets
+
+    def add_patient_information(self, patient_ids, phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, sim_gene_node_idx, gene_sims, gene_degs, disease_node_idx, candidate_disease_node_idx, labels, disease_labels, patient_labels, additional_labels, adjs, batch_size, n_id, sparse_idx, target_batch): #candidate_disease_node_idx
+
+        # Create Data Object & Add patient level information
+        adjs = [HeterogeneousEdgeIndex(adj.edge_index, adj.e_id, self.all_edge_attr[adj.e_id], adj.size) for adj in adjs] 
+        max_n_candidates = max([len(l) for l in candidate_gene_node_idx])
+        data = Data(adjs = adjs, 
+                batch_size = batch_size,
+                patient_ids = patient_ids,
+                n_id = n_id
+                )
+        if self.hparams['loss'] != 'patient_disease_NCA' and self.hparams['loss'] != 'patient_patient_NCA':
+            if None in list(labels): data['one_hot_labels'] = None
+            else: data['one_hot_labels'] = torch.LongTensor(label_binarize(labels, classes = list(range(max_n_candidates))))
+
+        if self.use_diseases:
+            data['disease_one_hot_labels'] = disease_labels 
+
+        if self.hparams['loss'] == 'patient_patient_NCA':
+            if patient_labels is None: data['patient_labels'] = None
+            else: data['patient_labels'] = torch.stack(patient_labels)
+
+        # Get candidate genes to phenotypes SPL
+        if not self.gp_spl is None:
+            if not self.spl_indexing_dict is None:
+                patient_ids = np.vectorize(self.spl_indexing_dict.get)(patient_ids).astype(int)
+            gene_to_phenotypes_spl = -torch.Tensor(self.gp_spl[patient_ids,:])
+            # get gene idx to spl information
+            cand_gene_idx_to_spl = [torch.LongTensor(np.vectorize(self.nid_to_spl_dict.get)(cand_genes)) for cand_genes in list(candidate_gene_node_idx)]
+            # get SPLs for each patient's candidate genes
+            batch_cand_gene_to_phenotypes_spl = [gene_spls[cand_genes] for cand_genes, gene_spls in zip(cand_gene_idx_to_spl, gene_to_phenotypes_spl)]
+            # pad to same # of candidate genes
+            data['batch_cand_gene_to_phenotypes_spl'] = pad_sequence(batch_cand_gene_to_phenotypes_spl, batch_first=True, padding_value=0)
+            # get unique gene idx across all patients in the batch
+            cand_gene_idx_flattened_unique = torch.unique(torch.cat(cand_gene_idx_to_spl)).flatten()
+            # get SPLs for unique genes in the batch
+            data['batch_concat_cand_gene_to_phenotypes_spl'] = gene_to_phenotypes_spl[:, cand_gene_idx_flattened_unique]
+        else:
+            data['batch_cand_gene_to_phenotypes_spl'] = None
+            data['batch_concat_cand_gene_to_phenotypes_spl'] = None
+
+
+        # Create mapping from KG node IDs to batch indices
+        node2batch = {n+1: int(i+1) for i, n in enumerate(data.n_id.tolist())}
+        node2batch[0] = 0
+
+        # add phenotype / gene / disease names
+        data['phenotype_names'] = [[(self.patient_dataset.node_idx_to_name(p.item()), self.patient_dataset.node_idx_to_degree(p.item())) for p in p_list] for p_list in phenotype_node_idx ]
+        data['cand_gene_names'] = [[self.patient_dataset.node_idx_to_name(g.item()) for g in g_list] for g_list in candidate_gene_node_idx ]
+        data['corr_gene_names'] = [[self.patient_dataset.node_idx_to_name(g.item()) for g in g_list] for g_list in correct_genes_node_idx  ]
+        data['disease_names'] = [[self.patient_dataset.node_idx_to_name(d.item()) for d in d_list] for d_list in disease_node_idx ]
+
+        if self.use_diseases:
+            data['cand_disease_names'] = [[self.patient_dataset.node_idx_to_name(d.item()) for d in d_list] for d_list in candidate_disease_node_idx ]
+
+
+        #reindex nodes to make room for padding
+        phenotype_node_idx = [p + 1 for p in phenotype_node_idx]
+        candidate_gene_node_idx = [g + 1 for g in candidate_gene_node_idx]
+        correct_genes_node_idx = [g + 1 for g in correct_genes_node_idx]
+        if self.use_diseases:
+            disease_node_idx = [d + 1 for d in disease_node_idx]
+            candidate_disease_node_idx = [d + 1 for d in candidate_disease_node_idx]
+        if 'augment_genes' in self.hparams and self.hparams['augment_genes']:
+            sim_gene_node_idx = [g + 1 for g in sim_gene_node_idx]
+
+        # if there aren't any disease idx in the batch, we add filler
+        if self.use_diseases:
+            if all(len(t) == 0 for t in disease_node_idx):
+                disease_node_idx = [torch.LongTensor([0]) for i in range(len(disease_node_idx))]
+            if all(len(t) == 0 for t in candidate_disease_node_idx):
+                candidate_disease_node_idx = [torch.LongTensor([0]) for i in range(len(candidate_disease_node_idx))]
+
+        # add padding to patient phenotype and gene node idx
+        data['batch_pheno_nid'] = pad_sequence(phenotype_node_idx, batch_first=True, padding_value=0) 
+        if len(candidate_gene_node_idx[0]) > 0:
+            data['batch_cand_gene_nid'] = pad_sequence(candidate_gene_node_idx, batch_first=True, padding_value=0) 
+        data['batch_corr_gene_nid'] = pad_sequence(correct_genes_node_idx, batch_first=True, padding_value=0) 
+        if self.use_diseases:
+            data['batch_disease_nid'] = pad_sequence(disease_node_idx, batch_first=True, padding_value=0) 
+            data['batch_cand_disease_nid'] = pad_sequence(candidate_disease_node_idx, batch_first=True, padding_value=0) 
+        if 'augment_genes' in self.hparams and self.hparams['augment_genes']:
+            data['batch_cand_gene_degs'] = pad_sequence(gene_degs, batch_first=True, padding_value=0) 
+            data['batch_sim_gene_nid'] = pad_sequence(sim_gene_node_idx, batch_first=True, padding_value=0) 
+            data['batch_sim_gene_sims'] = pad_sequence(gene_sims, batch_first=True, padding_value=0)
+            # Normalize
+            data['batch_sim_gene_sims'] = data['batch_sim_gene_sims'] / torch.sum(data['batch_sim_gene_sims'], dim=1, keepdim=True)
+        else:
+            if len(candidate_gene_node_idx[0]) > 0:
+                data['batch_cand_gene_nid'] = pad_sequence(candidate_gene_node_idx, batch_first=True, padding_value=0) 
+
+        # Convert KG node IDs to batch IDs
+        # When performing inference (i.e., predict.py), use the original node IDs because the full KG is used in forward pass of node model
+        if self.dataset_type != "predict":
+            data['batch_pheno_nid']  = torch.LongTensor(np.vectorize(node2batch.get)(data['batch_pheno_nid']))
+            if len(candidate_gene_node_idx[0]) > 0:
+                data['batch_cand_gene_nid'] = torch.LongTensor(np.vectorize(node2batch.get)(data['batch_cand_gene_nid']))
+            if len(correct_genes_node_idx[0]) > 0:
+                data['batch_corr_gene_nid'] = torch.LongTensor(np.vectorize(node2batch.get)(data['batch_corr_gene_nid']))
+            if self.use_diseases:
+                data['batch_disease_nid'] = torch.LongTensor(np.vectorize(node2batch.get)(data['batch_disease_nid']))
+                data['batch_cand_disease_nid'] = torch.LongTensor(np.vectorize(node2batch.get)(data['batch_cand_disease_nid']))
+            if 'augment_genes' in self.hparams and self.hparams['augment_genes']:
+                data['batch_sim_gene_nid'] = torch.LongTensor(np.vectorize(node2batch.get)(data['batch_sim_gene_nid']))
+        return data
+
+    def get_candidate_diseases(self, disease_node_idx, candidate_gene_node_idx):
+        cand_diseases = self.patient_dataset.get_candidate_diseases(cand_type=self.hparams['candidate_disease_type'])
+        if self.n_cand_diseases != -1: cand_diseases = cand_diseases[torch.randperm(len(cand_diseases))][0:self.n_cand_diseases] 
+        
+        if self.hparams['only_hard_distractors']: #add candidates to every patient
+            candidate_disease_node_idx = tuple(torch.unique(torch.cat([corr_dis, cand_diseases ]), sorted=False) for corr_dis in disease_node_idx)
+            candidate_disease_node_idx = tuple(torch.unique(dis[torch.randperm(len(dis))], sorted=False, return_inverse=False, return_counts=False) for dis in candidate_disease_node_idx)
+        else: # split candidates across all patients in the batch
+            all_correct_diseases = torch.cat(disease_node_idx)
+            all_diseases = torch.unique(torch.cat([all_correct_diseases, cand_diseases]))
+            all_diseases = all_diseases[torch.randperm(len(all_diseases))]
+            candidate_disease_node_idx = np.array_split(all_diseases, len(candidate_gene_node_idx))
+            candidate_disease_node_idx = tuple(candidate_disease_node_idx)
+        max_n_dis_candidates = max([len(l) for l in candidate_disease_node_idx])
+        if max_n_dis_candidates == 0: 
+            max_n_dis_candidates = 1
+            print('WARNING: there are no disease candidates')
+
+        disease_ind = [(dis.unsqueeze(1) == corr_dis.unsqueeze(0)).nonzero(as_tuple=True)[0] if len(corr_dis) > 0 else torch.tensor(-1) for dis, corr_dis in zip(candidate_disease_node_idx, disease_node_idx)]
+        disease_labels = torch.zeros((len(candidate_disease_node_idx), max_n_dis_candidates))
+        for i, ind in enumerate(disease_ind): disease_labels[i,ind[ind != -1]] = 1
+        return candidate_disease_node_idx, disease_labels
+
+    def get_candidate_patients(self, patient_ids):
+        # get patients with the same disease/gene
+        similar_pat_ids = [self.patient_dataset.get_similar_patients(p_id, similarity_type=self.hparams['patient_similarity_type']) for p_id in patient_ids]
+        # shuffle patients & subset to n_sim_pats so we have X similar patients per patient in the batch
+        similar_pat_ids = [p[:self.hparams['n_similar_patients']] for p in similar_pat_ids] #[torch.randperm(len(p))]
+        # Retrieve the patients for each of the sampled patient ids if they aren't already in the batch
+        patient_ids = list(patient_ids) 
+        similar_pats = [self.patient_dataset[self.patient_dataset.patient_id_to_index[p_id.item()]] for p_ids in similar_pat_ids for p_id in p_ids if p_id.item() not in patient_ids]
+        return similar_pats
+    
+    def sample(self, batch, source_batch, target_batch):
+        batch_size: int = len(batch)
+        adjs = []
+        n_id = batch
+        for size in self.sizes:
+
+            adj_t, n_id = self.adj_t.sample_adj(n_id, size, replace=False)
+            e_id = adj_t.storage.value()
+            size = adj_t.sparse_sizes()[::-1]
+            if self.__val__ is not None:
+                adj_t.set_value_(self.__val__[e_id], layout='coo')
+
+            if self.is_sparse_tensor: #TODO: implement filter_edges if sparse tensor
+                adjs.append(Adj(adj_t, e_id, size))
+            else:
+                row, col, _ = adj_t.coo()
+                edge_index = torch.stack([col, row], dim=0)
+                if self.do_filter_edges and self.dataset_type == 'train':
+                    edge_index, e_id = self.filter_edges(edge_index, e_id, source_batch, target_batch)
+                adjs.append(EdgeIndex(edge_index, e_id, size))
+
+        adjs = [adjs[0]] if len(adjs) == 1 else adjs[::-1]
+        return adjs, batch_size, n_id
+    
+    def get_similar_genes(self, patient_ids, candidate_gene_node_idx):
+        k = self.hparams['n_sim_genes']
+        gene_ids = []
+        sims = []
+        degs = []
+        assert len(patient_ids) == len(candidate_gene_node_idx)
+        for p, p_cand_genes in zip(patient_ids, candidate_gene_node_idx):
+            p_genes = []
+            p_sims = []
+            p_degs = []
+            for g in p_cand_genes:
+                p_genes.append(torch.LongTensor([idx for idx, sim in list(self.gene_similarity_dict[int(g)])[:k]]))
+                p_sims.append(torch.LongTensor([sim for idx, sim in list(self.gene_similarity_dict[int(g)])[:k]]))
+                p_degs.append(self.gene_deg_dict[int(g)])
+            gene_ids.append(torch.stack(p_genes))
+            sims.append(torch.stack(p_sims))
+            degs.append(torch.LongTensor(p_degs))
+        assert len(gene_ids) == len(patient_ids)
+        assert len(sims) == len(patient_ids)
+        unique_genes = torch.unique(torch.cat(gene_ids).flatten()).unsqueeze(-1)
+        return tuple(gene_ids), tuple(sims), tuple(degs), tuple(unique_genes)
+
+    def collate(self, batch):
+        t00 = time.time()
+        phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, disease_node_idx, labels, additional_labels, patient_ids = zip(*batch)
+
+        # Up-sample under-represented candidate genes
+        t0 = time.time()
+        if self.upsample_cand > 0:
+            curr_cand_gene_freq = Counter(torch.cat(candidate_gene_node_idx).flatten().tolist())
+            self.cand_gene_freq += curr_cand_gene_freq
+            num_patients = len(candidate_gene_node_idx) * self.upsample_cand
+            lowest_k_cand = self.cand_gene_freq.most_common()[:-num_patients-1:-1]
+            lowest_k_cand = np.array_split([g[0] for g in lowest_k_cand], len(candidate_gene_node_idx))
+            
+            upsampled_candidate_gene_node_idx = []
+            added_cand_gene = []
+            for patient, cand_gene, corr_gene_idx in zip(candidate_gene_node_idx, lowest_k_cand, labels):
+                
+                # Remove correct genes from list of upsampled candidate genes
+                corr_gene_nid = patient[corr_gene_idx]
+                cand_gene = cand_gene[~np.isin(cand_gene, corr_gene_nid)].flatten()
+                
+                # Remove duplicates
+                unique_cand_genes, new_cand_genes_freq = torch.unique(torch.tensor(patient.tolist() + list(cand_gene)), return_counts = True)
+                unique_cand_genes = unique_cand_genes[new_cand_genes_freq == 1]
+                cand_gene = cand_gene[np.isin(cand_gene, unique_cand_genes)]                
+                
+                # Add upsampled candidate genes
+                added_cand_gene.extend(list(cand_gene))
+                new_cand_list = torch.tensor(patient.tolist() + list(cand_gene))
+                upsampled_candidate_gene_node_idx.append(new_cand_list)
+            
+            candidate_gene_node_idx = tuple(upsampled_candidate_gene_node_idx)
+            self.cand_gene_freq += Counter(added_cand_gene)
+
+        
+        # Add similar patients to batch (for "patients like me" head)
+        if self.hparams['add_similar_patients']:
+            similar_pats = self.get_candidate_patients(patient_ids)
+            # merge original batch with sampled patients
+            phenotype_node_idx_sim, candidate_gene_node_idx_sim, correct_genes_node_idx_sim, disease_node_idx_sim, labels_sim, additional_labels_sim, patient_ids_sim = zip(*similar_pats)
+            phenotype_node_idx = phenotype_node_idx + phenotype_node_idx_sim
+            candidate_gene_node_idx = candidate_gene_node_idx + candidate_gene_node_idx_sim
+            correct_genes_node_idx = correct_genes_node_idx + correct_genes_node_idx_sim
+            disease_node_idx = disease_node_idx + disease_node_idx_sim
+            labels = labels + labels_sim
+            additional_labels = additional_labels + additional_labels_sim
+            patient_ids = patient_ids + patient_ids_sim
+        
+        # get patient labels
+        patient_labels = correct_genes_node_idx
+        
+        # Add candidate diseases to batch
+        if self.hparams['add_cand_diseases']:
+            candidate_disease_node_idx, disease_labels = self.get_candidate_diseases(disease_node_idx, candidate_gene_node_idx)
+        else: 
+            candidate_disease_node_idx = disease_node_idx
+            disease_labels = torch.tensor([1] * len(candidate_disease_node_idx))
+
+        if self.hparams['augment_genes']:
+            sim_gene_node_idx, gene_sims, gene_degs, unique_sim_genes = self.get_similar_genes(patient_ids, candidate_gene_node_idx)
+        else:
+            unique_sim_genes = gene_degs = gene_sims = sim_gene_node_idx = None
+
+        t1 = time.time()
+
+        # get nodes from patients + randomly sampled nodes
+        source_batch, sparse_idx = self.get_source_nodes(phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, disease_node_idx, candidate_disease_node_idx, unique_sim_genes)
+       
+        # sample nodes to form positive edges
+        source_batch, target_batch = self.sample_target_nodes(source_batch) 
+        batch = torch.cat([source_batch, target_batch], dim=0) 
+        t2 = time.time()
+
+        # get k hop adj graph
+        adjs, batch_size, n_id = self.sample(batch, source_batch, target_batch)
+        t3 = time.time()
+
+        # add patient information to data object
+        data = self.add_patient_information(patient_ids, phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, sim_gene_node_idx, gene_sims, gene_degs, disease_node_idx, candidate_disease_node_idx, labels, disease_labels, patient_labels, additional_labels, adjs, batch_size, n_id, sparse_idx, target_batch) #candidate_disease_node_idx
+        t4 = time.time()
+        
+        if self.hparams['time']:
+            print(f'It takes {t0-t00:0.4f}s to unzip batch, {t1-t0:0.4f}s to upsample candidate gene nodes, {t2-t1:0.4f}s to sample positive nodes, {t3-t2:0.4f}s to get k-hop adjs, and {t4-t3:0.4f}s to add patient information')
+        return data        
+
+    def __repr__(self):
+        return '{}(sizes={})'.format(self.__class__.__name__, self.sizes)
+
+
+