a b/kgwas/eval_utils.py
1
import numpy as np
2
import pandas as pd
3
import torch
4
from scipy import interpolate
5
6
from .utils import load_dict
7
import pandas as pd
8
import numpy as np
9
from copy import copy
10
11
def find_closest_x(df_pred, lower_bound=0, upper_bound=200, tolerance=0.01):
12
    upper = 1e-2
13
    lower = 1e-3
14
    
15
    while lower_bound <= upper_bound:
16
        mid = (lower_bound + upper_bound) / 2
17
        #result = len(np.where(df_pred.P_weighted.values * mid < 2e-4)[0]) / len(np.where(df_pred.P.values < 2e-4)[0])
18
        res1 = len(np.where((df_pred.P_weighted.values * mid < upper) & (df_pred.P_weighted.values * mid > lower))[0])
19
        res2 = len(np.where((df_pred.P.values < upper) & (df_pred.P.values > lower))[0])
20
        result = res1/res2
21
        if abs(result - 1) < tolerance:
22
            return mid
23
        elif result > 1:
24
            lower_bound = mid + tolerance
25
        else:
26
            upper_bound = mid - tolerance
27
28
    return mid
29
30
def get_clumps_gold_label(data_path, gold_label_gwas, t_p = 5e-8, no_hla = False, column = 'P', snp2ld_snps = None):
31
    snp2ld_snps_with_hla = load_dict(data_path + 'ld_score/ukb_white_ld_10MB.pkl')
32
    snp2ld_snps_no_hla = load_dict(data_path + 'ld_score/ukb_white_ld_10MB_no_hla.pkl')
33
34
    if not snp2ld_snps:
35
        if no_hla:
36
            snp2ld_snps = snp2ld_snps_no_hla
37
        else:
38
            snp2ld_snps = snp2ld_snps_with_hla
39
    clumps = []
40
    snps_in_clumps = []
41
    snp_hits = gold_label_gwas[gold_label_gwas[column] < t_p].sort_values(column).SNP.values
42
    for snp in snp_hits:
43
        if snp in snps_in_clumps:
44
            ## already in existing clumps => not create a new clump
45
            pass
46
        else:
47
            if snp in snp2ld_snps:
48
                # ld block
49
                clumps.append([snp] + snp2ld_snps[snp])
50
                snps_in_clumps += snp2ld_snps[snp]
51
                snps_in_clumps += [snp]
52
            else:
53
                # no other SNPs tagged
54
                clumps.append([snp])
55
                snps_in_clumps += [snp]
56
    return clumps
57
58
def get_meta_clumps(clumps, data_path):
59
    snp2cm = dict(pd.read_csv(data_path + 'misc_data/ukb_white_with_cm.bim', sep = '\t', header = None)[[1, 2]].values)
60
    snp2chr = dict(pd.read_csv(data_path + 'misc_data/ukb_white_with_cm.bim', sep = '\t', header = None)[[1, 0]].values)
61
62
    idx2clump = {'Clump ' + str(idx): i for idx, i in enumerate(clumps)}
63
    idx2clump_chromosome = {'Clump ' + str(idx): snp2chr[i[0]] for idx, i in enumerate(clumps)}
64
    idx2clump_cm = {'Clump ' + str(idx): snp2cm[i[0]] for idx, i in enumerate(clumps)}
65
    
66
    idx2clump_cm_min = {'Clump ' + str(idx): min([snp2cm[x] for x in i]) for idx, i in enumerate(clumps)}
67
    idx2clump_cm_max = {'Clump ' + str(idx): max([snp2cm[x] for x in i]) for idx, i in enumerate(clumps)}
68
    
69
    df_clumps = pd.DataFrame([idx2clump_chromosome, idx2clump_cm, idx2clump, idx2clump_cm_min, idx2clump_cm_max]).T.reset_index().rename(columns = {'index': 'Clump idx', 0: 'Chromosome', 1: 'cM',  2: 'Clump rsids', 3: 'cM_min',4: 'cM_max'})
70
    
71
    all_mega_clump_across_chr = []
72
    for chrom in df_clumps.Chromosome.unique():
73
        df_clump_chr = df_clumps[df_clumps.Chromosome == chrom]
74
        all_mega_clump = []
75
        cur_mega_clump = []
76
        base_cM = 0
77
        for i,cM_hit,cM_min,cM_max in df_clump_chr.sort_values('cM')[['Clump idx', 'cM', 'cM_min', 'cM_max']].values:
78
            if (cM_min - base_cM) < 0.1:
79
                cur_mega_clump.append(i)
80
                base_cM = cM_max
81
            else:
82
                ### this clump is >0.1 cM farther away from the previous clump
83
                all_mega_clump.append(cur_mega_clump)
84
                base_cM = cM_max
85
                cur_mega_clump = [i]
86
        all_mega_clump.append(cur_mega_clump)
87
        if len(all_mega_clump[0]) == 0:
88
            all_mega_clump_across_chr += all_mega_clump[1:]
89
        else:
90
            all_mega_clump_across_chr += all_mega_clump
91
    idx2mega_clump = {'Mega-Clump '+str(idx): i for idx, i in enumerate(all_mega_clump_across_chr)}
92
    
93
    def flatten(l):
94
        return [item for sublist in l for item in sublist]
95
    
96
    idx2mega_clump_rsid = {'Mega-Clump '+str(idx): flatten([idx2clump[j] for j in i]) for idx, i in enumerate(all_mega_clump_across_chr)}
97
    idx2mega_clump_chrom = {'Mega-Clump '+str(idx): idx2clump_chromosome[i[0]] for idx, i in enumerate(all_mega_clump_across_chr)}
98
    
99
    return idx2mega_clump, idx2mega_clump_rsid, idx2mega_clump_chrom
100
    
101
    
102
def get_mega_clump_query(data_path, clumps, snp_hits, no_hla = False, snp2ld_snps = None):
103
    snp2ld_snps_with_hla = load_dict(data_path + 'ld_score/ukb_white_ld_10MB.pkl')
104
    snp2ld_snps_no_hla = load_dict(data_path + 'ld_score/ukb_white_ld_10MB_no_hla.pkl')
105
106
    if not snp2ld_snps:
107
        if no_hla:
108
            snp2ld_snps = snp2ld_snps_no_hla
109
        else:
110
            snp2ld_snps = snp2ld_snps_with_hla
111
        
112
    clumps_pred = []
113
    snps_in_clumps_pred = []
114
    K = max(len(clumps) * 3, 100)
115
    for snp in snp_hits:
116
        ## top ranked snps
117
        if len(clumps_pred) >= K:
118
            ## just going to get the top K clumps where K is set to be very large number -> we don't generate all clumps since as K goes extremely large, they are never prioritized and evaluated.
119
            break
120
        else:
121
            if snp in snps_in_clumps_pred:
122
                ## already in previous found clumps, move forward
123
                pass
124
            else:
125
                if snp in snp2ld_snps:
126
                    # this snp has ld tagged snps
127
                    clumps_pred.append([snp] + snp2ld_snps[snp])
128
                    snps_in_clumps_pred += snp2ld_snps[snp]
129
                    snps_in_clumps_pred += [snp]
130
                else:
131
                    # this snp does not have ld tagged snps, at least in UKB
132
                    clumps_pred.append([snp])
133
                    snps_in_clumps_pred += [snp]
134
    idx2mega_clump_pred, idx2mega_clump_rsid_pred, idx2mega_clump_chrom_pred = get_meta_clumps(clumps_pred, data_path)
135
    return idx2mega_clump_pred, idx2mega_clump_rsid_pred, idx2mega_clump_chrom_pred
136
137
def get_curve(mega_clump_pred, mega_clump_gold):
138
    recall_k = {}
139
    precision_k = {}
140
    found_clump_idx = []
141
    clump_idx_record = {}
142
    pred_clump_has_hit_count = 0
143
    for k, query_clump in enumerate(mega_clump_pred):
144
        ## go through the predicted top ranked clumps one by one
145
        k += 1
146
        does_this_clump_overlap_with_any_true_clumps = False
147
        ## this is used to calculate precision, to see if this clump overlaps with any of the gold clumps
148
        for clump_idx, clump in enumerate(mega_clump_gold):
149
            ## overlaps with this gold clump
150
            if len(np.intersect1d(query_clump, clump)) > 0:
151
                if clump_idx not in found_clump_idx:
152
                    ## if the clump is never found before, flag it
153
                    found_clump_idx.append(clump_idx)
154
                does_this_clump_overlap_with_any_true_clumps = True
155
        clump_idx_record[k] = copy(found_clump_idx)
156
        if does_this_clump_overlap_with_any_true_clumps:
157
            pred_clump_has_hit_count += 1
158
159
        recall_k[k] = len(found_clump_idx)/len(mega_clump_gold)
160
        precision_k[k] = pred_clump_has_hit_count/k
161
162
    #sns.scatterplot([recall_k[k+1] for k in range(len(mega_clump_pred))], [precision_k[k+1] for k in range(len(mega_clump_pred))], s = 1)
163
    return recall_k, precision_k, clump_idx_record
164
165
def get_prec_recall(pred_hits, gold_hits):
166
    recall = len(np.intersect1d(pred_hits, gold_hits))/len(gold_hits)
167
    if len(pred_hits) != 0:
168
        precision = len(np.intersect1d(pred_hits, gold_hits))/len(pred_hits)
169
    else:
170
        precision = 0
171
    return {'recall': recall,
172
           'precision': precision}
173
174
def find_nearest(array, value):
175
    array = np.asarray(array)
176
    idx = (np.abs(array - value)).argmin()
177
    return array[idx]
178
179
def get_cluster_from_gwas(df, cluster_distance_threshold = 500000, \
180
                          threshold_extend = False, cluster_compare_threshold = None, \
181
                         verbose = True):
182
    
183
    cluster_chr_pos = {}
184
    cluster_chr_rs = {}
185
186
    for chr_num in df['#CHROM'].unique():
187
        df_hits_chr = df[df['#CHROM'] == chr_num]
188
        df_hits_chr = df_hits_chr.sort_values('POS')
189
        pos = df_hits_chr.POS.values
190
        rs = df_hits_chr.ID.values
191
192
        cluster_set = []
193
        cluster_set_rs = []
194
195
        cur_pos = pos[0]
196
        cur_rs = rs[0]
197
        cur_set = [cur_pos]
198
        cur_set_rs = [rs[0]]
199
200
        for idx, next_pos in enumerate(pos[1:]):
201
202
            if next_pos - cur_pos < cluster_distance_threshold:
203
                cur_set.append(next_pos)
204
                cur_set_rs.append(rs[idx + 1])
205
                if threshold_extend:
206
                    cur_pos = next_pos
207
            else:
208
                cluster_set.append(cur_set)
209
                cluster_set_rs.append(cur_set_rs)
210
                cur_pos = next_pos
211
                cur_set = [cur_pos]
212
                cur_set_rs = [rs[idx + 1]]
213
214
        cluster_set.append(cur_set)
215
        cluster_set_rs.append(cur_set_rs)
216
217
        cluster_chr_pos[chr_num] = cluster_set
218
        cluster_chr_rs[chr_num] = cluster_set_rs
219
        
220
    cluster_chr_pos_flatten = {}
221
    cluster_chr_cluster_idx_flatten = {}
222
    cluster_chr_cluster_pos2idx_flatten = {}
223
224
    for chr_num, cluster_list in cluster_chr_pos.items():
225
        pos_flatten = []
226
        idx_flatten = []
227
        for idx, cluster in enumerate(cluster_list):
228
            pos_flatten = pos_flatten + cluster
229
            idx_flatten = idx_flatten + [idx] * len(cluster)
230
        cluster_chr_pos_flatten[chr_num] = pos_flatten
231
        cluster_chr_cluster_idx_flatten[chr_num] = idx_flatten
232
        cluster_chr_cluster_pos2idx_flatten[chr_num] = dict(zip(pos_flatten, idx_flatten))
233
        
234
    if verbose:
235
        print('Number of clusters: ' + str(sum([len(j) for j in cluster_chr_pos.values()])))
236
    
237
    cluster_chr_range = {}
238
    for i,j in cluster_chr_pos.items():
239
        cluster_chr_range[i] = [(min(x) - cluster_compare_threshold, max(x) + cluster_compare_threshold) for x in j]
240
    
241
    return cluster_chr_pos, cluster_chr_rs, cluster_chr_pos_flatten, \
242
            cluster_chr_cluster_idx_flatten, cluster_chr_cluster_pos2idx_flatten, cluster_chr_range
243
244
245
def get_cluster_hits_from_pred(pred_hits, threshold, lr_uni, cluster_chr_pos_flatten, cluster_chr_cluster_pos2idx_flatten):
246
    df_hits = lr_uni[lr_uni.ID.isin(pred_hits)]
247
    df_hits['closest_cluster'] = df_hits.apply(lambda x: find_nearest(cluster_chr_pos_flatten[x['#CHROM']], x.POS), axis = 1)
248
    df_hits['distance2cluster'] = df_hits.apply(lambda x: abs(x.closest_cluster - x.POS), axis = 1)
249
    df_hits['include_as_cluster'] = df_hits.apply(lambda x: x.distance2cluster < threshold, axis = 1)
250
    df_hits['cluster_id'] = df_hits.apply(lambda x: str(x['#CHROM']) + '_' + str(cluster_chr_cluster_pos2idx_flatten[x['#CHROM']][x['closest_cluster']]), axis = 1)
251
    cluster2count = dict(df_hits[df_hits.include_as_cluster].cluster_id.value_counts())
252
    num_non_hits = len(df_hits[~df_hits.include_as_cluster])
253
    novel_rs_id = df_hits[~df_hits.include_as_cluster].ID.values
254
    print('Number of predicted hits: ' + str(len(pred_hits)))
255
    print('Number of predicted hits not in the existing clusters: ' + str(len(novel_rs_id)))
256
    print('Number of cluster hits: ' + str(len(cluster2count)))
257
    return cluster2count, num_non_hits, df_hits, novel_rs_id
258
259
def plot_cluster_range(chr_num, gnn_cluster_chr_range, cluster_chr_range, \
260
                       gold_cluster_chr_range, findor_cluster_chr_range, x_start = None, x_end = None, \
261
                       base_gwas_name = 'FastGWA', gold_ref_name = 'GWAS Catalog'):
262
263
    fig = plt.figure(figsize=(14, 3)) # Set the figure size
264
    ax = fig.add_subplot(111)
265
    
266
    if chr_num not in cluster_chr_range:
267
        cluster_chr_range[chr_num] = {}
268
    if chr_num not in gnn_cluster_chr_range:
269
        gnn_cluster_chr_range[chr_num] = {}
270
    if chr_num not in gold_cluster_chr_range:
271
        gold_cluster_chr_range[chr_num] = {}
272
        
273
    if chr_num not in findor_cluster_chr_range:
274
        findor_cluster_chr_range[chr_num] = {}
275
    
276
    for i in findor_cluster_chr_range[chr_num]:
277
        plt.plot(i, ['FINDOR', 'FINDOR'], '*-')  
278
    
279
    for i in gnn_cluster_chr_range[chr_num]:
280
        plt.plot(i, ['GNN', 'GNN'], 's-')
281
282
    for i in cluster_chr_range[chr_num]:
283
        plt.plot(i, [base_gwas_name, base_gwas_name], '^-')
284
285
    for i in gold_cluster_chr_range[chr_num]:
286
        plt.plot(i, [gold_ref_name, gold_ref_name], 'o-')  
287
288
    plt.xlabel('Position Index at Chromosome ' + str(chr_num))
289
    
290
    if x_start is not None:
291
        ax.set_xlim([x_start,x_end])
292
    plt.show()
293
294
def get_pr_curve(cluster_distance_threshold, gold_label_gwas_hits, method_hit_gwas, low_data_gwas_hits, \
295
                 cluster_compare_threshold = None, method_name = 'gnn'):
296
    if cluster_compare_threshold is None:
297
        cluster_compare_threshold = int(cluster_distance_threshold/2)
298
    gold_cluster_chr_pos, gold_cluster_chr_rs, \
299
    gold_cluster_chr_pos_flatten, gold_cluster_chr_cluster_idx_flatten, \
300
    gold_cluster_chr_cluster_pos2idx_flatten, gold_cluster_chr_range = get_cluster_from_gwas(gold_label_gwas_hits, \
301
                                                                     cluster_distance_threshold, \
302
                                                                    threshold_extend = threshold_extend, \
303
                                                                    cluster_compare_threshold = cluster_compare_threshold, \
304
                                                                    verbose = False)
305
306
    cluster_chr_pos, cluster_chr_rs, \
307
    cluster_chr_pos_flatten, cluster_chr_cluster_idx_flatten, \
308
    cluster_chr_cluster_pos2idx_flatten, cluster_chr_range = get_cluster_from_gwas(low_data_gwas_hits, \
309
                                                                cluster_distance_threshold, \
310
                                                                threshold_extend = threshold_extend, \
311
                                                                cluster_compare_threshold = cluster_compare_threshold, \
312
                                                                verbose = False)
313
    
314
    gnn_cluster_chr_pos, gnn_cluster_chr_rs, \
315
    gnn_cluster_chr_pos_flatten, gnn_cluster_chr_cluster_idx_flatten, \
316
    gnn_cluster_chr_cluster_pos2idx_flatten, gnn_cluster_chr_range = get_cluster_from_gwas(method_hit_gwas, \
317
                                                                    cluster_distance_threshold, \
318
                                                                    threshold_extend = threshold_extend, \
319
                                                                    cluster_compare_threshold = cluster_compare_threshold, \
320
                                                                    verbose = False)        
321
    
322
    total = sum([len(j) for i,j in gold_cluster_chr_range.items()])
323
    
324
    #plink_set_overlap = sum([len(j) for j in find_overlap_clusters(cluster_chr_range, gold_cluster_chr_range).values()])
325
    plink_set_total = sum([len(j) for i,j in cluster_chr_range.items()])
326
    
327
    plink_set_overlap_ref = 0
328
    plink_set_overlap_query = 0
329
    for j in find_overlap_clusters(cluster_chr_range, gold_cluster_chr_range).values():
330
        plink_set_overlap_ref += len(np.unique([set(i[1]) for i in j]))
331
        plink_set_overlap_query += len(np.unique([set(i[0]) for i in j]))
332
        
333
    #gnn_set_overlap = sum([len(j) for j in find_overlap_clusters(gnn_cluster_chr_range, gold_cluster_chr_range).values()])
334
    gnn_set_total = sum([len(j) for i,j in gnn_cluster_chr_range.items()])
335
    
336
    gnn_set_overlap_ref = 0
337
    gnn_set_overlap_query = 0
338
    for j in find_overlap_clusters(gnn_cluster_chr_range, gold_cluster_chr_range).values():
339
        gnn_set_overlap_ref += len(np.unique([set(i[1]) for i in j]))
340
        gnn_set_overlap_query += len(np.unique([set(i[0]) for i in j]))
341
    
342
    
343
    '''
344
    low_data_gold_hits = low_data_gwas[low_data_gwas.ID.isin(gold_label_gwas_hits.ID.values)]
345
    low_data_gold_hits['cluster_id'] = low_data_gold_hits.apply(lambda x: str(x['#CHROM']) + '_' + \
346
                                                            str(gold_cluster_chr_cluster_pos2idx_flatten[x['#CHROM']][x.POS]), axis = 1)
347
    cluster2min_p = dict(low_data_gold_hits.groupby('cluster_id').P.min())
348
    flat_clusters = [i for i,j in cluster2min_p.items() if j > 1e-3]
349
    gold_label_gwas_hits['closest_cluster'] = gold_label_gwas_hits.apply(lambda x: find_nearest(gold_cluster_chr_pos_flatten[x['#CHROM']], x.POS), axis = 1)
350
    gold_label_gwas_hits['distance2cluster'] = gold_label_gwas_hits.apply(lambda x: abs(x.closest_cluster - x.POS), axis = 1)
351
    gold_label_gwas_hits['cluster_id'] = gold_label_gwas_hits.apply(lambda x: str(x['#CHROM']) + '_' + str(gold_cluster_chr_cluster_pos2idx_flatten[x['#CHROM']][x['closest_cluster']]), axis = 1)
352
    pos_pred = np.unique(low_data_gwas_hits.ID.values.tolist() + pred_hits.tolist())
353
    flat_cluster_range = {}
354
    for i in flat_clusters:
355
        chr_num = int(i.split('_')[0])
356
        cluster_idx = int(i.split('_')[1])
357
        if chr_num in flat_cluster_range:
358
            flat_cluster_range[chr_num].append(gold_cluster_chr_range[chr_num][cluster_idx])
359
        else:
360
            flat_cluster_range[chr_num] = [gold_cluster_chr_range[chr_num][cluster_idx]]
361
362
    flat_cluster_recalled = sum([len(j) for j in find_overlap_clusters(gnn_cluster_chr_range, flat_cluster_range).values()])
363
    flat_cluster_recalled_plink = sum([len(j) for j in find_overlap_clusters(cluster_chr_range, flat_cluster_range).values()])
364
365
    '''
366
    
367
    if gnn_set_total == 0:
368
        gnn_set_precision = -1
369
    else:
370
        gnn_set_precision = gnn_set_overlap_query/gnn_set_total
371
    
372
    if plink_set_total == 0:
373
        plink_precision = -1
374
    else:
375
        plink_precision = plink_set_overlap_query/plink_set_total
376
377
    
378
    return {'plink_precision':plink_precision, 
379
            'plink_recall': plink_set_overlap_ref/total,
380
            method_name + '_precision': gnn_set_precision,
381
            method_name + '_recall': gnn_set_overlap_ref/total,
382
            'plink_set_overlap_ref': plink_set_overlap_ref,
383
            'plink_set_overlap_query': plink_set_overlap_query,
384
            'plink_set_total': plink_set_total,
385
            method_name + '_set_overlap_ref': gnn_set_overlap_ref,
386
            method_name + '_set_overlap_query': gnn_set_overlap_query,
387
            method_name + '_set_total': gnn_set_total,
388
            'total_set': total
389
            #'gnn_flat_cluster_recall': flat_cluster_recalled/len(flat_clusters),
390
            #'plink_flat_cluster_recall': flat_cluster_recalled_plink/len(flat_clusters)
391
           }
392
393
from tqdm import tqdm
394
def find_overlap_clusters(query_cluster2range, gold_cluster2range):
395
    set_found_cluster_all = {}
396
    for chr_num, eval_cluster in query_cluster2range.items():
397
        if chr_num in gold_cluster2range:
398
            gold_cluster = gold_cluster2range[chr_num]
399
            set_found_cluster = []
400
            for a in eval_cluster:
401
                for b in gold_cluster:
402
                    if (a[0] <= b[1]) and (b[0] <= a[1]):
403
                        set_found_cluster.append((a, b))
404
                        break
405
            set_found_cluster_all[chr_num] = set_found_cluster 
406
407
    return set_found_cluster_all
408
409
410
def find_non_overlap_clusters(query_cluster2range, gold_cluster2range):
411
    set_not_found_cluster_all = {}
412
    for chr_num, eval_cluster in query_cluster2range.items():
413
        gold_cluster = gold_cluster2range[chr_num]
414
        
415
        set_not_found_cluster = []
416
        for a in eval_cluster:
417
            set_found_cluster = []
418
            for b in gold_cluster:
419
                if (a[0] <= b[1]) and (b[0] <= a[1]):
420
                    set_found_cluster.append((a, b))
421
                    break
422
                    
423
            if len(set_found_cluster) == 0:
424
                set_not_found_cluster.append(a)
425
                
426
        set_not_found_cluster_all[chr_num] = set_not_found_cluster 
427
428
    return set_not_found_cluster_all
429
430
431
### eval support functions
432
433
def quantileNormalize(df_input):
434
    df = df_input.copy()
435
    #compute rank
436
    dic = {}
437
    for col in df:
438
        dic.update({col : sorted(df[col])})
439
    sorted_df = pd.DataFrame(dic)
440
    rank = sorted_df.mean(axis = 1).tolist()
441
    #sort
442
    for col in df:
443
        t = np.searchsorted(np.sort(df[col]), df[col])
444
        df[col] = [rank[i] for i in t]
445
    return df
446
447
def get_cluster_count(method_hit_gwas, cluster_distance_threshold, cluster_compare_threshold, threshold_extend, gold_cluster_chr_range):
448
    gnn_cluster_chr_pos, gnn_cluster_chr_rs, \
449
    gnn_cluster_chr_pos_flatten, gnn_cluster_chr_cluster_idx_flatten, \
450
    gnn_cluster_chr_cluster_pos2idx_flatten, gnn_cluster_chr_range = get_cluster_from_gwas(method_hit_gwas, \
451
                                                                    cluster_distance_threshold, \
452
                                                                    threshold_extend = threshold_extend, \
453
                                                                    cluster_compare_threshold = cluster_compare_threshold, \
454
                                                                    verbose = False)        
455
456
    total = sum([len(j) for i,j in gold_cluster_chr_range.items()])
457
    gnn_set_total = sum([len(j) for i,j in gnn_cluster_chr_range.items()])
458
459
    gnn_set_overlap_ref = 0
460
    gnn_set_overlap_query = 0
461
    for j in find_overlap_clusters(gnn_cluster_chr_range, gold_cluster_chr_range).values():
462
        gnn_set_overlap_ref += len(np.unique([set(i[1]) for i in j]))
463
        gnn_set_overlap_query += len(np.unique([set(i[0]) for i in j]))
464
        
465
        
466
    return {'set_overlap_ref': gnn_set_overlap_ref,
467
            'set_overlap_query': gnn_set_overlap_query,
468
            'set_total': gnn_set_total,
469
            'total_set': total
470
           }
471
472
## search every 100 until it is larger than k, then search every 10, then search every 1
473
def get_top_k_clusters(query_rank, top_hits_k_range, cluster_distance_threshold, cluster_compare_threshold, threshold_extend, gold_cluster_chr_range):
474
    snp_k = 0
475
    k_to_cluster = {}
476
    k_to_closest_x = {}
477
    for k in top_hits_k_range:
478
        while True:
479
            out = get_cluster_count(query_rank[:snp_k], cluster_distance_threshold, 
480
                          cluster_compare_threshold, threshold_extend, gold_cluster_chr_range)
481
            if out['set_total'] < k:
482
                snp_k += 100
483
            else:
484
                snp_k -= 100
485
                while True:
486
                    out = get_cluster_count(query_rank[:snp_k], cluster_distance_threshold, 
487
                          cluster_compare_threshold, threshold_extend, gold_cluster_chr_range)
488
                    if out['set_total'] < k:
489
                        snp_k += 10
490
                    else:
491
                        closest_x = snp_k
492
                        closest_distance = abs(out['set_total'] - k)
493
                        for x in range(snp_k - 10, snp_k):
494
                            out = get_cluster_count(query_rank[:x], cluster_distance_threshold, 
495
                                  cluster_compare_threshold, threshold_extend, gold_cluster_chr_range)
496
                            if abs(out['set_total'] - k) <= closest_distance:
497
                                closest_x = x
498
                                closest_distance = abs(out['set_total'] - k)
499
                        break
500
                break
501
502
        k_to_cluster[k] = get_cluster_count(query_rank[:closest_x], cluster_distance_threshold, 
503
                      cluster_compare_threshold, threshold_extend, gold_cluster_chr_range)
504
        k_to_closest_x[k] = closest_x
505
        
506
    return k_to_cluster, k_to_closest_x
507
508
509
def storey_pi_estimator(gwas_data, bin_index):
510
    """
511
    Estimate pi0/pi1 using Storey and Tibshirani (PNAS 2003) estimator.
512
    Argss
513
    =====
514
    bin_index: array of indices for a particular bin
515
    """
516
    pvalue = gwas_data.loc[bin_index,'P'] # extract pvalues from specific bin based index
517
        
518
    #assert(pvalue.min() >= 0 and pvalue.max() <= 1), "Error: p-values should be between 0 and 1"
519
    total_tests = float(len(pvalue))
520
    pi0 = []
521
    lam = np.arange(0.05, 0.95, 0.05)
522
    counts = np.array([(pvalue > i).sum() for i in np.arange(0.05, 0.95, 0.05)])
523
    for l in range(len(lam)):
524
        pi0.append(counts[l] / (total_tests * (1 - lam[l])))
525
526
    # fit  cubic spline
527
    if not np.all(np.isfinite(pi0)):
528
        print("Not all pi0 is finite!!! filtering to finite indices...")
529
        finite_indices = np.isfinite(pi0)
530
        lam = lam[finite_indices]
531
        pi0 = pi0[finite_indices]
532
    
533
    cubic_spline = interpolate.CubicSpline(lam, pi0)
534
    pi0_est = cubic_spline(lam[-1])
535
    if(pi0_est >1): #take care of out of bounds estimate
536
        pi0_est = 1
537
    return pi0_est
538
539
def storey_ribshirani_integrate(gwas_data, column = 'pred', num_bins = 100):
540
    num_bins = float(num_bins)
541
    quantiles = np.arange(0, 1 + 1 / (num_bins+1), 1 / num_bins)
542
    predicted_tagged_variance_quantiles = gwas_data[column].quantile(quantiles)
543
    #expand top quantiles to ensure everything is within range
544
    predicted_tagged_variance_quantiles[0] = predicted_tagged_variance_quantiles[0]-1
545
    predicted_tagged_variance_quantiles[1] = predicted_tagged_variance_quantiles[1]+1
546
    predicted_tagged_variance_quantiles = predicted_tagged_variance_quantiles.drop_duplicates()
547
    num_bins = len(predicted_tagged_variance_quantiles)-1
548
    bins = pd.cut(gwas_data[column], predicted_tagged_variance_quantiles, labels=np.arange(num_bins)) #create the lables
549
    gwas_data['bin_number'] = bins
550
551
    gwas_data['pi0'] = None
552
    
553
    if (gwas_data['P'].min() < 0) or (gwas_data['P'].max() > 1):
554
        print("detected p-values < 0 or > 1, please double check. we clipped it to 0-1 for now...")
555
        gwas_data['P'] = gwas_data['P'].clip(lower=0, upper=1)
556
        
557
    #print("Estimating pi0 within each bin")
558
    for i in range(num_bins):
559
        bin_index = gwas_data['bin_number']== i # determine index of snps in bin number i
560
        if len(gwas_data[bin_index])>0:
561
            pi0 = storey_pi_estimator(gwas_data, bin_index)
562
            ## preventing exploding weights
563
            if pi0 < 1e-5:
564
                pi0 = 1e-5
565
            if pi0 > 1-1e-5:
566
                pi0 = 1-1e-5
567
            gwas_data.loc[bin_index, 'pi0'] = pi0
568
    if any(gwas_data['pi0'] == 1): # if a bin is estimated to be all null, give the smallest non-null weight
569
        one_index = gwas_data['pi0'] == 1
570
        largest_pi0 = gwas_data.loc[~one_index]['pi0'].max()
571
        gwas_data.loc[one_index,'pi0'] = largest_pi0
572
        
573
    if any(gwas_data['pi0'] == 0): # if a bin is estimated to be all alternative, give the largest non-null weight
574
        one_index = gwas_data['pi0'] == 0
575
        largest_pi0 = gwas_data.loc[~one_index]['pi0'].min()
576
        gwas_data.loc[one_index,'pi0'] = largest_pi0
577
        
578
    #print("Re-weighting SNPs")
579
    weights = (1-gwas_data['pi0'])/(gwas_data['pi0'])
580
    
581
    ## avoiding exploding p-values
582
    #weights = np.maximum(1, weights.values)
583
    mean_weight = weights.mean()
584
    weights = weights/mean_weight #normalize weights to have mean 1
585
    
586
    ## avoiding exploding p-values
587
    #weights = np.maximum(1, weights.values)
588
    
589
    gwas_data['weights'] = weights
590
    gwas_data['P_weighted'] = gwas_data['P']/weights #reweight SNPs
591
592
    index = gwas_data['P_weighted'] > 1
593
    #gwas_data.loc[index, 'P_weighted'] = 1
594
    gwas_data.loc[index, 'P_weighted'] = gwas_data['P'][index] ## using original p-value when above 1
595
    gwas_data.loc[gwas_data['P_weighted'].isnull(), 'P_weighted'] = 1    
596
    return gwas_data['P_weighted'].values