|
a |
|
b/shepherd/samplers.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import torch.nn.functional as F |
|
|
4 |
from torch import Tensor |
|
|
5 |
from torch_sparse import SparseTensor |
|
|
6 |
from torch_cluster import random_walk |
|
|
7 |
from torch_geometric.data.sampler import EdgeIndex, Adj |
|
|
8 |
from torch.nn.utils.rnn import pad_sequence |
|
|
9 |
from torch.utils.data import Dataset |
|
|
10 |
from torch_geometric.utils import add_self_loops, add_remaining_self_loops |
|
|
11 |
from torch_geometric.data import Data, DataLoader, NeighborSampler |
|
|
12 |
|
|
|
13 |
from typing import List, Optional, Tuple, NamedTuple, Union, Callable, Dict |
|
|
14 |
from collections import defaultdict |
|
|
15 |
import time |
|
|
16 |
import random |
|
|
17 |
import pickle |
|
|
18 |
from collections import Counter |
|
|
19 |
from operator import itemgetter |
|
|
20 |
import copy |
|
|
21 |
import numpy as np |
|
|
22 |
from utils.pretrain_utils import get_indices_into_edge_index, HeterogeneousEdgeIndex |
|
|
23 |
from sklearn.preprocessing import label_binarize |
|
|
24 |
|
|
|
25 |
import project_config |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
class NeighborSampler(torch.utils.data.DataLoader): |
|
|
29 |
r"""The neighbor sampler from the `"Inductive Representation Learning on |
|
|
30 |
Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper, which allows |
|
|
31 |
for mini-batch training of GNNs on large-scale graphs where full-batch |
|
|
32 |
training is not feasible. |
|
|
33 |
Given a GNN with :math:`L` layers and a specific mini-batch of nodes |
|
|
34 |
:obj:`node_idx` for which we want to compute embeddings, this module |
|
|
35 |
iteratively samples neighbors and constructs bipartite graphs that simulate |
|
|
36 |
the actual computation flow of GNNs. |
|
|
37 |
More specifically, :obj:`sizes` denotes how much neighbors we want to |
|
|
38 |
sample for each node in each layer. |
|
|
39 |
This module then takes in these :obj:`sizes` and iteratively samples |
|
|
40 |
:obj:`sizes[l]` for each node involved in layer :obj:`l`. |
|
|
41 |
In the next layer, sampling is repeated for the union of nodes that were |
|
|
42 |
already encountered. |
|
|
43 |
The actual computation graphs are then returned in reverse-mode, meaning |
|
|
44 |
that we pass messages from a larger set of nodes to a smaller one, until we |
|
|
45 |
reach the nodes for which we originally wanted to compute embeddings. |
|
|
46 |
Hence, an item returned by :class:`NeighborSampler` holds the current |
|
|
47 |
:obj:`batch_size`, the IDs :obj:`n_id` of all nodes involved in the |
|
|
48 |
computation, and a list of bipartite graph objects via the tuple |
|
|
49 |
:obj:`(edge_index, e_id, size)`, where :obj:`edge_index` represents the |
|
|
50 |
bipartite edges between source and target nodes, :obj:`e_id` denotes the |
|
|
51 |
IDs of original edges in the full graph, and :obj:`size` holds the shape |
|
|
52 |
of the bipartite graph. |
|
|
53 |
For each bipartite graph, target nodes are also included at the beginning |
|
|
54 |
of the list of source nodes so that one can easily apply skip-connections |
|
|
55 |
or add self-loops. |
|
|
56 |
.. note:: |
|
|
57 |
For an example of using :obj:`NeighborSampler`, see |
|
|
58 |
`examples/reddit.py |
|
|
59 |
<https://github.com/rusty1s/pytorch_geometric/blob/master/examples/ |
|
|
60 |
reddit.py>`_ or |
|
|
61 |
`examples/ogbn_products_sage.py |
|
|
62 |
<https://github.com/rusty1s/pytorch_geometric/blob/master/examples/ |
|
|
63 |
ogbn_products_sage.py>`_. |
|
|
64 |
Args: |
|
|
65 |
edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a |
|
|
66 |
:obj:`torch_sparse.SparseTensor` that defines the underlying graph |
|
|
67 |
connectivity/message passing flow. |
|
|
68 |
:obj:`edge_index` holds the indices of a (sparse) symmetric |
|
|
69 |
adjacency matrix. |
|
|
70 |
If :obj:`edge_index` is of type :obj:`torch.LongTensor`, its shape |
|
|
71 |
must be defined as :obj:`[2, num_edges]`, where messages from nodes |
|
|
72 |
:obj:`edge_index[0]` are sent to nodes in :obj:`edge_index[1]` |
|
|
73 |
(in case :obj:`flow="source_to_target"`). |
|
|
74 |
If :obj:`edge_index` is of type :obj:`torch_sparse.SparseTensor`, |
|
|
75 |
its sparse indices :obj:`(row, col)` should relate to |
|
|
76 |
:obj:`row = edge_index[1]` and :obj:`col = edge_index[0]`. |
|
|
77 |
The major difference between both formats is that we need to input |
|
|
78 |
the *transposed* sparse adjacency matrix. |
|
|
79 |
sizes ([int]): The number of neighbors to sample for each node in each |
|
|
80 |
layer. If set to :obj:`sizes[l] = -1`, all neighbors are included |
|
|
81 |
in layer :obj:`l`. |
|
|
82 |
node_idx (LongTensor, optional): The nodes that should be considered |
|
|
83 |
for creating mini-batches. If set to :obj:`None`, all nodes will be |
|
|
84 |
considered. |
|
|
85 |
num_nodes (int, optional): The number of nodes in the graph. |
|
|
86 |
(default: :obj:`None`) |
|
|
87 |
return_e_id (bool, optional): If set to :obj:`False`, will not return |
|
|
88 |
original edge indices of sampled edges. This is only useful in case |
|
|
89 |
when operating on graphs without edge features to save memory. |
|
|
90 |
(default: :obj:`True`) |
|
|
91 |
transform (callable, optional): A function/transform that takes in |
|
|
92 |
an a sampled mini-batch and returns a transformed version. |
|
|
93 |
(default: :obj:`None`) |
|
|
94 |
**kwargs (optional): Additional arguments of |
|
|
95 |
:class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, |
|
|
96 |
:obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. |
|
|
97 |
""" |
|
|
98 |
def __init__(self, dataset_type: str, edge_index: Union[Tensor, SparseTensor], |
|
|
99 |
sample_edge_index: Union[Tensor, SparseTensor], |
|
|
100 |
sizes: List[int], |
|
|
101 |
node_idx: Optional[Tensor] = None, |
|
|
102 |
num_nodes: Optional[int] = None, return_e_id: bool = True, |
|
|
103 |
transform: Callable = None, |
|
|
104 |
do_filter_edges: bool = True, |
|
|
105 |
**kwargs): |
|
|
106 |
|
|
|
107 |
edge_index = edge_index.to('cpu') |
|
|
108 |
sample_edge_index = sample_edge_index.to('cpu') |
|
|
109 |
|
|
|
110 |
# add self loops |
|
|
111 |
sample_edge_index, _ = add_self_loops(sample_edge_index) |
|
|
112 |
|
|
|
113 |
|
|
|
114 |
if 'collate_fn' in kwargs: |
|
|
115 |
del kwargs['collate_fn'] |
|
|
116 |
|
|
|
117 |
# Save for Pytorch Lightning... |
|
|
118 |
self.dataset_type = dataset_type |
|
|
119 |
self.edge_index = edge_index #always train edge index |
|
|
120 |
self.sample_edge_index = sample_edge_index # depends on train/val/test |
|
|
121 |
self.node_idx = node_idx |
|
|
122 |
self.num_nodes = num_nodes |
|
|
123 |
|
|
|
124 |
self.sizes = sizes |
|
|
125 |
self.return_e_id = return_e_id |
|
|
126 |
self.transform = transform |
|
|
127 |
self.is_sparse_tensor = isinstance(edge_index, SparseTensor) |
|
|
128 |
self.__val__ = None |
|
|
129 |
self.do_filter_edges = do_filter_edges |
|
|
130 |
|
|
|
131 |
# Obtain a *transposed* `SparseTensor` instance. |
|
|
132 |
if not self.is_sparse_tensor: |
|
|
133 |
if (num_nodes is None and node_idx is not None |
|
|
134 |
and node_idx.dtype == torch.bool): |
|
|
135 |
num_nodes = node_idx.size(0) |
|
|
136 |
sample_num_nodes = num_nodes |
|
|
137 |
if (num_nodes is None and node_idx is not None |
|
|
138 |
and node_idx.dtype == torch.long): |
|
|
139 |
num_nodes = max(int(edge_index.max()), int(node_idx.max())) + 1 |
|
|
140 |
sample_num_nodes = num_nodes |
|
|
141 |
if num_nodes is None: |
|
|
142 |
num_nodes = int(edge_index.max()) + 1 |
|
|
143 |
sample_num_nodes = int(sample_edge_index.max()) + 1 |
|
|
144 |
|
|
|
145 |
value = torch.arange(edge_index.size(1)) if return_e_id else None |
|
|
146 |
sample_value = torch.arange(sample_edge_index.size(1)) if return_e_id else None |
|
|
147 |
self.adj_t = SparseTensor(row=edge_index[0], col=edge_index[1], |
|
|
148 |
value=value, |
|
|
149 |
sparse_sizes=(num_nodes, num_nodes)).t() |
|
|
150 |
self.adj_t_sample = SparseTensor(row=sample_edge_index[0], col=sample_edge_index[1], |
|
|
151 |
value=sample_value, |
|
|
152 |
sparse_sizes=(sample_num_nodes, sample_num_nodes)).t() |
|
|
153 |
else: |
|
|
154 |
adj_t = edge_index |
|
|
155 |
adj_t_sample = sample_edge_index |
|
|
156 |
if return_e_id: |
|
|
157 |
self.__val__ = adj_t.storage.value() |
|
|
158 |
value = torch.arange(adj_t.nnz()) |
|
|
159 |
adj_t = adj_t.set_value(value, layout='coo') |
|
|
160 |
adj_t_sample = adj_t_sample.set_value(torch.arange(adj_t_sample.nnz()), layout='coo') |
|
|
161 |
self.adj_t = adj_t |
|
|
162 |
self.adj_t_sample = adj_t_sample |
|
|
163 |
|
|
|
164 |
self.adj_t.storage.rowptr() |
|
|
165 |
self.adj_t_sample.storage.rowptr() |
|
|
166 |
|
|
|
167 |
if node_idx is None: |
|
|
168 |
node_idx = torch.arange(self.adj_t_sample.sparse_size(0)) |
|
|
169 |
elif node_idx.dtype == torch.bool: |
|
|
170 |
node_idx = node_idx.nonzero(as_tuple=False).view(-1) |
|
|
171 |
|
|
|
172 |
super(NeighborSampler, self).__init__( |
|
|
173 |
node_idx.view(-1).tolist(), collate_fn=self.sample, **kwargs) |
|
|
174 |
|
|
|
175 |
|
|
|
176 |
|
|
|
177 |
def filter_edges(self, edge_index, e_id, source_nodes, target_nodes): |
|
|
178 |
''' |
|
|
179 |
Filter out the edges we're trying to predict in the current batch from the edge index |
|
|
180 |
NOTE: edge_index here is re-indexed |
|
|
181 |
''' |
|
|
182 |
reindex_source_nodes = torch.arange(source_nodes.size(0)) |
|
|
183 |
reindex_target_nodes = torch.arange(start = source_nodes.size(0), end = source_nodes.size(0) + target_nodes.size(0)) |
|
|
184 |
|
|
|
185 |
# get reverse edges to filter as well |
|
|
186 |
all_source_nodes = torch.cat([reindex_source_nodes, reindex_target_nodes]) |
|
|
187 |
all_target_nodes = torch.cat([reindex_target_nodes, reindex_source_nodes]) |
|
|
188 |
ind_to_edge_index, ind_to_nodes = get_indices_into_edge_index(edge_index, all_source_nodes, all_target_nodes) #get index into the original edge index (this returns e_ids) |
|
|
189 |
mask = torch.ones(edge_index.size(1), dtype=torch.bool) |
|
|
190 |
mask[ind_to_edge_index] = False |
|
|
191 |
|
|
|
192 |
return edge_index[:, mask], e_id[mask] |
|
|
193 |
|
|
|
194 |
|
|
|
195 |
def sample(self, source_batch): |
|
|
196 |
|
|
|
197 |
#convert to tensor |
|
|
198 |
if not isinstance(source_batch, Tensor): |
|
|
199 |
source_batch = torch.tensor(source_batch) |
|
|
200 |
|
|
|
201 |
# sample nodes to form positive edges. we will try to predict these edges |
|
|
202 |
row, col, e_id = self.adj_t_sample.coo() |
|
|
203 |
target_batch = random_walk(row, col, source_batch, walk_length=1, coalesced=False)[:, 1] #NOTE: only does self loops when no edges in the current partition of the dataset |
|
|
204 |
batch = torch.cat([source_batch, target_batch], dim=0) |
|
|
205 |
|
|
|
206 |
batch_size: int = len(batch) |
|
|
207 |
adjs = [] |
|
|
208 |
n_id = batch |
|
|
209 |
for size in self.sizes: |
|
|
210 |
adj_t, n_id = self.adj_t.sample_adj(n_id, size, replace=False) |
|
|
211 |
e_id = adj_t.storage.value() |
|
|
212 |
size = adj_t.sparse_sizes()[::-1] |
|
|
213 |
if self.__val__ is not None: |
|
|
214 |
adj_t.set_value_(self.__val__[e_id], layout='coo') |
|
|
215 |
|
|
|
216 |
if self.is_sparse_tensor: #TODO: implement filter_edges if sparse tensor |
|
|
217 |
adjs.append(Adj(adj_t, e_id, size)) |
|
|
218 |
else: |
|
|
219 |
row, col, _ = adj_t.coo() |
|
|
220 |
edge_index = torch.stack([col, row], dim=0) |
|
|
221 |
|
|
|
222 |
if self.do_filter_edges and self.dataset_type == 'train': |
|
|
223 |
edge_index, e_id = self.filter_edges(edge_index, e_id, source_batch, target_batch) |
|
|
224 |
adjs.append(EdgeIndex(edge_index, e_id, size)) |
|
|
225 |
|
|
|
226 |
adjs = adjs[0] if len(adjs) == 1 else adjs[::-1] |
|
|
227 |
out = (batch_size, n_id, adjs) |
|
|
228 |
out = self.transform(*out) if self.transform is not None else out |
|
|
229 |
return out |
|
|
230 |
|
|
|
231 |
def __repr__(self): |
|
|
232 |
return '{}(sizes={})'.format(self.__class__.__name__, self.sizes) |
|
|
233 |
|
|
|
234 |
class PatientNeighborSampler(torch.utils.data.DataLoader): |
|
|
235 |
|
|
|
236 |
def __init__(self, dataset_type: str, edge_index: Union[Tensor, SparseTensor], |
|
|
237 |
sample_edge_index: Union[Tensor, SparseTensor], |
|
|
238 |
sizes: List[int], |
|
|
239 |
patient_dataset, |
|
|
240 |
all_edge_attributes, |
|
|
241 |
n_nodes: int, |
|
|
242 |
relevant_node_idx = None, |
|
|
243 |
do_filter_edges: Optional[bool] = False, |
|
|
244 |
num_nodes: Optional[int] = None, |
|
|
245 |
return_e_id: bool = True, |
|
|
246 |
sparse_sample: Optional[int] = 0, |
|
|
247 |
train_phenotype_counter: Dict = None, |
|
|
248 |
train_gene_counter: Dict = None, |
|
|
249 |
sample_edges_from_train_patients=False, |
|
|
250 |
upsample_cand: Optional[int] = 0, |
|
|
251 |
n_cand_diseases=-1, |
|
|
252 |
use_diseases=False, |
|
|
253 |
nid_to_spl_dict = None, |
|
|
254 |
gp_spl = None, |
|
|
255 |
spl_indexing_dict=None, |
|
|
256 |
|
|
|
257 |
gene_similarity_dict=None, |
|
|
258 |
gene_deg_dict = None, |
|
|
259 |
|
|
|
260 |
hparams=None, |
|
|
261 |
transform: Callable = None, |
|
|
262 |
**kwargs): |
|
|
263 |
|
|
|
264 |
edge_index = edge_index.to('cpu') |
|
|
265 |
sample_edge_index = sample_edge_index.to('cpu') |
|
|
266 |
|
|
|
267 |
# add self loops |
|
|
268 |
sample_edge_index = torch.cat((sample_edge_index, torch.stack([edge_index.unique(), edge_index.unique()])),1 ) |
|
|
269 |
sample_edge_index, _ = add_remaining_self_loops(sample_edge_index) |
|
|
270 |
|
|
|
271 |
if 'collate_fn' in kwargs: |
|
|
272 |
del kwargs['collate_fn'] |
|
|
273 |
|
|
|
274 |
# Save for Pytorch Lightning... |
|
|
275 |
self.do_filter_edges = do_filter_edges |
|
|
276 |
self.relevant_node_idx = relevant_node_idx |
|
|
277 |
self.n_nodes = n_nodes |
|
|
278 |
self.all_edge_attr = all_edge_attributes |
|
|
279 |
self.dataset_type = dataset_type |
|
|
280 |
self.sparse_sample = sparse_sample |
|
|
281 |
self.edge_index = edge_index #always train edge index |
|
|
282 |
self.sample_edge_index = sample_edge_index # depends on train/val/test |
|
|
283 |
self.patient_dataset = patient_dataset |
|
|
284 |
self.num_nodes = num_nodes |
|
|
285 |
self.train_phenotype_counter = train_phenotype_counter |
|
|
286 |
self.train_gene_counter = train_gene_counter |
|
|
287 |
self.sample_edges_from_train_patients = sample_edges_from_train_patients |
|
|
288 |
self.sizes = sizes |
|
|
289 |
self.return_e_id = return_e_id |
|
|
290 |
self.transform = transform |
|
|
291 |
self.is_sparse_tensor = isinstance(edge_index, SparseTensor) |
|
|
292 |
self.__val__ = None |
|
|
293 |
|
|
|
294 |
# For SPL |
|
|
295 |
self.nid_to_spl_dict = nid_to_spl_dict |
|
|
296 |
if hparams["alpha"] < 1: self.gp_spl = gp_spl |
|
|
297 |
else: self.gp_spl = None |
|
|
298 |
self.spl_indexing_dict = spl_indexing_dict |
|
|
299 |
|
|
|
300 |
# Up-sample candidate genes |
|
|
301 |
self.upsample_cand = upsample_cand |
|
|
302 |
self.cand_gene_freq = Counter([]) |
|
|
303 |
with open(str(project_config.KG_DIR / f'ensembl_to_idx_dict_{project_config.CURR_KG}.pkl'), 'rb') as handle: |
|
|
304 |
ensembl_to_idx_dict = pickle.load(handle) # create ensembl to node_idx map |
|
|
305 |
idx_to_ensembl_dict = {v: k for k, v in ensembl_to_idx_dict.items()} |
|
|
306 |
self.cand_gene_freq = Counter([k for k in nid_to_spl_dict if k in idx_to_ensembl_dict]) # Upsample from all gene nodes in the KG |
|
|
307 |
|
|
|
308 |
self.n_cand_diseases = n_cand_diseases |
|
|
309 |
self.use_diseases = use_diseases |
|
|
310 |
self.hparams = hparams |
|
|
311 |
|
|
|
312 |
self.gene_similarity_dict = gene_similarity_dict |
|
|
313 |
self.gene_deg_dict = gene_deg_dict |
|
|
314 |
|
|
|
315 |
# Obtain a *transposed* `SparseTensor` instance. |
|
|
316 |
if not self.is_sparse_tensor: |
|
|
317 |
if num_nodes is None: |
|
|
318 |
num_nodes = int(edge_index.max()) + 1 |
|
|
319 |
sample_num_nodes = int(sample_edge_index.max()) + 1 |
|
|
320 |
|
|
|
321 |
value = torch.arange(edge_index.size(1)) if return_e_id else None |
|
|
322 |
sample_value = torch.arange(sample_edge_index.size(1)) if return_e_id else None |
|
|
323 |
self.adj_t = SparseTensor(row=edge_index[0], col=edge_index[1], |
|
|
324 |
value=value, |
|
|
325 |
sparse_sizes=(num_nodes, num_nodes)).t() |
|
|
326 |
self.adj_t_sample = SparseTensor(row=sample_edge_index[0], col=sample_edge_index[1], |
|
|
327 |
value=sample_value, |
|
|
328 |
sparse_sizes=(sample_num_nodes, sample_num_nodes)).t() |
|
|
329 |
else: |
|
|
330 |
adj_t = edge_index |
|
|
331 |
adj_t_sample = sample_edge_index |
|
|
332 |
if return_e_id: |
|
|
333 |
self.__val__ = adj_t.storage.value() |
|
|
334 |
value = torch.arange(adj_t.nnz()) |
|
|
335 |
adj_t = adj_t.set_value(value, layout='coo') |
|
|
336 |
adj_t_sample = adj_t_sample.set_value(torch.arange(adj_t_sample.nnz()), layout='coo') |
|
|
337 |
self.adj_t = adj_t |
|
|
338 |
self.adj_t_sample = adj_t_sample |
|
|
339 |
|
|
|
340 |
self.adj_t.storage.rowptr() |
|
|
341 |
self.adj_t_sample.storage.rowptr() |
|
|
342 |
|
|
|
343 |
|
|
|
344 |
|
|
|
345 |
super(PatientNeighborSampler, self).__init__( |
|
|
346 |
self.patient_dataset, collate_fn=self.collate, **kwargs) |
|
|
347 |
|
|
|
348 |
def filter_edges(self, edge_index, e_id, source_nodes, target_nodes): |
|
|
349 |
''' |
|
|
350 |
Filter out the edges we're trying to predict in the current batch from the edge index |
|
|
351 |
NOTE: edge_index here is re-indexed |
|
|
352 |
''' |
|
|
353 |
reindex_source_nodes = torch.arange(source_nodes.size(0)) |
|
|
354 |
reindex_target_nodes = torch.arange(start = source_nodes.size(0), end = source_nodes.size(0) + target_nodes.size(0)) |
|
|
355 |
|
|
|
356 |
# get reverse edges to filter as well |
|
|
357 |
all_source_nodes = torch.cat([reindex_source_nodes, reindex_target_nodes]) |
|
|
358 |
all_target_nodes = torch.cat([reindex_target_nodes, reindex_source_nodes]) |
|
|
359 |
ind_to_edge_index, ind_to_nodes = get_indices_into_edge_index(edge_index, all_source_nodes, all_target_nodes) #get index into the original edge index (this returns e_ids) |
|
|
360 |
mask = torch.ones(edge_index.size(1), dtype=torch.bool) |
|
|
361 |
mask[ind_to_edge_index] = False |
|
|
362 |
|
|
|
363 |
return edge_index[:, mask], e_id[mask] |
|
|
364 |
|
|
|
365 |
def get_source_nodes(self, phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, disease_node_idx, candidate_disease_node_idx, sim_gene_node_idx): |
|
|
366 |
|
|
|
367 |
# Get batch node indices based on patient phenotypes and genes |
|
|
368 |
if sim_gene_node_idx is not None: |
|
|
369 |
source_batch = torch.cat(phenotype_node_idx + candidate_gene_node_idx + correct_genes_node_idx + disease_node_idx + candidate_disease_node_idx + sim_gene_node_idx) |
|
|
370 |
else: |
|
|
371 |
source_batch = torch.cat(phenotype_node_idx + candidate_gene_node_idx + correct_genes_node_idx + disease_node_idx + candidate_disease_node_idx) |
|
|
372 |
|
|
|
373 |
# Randomly sample nodes in KG |
|
|
374 |
if self.sparse_sample > 0: |
|
|
375 |
if self.relevant_node_idx == None: |
|
|
376 |
rand_idx = torch.randint(high=self.n_nodes, size=(self.sparse_sample,)) # NOTE that this can sample duplicates, but has the benefit of randomly sampling new nodes each epoch |
|
|
377 |
else: |
|
|
378 |
rand_idx = self.relevant_node_idx[torch.randint(high=self.relevant_node_idx.size(0), size=(self.sparse_sample,))] |
|
|
379 |
|
|
|
380 |
source_batch = torch.cat([source_batch, rand_idx]) |
|
|
381 |
source_batch = torch.unique(source_batch) |
|
|
382 |
sparse_idx = torch.unique(rand_idx) |
|
|
383 |
else: |
|
|
384 |
source_batch = torch.unique(source_batch) |
|
|
385 |
sparse_idx = torch.Tensor([]) |
|
|
386 |
|
|
|
387 |
return source_batch, sparse_idx |
|
|
388 |
|
|
|
389 |
def sample_target_nodes(self, source_batch): |
|
|
390 |
row, col, e_id = self.adj_t_sample.coo() |
|
|
391 |
|
|
|
392 |
if self.sample_edges_from_train_patients: |
|
|
393 |
train_patient_nodes = torch.tensor(list(self.train_phenotype_counter.keys()) + list(self.train_gene_counter.keys())) |
|
|
394 |
ind_with_train_patient_nodes = (col == train_patient_nodes.unsqueeze(-1)).nonzero(as_tuple=True)[1] |
|
|
395 |
subset_row = row[ind_with_train_patient_nodes] |
|
|
396 |
subset_col = col[ind_with_train_patient_nodes] |
|
|
397 |
try: |
|
|
398 |
# first try to find an edge that connects back to the training set patient data |
|
|
399 |
targets = random_walk(subset_row, subset_col, source_batch, walk_length=1, coalesced=False)[:, 1] #NOTE: only does self loops when no edges in the current partition of the dataset |
|
|
400 |
source_batch_1 = source_batch[~torch.eq(source_batch, targets)] |
|
|
401 |
targets_1 = targets[~torch.eq(source_batch, targets)] |
|
|
402 |
|
|
|
403 |
# if no edges are found, use all available edges in this split of the data |
|
|
404 |
source_batch_2 = source_batch[torch.eq(source_batch, targets)] |
|
|
405 |
targets_2 = random_walk(row, col, source_batch_2, walk_length=1, coalesced=False)[:, 1] #NOTE: only does self loops when no edges in the current partition of the dataset |
|
|
406 |
|
|
|
407 |
#concat the two together |
|
|
408 |
source_batch = torch.cat([source_batch_1, source_batch_2]) |
|
|
409 |
targets = torch.cat([targets_1, targets_2]) |
|
|
410 |
|
|
|
411 |
except: |
|
|
412 |
targets = random_walk(row, col, source_batch, walk_length=1, coalesced=False)[:, 1] #NOTE: only does self loops when no edges in the current partition of the dataset |
|
|
413 |
else: |
|
|
414 |
targets = random_walk(row, col, source_batch, walk_length=1, coalesced=False)[:, 1] #NOTE: only does self loops when no edges in the current partition of the dataset |
|
|
415 |
return source_batch, targets |
|
|
416 |
|
|
|
417 |
def add_patient_information(self, patient_ids, phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, sim_gene_node_idx, gene_sims, gene_degs, disease_node_idx, candidate_disease_node_idx, labels, disease_labels, patient_labels, additional_labels, adjs, batch_size, n_id, sparse_idx, target_batch): #candidate_disease_node_idx |
|
|
418 |
|
|
|
419 |
# Create Data Object & Add patient level information |
|
|
420 |
adjs = [HeterogeneousEdgeIndex(adj.edge_index, adj.e_id, self.all_edge_attr[adj.e_id], adj.size) for adj in adjs] |
|
|
421 |
max_n_candidates = max([len(l) for l in candidate_gene_node_idx]) |
|
|
422 |
data = Data(adjs = adjs, |
|
|
423 |
batch_size = batch_size, |
|
|
424 |
patient_ids = patient_ids, |
|
|
425 |
n_id = n_id |
|
|
426 |
) |
|
|
427 |
if self.hparams['loss'] != 'patient_disease_NCA' and self.hparams['loss'] != 'patient_patient_NCA': |
|
|
428 |
if None in list(labels): data['one_hot_labels'] = None |
|
|
429 |
else: data['one_hot_labels'] = torch.LongTensor(label_binarize(labels, classes = list(range(max_n_candidates)))) |
|
|
430 |
|
|
|
431 |
if self.use_diseases: |
|
|
432 |
data['disease_one_hot_labels'] = disease_labels |
|
|
433 |
|
|
|
434 |
if self.hparams['loss'] == 'patient_patient_NCA': |
|
|
435 |
if patient_labels is None: data['patient_labels'] = None |
|
|
436 |
else: data['patient_labels'] = torch.stack(patient_labels) |
|
|
437 |
|
|
|
438 |
# Get candidate genes to phenotypes SPL |
|
|
439 |
if not self.gp_spl is None: |
|
|
440 |
if not self.spl_indexing_dict is None: |
|
|
441 |
patient_ids = np.vectorize(self.spl_indexing_dict.get)(patient_ids).astype(int) |
|
|
442 |
gene_to_phenotypes_spl = -torch.Tensor(self.gp_spl[patient_ids,:]) |
|
|
443 |
# get gene idx to spl information |
|
|
444 |
cand_gene_idx_to_spl = [torch.LongTensor(np.vectorize(self.nid_to_spl_dict.get)(cand_genes)) for cand_genes in list(candidate_gene_node_idx)] |
|
|
445 |
# get SPLs for each patient's candidate genes |
|
|
446 |
batch_cand_gene_to_phenotypes_spl = [gene_spls[cand_genes] for cand_genes, gene_spls in zip(cand_gene_idx_to_spl, gene_to_phenotypes_spl)] |
|
|
447 |
# pad to same # of candidate genes |
|
|
448 |
data['batch_cand_gene_to_phenotypes_spl'] = pad_sequence(batch_cand_gene_to_phenotypes_spl, batch_first=True, padding_value=0) |
|
|
449 |
# get unique gene idx across all patients in the batch |
|
|
450 |
cand_gene_idx_flattened_unique = torch.unique(torch.cat(cand_gene_idx_to_spl)).flatten() |
|
|
451 |
# get SPLs for unique genes in the batch |
|
|
452 |
data['batch_concat_cand_gene_to_phenotypes_spl'] = gene_to_phenotypes_spl[:, cand_gene_idx_flattened_unique] |
|
|
453 |
else: |
|
|
454 |
data['batch_cand_gene_to_phenotypes_spl'] = None |
|
|
455 |
data['batch_concat_cand_gene_to_phenotypes_spl'] = None |
|
|
456 |
|
|
|
457 |
|
|
|
458 |
# Create mapping from KG node IDs to batch indices |
|
|
459 |
node2batch = {n+1: int(i+1) for i, n in enumerate(data.n_id.tolist())} |
|
|
460 |
node2batch[0] = 0 |
|
|
461 |
|
|
|
462 |
# add phenotype / gene / disease names |
|
|
463 |
data['phenotype_names'] = [[(self.patient_dataset.node_idx_to_name(p.item()), self.patient_dataset.node_idx_to_degree(p.item())) for p in p_list] for p_list in phenotype_node_idx ] |
|
|
464 |
data['cand_gene_names'] = [[self.patient_dataset.node_idx_to_name(g.item()) for g in g_list] for g_list in candidate_gene_node_idx ] |
|
|
465 |
data['corr_gene_names'] = [[self.patient_dataset.node_idx_to_name(g.item()) for g in g_list] for g_list in correct_genes_node_idx ] |
|
|
466 |
data['disease_names'] = [[self.patient_dataset.node_idx_to_name(d.item()) for d in d_list] for d_list in disease_node_idx ] |
|
|
467 |
|
|
|
468 |
if self.use_diseases: |
|
|
469 |
data['cand_disease_names'] = [[self.patient_dataset.node_idx_to_name(d.item()) for d in d_list] for d_list in candidate_disease_node_idx ] |
|
|
470 |
|
|
|
471 |
|
|
|
472 |
#reindex nodes to make room for padding |
|
|
473 |
phenotype_node_idx = [p + 1 for p in phenotype_node_idx] |
|
|
474 |
candidate_gene_node_idx = [g + 1 for g in candidate_gene_node_idx] |
|
|
475 |
correct_genes_node_idx = [g + 1 for g in correct_genes_node_idx] |
|
|
476 |
if self.use_diseases: |
|
|
477 |
disease_node_idx = [d + 1 for d in disease_node_idx] |
|
|
478 |
candidate_disease_node_idx = [d + 1 for d in candidate_disease_node_idx] |
|
|
479 |
if 'augment_genes' in self.hparams and self.hparams['augment_genes']: |
|
|
480 |
sim_gene_node_idx = [g + 1 for g in sim_gene_node_idx] |
|
|
481 |
|
|
|
482 |
# if there aren't any disease idx in the batch, we add filler |
|
|
483 |
if self.use_diseases: |
|
|
484 |
if all(len(t) == 0 for t in disease_node_idx): |
|
|
485 |
disease_node_idx = [torch.LongTensor([0]) for i in range(len(disease_node_idx))] |
|
|
486 |
if all(len(t) == 0 for t in candidate_disease_node_idx): |
|
|
487 |
candidate_disease_node_idx = [torch.LongTensor([0]) for i in range(len(candidate_disease_node_idx))] |
|
|
488 |
|
|
|
489 |
# add padding to patient phenotype and gene node idx |
|
|
490 |
data['batch_pheno_nid'] = pad_sequence(phenotype_node_idx, batch_first=True, padding_value=0) |
|
|
491 |
if len(candidate_gene_node_idx[0]) > 0: |
|
|
492 |
data['batch_cand_gene_nid'] = pad_sequence(candidate_gene_node_idx, batch_first=True, padding_value=0) |
|
|
493 |
data['batch_corr_gene_nid'] = pad_sequence(correct_genes_node_idx, batch_first=True, padding_value=0) |
|
|
494 |
if self.use_diseases: |
|
|
495 |
data['batch_disease_nid'] = pad_sequence(disease_node_idx, batch_first=True, padding_value=0) |
|
|
496 |
data['batch_cand_disease_nid'] = pad_sequence(candidate_disease_node_idx, batch_first=True, padding_value=0) |
|
|
497 |
if 'augment_genes' in self.hparams and self.hparams['augment_genes']: |
|
|
498 |
data['batch_cand_gene_degs'] = pad_sequence(gene_degs, batch_first=True, padding_value=0) |
|
|
499 |
data['batch_sim_gene_nid'] = pad_sequence(sim_gene_node_idx, batch_first=True, padding_value=0) |
|
|
500 |
data['batch_sim_gene_sims'] = pad_sequence(gene_sims, batch_first=True, padding_value=0) |
|
|
501 |
# Normalize |
|
|
502 |
data['batch_sim_gene_sims'] = data['batch_sim_gene_sims'] / torch.sum(data['batch_sim_gene_sims'], dim=1, keepdim=True) |
|
|
503 |
else: |
|
|
504 |
if len(candidate_gene_node_idx[0]) > 0: |
|
|
505 |
data['batch_cand_gene_nid'] = pad_sequence(candidate_gene_node_idx, batch_first=True, padding_value=0) |
|
|
506 |
|
|
|
507 |
# Convert KG node IDs to batch IDs |
|
|
508 |
# When performing inference (i.e., predict.py), use the original node IDs because the full KG is used in forward pass of node model |
|
|
509 |
if self.dataset_type != "predict": |
|
|
510 |
data['batch_pheno_nid'] = torch.LongTensor(np.vectorize(node2batch.get)(data['batch_pheno_nid'])) |
|
|
511 |
if len(candidate_gene_node_idx[0]) > 0: |
|
|
512 |
data['batch_cand_gene_nid'] = torch.LongTensor(np.vectorize(node2batch.get)(data['batch_cand_gene_nid'])) |
|
|
513 |
if len(correct_genes_node_idx[0]) > 0: |
|
|
514 |
data['batch_corr_gene_nid'] = torch.LongTensor(np.vectorize(node2batch.get)(data['batch_corr_gene_nid'])) |
|
|
515 |
if self.use_diseases: |
|
|
516 |
data['batch_disease_nid'] = torch.LongTensor(np.vectorize(node2batch.get)(data['batch_disease_nid'])) |
|
|
517 |
data['batch_cand_disease_nid'] = torch.LongTensor(np.vectorize(node2batch.get)(data['batch_cand_disease_nid'])) |
|
|
518 |
if 'augment_genes' in self.hparams and self.hparams['augment_genes']: |
|
|
519 |
data['batch_sim_gene_nid'] = torch.LongTensor(np.vectorize(node2batch.get)(data['batch_sim_gene_nid'])) |
|
|
520 |
return data |
|
|
521 |
|
|
|
522 |
def get_candidate_diseases(self, disease_node_idx, candidate_gene_node_idx): |
|
|
523 |
cand_diseases = self.patient_dataset.get_candidate_diseases(cand_type=self.hparams['candidate_disease_type']) |
|
|
524 |
if self.n_cand_diseases != -1: cand_diseases = cand_diseases[torch.randperm(len(cand_diseases))][0:self.n_cand_diseases] |
|
|
525 |
|
|
|
526 |
if self.hparams['only_hard_distractors']: #add candidates to every patient |
|
|
527 |
candidate_disease_node_idx = tuple(torch.unique(torch.cat([corr_dis, cand_diseases ]), sorted=False) for corr_dis in disease_node_idx) |
|
|
528 |
candidate_disease_node_idx = tuple(torch.unique(dis[torch.randperm(len(dis))], sorted=False, return_inverse=False, return_counts=False) for dis in candidate_disease_node_idx) |
|
|
529 |
else: # split candidates across all patients in the batch |
|
|
530 |
all_correct_diseases = torch.cat(disease_node_idx) |
|
|
531 |
all_diseases = torch.unique(torch.cat([all_correct_diseases, cand_diseases])) |
|
|
532 |
all_diseases = all_diseases[torch.randperm(len(all_diseases))] |
|
|
533 |
candidate_disease_node_idx = np.array_split(all_diseases, len(candidate_gene_node_idx)) |
|
|
534 |
candidate_disease_node_idx = tuple(candidate_disease_node_idx) |
|
|
535 |
max_n_dis_candidates = max([len(l) for l in candidate_disease_node_idx]) |
|
|
536 |
if max_n_dis_candidates == 0: |
|
|
537 |
max_n_dis_candidates = 1 |
|
|
538 |
print('WARNING: there are no disease candidates') |
|
|
539 |
|
|
|
540 |
disease_ind = [(dis.unsqueeze(1) == corr_dis.unsqueeze(0)).nonzero(as_tuple=True)[0] if len(corr_dis) > 0 else torch.tensor(-1) for dis, corr_dis in zip(candidate_disease_node_idx, disease_node_idx)] |
|
|
541 |
disease_labels = torch.zeros((len(candidate_disease_node_idx), max_n_dis_candidates)) |
|
|
542 |
for i, ind in enumerate(disease_ind): disease_labels[i,ind[ind != -1]] = 1 |
|
|
543 |
return candidate_disease_node_idx, disease_labels |
|
|
544 |
|
|
|
545 |
def get_candidate_patients(self, patient_ids): |
|
|
546 |
# get patients with the same disease/gene |
|
|
547 |
similar_pat_ids = [self.patient_dataset.get_similar_patients(p_id, similarity_type=self.hparams['patient_similarity_type']) for p_id in patient_ids] |
|
|
548 |
# shuffle patients & subset to n_sim_pats so we have X similar patients per patient in the batch |
|
|
549 |
similar_pat_ids = [p[:self.hparams['n_similar_patients']] for p in similar_pat_ids] #[torch.randperm(len(p))] |
|
|
550 |
# Retrieve the patients for each of the sampled patient ids if they aren't already in the batch |
|
|
551 |
patient_ids = list(patient_ids) |
|
|
552 |
similar_pats = [self.patient_dataset[self.patient_dataset.patient_id_to_index[p_id.item()]] for p_ids in similar_pat_ids for p_id in p_ids if p_id.item() not in patient_ids] |
|
|
553 |
return similar_pats |
|
|
554 |
|
|
|
555 |
def sample(self, batch, source_batch, target_batch): |
|
|
556 |
batch_size: int = len(batch) |
|
|
557 |
adjs = [] |
|
|
558 |
n_id = batch |
|
|
559 |
for size in self.sizes: |
|
|
560 |
|
|
|
561 |
adj_t, n_id = self.adj_t.sample_adj(n_id, size, replace=False) |
|
|
562 |
e_id = adj_t.storage.value() |
|
|
563 |
size = adj_t.sparse_sizes()[::-1] |
|
|
564 |
if self.__val__ is not None: |
|
|
565 |
adj_t.set_value_(self.__val__[e_id], layout='coo') |
|
|
566 |
|
|
|
567 |
if self.is_sparse_tensor: #TODO: implement filter_edges if sparse tensor |
|
|
568 |
adjs.append(Adj(adj_t, e_id, size)) |
|
|
569 |
else: |
|
|
570 |
row, col, _ = adj_t.coo() |
|
|
571 |
edge_index = torch.stack([col, row], dim=0) |
|
|
572 |
if self.do_filter_edges and self.dataset_type == 'train': |
|
|
573 |
edge_index, e_id = self.filter_edges(edge_index, e_id, source_batch, target_batch) |
|
|
574 |
adjs.append(EdgeIndex(edge_index, e_id, size)) |
|
|
575 |
|
|
|
576 |
adjs = [adjs[0]] if len(adjs) == 1 else adjs[::-1] |
|
|
577 |
return adjs, batch_size, n_id |
|
|
578 |
|
|
|
579 |
def get_similar_genes(self, patient_ids, candidate_gene_node_idx): |
|
|
580 |
k = self.hparams['n_sim_genes'] |
|
|
581 |
gene_ids = [] |
|
|
582 |
sims = [] |
|
|
583 |
degs = [] |
|
|
584 |
assert len(patient_ids) == len(candidate_gene_node_idx) |
|
|
585 |
for p, p_cand_genes in zip(patient_ids, candidate_gene_node_idx): |
|
|
586 |
p_genes = [] |
|
|
587 |
p_sims = [] |
|
|
588 |
p_degs = [] |
|
|
589 |
for g in p_cand_genes: |
|
|
590 |
p_genes.append(torch.LongTensor([idx for idx, sim in list(self.gene_similarity_dict[int(g)])[:k]])) |
|
|
591 |
p_sims.append(torch.LongTensor([sim for idx, sim in list(self.gene_similarity_dict[int(g)])[:k]])) |
|
|
592 |
p_degs.append(self.gene_deg_dict[int(g)]) |
|
|
593 |
gene_ids.append(torch.stack(p_genes)) |
|
|
594 |
sims.append(torch.stack(p_sims)) |
|
|
595 |
degs.append(torch.LongTensor(p_degs)) |
|
|
596 |
assert len(gene_ids) == len(patient_ids) |
|
|
597 |
assert len(sims) == len(patient_ids) |
|
|
598 |
unique_genes = torch.unique(torch.cat(gene_ids).flatten()).unsqueeze(-1) |
|
|
599 |
return tuple(gene_ids), tuple(sims), tuple(degs), tuple(unique_genes) |
|
|
600 |
|
|
|
601 |
def collate(self, batch): |
|
|
602 |
t00 = time.time() |
|
|
603 |
phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, disease_node_idx, labels, additional_labels, patient_ids = zip(*batch) |
|
|
604 |
|
|
|
605 |
# Up-sample under-represented candidate genes |
|
|
606 |
t0 = time.time() |
|
|
607 |
if self.upsample_cand > 0: |
|
|
608 |
curr_cand_gene_freq = Counter(torch.cat(candidate_gene_node_idx).flatten().tolist()) |
|
|
609 |
self.cand_gene_freq += curr_cand_gene_freq |
|
|
610 |
num_patients = len(candidate_gene_node_idx) * self.upsample_cand |
|
|
611 |
lowest_k_cand = self.cand_gene_freq.most_common()[:-num_patients-1:-1] |
|
|
612 |
lowest_k_cand = np.array_split([g[0] for g in lowest_k_cand], len(candidate_gene_node_idx)) |
|
|
613 |
|
|
|
614 |
upsampled_candidate_gene_node_idx = [] |
|
|
615 |
added_cand_gene = [] |
|
|
616 |
for patient, cand_gene, corr_gene_idx in zip(candidate_gene_node_idx, lowest_k_cand, labels): |
|
|
617 |
|
|
|
618 |
# Remove correct genes from list of upsampled candidate genes |
|
|
619 |
corr_gene_nid = patient[corr_gene_idx] |
|
|
620 |
cand_gene = cand_gene[~np.isin(cand_gene, corr_gene_nid)].flatten() |
|
|
621 |
|
|
|
622 |
# Remove duplicates |
|
|
623 |
unique_cand_genes, new_cand_genes_freq = torch.unique(torch.tensor(patient.tolist() + list(cand_gene)), return_counts = True) |
|
|
624 |
unique_cand_genes = unique_cand_genes[new_cand_genes_freq == 1] |
|
|
625 |
cand_gene = cand_gene[np.isin(cand_gene, unique_cand_genes)] |
|
|
626 |
|
|
|
627 |
# Add upsampled candidate genes |
|
|
628 |
added_cand_gene.extend(list(cand_gene)) |
|
|
629 |
new_cand_list = torch.tensor(patient.tolist() + list(cand_gene)) |
|
|
630 |
upsampled_candidate_gene_node_idx.append(new_cand_list) |
|
|
631 |
|
|
|
632 |
candidate_gene_node_idx = tuple(upsampled_candidate_gene_node_idx) |
|
|
633 |
self.cand_gene_freq += Counter(added_cand_gene) |
|
|
634 |
|
|
|
635 |
|
|
|
636 |
# Add similar patients to batch (for "patients like me" head) |
|
|
637 |
if self.hparams['add_similar_patients']: |
|
|
638 |
similar_pats = self.get_candidate_patients(patient_ids) |
|
|
639 |
# merge original batch with sampled patients |
|
|
640 |
phenotype_node_idx_sim, candidate_gene_node_idx_sim, correct_genes_node_idx_sim, disease_node_idx_sim, labels_sim, additional_labels_sim, patient_ids_sim = zip(*similar_pats) |
|
|
641 |
phenotype_node_idx = phenotype_node_idx + phenotype_node_idx_sim |
|
|
642 |
candidate_gene_node_idx = candidate_gene_node_idx + candidate_gene_node_idx_sim |
|
|
643 |
correct_genes_node_idx = correct_genes_node_idx + correct_genes_node_idx_sim |
|
|
644 |
disease_node_idx = disease_node_idx + disease_node_idx_sim |
|
|
645 |
labels = labels + labels_sim |
|
|
646 |
additional_labels = additional_labels + additional_labels_sim |
|
|
647 |
patient_ids = patient_ids + patient_ids_sim |
|
|
648 |
|
|
|
649 |
# get patient labels |
|
|
650 |
patient_labels = correct_genes_node_idx |
|
|
651 |
|
|
|
652 |
# Add candidate diseases to batch |
|
|
653 |
if self.hparams['add_cand_diseases']: |
|
|
654 |
candidate_disease_node_idx, disease_labels = self.get_candidate_diseases(disease_node_idx, candidate_gene_node_idx) |
|
|
655 |
else: |
|
|
656 |
candidate_disease_node_idx = disease_node_idx |
|
|
657 |
disease_labels = torch.tensor([1] * len(candidate_disease_node_idx)) |
|
|
658 |
|
|
|
659 |
if self.hparams['augment_genes']: |
|
|
660 |
sim_gene_node_idx, gene_sims, gene_degs, unique_sim_genes = self.get_similar_genes(patient_ids, candidate_gene_node_idx) |
|
|
661 |
else: |
|
|
662 |
unique_sim_genes = gene_degs = gene_sims = sim_gene_node_idx = None |
|
|
663 |
|
|
|
664 |
t1 = time.time() |
|
|
665 |
|
|
|
666 |
# get nodes from patients + randomly sampled nodes |
|
|
667 |
source_batch, sparse_idx = self.get_source_nodes(phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, disease_node_idx, candidate_disease_node_idx, unique_sim_genes) |
|
|
668 |
|
|
|
669 |
# sample nodes to form positive edges |
|
|
670 |
source_batch, target_batch = self.sample_target_nodes(source_batch) |
|
|
671 |
batch = torch.cat([source_batch, target_batch], dim=0) |
|
|
672 |
t2 = time.time() |
|
|
673 |
|
|
|
674 |
# get k hop adj graph |
|
|
675 |
adjs, batch_size, n_id = self.sample(batch, source_batch, target_batch) |
|
|
676 |
t3 = time.time() |
|
|
677 |
|
|
|
678 |
# add patient information to data object |
|
|
679 |
data = self.add_patient_information(patient_ids, phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, sim_gene_node_idx, gene_sims, gene_degs, disease_node_idx, candidate_disease_node_idx, labels, disease_labels, patient_labels, additional_labels, adjs, batch_size, n_id, sparse_idx, target_batch) #candidate_disease_node_idx |
|
|
680 |
t4 = time.time() |
|
|
681 |
|
|
|
682 |
if self.hparams['time']: |
|
|
683 |
print(f'It takes {t0-t00:0.4f}s to unzip batch, {t1-t0:0.4f}s to upsample candidate gene nodes, {t2-t1:0.4f}s to sample positive nodes, {t3-t2:0.4f}s to get k-hop adjs, and {t4-t3:0.4f}s to add patient information') |
|
|
684 |
return data |
|
|
685 |
|
|
|
686 |
def __repr__(self): |
|
|
687 |
return '{}(sizes={})'.format(self.__class__.__name__, self.sizes) |
|
|
688 |
|
|
|
689 |
|
|
|
690 |
|