a b/kgwas/utils.py
1
import os, sys
2
from scipy.sparse import csr_matrix
3
from scipy.sparse.csgraph import connected_components
4
import pandas as pd
5
import numpy as np
6
import pickle
7
from tqdm import tqdm
8
from scipy.stats import pearsonr
9
from sklearn.metrics import mean_squared_error, precision_score
10
import torch
11
from torch.nn import functional as F 
12
from torch import nn
13
from multiprocessing import Pool
14
from tqdm import tqdm
15
from functools import partial
16
17
from .params import main_data_path, cohort_data_path, kinship_path, withdraw_path
18
19
20
def evaluate_minibatch_clean(loader, model, device):    
21
    model.eval()
22
    pred_all = []
23
    truth = []
24
    results = {}
25
    for step, batch in enumerate(tqdm(loader)):        
26
        batch = batch.to(device)
27
        bs_batch = batch['SNP'].batch_size
28
        
29
        out = model(batch.x_dict, batch.edge_index_dict, bs_batch)
30
        pred = out.reshape(-1)
31
        y_batch = batch['SNP'].y[:bs_batch]
32
        
33
        pred_all.extend(pred.detach().cpu().numpy())
34
        truth.extend(y_batch.detach().cpu().numpy())
35
        del y_batch, pred, batch, out
36
        
37
    results['pred'] = np.hstack(pred_all)
38
    results['truth'] = np.hstack(truth)
39
    return results
40
41
def compute_metrics(results, binary, coverage = None, uncertainty_reg = 1, loss_fct = None):
42
    metrics = {}
43
    metrics['mse'] = mean_squared_error(results['pred'], results['truth'])
44
    metrics['pearsonr'] = pearsonr(results['pred'], results['truth'])[0]
45
    return metrics
46
47
48
'''
49
requires to modify the pyg source code since it does not support heterogeneous graph attention
50
51
miniconda3/envs/a100_env/lib/python3.8/site-packages/torch_geometric/nn/conv/hgt_conv.py
52
53
def group(xs: List[Tensor], aggr: Optional[str]) -> Optional[Tensor]:
54
    if len(xs) == 0:
55
        return None
56
    elif aggr is None:
57
        return torch.stack(xs, dim=1)
58
    elif len(xs) == 1:
59
        return xs[0]
60
    elif isinstance(xs, list) and isinstance(xs[0], tuple):
61
        xs_old = [i[0] for i in xs]
62
        out = torch.stack(xs_old, dim=0)
63
        out = getattr(torch, aggr)(out, dim=0)
64
        out = out[0] if isinstance(out, tuple) else out        
65
        att = [i[1] for i in xs]
66
        return (out, att)
67
    else:
68
        out = torch.stack(xs, dim=0)
69
        out = getattr(torch, aggr)(out, dim=0)
70
        out = out[0] if isinstance(out, tuple) else out
71
        return out
72
73
'''
74
75
76
def get_attention_weight(model, x_dict, edge_index_dict):
77
    attention_all_layers = []
78
    for conv in model.convs:
79
        out = conv(x_dict, edge_index_dict, return_attention_weights_dict = dict(zip(list(data.edge_index_dict.keys()), [True] * len(list(data.edge_index_dict.keys())))))
80
        x_dict = {i: j[0] for i,j in out.items()}
81
        attention_layer = {i: j[1] for i,j in out.items()}
82
        attention_all_layers.append(attention_layer)
83
        x_dict = {key: x.relu() for key, x in x_dict.items()}    
84
    idx2n_id = {}
85
    for i in batch.node_types:
86
        idx2n_id[i] = dict(zip(range(len(batch[i].n_id)), batch[i].n_id.numpy()))
87
        
88
    node_type = 'SNP'
89
    edge2weight_l1 = {}
90
    edge2weight_l2 = {}
91
92
    edge_type_node = [i for i,j in batch.edge_index_dict.items() if i[2] == node_type]
93
    edge_type_node_len = [j.shape[1] for i,j in batch.edge_index_dict.items() if i[2] == node_type]
94
95
    for idx, edge_type in enumerate(edge_type_node):
96
        edge2weight_l1[edge_type] = attention_all_layers[0][node_type][idx]
97
        assert edge_type_node_len[idx] == edge2weight_l1[edge_type][0].shape[1]
98
99
        edge2weight_l2[edge_type] = attention_all_layers[1][node_type][idx]
100
        assert edge_type_node_len[idx] == edge2weight_l2[edge_type][0].shape[1]
101
102
        edge2weight_l1[edge_type][0][0] = torch.LongTensor([idx2n_id[edge_type[0]][ent] for ent in edge2weight_l1[edge_type][0][0].detach().cpu().numpy()])
103
        edge2weight_l1[edge_type][0][1] = torch.LongTensor([idx2n_id[edge_type[2]][ent] for ent in edge2weight_l1[edge_type][0][1].detach().cpu().numpy()])
104
        
105
    return edge2weight_l1, edge2weight_l2
106
    
107
108
def get_fields(all_field_ids, main_data_path):
109
    headers = pd.read_csv(main_data_path, nrows = 1).columns
110
    relevant_headers = [i for i, header in enumerate(headers) if header == 'eid' or \
111
            any([header.startswith('%d-' % field_id) for field_id in all_field_ids])]
112
    return pd.read_csv(main_data_path, usecols = relevant_headers)
113
114
115
def get_row_last_values(df):
116
    
117
    result = pd.Series(np.nan, index = df.index)
118
119
    for column in df.columns[::-1]:
120
        result = result.where(pd.notnull(result), df[column])
121
122
    return result
123
124
def remove_kinships(eid, verbose = True):
125
126
    '''
127
    Determines which samples need to be removed such that the remaining samples will have no kinship connections whatsoever (according to the
128
    kinship table provided by the UKBB). In order to determine that, kinship groups will first be determined (@see get_kinship_groups), and 
129
    only one sample will remain within each of the groups. For the sake of determinism, the sample with the lowest eid will be selected within
130
    each kinship group, and the rest will be discarded.
131
    @param eid (pd.Series): A series whose values are UKBB sample IDs, from which kinships should be removed.
132
    @param verbose (bool): Whether to log details of the operation of this function.
133
    @return: A mask of samples to keep (pd.Series with index corresponding to the eid input, and boolean values).
134
    '''
135
    
136
    all_eids = set(eid)
137
    kinship_groups = get_kinship_groups()
138
    
139
    relevant_kinship_groups = [kinship_group & all_eids for kinship_group in kinship_groups]
140
    relevant_kinship_groups = [kinship_group for kinship_group in relevant_kinship_groups if len(kinship_group) >= 2]
141
    unchosen_kinship_representatives = set.union(*[set(sorted(kinship_group)[1:]) for kinship_group in relevant_kinship_groups])
142
    no_kinship_mask = ~eid.isin(unchosen_kinship_representatives)
143
    
144
    if verbose:
145
        print_sys(('Constructed %d kinship groups (%d samples), of which %d (%d samples) are relevant for the dataset (i.e. containing at least 2 ' + \
146
                'samples in the dataset). Picking only one representative of each group and removing the %d other samples in those groups ' + \
147
                'has reduced the dataset from %d to %d samples.') % (len(kinship_groups), len(set.union(*kinship_groups)), \
148
                len(relevant_kinship_groups), len(set.union(*relevant_kinship_groups)), len(unchosen_kinship_representatives), len(no_kinship_mask), \
149
                no_kinship_mask.sum()))
150
    
151
    return no_kinship_mask
152
    
153
def get_kinship_groups():
154
155
    '''
156
    Uses the kinship table provided by the UKBB (as specified by the KINSHIP_TABLE_FILE_PATH configuration) in order to determine kinship groups.
157
    Each kinship group is a connected component of samples in the graph of kinships (where each node is a UKBB sample, and an edge exists between
158
    each pair of samples reported in the kinship table).
159
    @return: A list of sets of strings (the strings are the sample IDs, i.e. eid). Each set of samples is a kinship group.
160
    '''
161
    
162
    kinship_table = pd.read_csv(kinship_path, sep = ' ')
163
    kinship_ids = np.array(sorted(set(kinship_table['ID1']) | set(kinship_table['ID2'])))
164
    n_kinship_ids = len(kinship_ids)
165
    kinship_id_to_index = pd.Series(np.arange(n_kinship_ids), index = kinship_ids)
166
167
    kinship_index1 = kinship_table['ID1'].map(kinship_id_to_index).values
168
    kinship_index2 = kinship_table['ID2'].map(kinship_id_to_index).values
169
170
    symmetric_kinship_index1 = np.concatenate([kinship_index1, kinship_index2])
171
    symmetric_kinship_index2 = np.concatenate([kinship_index2, kinship_index1])
172
173
    kinship_matrix = csr_matrix((np.ones(len(symmetric_kinship_index1), dtype = bool), (symmetric_kinship_index1, \
174
            symmetric_kinship_index2)), shape = (n_kinship_ids, n_kinship_ids), dtype = bool)
175
176
    _, kinship_labels = connected_components(kinship_matrix, directed = False)
177
    kinship_labels = pd.Series(kinship_labels, index = kinship_ids)
178
    return [set(group_kinship_labels.index) for _, group_kinship_labels in kinship_labels.groupby(kinship_labels)]
179
    
180
181
def save_dict(path, obj):
182
    """save an object to a pickle file
183
184
    Args:
185
        path (str): the path to save the pickle file
186
        obj (object): any file
187
    """
188
    with open(path, 'wb') as f:
189
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
190
191
def load_dict(path):
192
    """load an object from a path
193
194
    Args:
195
        path (str): the path where the pickle file locates
196
197
    Returns:
198
        object: loaded pickle file
199
    """
200
    with open(path, 'rb') as f:
201
        return pickle.load(f)
202
    
203
def save_model(model, config, path_dir):
204
    if not os.path.exists(path_dir):
205
        os.makedirs(path_dir)
206
    torch.save(model.state_dict(), path_dir + '/model.pt')
207
    save_dict(path_dir + '/config.pkl', config)
208
209
def load_pretrained(path, model):
210
    state_dict = torch.load(os.path.join(path, 'model.pt'), map_location = torch.device('cpu'))
211
    # to support training from multi-gpus data-parallel:
212
    if next(iter(state_dict))[:7] == 'module.':
213
        # the pretrained model is from data-parallel module
214
        from collections import OrderedDict
215
        new_state_dict = OrderedDict()
216
        for k, v in state_dict.items():
217
            name = k[7:] # remove `module.`
218
            new_state_dict[name] = v
219
        state_dict = new_state_dict
220
221
    model.load_state_dict(state_dict)
222
    return model
223
224
def get_args(path):
225
    return load_dict(os.path.join(path, 'config.pkl'))
226
    
227
def print_sys(s):
228
    """system print
229
230
    Args:
231
        s (str): the string to print
232
    """
233
    print(s, flush = True, file = sys.stderr)
234
    
235
    
236
def get_plink_QC_fam():
237
    fam_path = '/dfs/project/datasets/20220524-ukbiobank/data/genetics/ukb_all.fam'
238
    data = ukbb_cohort(main_data_path, cohort_data_path, withdraw_path, keep_relatives=True).cohort
239
    df_fam = pd.read_csv(fam_path, sep = ' ', header = None)
240
    df_fam[df_fam[0].isin(data)].reset_index(drop = True).to_csv('/dfs/project/datasets/20220524-ukbiobank/data/cohort/qc_cohort.txt', header = None, index = False, sep = ' ')
241
242
    
243
def get_plink_no_rel_fam():
244
    fam_path = '/dfs/project/datasets/20220524-ukbiobank/data/genetics/ukb_all.fam'
245
    data = ukbb_cohort(main_data_path, cohort_data_path, withdraw_path, keep_relatives=False).cohort
246
    df_fam = pd.read_csv(fam_path, sep = ' ', header = None)
247
    df_fam[df_fam[0].isin(data)].reset_index(drop = True).to_csv('/dfs/project/datasets/20220524-ukbiobank/data/cohort/no_rel.fam', header = None, index = False, sep = ' ')
248
249
def get_precision_recall_at_N(res, hits_all, input_dim, N, column_rsid = 'ID', thres = 5e-8):
250
    eval_dict = {}
251
    hits_sub = res[res.P < thres][column_rsid].values
252
    p_sorted = res.sort_values('P')[column_rsid].values
253
    
254
    for K in range(1, input_dim, 10000):
255
        topK_true = np.intersect1d(hits_all, p_sorted[:K])
256
        recall = len(topK_true)/len(hits_all)
257
        if recall > N:
258
            break
259
    
260
    for K in range(K-10000, K, 1000):
261
        topK_true = np.intersect1d(hits_all, p_sorted[:K])
262
        recall = len(topK_true)/len(hits_all)
263
        if recall > N:
264
            break
265
266
    for K in range(K-1000, K, 100):
267
        topK_true = np.intersect1d(hits_all, p_sorted[:K])
268
        recall = len(topK_true)/len(hits_all)
269
        if recall > N:
270
            break
271
272
    for K in range(K-100, K, 10):
273
        topK_true = np.intersect1d(hits_all, p_sorted[:K])
274
        recall = len(topK_true)/len(hits_all)
275
        if recall > N:
276
            break
277
            
278
    for K in range(K-10, K):
279
        topK_true = np.intersect1d(hits_all, p_sorted[:K])
280
        recall = len(topK_true)/len(hits_all)
281
        if recall > N:
282
            break
283
            
284
    print_sys('PR@' + str(int(N * 100)) + ' is achieved when K = ' + str(K))
285
    eval_dict['PR@' + str(int(N * 100)) + '_K'] = K
286
    topK_true = [1 if i in hits_all else 0 for i in p_sorted[:K]]
287
    precision = precision_score(topK_true, [1] * K)        
288
    eval_dict['PR@' + str(int(N * 100))] = precision
289
    
290
    return eval_dict
291
292
def get_gwas_results(res, hits_all, input_dim, column_rsid = 'ID', thres = 5e-8):
293
    eval_dict = {}
294
    hits_sub = res[res.P < thres][column_rsid].values
295
    eval_dict['overall_recall'] = len(np.intersect1d(hits_sub, hits_all))/len(hits_all)
296
    if len(hits_sub) == 0:
297
        eval_dict['overall_precision'] = 0
298
        eval_dict['overall_f1'] = 0
299
    else:
300
        eval_dict['overall_precision'] = len(np.intersect1d(hits_sub, hits_all))/len(hits_sub)
301
        eval_dict['overall_f1'] = 2 * eval_dict['overall_recall'] * eval_dict['overall_precision']/(eval_dict['overall_recall'] + eval_dict['overall_precision'])
302
    for K in [100, 500, 1000, 5000]:
303
        topK_true = [1 if i in hits_all else 0 for i in res.sort_values('P').iloc[:K][column_rsid].values]
304
        eval_dict['precision_' + str(K)] = precision_score(topK_true, [1] * K)
305
        eval_dict['recall_' + str(K)] = sum(topK_true)/len(hits_all)
306
    
307
    eval_dict.update(get_precision_recall_at_N(res, hits_all, input_dim, 0.8, column_rsid, thres))
308
    eval_dict.update(get_precision_recall_at_N(res, hits_all, input_dim, 0.9, column_rsid, thres))
309
    eval_dict.update(get_precision_recall_at_N(res, hits_all, input_dim, 0.95, column_rsid, thres))
310
    return eval_dict
311
312
313
def find_nearest(array, value):
314
    array = np.asarray(array)
315
    idx = (np.abs(array - value)).argmin()
316
    return array[idx]
317
318
319
def get_preds(logits, multi_label):
320
    if multi_label:
321
        preds = (logits.sigmoid() > 0.5).float()
322
    elif logits.shape[1] > 1:  # multi-class
323
        preds = logits.argmax(dim=1).float()
324
    else:  # binary
325
        preds = (logits.sigmoid() > 0.5).float()
326
    return preds
327
328
def process_data(data, use_edge_attr):
329
    if not use_edge_attr:
330
        data.edge_attr = None
331
    if data.get('edge_label', None) is None:
332
        data.edge_label = {i: torch.zeros(j.shape[1]) for i, j in data.edge_index_dict.items()}
333
    return data
334
335
336
def load_checkpoint(model, model_dir, model_name, map_location=None):
337
    checkpoint = torch.load(model_dir / (model_name + '.pt'), map_location=map_location)
338
    model.load_state_dict(checkpoint['model_state_dict'])
339
340
341
def save_checkpoint(model, model_dir, model_name):
342
    torch.save({'model_state_dict': model.state_dict()}, model_dir / (model_name + '.pt'))
343
344
345
def get_lr(optimizer):
346
    for param_group in optimizer.param_groups:
347
        return param_group['lr']
348
    
349
def flatten(list_of_lists):
350
    return [item for sublist in list_of_lists for item in sublist]
351
352
353
def find_connected_components_details(edges):
354
    graph = {}
355
    for u, v in edges:
356
        if u not in graph:
357
            graph[u] = []
358
        if v not in graph:
359
            graph[v] = []
360
        graph[u].append(v)
361
        graph[v].append(u)
362
363
    def dfs(vertex):
364
        visited_nodes = set()
365
        visited_edges = set()
366
        stack = [vertex]
367
        
368
        while stack:
369
            current = stack.pop()
370
            if current not in visited_nodes:
371
                visited_nodes.add(current)
372
                for neighbor in graph[current]:
373
                    stack.append(neighbor)
374
                    if (current, neighbor) not in visited_edges and (neighbor, current) not in visited_edges:
375
                        visited_edges.add((current, neighbor))
376
        return list(visited_nodes), list(visited_edges)
377
378
    visited = set()
379
    components = []
380
381
    for vertex in tqdm(graph):
382
        if vertex not in visited:
383
            nodes, edges = dfs(vertex)
384
            components.append({
385
                'nodes': nodes,
386
                'edges': edges
387
            })
388
            visited.update(nodes)
389
390
    return components
391
392
def flatten(lst):
393
    return [item for sublist in lst for item in sublist]
394
395
396
397
def ldsc_regression_weights(ld, w_ld, N, M, hsq, intercept=None, ii=None):
398
    '''
399
    Regression weights.
400
401
    Parameters
402
    ----------
403
    ld : np.matrix with shape (n_snp, 1)
404
        LD Scores (non-partitioned).
405
    w_ld : np.matrix with shape (n_snp, 1)
406
        LD Scores (non-partitioned) computed with sum r^2 taken over only those SNPs included
407
        in the regression.
408
    N :  np.matrix of ints > 0 with shape (n_snp, 1)
409
        Number of individuals sampled for each SNP.
410
    M : float > 0
411
        Number of SNPs used for estimating LD Score (need not equal number of SNPs included in
412
        the regression).
413
    hsq : float in [0,1]
414
        Heritability estimate.
415
416
    Returns
417
    -------
418
    w : np.matrix with shape (n_snp, 1)
419
        Regression weights. Approx equal to reciprocal of conditional variance function.
420
421
    '''
422
    M = float(M)
423
    if intercept is None:
424
        intercept = 1
425
426
    hsq = max(hsq, 0.0)
427
    hsq = min(hsq, 1.0)
428
    ld = np.fmax(ld, 1.0)
429
    w_ld = np.fmax(w_ld, 1.0)
430
    c = hsq * N / M
431
    het_w = 1.0 / (2 * np.square(intercept + np.multiply(c, ld)))
432
    oc_w = 1.0 / w_ld
433
    w = np.multiply(het_w, oc_w)
434
    return w
435
436
437
def get_network_weight(run, data):
438
    model = run.best_model
439
    model = model.to('cpu')
440
    graph_data = data.data.to('cpu')
441
442
    x_dict, edge_index_dict = graph_data.x_dict, graph_data.edge_index_dict
443
    attention_all_layers = []
444
    print('Retrieving weights...')
445
446
    x_dict['SNP'] = model.snp_feat_mlp(x_dict['SNP'])
447
    x_dict['Gene'] = model.gene_feat_mlp(x_dict['Gene'])
448
    x_dict['CellularComponent'] = model.go_feat_mlp(x_dict['CellularComponent'])
449
    x_dict['BiologicalProcess'] = model.go_feat_mlp(x_dict['BiologicalProcess'])
450
    x_dict['MolecularFunction'] = model.go_feat_mlp(x_dict['MolecularFunction'])
451
452
    for conv in model.convs:
453
        x_dict = conv(x_dict, edge_index_dict, 
454
                    return_attention_weights_dict = dict(zip(list(graph_data.edge_index_dict.keys()), 
455
                                                            [True] * len(list(graph_data.edge_index_dict.keys())))),
456
                    return_raw_attention_weights_dict = dict(zip(list(graph_data.edge_index_dict.keys()), 
457
                                                            [True] * len(list(graph_data.edge_index_dict.keys())))),
458
                    )
459
        attention_layer = {i: j[1] for i,j in x_dict.items()}
460
        attention_all_layers.append(attention_layer)
461
        x_dict = {i: j[0] for i,j in x_dict.items()}
462
463
    layer2rel2att = {
464
        'l1': {},
465
        'l2': {}
466
    }
467
468
    print('Aggregating across node types...')
469
470
    for node_type in graph_data.x_dict.keys():
471
        edge_type_node = [i for i,j in graph_data.edge_index_dict.items() if i[2] == node_type]
472
        for idx, i in enumerate(attention_all_layers[0][node_type]):
473
            layer2rel2att['l1'][edge_type_node[idx]] = np.vstack((i[0].detach().cpu().numpy(), i[1].detach().cpu().numpy().reshape(-1)))
474
        for idx, i in enumerate(attention_all_layers[1][node_type]):
475
            layer2rel2att['l2'][edge_type_node[idx]] = np.vstack((i[0].detach().cpu().numpy(), i[1].detach().cpu().numpy().reshape(-1)))
476
    df_val_all = pd.DataFrame()
477
    for rel, value in layer2rel2att['l1'].items():
478
        df_val = pd.DataFrame(value).T.rename(columns = {0: 'h_idx', 1: 't_idx', 2: 'weight'})
479
        df_val['h_type'] = rel[0] 
480
        df_val['rel_type'] = rel[1] 
481
        df_val['t_type'] = rel[2] 
482
        df_val['layer'] = 'l1'
483
        df_val_all = df_val_all.append(df_val)
484
485
    for rel, value in layer2rel2att['l2'].items():
486
        df_val = pd.DataFrame(value).T.rename(columns = {0: 'h_idx', 1: 't_idx', 2: 'weight'})
487
        df_val['h_type'] = rel[0] 
488
        df_val['rel_type'] = rel[1] 
489
        df_val['t_type'] = rel[2] 
490
        df_val['layer'] = 'l2'
491
        df_val_all = df_val_all.append(df_val)
492
493
    df_val_all = df_val_all.drop_duplicates(['h_idx', 't_idx', 'rel_type', 'layer'])
494
    return df_val_all
495
496
def get_local_interpretation(query_snp, v2g, g2g, g2p, g2v, id2idx, K_neighbors):
497
    try:
498
        snp2gene_around_snp = v2g[v2g.t_idx == id2idx['SNP'][query_snp]]
499
        snp2gene_around_snp = snp2gene_around_snp.sort_values('importance')[::-1]
500
        gene_hit = snp2gene_around_snp.iloc[:K_neighbors]
501
        gene_hit.loc[:, 'rel_type'] = gene_hit.rel_type.apply(lambda x: x[4:])
502
503
        g2g_focal = pd.DataFrame()
504
        for gene in gene_hit.h_id.values:
505
            g2g_focal = g2g_focal.append(g2g[g2g.t_id == gene].sort_values('importance')[::-1].iloc[:K_neighbors])
506
        g2g_focal.loc[:,'rel_type'] = g2g_focal.rel_type.apply(lambda x: x.split('-')[1])
507
508
        g2p_focal = pd.DataFrame()
509
        for gene in gene_hit.h_id.values:
510
            g2p_focal = g2p_focal.append(g2p[g2p.t_id == gene].sort_values('importance')[::-1].iloc[:K_neighbors])
511
512
        g2p_focal.loc[:,'rel_type'] = g2p_focal.rel_type.apply(lambda x: x.split('-')[1])
513
514
        g2v_focal = pd.DataFrame()
515
        for gene in gene_hit.h_id.values:
516
            g2v_focal = g2v_focal.append(g2v[g2v.t_id == gene].sort_values('importance')[::-1].iloc[:K_neighbors])
517
        local_neighborhood_around_snp = pd.concat((gene_hit, g2g_focal, g2p_focal, g2v_focal))
518
        local_neighborhood_around_snp.loc[:,'QUERY_SNP'] = query_snp
519
        return local_neighborhood_around_snp
520
    except:
521
        return None
522
523
def generate_viz(run, df_network, data_path, variant_threshold = 5e-8, 
524
                magma_path = None, magma_threshold = 0.05, program_threshold = 0.05,
525
                K_neighbors = 3, num_cpus = 1):
526
    gwas = run.kgwas_res
527
    idx2id = run.data.idx2id
528
    id2idx = run.data.id2idx
529
    print('Start generating disease critical network...')
530
531
    gene_sets = load_dict(os.path.join(data_path, 'misc_data/gene_set_bp.pkl'))
532
    with open(os.path.join(data_path, 'misc_data/go2name.pkl'), 'rb') as f:
533
        go2name = pickle.load(f)
534
    
535
    df_network = df_network[~df_network.rel_type.isin(['TSS', 'rev_TSS'])]
536
537
    snp2genes = df_network[(df_network.t_type == 'SNP') 
538
                       & (df_network.h_type == 'Gene')]
539
    gene2gene = df_network[(df_network.t_type == 'Gene') 
540
                           & (df_network.h_type == 'Gene')]
541
    gene2go = df_network[(df_network.t_type == 'Gene') 
542
                               & (df_network.h_type.isin(['BiologicalProcess']))]
543
544
    if 'SNP' not in gwas.columns.values:
545
        gwas.loc[:, 'SNP'] = gwas['ID']
546
    hit_snps = gwas[gwas.P < 5e-8].SNP.values
547
    hit_snps_idx = [id2idx['SNP'][i] for i in hit_snps]
548
    
549
    if magma_path is not None:
550
        # use magma genes and GSEA programs
551
        print('Using MAGMA genes to filter...')
552
        gwas_gene = pd.read_csv(magma_path, sep = '\s+')
553
        id2gene = dict(pd.read_csv(os.path.join(data_path, 'misc_data/NCBI37.3.gene.loc'), sep = '\t', header = None)[[0,5]].values)
554
        gwas_gene.loc[:,'GENE'] = gwas_gene['GENE'].apply(lambda x: id2gene[x])
555
556
        import statsmodels.api as sm
557
        p_values = gwas_gene['P']
558
        corrected_p_values = sm.stats.multipletests(p_values, alpha=magma_threshold, method='bonferroni')[1]
559
        gwas_gene.loc[:,'corrected_p_value'] = corrected_p_values
560
        df_gene_hits = gwas_gene[gwas_gene['corrected_p_value'] < magma_threshold]
561
        rnk = df_gene_hits[['GENE', 'ZSTAT']].set_index('GENE')
562
        gene_hit_idx = [id2idx['Gene'][i] for i in df_gene_hits.GENE.values if i in id2idx['Gene']]
563
564
        try:
565
            gsea_results_BP = gp.prerank(rnk=rnk, gene_sets=gene_sets, 
566
                                        outdir=None, permutation_num=100, 
567
                                        min_size=2, max_size=1000, seed = 42)
568
            gsea_results_BP = gsea_results_BP.res2d
569
            go_hits = gsea_results_BP[gsea_results_BP['NOM p-val'] < program_threshold].Term.values
570
            if len(go_hits) <= 5:
571
                go_hits = gsea_results_BP.sort_values('NOM p-val')[:5].Term.values
572
            go_hits_idx = [id2idx['BiologicalProcess'][x] for x in go_hits]
573
            print('Using GSEA gene programs to filter...')
574
        except:
575
            print('No significant gene programs found...')
576
            go_hits_idx = []
577
    else:
578
        # use all genes and gene programs
579
        print('No filters... Using all genes and gene programs...')
580
        gene_hit_idx = list(id2idx['Gene'].values())
581
        go_hits_idx = list(id2idx['BiologicalProcess'].values())
582
    
583
584
    snp2genes_hit = snp2genes[snp2genes.t_idx.isin(hit_snps_idx) & snp2genes.h_idx.isin(gene_hit_idx)]
585
    rel2mean = snp2genes_hit.groupby('rel_type').weight.mean().reset_index().rename(columns = {'weight': 'rel_type_mean'})
586
    rel2std = snp2genes_hit.groupby('rel_type').weight.agg(np.std).reset_index().rename(columns = {'weight': 'rel_type_std'})
587
588
    snp2genes_hit = snp2genes_hit.merge(rel2std)
589
    snp2genes_hit = snp2genes_hit.merge(rel2mean)
590
    snp2genes_hit.loc[:,'z_rel'] = (snp2genes_hit['weight'] - snp2genes_hit['rel_type_mean'])/snp2genes_hit['rel_type_std']
591
    
592
    v2g_hit = snp2genes_hit.groupby(['h_idx', 't_idx']).z_rel.max().reset_index().rename(columns={'z_rel': 'importance'})
593
    v2g_hit_with_rel_type = pd.merge(v2g_hit, snp2genes_hit, left_on=['h_idx', 't_idx', 'importance'], right_on=['h_idx', 't_idx', 'z_rel'], how='left')
594
    v2g_hit = v2g_hit_with_rel_type[['h_idx', 't_idx', 'importance', 'h_type', 't_type', 'rel_type']]
595
    v2g_hit.loc[:,'rel_type'] = v2g_hit.rel_type.apply(lambda x: x[4:])
596
    v2g_hit.loc[:,'Category'] = 'V2G'
597
598
    v2g_hit.loc[:,'h_id'] = v2g_hit['h_idx'].apply(lambda x: idx2id['Gene'][x])
599
    v2g_hit.loc[:,'t_id'] = v2g_hit['t_idx'].apply(lambda x: idx2id['SNP'][x])
600
601
    gene2gene_hit = gene2gene[gene2gene.h_idx.isin(gene_hit_idx) & gene2gene.t_idx.isin(gene_hit_idx)]
602
    rel2mean = gene2gene_hit.groupby('rel_type').weight.mean().reset_index().rename(columns = {'weight': 'rel_type_mean'})
603
    rel2std = gene2gene_hit.groupby('rel_type').weight.agg(np.std).reset_index().rename(columns = {'weight': 'rel_type_std'})
604
605
    gene2gene_hit = gene2gene_hit.merge(rel2std)
606
    gene2gene_hit = gene2gene_hit.merge(rel2mean)
607
    gene2gene_hit.loc[:,'z_rel'] = (gene2gene_hit['weight'] - gene2gene_hit['rel_type_mean'])/gene2gene_hit['rel_type_std']
608
609
    g2g_hit = gene2gene_hit.groupby(['h_idx', 't_idx']).z_rel.max().reset_index().rename(columns={'z_rel': 'importance'})
610
    g2g_hit_with_rel_type = pd.merge(g2g_hit, gene2gene_hit, left_on=['h_idx', 't_idx', 'importance'], right_on=['h_idx', 't_idx', 'z_rel'], how='left')
611
    g2g_hit = g2g_hit_with_rel_type[['h_idx', 't_idx', 'importance', 'h_type', 't_type', 'rel_type']]
612
    g2g_hit.loc[:,'rel_type'] = g2g_hit.rel_type.apply(lambda x: x.split('-')[1])
613
    g2g_hit.loc[:,'Category'] = 'G2G'
614
615
    g2g_hit.loc[:,'h_id'] = g2g_hit['h_idx'].apply(lambda x: idx2id['Gene'][x])
616
    g2g_hit.loc[:,'t_id'] = g2g_hit['t_idx'].apply(lambda x: idx2id['Gene'][x])
617
618
    gene2program_hit = gene2go[gene2go.t_idx.isin(gene_hit_idx) & gene2go.h_idx.isin(go_hits_idx)]
619
    rel2mean = gene2program_hit.groupby('rel_type').weight.mean().reset_index().rename(columns = {'weight': 'rel_type_mean'})
620
    rel2std = gene2program_hit.groupby('rel_type').weight.agg(np.std).reset_index().rename(columns = {'weight': 'rel_type_std'})
621
622
    gene2program_hit = gene2program_hit.merge(rel2std)
623
    gene2program_hit = gene2program_hit.merge(rel2mean)
624
    gene2program_hit.loc[:,'z_rel'] = (gene2program_hit['weight'] - gene2program_hit['rel_type_mean'])/gene2program_hit['rel_type_std']
625
626
    g2p_hit = gene2program_hit.groupby(['h_idx', 't_idx']).z_rel.max().reset_index().rename(columns={'z_rel': 'importance'})
627
628
    g2p_hit_with_rel_type = pd.merge(g2p_hit, gene2program_hit, left_on=['h_idx', 't_idx', 'importance'], right_on=['h_idx', 't_idx', 'z_rel'], how='left')
629
    g2p_hit = g2p_hit_with_rel_type[['h_idx', 't_idx', 'importance', 'h_type', 't_type', 'rel_type']]
630
    g2p_hit.loc[:,'rel_type'] = g2p_hit.rel_type.apply(lambda x: x.split('-')[1])
631
    g2p_hit.loc[:,'Category'] = 'G2P'
632
    g2p_hit.loc[:,'h_id'] = g2p_hit['h_idx'].apply(lambda x: idx2id['BiologicalProcess'][x])
633
    g2p_hit.loc[:,'t_id'] = g2p_hit['t_idx'].apply(lambda x: idx2id['Gene'][x])
634
    g2p_hit.loc[:,'h_id'] = g2p_hit.h_id.apply(lambda x: go2name[x].capitalize() if x in go2name else x)
635
    disease_critical_network = pd.concat((v2g_hit, g2g_hit, g2p_hit)).reset_index(drop = True)
636
637
    print('Disease critical network finished generating...')
638
    print('Generating variant interpretation networks...')
639
640
    #### get for variant interpretation -> since we are looking at top K neighbors, we don't filter
641
    
642
    # V2G
643
    rel2mean = snp2genes_hit.groupby('rel_type').weight.mean().reset_index().rename(columns = {'weight': 'rel_type_mean'})
644
    rel2std = snp2genes_hit.groupby('rel_type').weight.agg(np.std).reset_index().rename(columns = {'weight': 'rel_type_std'})
645
646
    snp2genes = snp2genes.merge(rel2std)
647
    snp2genes = snp2genes.merge(rel2mean)
648
    snp2genes.loc[:,'z_rel'] = (snp2genes['weight'] - snp2genes['rel_type_mean'])/snp2genes['rel_type_std']
649
    snp2genes = snp2genes.rename(columns={'z_rel': 'importance'})
650
    v2g = snp2genes.groupby(['h_idx', 't_idx']).importance.max().reset_index()
651
    v2g_with_rel_type = pd.merge(v2g, snp2genes, left_on=['h_idx', 't_idx', 'importance'], right_on=['h_idx', 't_idx', 'importance'], how='left')
652
    v2g = v2g_with_rel_type[['h_idx', 't_idx', 'importance', 'h_type', 't_type', 'rel_type']]
653
654
    v2g.loc[:,'h_id'] = v2g['h_idx'].apply(lambda x: idx2id['Gene'][x])
655
    v2g.loc[:,'t_id'] = v2g['t_idx'].apply(lambda x: idx2id['SNP'][x])
656
657
    ## G2G
658
659
    rel2mean = gene2gene_hit.groupby('rel_type').weight.mean().reset_index().rename(columns = {'weight': 'rel_type_mean'})
660
    rel2std = gene2gene_hit.groupby('rel_type').weight.agg(np.std).reset_index().rename(columns = {'weight': 'rel_type_std'})
661
662
    gene2gene = gene2gene.merge(rel2std)
663
    gene2gene = gene2gene.merge(rel2mean)
664
    gene2gene.loc[:,'z_rel'] = (gene2gene['weight'] - gene2gene['rel_type_mean'])/gene2gene['rel_type_std']
665
    gene2gene = gene2gene.rename(columns={'z_rel': 'importance'})
666
667
    g2g = gene2gene.groupby(['h_idx', 't_idx']).importance.max().reset_index()
668
    g2g_with_rel_type = pd.merge(g2g, gene2gene, left_on=['h_idx', 't_idx', 'importance'], right_on=['h_idx', 't_idx', 'importance'], how='left')
669
    g2g = g2g_with_rel_type[['h_idx', 't_idx', 'importance', 'h_type', 't_type', 'rel_type']]
670
671
    g2g.loc[:,'h_id'] = g2g['h_idx'].apply(lambda x: idx2id['Gene'][x])
672
    g2g.loc[:,'t_id'] = g2g['t_idx'].apply(lambda x: idx2id['Gene'][x])
673
    g2g = g2g[g2g.h_idx != g2g.t_idx]
674
675
    ## G2P
676
677
    rel2mean = gene2program_hit.groupby('rel_type').weight.mean().reset_index().rename(columns = {'weight': 'rel_type_mean'})
678
    rel2std = gene2program_hit.groupby('rel_type').weight.agg(np.std).reset_index().rename(columns = {'weight': 'rel_type_std'})
679
680
    gene2go = gene2go.merge(rel2std)
681
    gene2go = gene2go.merge(rel2mean)
682
    gene2go.loc[:,'z_rel'] = (gene2go['weight'] - gene2go['rel_type_mean'])/gene2go['rel_type_std']
683
    gene2go = gene2go.rename(columns={'z_rel': 'importance'})
684
685
    g2p = gene2go.groupby(['h_idx', 't_idx']).importance.max().reset_index()
686
    g2p_with_rel_type = pd.merge(g2p, gene2go, left_on=['h_idx', 't_idx', 'importance'], right_on=['h_idx', 't_idx', 'importance'], how='left')
687
    g2p = g2p_with_rel_type[['h_idx', 't_idx', 'importance', 'h_type', 't_type', 'rel_type']]
688
689
    g2p.loc[:,'h_id'] = g2p['h_idx'].apply(lambda x: go2name[idx2id['BiologicalProcess'][x]].capitalize() if idx2id['BiologicalProcess'][x] in go2name else idx2id['BiologicalProcess'][x])
690
    g2p.loc[:,'t_id'] = g2p['t_idx'].apply(lambda x: idx2id['Gene'][x])
691
692
693
    ## G2V
694
695
    gene2snp = df_network[(df_network.h_type == 'SNP') 
696
                       & (df_network.t_type == 'Gene')]
697
698
    gene2snp_hit = gene2snp[gene2snp.h_idx.isin(hit_snps_idx) & gene2snp.t_idx.isin(gene_hit_idx)]
699
700
    rel2mean = gene2snp_hit.groupby('rel_type').weight.mean().reset_index().rename(columns = {'weight': 'rel_type_mean'})
701
    rel2std = gene2snp_hit.groupby('rel_type').weight.agg(np.std).reset_index().rename(columns = {'weight': 'rel_type_std'})
702
703
    gene2snp = gene2snp.merge(rel2std)
704
    gene2snp = gene2snp.merge(rel2mean)
705
    gene2snp.loc[:,'z_rel'] = (gene2snp['weight'] - gene2snp['rel_type_mean'])/gene2snp['rel_type_std']
706
    gene2snp = gene2snp.rename(columns={'z_rel': 'importance'})
707
708
    g2v = gene2snp.groupby(['h_idx', 't_idx']).importance.max().reset_index()
709
    g2v_with_rel_type = pd.merge(g2v, gene2snp, left_on=['h_idx', 't_idx', 'importance'], right_on=['h_idx', 't_idx', 'importance'], how='left')
710
    g2v = g2v_with_rel_type[['h_idx', 't_idx', 'importance', 'h_type', 't_type', 'rel_type']]
711
712
    g2v.loc[:,'h_id'] = g2v['h_idx'].apply(lambda x: idx2id['SNP'][x])
713
    g2v.loc[:,'t_id'] = g2v['t_idx'].apply(lambda x: idx2id['Gene'][x])
714
    
715
    print('Number of hit snps: ', len(hit_snps))
716
    process_func = partial(get_local_interpretation, v2g=v2g, g2g=g2g, g2p=g2p, g2v=g2v, id2idx=id2idx, K_neighbors=K_neighbors)
717
718
    with Pool(num_cpus) as p:
719
        res = list(tqdm(p.imap(process_func, hit_snps), total=len(hit_snps)))
720
    try:
721
        df_variant_interpretation = pd.concat([i for i in res if i is not None])
722
    except:
723
        df_variant_interpretation = pd.DataFrame()
724
725
    return df_variant_interpretation, disease_critical_network