Diff of /shepherd/samplers.py [000000] .. [db6163]

Switch to unified view

a b/shepherd/samplers.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
from torch import Tensor
5
from torch_sparse import SparseTensor
6
from torch_cluster import random_walk
7
from torch_geometric.data.sampler import EdgeIndex, Adj
8
from torch.nn.utils.rnn import pad_sequence
9
from torch.utils.data import Dataset
10
from torch_geometric.utils import add_self_loops, add_remaining_self_loops
11
from torch_geometric.data import Data, DataLoader, NeighborSampler
12
13
from typing import List, Optional, Tuple, NamedTuple, Union, Callable, Dict
14
from collections import defaultdict
15
import time
16
import random
17
import pickle
18
from collections import Counter
19
from operator import itemgetter
20
import copy
21
import numpy as np
22
from utils.pretrain_utils import get_indices_into_edge_index, HeterogeneousEdgeIndex 
23
from sklearn.preprocessing import label_binarize
24
25
import project_config
26
27
28
class NeighborSampler(torch.utils.data.DataLoader):
29
    r"""The neighbor sampler from the `"Inductive Representation Learning on
30
    Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper, which allows
31
    for mini-batch training of GNNs on large-scale graphs where full-batch
32
    training is not feasible.
33
    Given a GNN with :math:`L` layers and a specific mini-batch of nodes
34
    :obj:`node_idx` for which we want to compute embeddings, this module
35
    iteratively samples neighbors and constructs bipartite graphs that simulate
36
    the actual computation flow of GNNs.
37
    More specifically, :obj:`sizes` denotes how much neighbors we want to
38
    sample for each node in each layer.
39
    This module then takes in these :obj:`sizes` and iteratively samples
40
    :obj:`sizes[l]` for each node involved in layer :obj:`l`.
41
    In the next layer, sampling is repeated for the union of nodes that were
42
    already encountered.
43
    The actual computation graphs are then returned in reverse-mode, meaning
44
    that we pass messages from a larger set of nodes to a smaller one, until we
45
    reach the nodes for which we originally wanted to compute embeddings.
46
    Hence, an item returned by :class:`NeighborSampler` holds the current
47
    :obj:`batch_size`, the IDs :obj:`n_id` of all nodes involved in the
48
    computation, and a list of bipartite graph objects via the tuple
49
    :obj:`(edge_index, e_id, size)`, where :obj:`edge_index` represents the
50
    bipartite edges between source and target nodes, :obj:`e_id` denotes the
51
    IDs of original edges in the full graph, and :obj:`size` holds the shape
52
    of the bipartite graph.
53
    For each bipartite graph, target nodes are also included at the beginning
54
    of the list of source nodes so that one can easily apply skip-connections
55
    or add self-loops.
56
    .. note::
57
        For an example of using :obj:`NeighborSampler`, see
58
        `examples/reddit.py
59
        <https://github.com/rusty1s/pytorch_geometric/blob/master/examples/
60
        reddit.py>`_ or
61
        `examples/ogbn_products_sage.py
62
        <https://github.com/rusty1s/pytorch_geometric/blob/master/examples/
63
        ogbn_products_sage.py>`_.
64
    Args:
65
        edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a
66
            :obj:`torch_sparse.SparseTensor` that defines the underlying graph
67
            connectivity/message passing flow.
68
            :obj:`edge_index` holds the indices of a (sparse) symmetric
69
            adjacency matrix.
70
            If :obj:`edge_index` is of type :obj:`torch.LongTensor`, its shape
71
            must be defined as :obj:`[2, num_edges]`, where messages from nodes
72
            :obj:`edge_index[0]` are sent to nodes in :obj:`edge_index[1]`
73
            (in case :obj:`flow="source_to_target"`).
74
            If :obj:`edge_index` is of type :obj:`torch_sparse.SparseTensor`,
75
            its sparse indices :obj:`(row, col)` should relate to
76
            :obj:`row = edge_index[1]` and :obj:`col = edge_index[0]`.
77
            The major difference between both formats is that we need to input
78
            the *transposed* sparse adjacency matrix.
79
        sizes ([int]): The number of neighbors to sample for each node in each
80
            layer. If set to :obj:`sizes[l] = -1`, all neighbors are included
81
            in layer :obj:`l`.
82
        node_idx (LongTensor, optional): The nodes that should be considered
83
            for creating mini-batches. If set to :obj:`None`, all nodes will be
84
            considered.
85
        num_nodes (int, optional): The number of nodes in the graph.
86
            (default: :obj:`None`)
87
        return_e_id (bool, optional): If set to :obj:`False`, will not return
88
            original edge indices of sampled edges. This is only useful in case
89
            when operating on graphs without edge features to save memory.
90
            (default: :obj:`True`)
91
        transform (callable, optional): A function/transform that takes in
92
            an a sampled mini-batch and returns a transformed version.
93
            (default: :obj:`None`)
94
        **kwargs (optional): Additional arguments of
95
            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
96
            :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
97
    """
98
    def __init__(self, dataset_type: str, edge_index: Union[Tensor, SparseTensor], 
99
                sample_edge_index: Union[Tensor, SparseTensor],
100
                 sizes: List[int],
101
                 node_idx: Optional[Tensor] = None,
102
                 num_nodes: Optional[int] = None, return_e_id: bool = True,
103
                 transform: Callable = None,
104
                 do_filter_edges: bool = True, 
105
                 **kwargs):
106
107
        edge_index = edge_index.to('cpu')
108
        sample_edge_index = sample_edge_index.to('cpu')
109
110
        # add self loops
111
        sample_edge_index, _ = add_self_loops(sample_edge_index)
112
113
114
        if 'collate_fn' in kwargs:
115
            del kwargs['collate_fn']
116
117
        # Save for Pytorch Lightning...
118
        self.dataset_type = dataset_type
119
        self.edge_index = edge_index #always train edge index
120
        self.sample_edge_index = sample_edge_index # depends on train/val/test
121
        self.node_idx = node_idx
122
        self.num_nodes = num_nodes
123
124
        self.sizes = sizes
125
        self.return_e_id = return_e_id
126
        self.transform = transform
127
        self.is_sparse_tensor = isinstance(edge_index, SparseTensor)
128
        self.__val__ = None
129
        self.do_filter_edges = do_filter_edges
130
131
        # Obtain a *transposed* `SparseTensor` instance.
132
        if not self.is_sparse_tensor:
133
            if (num_nodes is None and node_idx is not None
134
                    and node_idx.dtype == torch.bool):
135
                num_nodes = node_idx.size(0)
136
                sample_num_nodes = num_nodes
137
            if (num_nodes is None and node_idx is not None
138
                    and node_idx.dtype == torch.long):
139
                num_nodes = max(int(edge_index.max()), int(node_idx.max())) + 1
140
                sample_num_nodes = num_nodes
141
            if num_nodes is None:
142
                num_nodes = int(edge_index.max()) + 1
143
                sample_num_nodes = int(sample_edge_index.max()) + 1
144
145
            value = torch.arange(edge_index.size(1)) if return_e_id else None
146
            sample_value = torch.arange(sample_edge_index.size(1)) if return_e_id else None
147
            self.adj_t = SparseTensor(row=edge_index[0], col=edge_index[1],
148
                                      value=value,
149
                                      sparse_sizes=(num_nodes, num_nodes)).t()
150
            self.adj_t_sample = SparseTensor(row=sample_edge_index[0], col=sample_edge_index[1],
151
                                      value=sample_value,
152
                                      sparse_sizes=(sample_num_nodes, sample_num_nodes)).t()
153
        else:
154
            adj_t = edge_index
155
            adj_t_sample = sample_edge_index
156
            if return_e_id:
157
                self.__val__ = adj_t.storage.value()
158
                value = torch.arange(adj_t.nnz())
159
                adj_t = adj_t.set_value(value, layout='coo')
160
                adj_t_sample = adj_t_sample.set_value(torch.arange(adj_t_sample.nnz()), layout='coo')
161
            self.adj_t = adj_t
162
            self.adj_t_sample = adj_t_sample
163
164
        self.adj_t.storage.rowptr()
165
        self.adj_t_sample.storage.rowptr()
166
167
        if node_idx is None:
168
            node_idx = torch.arange(self.adj_t_sample.sparse_size(0)) 
169
        elif node_idx.dtype == torch.bool:
170
            node_idx = node_idx.nonzero(as_tuple=False).view(-1)
171
172
        super(NeighborSampler, self).__init__(
173
            node_idx.view(-1).tolist(), collate_fn=self.sample, **kwargs)
174
175
    
176
177
    def filter_edges(self, edge_index, e_id, source_nodes, target_nodes):
178
        '''
179
        Filter out the edges we're trying to predict in the current batch from the edge index
180
        NOTE: edge_index here is re-indexed
181
        '''
182
        reindex_source_nodes = torch.arange(source_nodes.size(0))
183
        reindex_target_nodes = torch.arange(start = source_nodes.size(0), end = source_nodes.size(0) + target_nodes.size(0))
184
185
        # get reverse edges to filter as well
186
        all_source_nodes = torch.cat([reindex_source_nodes, reindex_target_nodes])
187
        all_target_nodes = torch.cat([reindex_target_nodes, reindex_source_nodes])
188
        ind_to_edge_index, ind_to_nodes = get_indices_into_edge_index(edge_index, all_source_nodes, all_target_nodes) #get index into the original edge index (this returns e_ids)
189
        mask = torch.ones(edge_index.size(1), dtype=torch.bool)
190
        mask[ind_to_edge_index] = False
191
192
        return edge_index[:, mask], e_id[mask]
193
194
195
    def sample(self, source_batch):
196
        
197
        #convert to tensor
198
        if not isinstance(source_batch, Tensor):
199
            source_batch = torch.tensor(source_batch)
200
201
        # sample nodes to form positive edges. we will try to predict these edges
202
        row, col, e_id = self.adj_t_sample.coo()    
203
        target_batch = random_walk(row, col, source_batch, walk_length=1, coalesced=False)[:, 1] #NOTE: only does self loops when no edges in the current partition of the dataset
204
        batch = torch.cat([source_batch, target_batch], dim=0) 
205
206
        batch_size: int = len(batch)
207
        adjs = []
208
        n_id = batch
209
        for size in self.sizes:
210
            adj_t, n_id = self.adj_t.sample_adj(n_id, size, replace=False) 
211
            e_id = adj_t.storage.value()
212
            size = adj_t.sparse_sizes()[::-1]
213
            if self.__val__ is not None:
214
                adj_t.set_value_(self.__val__[e_id], layout='coo')
215
216
            if self.is_sparse_tensor: #TODO: implement filter_edges if sparse tensor
217
                adjs.append(Adj(adj_t, e_id, size))
218
            else:
219
                row, col, _ = adj_t.coo()
220
                edge_index = torch.stack([col, row], dim=0)
221
222
                if self.do_filter_edges and self.dataset_type == 'train':
223
                    edge_index, e_id = self.filter_edges(edge_index, e_id, source_batch, target_batch)
224
                adjs.append(EdgeIndex(edge_index, e_id, size))
225
226
        adjs = adjs[0] if len(adjs) == 1 else adjs[::-1]
227
        out = (batch_size, n_id, adjs)
228
        out = self.transform(*out) if self.transform is not None else out
229
        return out
230
231
    def __repr__(self):
232
        return '{}(sizes={})'.format(self.__class__.__name__, self.sizes)
233
234
class PatientNeighborSampler(torch.utils.data.DataLoader):
235
   
236
    def __init__(self, dataset_type: str, edge_index: Union[Tensor, SparseTensor], 
237
                 sample_edge_index: Union[Tensor, SparseTensor],
238
                 sizes: List[int], 
239
                 patient_dataset,
240
                 all_edge_attributes,
241
                 n_nodes: int,
242
                 relevant_node_idx = None,
243
                 do_filter_edges: Optional[bool] = False,
244
                 num_nodes: Optional[int] = None, 
245
                 return_e_id: bool = True,
246
                 sparse_sample: Optional[int] = 0,
247
                 train_phenotype_counter: Dict = None,
248
                 train_gene_counter: Dict = None,
249
                 sample_edges_from_train_patients=False,
250
                 upsample_cand: Optional[int] = 0,
251
                 n_cand_diseases=-1,
252
                 use_diseases=False,
253
                 nid_to_spl_dict = None,
254
                 gp_spl = None,
255
                 spl_indexing_dict=None,
256
257
                 gene_similarity_dict=None,
258
                 gene_deg_dict = None,
259
260
                 hparams=None,
261
                 transform: Callable = None, 
262
                 **kwargs):
263
264
        edge_index = edge_index.to('cpu')
265
        sample_edge_index = sample_edge_index.to('cpu')
266
267
        # add self loops
268
        sample_edge_index = torch.cat((sample_edge_index, torch.stack([edge_index.unique(), edge_index.unique()])),1 )
269
        sample_edge_index, _ = add_remaining_self_loops(sample_edge_index)
270
271
        if 'collate_fn' in kwargs:
272
            del kwargs['collate_fn']
273
274
        # Save for Pytorch Lightning...
275
        self.do_filter_edges = do_filter_edges
276
        self.relevant_node_idx = relevant_node_idx
277
        self.n_nodes = n_nodes
278
        self.all_edge_attr = all_edge_attributes
279
        self.dataset_type = dataset_type
280
        self.sparse_sample = sparse_sample
281
        self.edge_index = edge_index #always train edge index
282
        self.sample_edge_index = sample_edge_index # depends on train/val/test
283
        self.patient_dataset = patient_dataset
284
        self.num_nodes = num_nodes
285
        self.train_phenotype_counter = train_phenotype_counter
286
        self.train_gene_counter = train_gene_counter
287
        self.sample_edges_from_train_patients = sample_edges_from_train_patients
288
        self.sizes = sizes
289
        self.return_e_id = return_e_id
290
        self.transform = transform
291
        self.is_sparse_tensor = isinstance(edge_index, SparseTensor)
292
        self.__val__ = None
293
294
        # For SPL
295
        self.nid_to_spl_dict = nid_to_spl_dict 
296
        if hparams["alpha"] < 1: self.gp_spl = gp_spl
297
        else: self.gp_spl = None
298
        self.spl_indexing_dict = spl_indexing_dict
299
300
        # Up-sample candidate genes
301
        self.upsample_cand = upsample_cand
302
        self.cand_gene_freq = Counter([])
303
        with open(str(project_config.KG_DIR  / f'ensembl_to_idx_dict_{project_config.CURR_KG}.pkl'), 'rb') as handle:
304
            ensembl_to_idx_dict = pickle.load(handle) # create ensembl to node_idx map
305
        idx_to_ensembl_dict = {v: k for k, v in ensembl_to_idx_dict.items()}
306
        self.cand_gene_freq = Counter([k for k in nid_to_spl_dict if k in idx_to_ensembl_dict]) # Upsample from all gene nodes in the KG
307
        
308
        self.n_cand_diseases = n_cand_diseases
309
        self.use_diseases = use_diseases
310
        self.hparams = hparams
311
312
        self.gene_similarity_dict = gene_similarity_dict
313
        self.gene_deg_dict = gene_deg_dict
314
315
        # Obtain a *transposed* `SparseTensor` instance.
316
        if not self.is_sparse_tensor:
317
            if num_nodes is None:
318
                num_nodes = int(edge_index.max()) + 1
319
                sample_num_nodes = int(sample_edge_index.max()) + 1
320
321
            value = torch.arange(edge_index.size(1)) if return_e_id else None
322
            sample_value = torch.arange(sample_edge_index.size(1)) if return_e_id else None
323
            self.adj_t = SparseTensor(row=edge_index[0], col=edge_index[1],
324
                                      value=value,
325
                                      sparse_sizes=(num_nodes, num_nodes)).t()
326
            self.adj_t_sample = SparseTensor(row=sample_edge_index[0], col=sample_edge_index[1],
327
                                      value=sample_value,
328
                                      sparse_sizes=(sample_num_nodes, sample_num_nodes)).t()
329
        else:
330
            adj_t = edge_index
331
            adj_t_sample = sample_edge_index
332
            if return_e_id:
333
                self.__val__ = adj_t.storage.value()
334
                value = torch.arange(adj_t.nnz())
335
                adj_t = adj_t.set_value(value, layout='coo')
336
                adj_t_sample = adj_t_sample.set_value(torch.arange(adj_t_sample.nnz()), layout='coo')
337
            self.adj_t = adj_t
338
            self.adj_t_sample = adj_t_sample
339
340
        self.adj_t.storage.rowptr()
341
        self.adj_t_sample.storage.rowptr()
342
343
344
345
        super(PatientNeighborSampler, self).__init__(
346
            self.patient_dataset, collate_fn=self.collate, **kwargs)
347
348
    def filter_edges(self, edge_index, e_id, source_nodes, target_nodes):
349
        '''
350
        Filter out the edges we're trying to predict in the current batch from the edge index
351
        NOTE: edge_index here is re-indexed
352
        '''
353
        reindex_source_nodes = torch.arange(source_nodes.size(0))
354
        reindex_target_nodes = torch.arange(start = source_nodes.size(0), end = source_nodes.size(0) + target_nodes.size(0))
355
356
        # get reverse edges to filter as well
357
        all_source_nodes = torch.cat([reindex_source_nodes, reindex_target_nodes])
358
        all_target_nodes = torch.cat([reindex_target_nodes, reindex_source_nodes])
359
        ind_to_edge_index, ind_to_nodes = get_indices_into_edge_index(edge_index, all_source_nodes, all_target_nodes) #get index into the original edge index (this returns e_ids)
360
        mask = torch.ones(edge_index.size(1), dtype=torch.bool)
361
        mask[ind_to_edge_index] = False
362
363
        return edge_index[:, mask], e_id[mask]
364
365
    def get_source_nodes(self, phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, disease_node_idx, candidate_disease_node_idx, sim_gene_node_idx): 
366
        
367
        # Get batch node indices based on patient phenotypes and genes
368
        if sim_gene_node_idx is not None:
369
            source_batch = torch.cat(phenotype_node_idx +  candidate_gene_node_idx +  correct_genes_node_idx + disease_node_idx + candidate_disease_node_idx + sim_gene_node_idx)
370
        else:
371
            source_batch = torch.cat(phenotype_node_idx +  candidate_gene_node_idx +  correct_genes_node_idx + disease_node_idx + candidate_disease_node_idx)
372
373
         # Randomly sample nodes in KG 
374
        if self.sparse_sample > 0:
375
            if self.relevant_node_idx == None:
376
                rand_idx = torch.randint(high=self.n_nodes, size=(self.sparse_sample,)) # NOTE that this can sample duplicates, but has the benefit of randomly sampling new nodes each epoch
377
            else:
378
                rand_idx = self.relevant_node_idx[torch.randint(high=self.relevant_node_idx.size(0), size=(self.sparse_sample,))]
379
            
380
            source_batch = torch.cat([source_batch, rand_idx])
381
            source_batch = torch.unique(source_batch)
382
            sparse_idx = torch.unique(rand_idx)
383
        else:
384
            source_batch = torch.unique(source_batch)
385
            sparse_idx = torch.Tensor([])
386
387
        return source_batch, sparse_idx
388
389
    def sample_target_nodes(self, source_batch):
390
        row, col, e_id = self.adj_t_sample.coo() 
391
        
392
        if self.sample_edges_from_train_patients:
393
            train_patient_nodes = torch.tensor(list(self.train_phenotype_counter.keys()) + list(self.train_gene_counter.keys())) 
394
            ind_with_train_patient_nodes = (col == train_patient_nodes.unsqueeze(-1)).nonzero(as_tuple=True)[1]
395
            subset_row = row[ind_with_train_patient_nodes]
396
            subset_col = col[ind_with_train_patient_nodes]
397
            try:
398
                # first try to find an edge that connects back to the training set patient data
399
                targets = random_walk(subset_row, subset_col, source_batch, walk_length=1, coalesced=False)[:, 1] #NOTE: only does self loops when no edges in the current partition of the dataset
400
                source_batch_1 = source_batch[~torch.eq(source_batch, targets)]
401
                targets_1 = targets[~torch.eq(source_batch, targets)]
402
403
                # if no edges are found, use all available edges in this split of the data
404
                source_batch_2 = source_batch[torch.eq(source_batch, targets)]
405
                targets_2 = random_walk(row, col, source_batch_2, walk_length=1, coalesced=False)[:, 1] #NOTE: only does self loops when no edges in the current partition of the dataset
406
407
                #concat the two together
408
                source_batch = torch.cat([source_batch_1, source_batch_2])
409
                targets = torch.cat([targets_1, targets_2])
410
411
            except:
412
                targets = random_walk(row, col, source_batch, walk_length=1, coalesced=False)[:, 1] #NOTE: only does self loops when no edges in the current partition of the dataset
413
        else:
414
            targets = random_walk(row, col, source_batch, walk_length=1, coalesced=False)[:, 1] #NOTE: only does self loops when no edges in the current partition of the dataset
415
        return source_batch, targets
416
417
    def add_patient_information(self, patient_ids, phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, sim_gene_node_idx, gene_sims, gene_degs, disease_node_idx, candidate_disease_node_idx, labels, disease_labels, patient_labels, additional_labels, adjs, batch_size, n_id, sparse_idx, target_batch): #candidate_disease_node_idx
418
419
        # Create Data Object & Add patient level information
420
        adjs = [HeterogeneousEdgeIndex(adj.edge_index, adj.e_id, self.all_edge_attr[adj.e_id], adj.size) for adj in adjs] 
421
        max_n_candidates = max([len(l) for l in candidate_gene_node_idx])
422
        data = Data(adjs = adjs, 
423
                batch_size = batch_size,
424
                patient_ids = patient_ids,
425
                n_id = n_id
426
                )
427
        if self.hparams['loss'] != 'patient_disease_NCA' and self.hparams['loss'] != 'patient_patient_NCA':
428
            if None in list(labels): data['one_hot_labels'] = None
429
            else: data['one_hot_labels'] = torch.LongTensor(label_binarize(labels, classes = list(range(max_n_candidates))))
430
431
        if self.use_diseases:
432
            data['disease_one_hot_labels'] = disease_labels 
433
434
        if self.hparams['loss'] == 'patient_patient_NCA':
435
            if patient_labels is None: data['patient_labels'] = None
436
            else: data['patient_labels'] = torch.stack(patient_labels)
437
438
        # Get candidate genes to phenotypes SPL
439
        if not self.gp_spl is None:
440
            if not self.spl_indexing_dict is None:
441
                patient_ids = np.vectorize(self.spl_indexing_dict.get)(patient_ids).astype(int)
442
            gene_to_phenotypes_spl = -torch.Tensor(self.gp_spl[patient_ids,:])
443
            # get gene idx to spl information
444
            cand_gene_idx_to_spl = [torch.LongTensor(np.vectorize(self.nid_to_spl_dict.get)(cand_genes)) for cand_genes in list(candidate_gene_node_idx)]
445
            # get SPLs for each patient's candidate genes
446
            batch_cand_gene_to_phenotypes_spl = [gene_spls[cand_genes] for cand_genes, gene_spls in zip(cand_gene_idx_to_spl, gene_to_phenotypes_spl)]
447
            # pad to same # of candidate genes
448
            data['batch_cand_gene_to_phenotypes_spl'] = pad_sequence(batch_cand_gene_to_phenotypes_spl, batch_first=True, padding_value=0)
449
            # get unique gene idx across all patients in the batch
450
            cand_gene_idx_flattened_unique = torch.unique(torch.cat(cand_gene_idx_to_spl)).flatten()
451
            # get SPLs for unique genes in the batch
452
            data['batch_concat_cand_gene_to_phenotypes_spl'] = gene_to_phenotypes_spl[:, cand_gene_idx_flattened_unique]
453
        else:
454
            data['batch_cand_gene_to_phenotypes_spl'] = None
455
            data['batch_concat_cand_gene_to_phenotypes_spl'] = None
456
457
458
        # Create mapping from KG node IDs to batch indices
459
        node2batch = {n+1: int(i+1) for i, n in enumerate(data.n_id.tolist())}
460
        node2batch[0] = 0
461
462
        # add phenotype / gene / disease names
463
        data['phenotype_names'] = [[(self.patient_dataset.node_idx_to_name(p.item()), self.patient_dataset.node_idx_to_degree(p.item())) for p in p_list] for p_list in phenotype_node_idx ]
464
        data['cand_gene_names'] = [[self.patient_dataset.node_idx_to_name(g.item()) for g in g_list] for g_list in candidate_gene_node_idx ]
465
        data['corr_gene_names'] = [[self.patient_dataset.node_idx_to_name(g.item()) for g in g_list] for g_list in correct_genes_node_idx  ]
466
        data['disease_names'] = [[self.patient_dataset.node_idx_to_name(d.item()) for d in d_list] for d_list in disease_node_idx ]
467
468
        if self.use_diseases:
469
            data['cand_disease_names'] = [[self.patient_dataset.node_idx_to_name(d.item()) for d in d_list] for d_list in candidate_disease_node_idx ]
470
471
472
        #reindex nodes to make room for padding
473
        phenotype_node_idx = [p + 1 for p in phenotype_node_idx]
474
        candidate_gene_node_idx = [g + 1 for g in candidate_gene_node_idx]
475
        correct_genes_node_idx = [g + 1 for g in correct_genes_node_idx]
476
        if self.use_diseases:
477
            disease_node_idx = [d + 1 for d in disease_node_idx]
478
            candidate_disease_node_idx = [d + 1 for d in candidate_disease_node_idx]
479
        if 'augment_genes' in self.hparams and self.hparams['augment_genes']:
480
            sim_gene_node_idx = [g + 1 for g in sim_gene_node_idx]
481
482
        # if there aren't any disease idx in the batch, we add filler
483
        if self.use_diseases:
484
            if all(len(t) == 0 for t in disease_node_idx):
485
                disease_node_idx = [torch.LongTensor([0]) for i in range(len(disease_node_idx))]
486
            if all(len(t) == 0 for t in candidate_disease_node_idx):
487
                candidate_disease_node_idx = [torch.LongTensor([0]) for i in range(len(candidate_disease_node_idx))]
488
489
        # add padding to patient phenotype and gene node idx
490
        data['batch_pheno_nid'] = pad_sequence(phenotype_node_idx, batch_first=True, padding_value=0) 
491
        if len(candidate_gene_node_idx[0]) > 0:
492
            data['batch_cand_gene_nid'] = pad_sequence(candidate_gene_node_idx, batch_first=True, padding_value=0) 
493
        data['batch_corr_gene_nid'] = pad_sequence(correct_genes_node_idx, batch_first=True, padding_value=0) 
494
        if self.use_diseases:
495
            data['batch_disease_nid'] = pad_sequence(disease_node_idx, batch_first=True, padding_value=0) 
496
            data['batch_cand_disease_nid'] = pad_sequence(candidate_disease_node_idx, batch_first=True, padding_value=0) 
497
        if 'augment_genes' in self.hparams and self.hparams['augment_genes']:
498
            data['batch_cand_gene_degs'] = pad_sequence(gene_degs, batch_first=True, padding_value=0) 
499
            data['batch_sim_gene_nid'] = pad_sequence(sim_gene_node_idx, batch_first=True, padding_value=0) 
500
            data['batch_sim_gene_sims'] = pad_sequence(gene_sims, batch_first=True, padding_value=0)
501
            # Normalize
502
            data['batch_sim_gene_sims'] = data['batch_sim_gene_sims'] / torch.sum(data['batch_sim_gene_sims'], dim=1, keepdim=True)
503
        else:
504
            if len(candidate_gene_node_idx[0]) > 0:
505
                data['batch_cand_gene_nid'] = pad_sequence(candidate_gene_node_idx, batch_first=True, padding_value=0) 
506
507
        # Convert KG node IDs to batch IDs
508
        # When performing inference (i.e., predict.py), use the original node IDs because the full KG is used in forward pass of node model
509
        if self.dataset_type != "predict":
510
            data['batch_pheno_nid']  = torch.LongTensor(np.vectorize(node2batch.get)(data['batch_pheno_nid']))
511
            if len(candidate_gene_node_idx[0]) > 0:
512
                data['batch_cand_gene_nid'] = torch.LongTensor(np.vectorize(node2batch.get)(data['batch_cand_gene_nid']))
513
            if len(correct_genes_node_idx[0]) > 0:
514
                data['batch_corr_gene_nid'] = torch.LongTensor(np.vectorize(node2batch.get)(data['batch_corr_gene_nid']))
515
            if self.use_diseases:
516
                data['batch_disease_nid'] = torch.LongTensor(np.vectorize(node2batch.get)(data['batch_disease_nid']))
517
                data['batch_cand_disease_nid'] = torch.LongTensor(np.vectorize(node2batch.get)(data['batch_cand_disease_nid']))
518
            if 'augment_genes' in self.hparams and self.hparams['augment_genes']:
519
                data['batch_sim_gene_nid'] = torch.LongTensor(np.vectorize(node2batch.get)(data['batch_sim_gene_nid']))
520
        return data
521
522
    def get_candidate_diseases(self, disease_node_idx, candidate_gene_node_idx):
523
        cand_diseases = self.patient_dataset.get_candidate_diseases(cand_type=self.hparams['candidate_disease_type'])
524
        if self.n_cand_diseases != -1: cand_diseases = cand_diseases[torch.randperm(len(cand_diseases))][0:self.n_cand_diseases] 
525
        
526
        if self.hparams['only_hard_distractors']: #add candidates to every patient
527
            candidate_disease_node_idx = tuple(torch.unique(torch.cat([corr_dis, cand_diseases ]), sorted=False) for corr_dis in disease_node_idx)
528
            candidate_disease_node_idx = tuple(torch.unique(dis[torch.randperm(len(dis))], sorted=False, return_inverse=False, return_counts=False) for dis in candidate_disease_node_idx)
529
        else: # split candidates across all patients in the batch
530
            all_correct_diseases = torch.cat(disease_node_idx)
531
            all_diseases = torch.unique(torch.cat([all_correct_diseases, cand_diseases]))
532
            all_diseases = all_diseases[torch.randperm(len(all_diseases))]
533
            candidate_disease_node_idx = np.array_split(all_diseases, len(candidate_gene_node_idx))
534
            candidate_disease_node_idx = tuple(candidate_disease_node_idx)
535
        max_n_dis_candidates = max([len(l) for l in candidate_disease_node_idx])
536
        if max_n_dis_candidates == 0: 
537
            max_n_dis_candidates = 1
538
            print('WARNING: there are no disease candidates')
539
540
        disease_ind = [(dis.unsqueeze(1) == corr_dis.unsqueeze(0)).nonzero(as_tuple=True)[0] if len(corr_dis) > 0 else torch.tensor(-1) for dis, corr_dis in zip(candidate_disease_node_idx, disease_node_idx)]
541
        disease_labels = torch.zeros((len(candidate_disease_node_idx), max_n_dis_candidates))
542
        for i, ind in enumerate(disease_ind): disease_labels[i,ind[ind != -1]] = 1
543
        return candidate_disease_node_idx, disease_labels
544
545
    def get_candidate_patients(self, patient_ids):
546
        # get patients with the same disease/gene
547
        similar_pat_ids = [self.patient_dataset.get_similar_patients(p_id, similarity_type=self.hparams['patient_similarity_type']) for p_id in patient_ids]
548
        # shuffle patients & subset to n_sim_pats so we have X similar patients per patient in the batch
549
        similar_pat_ids = [p[:self.hparams['n_similar_patients']] for p in similar_pat_ids] #[torch.randperm(len(p))]
550
        # Retrieve the patients for each of the sampled patient ids if they aren't already in the batch
551
        patient_ids = list(patient_ids) 
552
        similar_pats = [self.patient_dataset[self.patient_dataset.patient_id_to_index[p_id.item()]] for p_ids in similar_pat_ids for p_id in p_ids if p_id.item() not in patient_ids]
553
        return similar_pats
554
    
555
    def sample(self, batch, source_batch, target_batch):
556
        batch_size: int = len(batch)
557
        adjs = []
558
        n_id = batch
559
        for size in self.sizes:
560
561
            adj_t, n_id = self.adj_t.sample_adj(n_id, size, replace=False)
562
            e_id = adj_t.storage.value()
563
            size = adj_t.sparse_sizes()[::-1]
564
            if self.__val__ is not None:
565
                adj_t.set_value_(self.__val__[e_id], layout='coo')
566
567
            if self.is_sparse_tensor: #TODO: implement filter_edges if sparse tensor
568
                adjs.append(Adj(adj_t, e_id, size))
569
            else:
570
                row, col, _ = adj_t.coo()
571
                edge_index = torch.stack([col, row], dim=0)
572
                if self.do_filter_edges and self.dataset_type == 'train':
573
                    edge_index, e_id = self.filter_edges(edge_index, e_id, source_batch, target_batch)
574
                adjs.append(EdgeIndex(edge_index, e_id, size))
575
576
        adjs = [adjs[0]] if len(adjs) == 1 else adjs[::-1]
577
        return adjs, batch_size, n_id
578
    
579
    def get_similar_genes(self, patient_ids, candidate_gene_node_idx):
580
        k = self.hparams['n_sim_genes']
581
        gene_ids = []
582
        sims = []
583
        degs = []
584
        assert len(patient_ids) == len(candidate_gene_node_idx)
585
        for p, p_cand_genes in zip(patient_ids, candidate_gene_node_idx):
586
            p_genes = []
587
            p_sims = []
588
            p_degs = []
589
            for g in p_cand_genes:
590
                p_genes.append(torch.LongTensor([idx for idx, sim in list(self.gene_similarity_dict[int(g)])[:k]]))
591
                p_sims.append(torch.LongTensor([sim for idx, sim in list(self.gene_similarity_dict[int(g)])[:k]]))
592
                p_degs.append(self.gene_deg_dict[int(g)])
593
            gene_ids.append(torch.stack(p_genes))
594
            sims.append(torch.stack(p_sims))
595
            degs.append(torch.LongTensor(p_degs))
596
        assert len(gene_ids) == len(patient_ids)
597
        assert len(sims) == len(patient_ids)
598
        unique_genes = torch.unique(torch.cat(gene_ids).flatten()).unsqueeze(-1)
599
        return tuple(gene_ids), tuple(sims), tuple(degs), tuple(unique_genes)
600
601
    def collate(self, batch):
602
        t00 = time.time()
603
        phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, disease_node_idx, labels, additional_labels, patient_ids = zip(*batch)
604
605
        # Up-sample under-represented candidate genes
606
        t0 = time.time()
607
        if self.upsample_cand > 0:
608
            curr_cand_gene_freq = Counter(torch.cat(candidate_gene_node_idx).flatten().tolist())
609
            self.cand_gene_freq += curr_cand_gene_freq
610
            num_patients = len(candidate_gene_node_idx) * self.upsample_cand
611
            lowest_k_cand = self.cand_gene_freq.most_common()[:-num_patients-1:-1]
612
            lowest_k_cand = np.array_split([g[0] for g in lowest_k_cand], len(candidate_gene_node_idx))
613
            
614
            upsampled_candidate_gene_node_idx = []
615
            added_cand_gene = []
616
            for patient, cand_gene, corr_gene_idx in zip(candidate_gene_node_idx, lowest_k_cand, labels):
617
                
618
                # Remove correct genes from list of upsampled candidate genes
619
                corr_gene_nid = patient[corr_gene_idx]
620
                cand_gene = cand_gene[~np.isin(cand_gene, corr_gene_nid)].flatten()
621
                
622
                # Remove duplicates
623
                unique_cand_genes, new_cand_genes_freq = torch.unique(torch.tensor(patient.tolist() + list(cand_gene)), return_counts = True)
624
                unique_cand_genes = unique_cand_genes[new_cand_genes_freq == 1]
625
                cand_gene = cand_gene[np.isin(cand_gene, unique_cand_genes)]                
626
                
627
                # Add upsampled candidate genes
628
                added_cand_gene.extend(list(cand_gene))
629
                new_cand_list = torch.tensor(patient.tolist() + list(cand_gene))
630
                upsampled_candidate_gene_node_idx.append(new_cand_list)
631
            
632
            candidate_gene_node_idx = tuple(upsampled_candidate_gene_node_idx)
633
            self.cand_gene_freq += Counter(added_cand_gene)
634
635
        
636
        # Add similar patients to batch (for "patients like me" head)
637
        if self.hparams['add_similar_patients']:
638
            similar_pats = self.get_candidate_patients(patient_ids)
639
            # merge original batch with sampled patients
640
            phenotype_node_idx_sim, candidate_gene_node_idx_sim, correct_genes_node_idx_sim, disease_node_idx_sim, labels_sim, additional_labels_sim, patient_ids_sim = zip(*similar_pats)
641
            phenotype_node_idx = phenotype_node_idx + phenotype_node_idx_sim
642
            candidate_gene_node_idx = candidate_gene_node_idx + candidate_gene_node_idx_sim
643
            correct_genes_node_idx = correct_genes_node_idx + correct_genes_node_idx_sim
644
            disease_node_idx = disease_node_idx + disease_node_idx_sim
645
            labels = labels + labels_sim
646
            additional_labels = additional_labels + additional_labels_sim
647
            patient_ids = patient_ids + patient_ids_sim
648
        
649
        # get patient labels
650
        patient_labels = correct_genes_node_idx
651
        
652
        # Add candidate diseases to batch
653
        if self.hparams['add_cand_diseases']:
654
            candidate_disease_node_idx, disease_labels = self.get_candidate_diseases(disease_node_idx, candidate_gene_node_idx)
655
        else: 
656
            candidate_disease_node_idx = disease_node_idx
657
            disease_labels = torch.tensor([1] * len(candidate_disease_node_idx))
658
659
        if self.hparams['augment_genes']:
660
            sim_gene_node_idx, gene_sims, gene_degs, unique_sim_genes = self.get_similar_genes(patient_ids, candidate_gene_node_idx)
661
        else:
662
            unique_sim_genes = gene_degs = gene_sims = sim_gene_node_idx = None
663
664
        t1 = time.time()
665
666
        # get nodes from patients + randomly sampled nodes
667
        source_batch, sparse_idx = self.get_source_nodes(phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, disease_node_idx, candidate_disease_node_idx, unique_sim_genes)
668
       
669
        # sample nodes to form positive edges
670
        source_batch, target_batch = self.sample_target_nodes(source_batch) 
671
        batch = torch.cat([source_batch, target_batch], dim=0) 
672
        t2 = time.time()
673
674
        # get k hop adj graph
675
        adjs, batch_size, n_id = self.sample(batch, source_batch, target_batch)
676
        t3 = time.time()
677
678
        # add patient information to data object
679
        data = self.add_patient_information(patient_ids, phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, sim_gene_node_idx, gene_sims, gene_degs, disease_node_idx, candidate_disease_node_idx, labels, disease_labels, patient_labels, additional_labels, adjs, batch_size, n_id, sparse_idx, target_batch) #candidate_disease_node_idx
680
        t4 = time.time()
681
        
682
        if self.hparams['time']:
683
            print(f'It takes {t0-t00:0.4f}s to unzip batch, {t1-t0:0.4f}s to upsample candidate gene nodes, {t2-t1:0.4f}s to sample positive nodes, {t3-t2:0.4f}s to get k-hop adjs, and {t4-t3:0.4f}s to add patient information')
684
        return data        
685
686
    def __repr__(self):
687
        return '{}(sizes={})'.format(self.__class__.__name__, self.sizes)
688
689
690