|
a |
|
b/kgwas/eval_utils.py |
|
|
1 |
import numpy as np |
|
|
2 |
import pandas as pd |
|
|
3 |
import torch |
|
|
4 |
from scipy import interpolate |
|
|
5 |
|
|
|
6 |
from .utils import load_dict |
|
|
7 |
import pandas as pd |
|
|
8 |
import numpy as np |
|
|
9 |
from copy import copy |
|
|
10 |
|
|
|
11 |
def find_closest_x(df_pred, lower_bound=0, upper_bound=200, tolerance=0.01): |
|
|
12 |
upper = 1e-2 |
|
|
13 |
lower = 1e-3 |
|
|
14 |
|
|
|
15 |
while lower_bound <= upper_bound: |
|
|
16 |
mid = (lower_bound + upper_bound) / 2 |
|
|
17 |
#result = len(np.where(df_pred.P_weighted.values * mid < 2e-4)[0]) / len(np.where(df_pred.P.values < 2e-4)[0]) |
|
|
18 |
res1 = len(np.where((df_pred.P_weighted.values * mid < upper) & (df_pred.P_weighted.values * mid > lower))[0]) |
|
|
19 |
res2 = len(np.where((df_pred.P.values < upper) & (df_pred.P.values > lower))[0]) |
|
|
20 |
result = res1/res2 |
|
|
21 |
if abs(result - 1) < tolerance: |
|
|
22 |
return mid |
|
|
23 |
elif result > 1: |
|
|
24 |
lower_bound = mid + tolerance |
|
|
25 |
else: |
|
|
26 |
upper_bound = mid - tolerance |
|
|
27 |
|
|
|
28 |
return mid |
|
|
29 |
|
|
|
30 |
def get_clumps_gold_label(data_path, gold_label_gwas, t_p = 5e-8, no_hla = False, column = 'P', snp2ld_snps = None): |
|
|
31 |
snp2ld_snps_with_hla = load_dict(data_path + 'ld_score/ukb_white_ld_10MB.pkl') |
|
|
32 |
snp2ld_snps_no_hla = load_dict(data_path + 'ld_score/ukb_white_ld_10MB_no_hla.pkl') |
|
|
33 |
|
|
|
34 |
if not snp2ld_snps: |
|
|
35 |
if no_hla: |
|
|
36 |
snp2ld_snps = snp2ld_snps_no_hla |
|
|
37 |
else: |
|
|
38 |
snp2ld_snps = snp2ld_snps_with_hla |
|
|
39 |
clumps = [] |
|
|
40 |
snps_in_clumps = [] |
|
|
41 |
snp_hits = gold_label_gwas[gold_label_gwas[column] < t_p].sort_values(column).SNP.values |
|
|
42 |
for snp in snp_hits: |
|
|
43 |
if snp in snps_in_clumps: |
|
|
44 |
## already in existing clumps => not create a new clump |
|
|
45 |
pass |
|
|
46 |
else: |
|
|
47 |
if snp in snp2ld_snps: |
|
|
48 |
# ld block |
|
|
49 |
clumps.append([snp] + snp2ld_snps[snp]) |
|
|
50 |
snps_in_clumps += snp2ld_snps[snp] |
|
|
51 |
snps_in_clumps += [snp] |
|
|
52 |
else: |
|
|
53 |
# no other SNPs tagged |
|
|
54 |
clumps.append([snp]) |
|
|
55 |
snps_in_clumps += [snp] |
|
|
56 |
return clumps |
|
|
57 |
|
|
|
58 |
def get_meta_clumps(clumps, data_path): |
|
|
59 |
snp2cm = dict(pd.read_csv(data_path + 'misc_data/ukb_white_with_cm.bim', sep = '\t', header = None)[[1, 2]].values) |
|
|
60 |
snp2chr = dict(pd.read_csv(data_path + 'misc_data/ukb_white_with_cm.bim', sep = '\t', header = None)[[1, 0]].values) |
|
|
61 |
|
|
|
62 |
idx2clump = {'Clump ' + str(idx): i for idx, i in enumerate(clumps)} |
|
|
63 |
idx2clump_chromosome = {'Clump ' + str(idx): snp2chr[i[0]] for idx, i in enumerate(clumps)} |
|
|
64 |
idx2clump_cm = {'Clump ' + str(idx): snp2cm[i[0]] for idx, i in enumerate(clumps)} |
|
|
65 |
|
|
|
66 |
idx2clump_cm_min = {'Clump ' + str(idx): min([snp2cm[x] for x in i]) for idx, i in enumerate(clumps)} |
|
|
67 |
idx2clump_cm_max = {'Clump ' + str(idx): max([snp2cm[x] for x in i]) for idx, i in enumerate(clumps)} |
|
|
68 |
|
|
|
69 |
df_clumps = pd.DataFrame([idx2clump_chromosome, idx2clump_cm, idx2clump, idx2clump_cm_min, idx2clump_cm_max]).T.reset_index().rename(columns = {'index': 'Clump idx', 0: 'Chromosome', 1: 'cM', 2: 'Clump rsids', 3: 'cM_min',4: 'cM_max'}) |
|
|
70 |
|
|
|
71 |
all_mega_clump_across_chr = [] |
|
|
72 |
for chrom in df_clumps.Chromosome.unique(): |
|
|
73 |
df_clump_chr = df_clumps[df_clumps.Chromosome == chrom] |
|
|
74 |
all_mega_clump = [] |
|
|
75 |
cur_mega_clump = [] |
|
|
76 |
base_cM = 0 |
|
|
77 |
for i,cM_hit,cM_min,cM_max in df_clump_chr.sort_values('cM')[['Clump idx', 'cM', 'cM_min', 'cM_max']].values: |
|
|
78 |
if (cM_min - base_cM) < 0.1: |
|
|
79 |
cur_mega_clump.append(i) |
|
|
80 |
base_cM = cM_max |
|
|
81 |
else: |
|
|
82 |
### this clump is >0.1 cM farther away from the previous clump |
|
|
83 |
all_mega_clump.append(cur_mega_clump) |
|
|
84 |
base_cM = cM_max |
|
|
85 |
cur_mega_clump = [i] |
|
|
86 |
all_mega_clump.append(cur_mega_clump) |
|
|
87 |
if len(all_mega_clump[0]) == 0: |
|
|
88 |
all_mega_clump_across_chr += all_mega_clump[1:] |
|
|
89 |
else: |
|
|
90 |
all_mega_clump_across_chr += all_mega_clump |
|
|
91 |
idx2mega_clump = {'Mega-Clump '+str(idx): i for idx, i in enumerate(all_mega_clump_across_chr)} |
|
|
92 |
|
|
|
93 |
def flatten(l): |
|
|
94 |
return [item for sublist in l for item in sublist] |
|
|
95 |
|
|
|
96 |
idx2mega_clump_rsid = {'Mega-Clump '+str(idx): flatten([idx2clump[j] for j in i]) for idx, i in enumerate(all_mega_clump_across_chr)} |
|
|
97 |
idx2mega_clump_chrom = {'Mega-Clump '+str(idx): idx2clump_chromosome[i[0]] for idx, i in enumerate(all_mega_clump_across_chr)} |
|
|
98 |
|
|
|
99 |
return idx2mega_clump, idx2mega_clump_rsid, idx2mega_clump_chrom |
|
|
100 |
|
|
|
101 |
|
|
|
102 |
def get_mega_clump_query(data_path, clumps, snp_hits, no_hla = False, snp2ld_snps = None): |
|
|
103 |
snp2ld_snps_with_hla = load_dict(data_path + 'ld_score/ukb_white_ld_10MB.pkl') |
|
|
104 |
snp2ld_snps_no_hla = load_dict(data_path + 'ld_score/ukb_white_ld_10MB_no_hla.pkl') |
|
|
105 |
|
|
|
106 |
if not snp2ld_snps: |
|
|
107 |
if no_hla: |
|
|
108 |
snp2ld_snps = snp2ld_snps_no_hla |
|
|
109 |
else: |
|
|
110 |
snp2ld_snps = snp2ld_snps_with_hla |
|
|
111 |
|
|
|
112 |
clumps_pred = [] |
|
|
113 |
snps_in_clumps_pred = [] |
|
|
114 |
K = max(len(clumps) * 3, 100) |
|
|
115 |
for snp in snp_hits: |
|
|
116 |
## top ranked snps |
|
|
117 |
if len(clumps_pred) >= K: |
|
|
118 |
## just going to get the top K clumps where K is set to be very large number -> we don't generate all clumps since as K goes extremely large, they are never prioritized and evaluated. |
|
|
119 |
break |
|
|
120 |
else: |
|
|
121 |
if snp in snps_in_clumps_pred: |
|
|
122 |
## already in previous found clumps, move forward |
|
|
123 |
pass |
|
|
124 |
else: |
|
|
125 |
if snp in snp2ld_snps: |
|
|
126 |
# this snp has ld tagged snps |
|
|
127 |
clumps_pred.append([snp] + snp2ld_snps[snp]) |
|
|
128 |
snps_in_clumps_pred += snp2ld_snps[snp] |
|
|
129 |
snps_in_clumps_pred += [snp] |
|
|
130 |
else: |
|
|
131 |
# this snp does not have ld tagged snps, at least in UKB |
|
|
132 |
clumps_pred.append([snp]) |
|
|
133 |
snps_in_clumps_pred += [snp] |
|
|
134 |
idx2mega_clump_pred, idx2mega_clump_rsid_pred, idx2mega_clump_chrom_pred = get_meta_clumps(clumps_pred, data_path) |
|
|
135 |
return idx2mega_clump_pred, idx2mega_clump_rsid_pred, idx2mega_clump_chrom_pred |
|
|
136 |
|
|
|
137 |
def get_curve(mega_clump_pred, mega_clump_gold): |
|
|
138 |
recall_k = {} |
|
|
139 |
precision_k = {} |
|
|
140 |
found_clump_idx = [] |
|
|
141 |
clump_idx_record = {} |
|
|
142 |
pred_clump_has_hit_count = 0 |
|
|
143 |
for k, query_clump in enumerate(mega_clump_pred): |
|
|
144 |
## go through the predicted top ranked clumps one by one |
|
|
145 |
k += 1 |
|
|
146 |
does_this_clump_overlap_with_any_true_clumps = False |
|
|
147 |
## this is used to calculate precision, to see if this clump overlaps with any of the gold clumps |
|
|
148 |
for clump_idx, clump in enumerate(mega_clump_gold): |
|
|
149 |
## overlaps with this gold clump |
|
|
150 |
if len(np.intersect1d(query_clump, clump)) > 0: |
|
|
151 |
if clump_idx not in found_clump_idx: |
|
|
152 |
## if the clump is never found before, flag it |
|
|
153 |
found_clump_idx.append(clump_idx) |
|
|
154 |
does_this_clump_overlap_with_any_true_clumps = True |
|
|
155 |
clump_idx_record[k] = copy(found_clump_idx) |
|
|
156 |
if does_this_clump_overlap_with_any_true_clumps: |
|
|
157 |
pred_clump_has_hit_count += 1 |
|
|
158 |
|
|
|
159 |
recall_k[k] = len(found_clump_idx)/len(mega_clump_gold) |
|
|
160 |
precision_k[k] = pred_clump_has_hit_count/k |
|
|
161 |
|
|
|
162 |
#sns.scatterplot([recall_k[k+1] for k in range(len(mega_clump_pred))], [precision_k[k+1] for k in range(len(mega_clump_pred))], s = 1) |
|
|
163 |
return recall_k, precision_k, clump_idx_record |
|
|
164 |
|
|
|
165 |
def get_prec_recall(pred_hits, gold_hits): |
|
|
166 |
recall = len(np.intersect1d(pred_hits, gold_hits))/len(gold_hits) |
|
|
167 |
if len(pred_hits) != 0: |
|
|
168 |
precision = len(np.intersect1d(pred_hits, gold_hits))/len(pred_hits) |
|
|
169 |
else: |
|
|
170 |
precision = 0 |
|
|
171 |
return {'recall': recall, |
|
|
172 |
'precision': precision} |
|
|
173 |
|
|
|
174 |
def find_nearest(array, value): |
|
|
175 |
array = np.asarray(array) |
|
|
176 |
idx = (np.abs(array - value)).argmin() |
|
|
177 |
return array[idx] |
|
|
178 |
|
|
|
179 |
def get_cluster_from_gwas(df, cluster_distance_threshold = 500000, \ |
|
|
180 |
threshold_extend = False, cluster_compare_threshold = None, \ |
|
|
181 |
verbose = True): |
|
|
182 |
|
|
|
183 |
cluster_chr_pos = {} |
|
|
184 |
cluster_chr_rs = {} |
|
|
185 |
|
|
|
186 |
for chr_num in df['#CHROM'].unique(): |
|
|
187 |
df_hits_chr = df[df['#CHROM'] == chr_num] |
|
|
188 |
df_hits_chr = df_hits_chr.sort_values('POS') |
|
|
189 |
pos = df_hits_chr.POS.values |
|
|
190 |
rs = df_hits_chr.ID.values |
|
|
191 |
|
|
|
192 |
cluster_set = [] |
|
|
193 |
cluster_set_rs = [] |
|
|
194 |
|
|
|
195 |
cur_pos = pos[0] |
|
|
196 |
cur_rs = rs[0] |
|
|
197 |
cur_set = [cur_pos] |
|
|
198 |
cur_set_rs = [rs[0]] |
|
|
199 |
|
|
|
200 |
for idx, next_pos in enumerate(pos[1:]): |
|
|
201 |
|
|
|
202 |
if next_pos - cur_pos < cluster_distance_threshold: |
|
|
203 |
cur_set.append(next_pos) |
|
|
204 |
cur_set_rs.append(rs[idx + 1]) |
|
|
205 |
if threshold_extend: |
|
|
206 |
cur_pos = next_pos |
|
|
207 |
else: |
|
|
208 |
cluster_set.append(cur_set) |
|
|
209 |
cluster_set_rs.append(cur_set_rs) |
|
|
210 |
cur_pos = next_pos |
|
|
211 |
cur_set = [cur_pos] |
|
|
212 |
cur_set_rs = [rs[idx + 1]] |
|
|
213 |
|
|
|
214 |
cluster_set.append(cur_set) |
|
|
215 |
cluster_set_rs.append(cur_set_rs) |
|
|
216 |
|
|
|
217 |
cluster_chr_pos[chr_num] = cluster_set |
|
|
218 |
cluster_chr_rs[chr_num] = cluster_set_rs |
|
|
219 |
|
|
|
220 |
cluster_chr_pos_flatten = {} |
|
|
221 |
cluster_chr_cluster_idx_flatten = {} |
|
|
222 |
cluster_chr_cluster_pos2idx_flatten = {} |
|
|
223 |
|
|
|
224 |
for chr_num, cluster_list in cluster_chr_pos.items(): |
|
|
225 |
pos_flatten = [] |
|
|
226 |
idx_flatten = [] |
|
|
227 |
for idx, cluster in enumerate(cluster_list): |
|
|
228 |
pos_flatten = pos_flatten + cluster |
|
|
229 |
idx_flatten = idx_flatten + [idx] * len(cluster) |
|
|
230 |
cluster_chr_pos_flatten[chr_num] = pos_flatten |
|
|
231 |
cluster_chr_cluster_idx_flatten[chr_num] = idx_flatten |
|
|
232 |
cluster_chr_cluster_pos2idx_flatten[chr_num] = dict(zip(pos_flatten, idx_flatten)) |
|
|
233 |
|
|
|
234 |
if verbose: |
|
|
235 |
print('Number of clusters: ' + str(sum([len(j) for j in cluster_chr_pos.values()]))) |
|
|
236 |
|
|
|
237 |
cluster_chr_range = {} |
|
|
238 |
for i,j in cluster_chr_pos.items(): |
|
|
239 |
cluster_chr_range[i] = [(min(x) - cluster_compare_threshold, max(x) + cluster_compare_threshold) for x in j] |
|
|
240 |
|
|
|
241 |
return cluster_chr_pos, cluster_chr_rs, cluster_chr_pos_flatten, \ |
|
|
242 |
cluster_chr_cluster_idx_flatten, cluster_chr_cluster_pos2idx_flatten, cluster_chr_range |
|
|
243 |
|
|
|
244 |
|
|
|
245 |
def get_cluster_hits_from_pred(pred_hits, threshold, lr_uni, cluster_chr_pos_flatten, cluster_chr_cluster_pos2idx_flatten): |
|
|
246 |
df_hits = lr_uni[lr_uni.ID.isin(pred_hits)] |
|
|
247 |
df_hits['closest_cluster'] = df_hits.apply(lambda x: find_nearest(cluster_chr_pos_flatten[x['#CHROM']], x.POS), axis = 1) |
|
|
248 |
df_hits['distance2cluster'] = df_hits.apply(lambda x: abs(x.closest_cluster - x.POS), axis = 1) |
|
|
249 |
df_hits['include_as_cluster'] = df_hits.apply(lambda x: x.distance2cluster < threshold, axis = 1) |
|
|
250 |
df_hits['cluster_id'] = df_hits.apply(lambda x: str(x['#CHROM']) + '_' + str(cluster_chr_cluster_pos2idx_flatten[x['#CHROM']][x['closest_cluster']]), axis = 1) |
|
|
251 |
cluster2count = dict(df_hits[df_hits.include_as_cluster].cluster_id.value_counts()) |
|
|
252 |
num_non_hits = len(df_hits[~df_hits.include_as_cluster]) |
|
|
253 |
novel_rs_id = df_hits[~df_hits.include_as_cluster].ID.values |
|
|
254 |
print('Number of predicted hits: ' + str(len(pred_hits))) |
|
|
255 |
print('Number of predicted hits not in the existing clusters: ' + str(len(novel_rs_id))) |
|
|
256 |
print('Number of cluster hits: ' + str(len(cluster2count))) |
|
|
257 |
return cluster2count, num_non_hits, df_hits, novel_rs_id |
|
|
258 |
|
|
|
259 |
def plot_cluster_range(chr_num, gnn_cluster_chr_range, cluster_chr_range, \ |
|
|
260 |
gold_cluster_chr_range, findor_cluster_chr_range, x_start = None, x_end = None, \ |
|
|
261 |
base_gwas_name = 'FastGWA', gold_ref_name = 'GWAS Catalog'): |
|
|
262 |
|
|
|
263 |
fig = plt.figure(figsize=(14, 3)) # Set the figure size |
|
|
264 |
ax = fig.add_subplot(111) |
|
|
265 |
|
|
|
266 |
if chr_num not in cluster_chr_range: |
|
|
267 |
cluster_chr_range[chr_num] = {} |
|
|
268 |
if chr_num not in gnn_cluster_chr_range: |
|
|
269 |
gnn_cluster_chr_range[chr_num] = {} |
|
|
270 |
if chr_num not in gold_cluster_chr_range: |
|
|
271 |
gold_cluster_chr_range[chr_num] = {} |
|
|
272 |
|
|
|
273 |
if chr_num not in findor_cluster_chr_range: |
|
|
274 |
findor_cluster_chr_range[chr_num] = {} |
|
|
275 |
|
|
|
276 |
for i in findor_cluster_chr_range[chr_num]: |
|
|
277 |
plt.plot(i, ['FINDOR', 'FINDOR'], '*-') |
|
|
278 |
|
|
|
279 |
for i in gnn_cluster_chr_range[chr_num]: |
|
|
280 |
plt.plot(i, ['GNN', 'GNN'], 's-') |
|
|
281 |
|
|
|
282 |
for i in cluster_chr_range[chr_num]: |
|
|
283 |
plt.plot(i, [base_gwas_name, base_gwas_name], '^-') |
|
|
284 |
|
|
|
285 |
for i in gold_cluster_chr_range[chr_num]: |
|
|
286 |
plt.plot(i, [gold_ref_name, gold_ref_name], 'o-') |
|
|
287 |
|
|
|
288 |
plt.xlabel('Position Index at Chromosome ' + str(chr_num)) |
|
|
289 |
|
|
|
290 |
if x_start is not None: |
|
|
291 |
ax.set_xlim([x_start,x_end]) |
|
|
292 |
plt.show() |
|
|
293 |
|
|
|
294 |
def get_pr_curve(cluster_distance_threshold, gold_label_gwas_hits, method_hit_gwas, low_data_gwas_hits, \ |
|
|
295 |
cluster_compare_threshold = None, method_name = 'gnn'): |
|
|
296 |
if cluster_compare_threshold is None: |
|
|
297 |
cluster_compare_threshold = int(cluster_distance_threshold/2) |
|
|
298 |
gold_cluster_chr_pos, gold_cluster_chr_rs, \ |
|
|
299 |
gold_cluster_chr_pos_flatten, gold_cluster_chr_cluster_idx_flatten, \ |
|
|
300 |
gold_cluster_chr_cluster_pos2idx_flatten, gold_cluster_chr_range = get_cluster_from_gwas(gold_label_gwas_hits, \ |
|
|
301 |
cluster_distance_threshold, \ |
|
|
302 |
threshold_extend = threshold_extend, \ |
|
|
303 |
cluster_compare_threshold = cluster_compare_threshold, \ |
|
|
304 |
verbose = False) |
|
|
305 |
|
|
|
306 |
cluster_chr_pos, cluster_chr_rs, \ |
|
|
307 |
cluster_chr_pos_flatten, cluster_chr_cluster_idx_flatten, \ |
|
|
308 |
cluster_chr_cluster_pos2idx_flatten, cluster_chr_range = get_cluster_from_gwas(low_data_gwas_hits, \ |
|
|
309 |
cluster_distance_threshold, \ |
|
|
310 |
threshold_extend = threshold_extend, \ |
|
|
311 |
cluster_compare_threshold = cluster_compare_threshold, \ |
|
|
312 |
verbose = False) |
|
|
313 |
|
|
|
314 |
gnn_cluster_chr_pos, gnn_cluster_chr_rs, \ |
|
|
315 |
gnn_cluster_chr_pos_flatten, gnn_cluster_chr_cluster_idx_flatten, \ |
|
|
316 |
gnn_cluster_chr_cluster_pos2idx_flatten, gnn_cluster_chr_range = get_cluster_from_gwas(method_hit_gwas, \ |
|
|
317 |
cluster_distance_threshold, \ |
|
|
318 |
threshold_extend = threshold_extend, \ |
|
|
319 |
cluster_compare_threshold = cluster_compare_threshold, \ |
|
|
320 |
verbose = False) |
|
|
321 |
|
|
|
322 |
total = sum([len(j) for i,j in gold_cluster_chr_range.items()]) |
|
|
323 |
|
|
|
324 |
#plink_set_overlap = sum([len(j) for j in find_overlap_clusters(cluster_chr_range, gold_cluster_chr_range).values()]) |
|
|
325 |
plink_set_total = sum([len(j) for i,j in cluster_chr_range.items()]) |
|
|
326 |
|
|
|
327 |
plink_set_overlap_ref = 0 |
|
|
328 |
plink_set_overlap_query = 0 |
|
|
329 |
for j in find_overlap_clusters(cluster_chr_range, gold_cluster_chr_range).values(): |
|
|
330 |
plink_set_overlap_ref += len(np.unique([set(i[1]) for i in j])) |
|
|
331 |
plink_set_overlap_query += len(np.unique([set(i[0]) for i in j])) |
|
|
332 |
|
|
|
333 |
#gnn_set_overlap = sum([len(j) for j in find_overlap_clusters(gnn_cluster_chr_range, gold_cluster_chr_range).values()]) |
|
|
334 |
gnn_set_total = sum([len(j) for i,j in gnn_cluster_chr_range.items()]) |
|
|
335 |
|
|
|
336 |
gnn_set_overlap_ref = 0 |
|
|
337 |
gnn_set_overlap_query = 0 |
|
|
338 |
for j in find_overlap_clusters(gnn_cluster_chr_range, gold_cluster_chr_range).values(): |
|
|
339 |
gnn_set_overlap_ref += len(np.unique([set(i[1]) for i in j])) |
|
|
340 |
gnn_set_overlap_query += len(np.unique([set(i[0]) for i in j])) |
|
|
341 |
|
|
|
342 |
|
|
|
343 |
''' |
|
|
344 |
low_data_gold_hits = low_data_gwas[low_data_gwas.ID.isin(gold_label_gwas_hits.ID.values)] |
|
|
345 |
low_data_gold_hits['cluster_id'] = low_data_gold_hits.apply(lambda x: str(x['#CHROM']) + '_' + \ |
|
|
346 |
str(gold_cluster_chr_cluster_pos2idx_flatten[x['#CHROM']][x.POS]), axis = 1) |
|
|
347 |
cluster2min_p = dict(low_data_gold_hits.groupby('cluster_id').P.min()) |
|
|
348 |
flat_clusters = [i for i,j in cluster2min_p.items() if j > 1e-3] |
|
|
349 |
gold_label_gwas_hits['closest_cluster'] = gold_label_gwas_hits.apply(lambda x: find_nearest(gold_cluster_chr_pos_flatten[x['#CHROM']], x.POS), axis = 1) |
|
|
350 |
gold_label_gwas_hits['distance2cluster'] = gold_label_gwas_hits.apply(lambda x: abs(x.closest_cluster - x.POS), axis = 1) |
|
|
351 |
gold_label_gwas_hits['cluster_id'] = gold_label_gwas_hits.apply(lambda x: str(x['#CHROM']) + '_' + str(gold_cluster_chr_cluster_pos2idx_flatten[x['#CHROM']][x['closest_cluster']]), axis = 1) |
|
|
352 |
pos_pred = np.unique(low_data_gwas_hits.ID.values.tolist() + pred_hits.tolist()) |
|
|
353 |
flat_cluster_range = {} |
|
|
354 |
for i in flat_clusters: |
|
|
355 |
chr_num = int(i.split('_')[0]) |
|
|
356 |
cluster_idx = int(i.split('_')[1]) |
|
|
357 |
if chr_num in flat_cluster_range: |
|
|
358 |
flat_cluster_range[chr_num].append(gold_cluster_chr_range[chr_num][cluster_idx]) |
|
|
359 |
else: |
|
|
360 |
flat_cluster_range[chr_num] = [gold_cluster_chr_range[chr_num][cluster_idx]] |
|
|
361 |
|
|
|
362 |
flat_cluster_recalled = sum([len(j) for j in find_overlap_clusters(gnn_cluster_chr_range, flat_cluster_range).values()]) |
|
|
363 |
flat_cluster_recalled_plink = sum([len(j) for j in find_overlap_clusters(cluster_chr_range, flat_cluster_range).values()]) |
|
|
364 |
|
|
|
365 |
''' |
|
|
366 |
|
|
|
367 |
if gnn_set_total == 0: |
|
|
368 |
gnn_set_precision = -1 |
|
|
369 |
else: |
|
|
370 |
gnn_set_precision = gnn_set_overlap_query/gnn_set_total |
|
|
371 |
|
|
|
372 |
if plink_set_total == 0: |
|
|
373 |
plink_precision = -1 |
|
|
374 |
else: |
|
|
375 |
plink_precision = plink_set_overlap_query/plink_set_total |
|
|
376 |
|
|
|
377 |
|
|
|
378 |
return {'plink_precision':plink_precision, |
|
|
379 |
'plink_recall': plink_set_overlap_ref/total, |
|
|
380 |
method_name + '_precision': gnn_set_precision, |
|
|
381 |
method_name + '_recall': gnn_set_overlap_ref/total, |
|
|
382 |
'plink_set_overlap_ref': plink_set_overlap_ref, |
|
|
383 |
'plink_set_overlap_query': plink_set_overlap_query, |
|
|
384 |
'plink_set_total': plink_set_total, |
|
|
385 |
method_name + '_set_overlap_ref': gnn_set_overlap_ref, |
|
|
386 |
method_name + '_set_overlap_query': gnn_set_overlap_query, |
|
|
387 |
method_name + '_set_total': gnn_set_total, |
|
|
388 |
'total_set': total |
|
|
389 |
#'gnn_flat_cluster_recall': flat_cluster_recalled/len(flat_clusters), |
|
|
390 |
#'plink_flat_cluster_recall': flat_cluster_recalled_plink/len(flat_clusters) |
|
|
391 |
} |
|
|
392 |
|
|
|
393 |
from tqdm import tqdm |
|
|
394 |
def find_overlap_clusters(query_cluster2range, gold_cluster2range): |
|
|
395 |
set_found_cluster_all = {} |
|
|
396 |
for chr_num, eval_cluster in query_cluster2range.items(): |
|
|
397 |
if chr_num in gold_cluster2range: |
|
|
398 |
gold_cluster = gold_cluster2range[chr_num] |
|
|
399 |
set_found_cluster = [] |
|
|
400 |
for a in eval_cluster: |
|
|
401 |
for b in gold_cluster: |
|
|
402 |
if (a[0] <= b[1]) and (b[0] <= a[1]): |
|
|
403 |
set_found_cluster.append((a, b)) |
|
|
404 |
break |
|
|
405 |
set_found_cluster_all[chr_num] = set_found_cluster |
|
|
406 |
|
|
|
407 |
return set_found_cluster_all |
|
|
408 |
|
|
|
409 |
|
|
|
410 |
def find_non_overlap_clusters(query_cluster2range, gold_cluster2range): |
|
|
411 |
set_not_found_cluster_all = {} |
|
|
412 |
for chr_num, eval_cluster in query_cluster2range.items(): |
|
|
413 |
gold_cluster = gold_cluster2range[chr_num] |
|
|
414 |
|
|
|
415 |
set_not_found_cluster = [] |
|
|
416 |
for a in eval_cluster: |
|
|
417 |
set_found_cluster = [] |
|
|
418 |
for b in gold_cluster: |
|
|
419 |
if (a[0] <= b[1]) and (b[0] <= a[1]): |
|
|
420 |
set_found_cluster.append((a, b)) |
|
|
421 |
break |
|
|
422 |
|
|
|
423 |
if len(set_found_cluster) == 0: |
|
|
424 |
set_not_found_cluster.append(a) |
|
|
425 |
|
|
|
426 |
set_not_found_cluster_all[chr_num] = set_not_found_cluster |
|
|
427 |
|
|
|
428 |
return set_not_found_cluster_all |
|
|
429 |
|
|
|
430 |
|
|
|
431 |
### eval support functions |
|
|
432 |
|
|
|
433 |
def quantileNormalize(df_input): |
|
|
434 |
df = df_input.copy() |
|
|
435 |
#compute rank |
|
|
436 |
dic = {} |
|
|
437 |
for col in df: |
|
|
438 |
dic.update({col : sorted(df[col])}) |
|
|
439 |
sorted_df = pd.DataFrame(dic) |
|
|
440 |
rank = sorted_df.mean(axis = 1).tolist() |
|
|
441 |
#sort |
|
|
442 |
for col in df: |
|
|
443 |
t = np.searchsorted(np.sort(df[col]), df[col]) |
|
|
444 |
df[col] = [rank[i] for i in t] |
|
|
445 |
return df |
|
|
446 |
|
|
|
447 |
def get_cluster_count(method_hit_gwas, cluster_distance_threshold, cluster_compare_threshold, threshold_extend, gold_cluster_chr_range): |
|
|
448 |
gnn_cluster_chr_pos, gnn_cluster_chr_rs, \ |
|
|
449 |
gnn_cluster_chr_pos_flatten, gnn_cluster_chr_cluster_idx_flatten, \ |
|
|
450 |
gnn_cluster_chr_cluster_pos2idx_flatten, gnn_cluster_chr_range = get_cluster_from_gwas(method_hit_gwas, \ |
|
|
451 |
cluster_distance_threshold, \ |
|
|
452 |
threshold_extend = threshold_extend, \ |
|
|
453 |
cluster_compare_threshold = cluster_compare_threshold, \ |
|
|
454 |
verbose = False) |
|
|
455 |
|
|
|
456 |
total = sum([len(j) for i,j in gold_cluster_chr_range.items()]) |
|
|
457 |
gnn_set_total = sum([len(j) for i,j in gnn_cluster_chr_range.items()]) |
|
|
458 |
|
|
|
459 |
gnn_set_overlap_ref = 0 |
|
|
460 |
gnn_set_overlap_query = 0 |
|
|
461 |
for j in find_overlap_clusters(gnn_cluster_chr_range, gold_cluster_chr_range).values(): |
|
|
462 |
gnn_set_overlap_ref += len(np.unique([set(i[1]) for i in j])) |
|
|
463 |
gnn_set_overlap_query += len(np.unique([set(i[0]) for i in j])) |
|
|
464 |
|
|
|
465 |
|
|
|
466 |
return {'set_overlap_ref': gnn_set_overlap_ref, |
|
|
467 |
'set_overlap_query': gnn_set_overlap_query, |
|
|
468 |
'set_total': gnn_set_total, |
|
|
469 |
'total_set': total |
|
|
470 |
} |
|
|
471 |
|
|
|
472 |
## search every 100 until it is larger than k, then search every 10, then search every 1 |
|
|
473 |
def get_top_k_clusters(query_rank, top_hits_k_range, cluster_distance_threshold, cluster_compare_threshold, threshold_extend, gold_cluster_chr_range): |
|
|
474 |
snp_k = 0 |
|
|
475 |
k_to_cluster = {} |
|
|
476 |
k_to_closest_x = {} |
|
|
477 |
for k in top_hits_k_range: |
|
|
478 |
while True: |
|
|
479 |
out = get_cluster_count(query_rank[:snp_k], cluster_distance_threshold, |
|
|
480 |
cluster_compare_threshold, threshold_extend, gold_cluster_chr_range) |
|
|
481 |
if out['set_total'] < k: |
|
|
482 |
snp_k += 100 |
|
|
483 |
else: |
|
|
484 |
snp_k -= 100 |
|
|
485 |
while True: |
|
|
486 |
out = get_cluster_count(query_rank[:snp_k], cluster_distance_threshold, |
|
|
487 |
cluster_compare_threshold, threshold_extend, gold_cluster_chr_range) |
|
|
488 |
if out['set_total'] < k: |
|
|
489 |
snp_k += 10 |
|
|
490 |
else: |
|
|
491 |
closest_x = snp_k |
|
|
492 |
closest_distance = abs(out['set_total'] - k) |
|
|
493 |
for x in range(snp_k - 10, snp_k): |
|
|
494 |
out = get_cluster_count(query_rank[:x], cluster_distance_threshold, |
|
|
495 |
cluster_compare_threshold, threshold_extend, gold_cluster_chr_range) |
|
|
496 |
if abs(out['set_total'] - k) <= closest_distance: |
|
|
497 |
closest_x = x |
|
|
498 |
closest_distance = abs(out['set_total'] - k) |
|
|
499 |
break |
|
|
500 |
break |
|
|
501 |
|
|
|
502 |
k_to_cluster[k] = get_cluster_count(query_rank[:closest_x], cluster_distance_threshold, |
|
|
503 |
cluster_compare_threshold, threshold_extend, gold_cluster_chr_range) |
|
|
504 |
k_to_closest_x[k] = closest_x |
|
|
505 |
|
|
|
506 |
return k_to_cluster, k_to_closest_x |
|
|
507 |
|
|
|
508 |
|
|
|
509 |
def storey_pi_estimator(gwas_data, bin_index): |
|
|
510 |
""" |
|
|
511 |
Estimate pi0/pi1 using Storey and Tibshirani (PNAS 2003) estimator. |
|
|
512 |
Argss |
|
|
513 |
===== |
|
|
514 |
bin_index: array of indices for a particular bin |
|
|
515 |
""" |
|
|
516 |
pvalue = gwas_data.loc[bin_index,'P'] # extract pvalues from specific bin based index |
|
|
517 |
|
|
|
518 |
#assert(pvalue.min() >= 0 and pvalue.max() <= 1), "Error: p-values should be between 0 and 1" |
|
|
519 |
total_tests = float(len(pvalue)) |
|
|
520 |
pi0 = [] |
|
|
521 |
lam = np.arange(0.05, 0.95, 0.05) |
|
|
522 |
counts = np.array([(pvalue > i).sum() for i in np.arange(0.05, 0.95, 0.05)]) |
|
|
523 |
for l in range(len(lam)): |
|
|
524 |
pi0.append(counts[l] / (total_tests * (1 - lam[l]))) |
|
|
525 |
|
|
|
526 |
# fit cubic spline |
|
|
527 |
if not np.all(np.isfinite(pi0)): |
|
|
528 |
print("Not all pi0 is finite!!! filtering to finite indices...") |
|
|
529 |
finite_indices = np.isfinite(pi0) |
|
|
530 |
lam = lam[finite_indices] |
|
|
531 |
pi0 = pi0[finite_indices] |
|
|
532 |
|
|
|
533 |
cubic_spline = interpolate.CubicSpline(lam, pi0) |
|
|
534 |
pi0_est = cubic_spline(lam[-1]) |
|
|
535 |
if(pi0_est >1): #take care of out of bounds estimate |
|
|
536 |
pi0_est = 1 |
|
|
537 |
return pi0_est |
|
|
538 |
|
|
|
539 |
def storey_ribshirani_integrate(gwas_data, column = 'pred', num_bins = 100): |
|
|
540 |
num_bins = float(num_bins) |
|
|
541 |
quantiles = np.arange(0, 1 + 1 / (num_bins+1), 1 / num_bins) |
|
|
542 |
predicted_tagged_variance_quantiles = gwas_data[column].quantile(quantiles) |
|
|
543 |
#expand top quantiles to ensure everything is within range |
|
|
544 |
predicted_tagged_variance_quantiles[0] = predicted_tagged_variance_quantiles[0]-1 |
|
|
545 |
predicted_tagged_variance_quantiles[1] = predicted_tagged_variance_quantiles[1]+1 |
|
|
546 |
predicted_tagged_variance_quantiles = predicted_tagged_variance_quantiles.drop_duplicates() |
|
|
547 |
num_bins = len(predicted_tagged_variance_quantiles)-1 |
|
|
548 |
bins = pd.cut(gwas_data[column], predicted_tagged_variance_quantiles, labels=np.arange(num_bins)) #create the lables |
|
|
549 |
gwas_data['bin_number'] = bins |
|
|
550 |
|
|
|
551 |
gwas_data['pi0'] = None |
|
|
552 |
|
|
|
553 |
if (gwas_data['P'].min() < 0) or (gwas_data['P'].max() > 1): |
|
|
554 |
print("detected p-values < 0 or > 1, please double check. we clipped it to 0-1 for now...") |
|
|
555 |
gwas_data['P'] = gwas_data['P'].clip(lower=0, upper=1) |
|
|
556 |
|
|
|
557 |
#print("Estimating pi0 within each bin") |
|
|
558 |
for i in range(num_bins): |
|
|
559 |
bin_index = gwas_data['bin_number']== i # determine index of snps in bin number i |
|
|
560 |
if len(gwas_data[bin_index])>0: |
|
|
561 |
pi0 = storey_pi_estimator(gwas_data, bin_index) |
|
|
562 |
## preventing exploding weights |
|
|
563 |
if pi0 < 1e-5: |
|
|
564 |
pi0 = 1e-5 |
|
|
565 |
if pi0 > 1-1e-5: |
|
|
566 |
pi0 = 1-1e-5 |
|
|
567 |
gwas_data.loc[bin_index, 'pi0'] = pi0 |
|
|
568 |
if any(gwas_data['pi0'] == 1): # if a bin is estimated to be all null, give the smallest non-null weight |
|
|
569 |
one_index = gwas_data['pi0'] == 1 |
|
|
570 |
largest_pi0 = gwas_data.loc[~one_index]['pi0'].max() |
|
|
571 |
gwas_data.loc[one_index,'pi0'] = largest_pi0 |
|
|
572 |
|
|
|
573 |
if any(gwas_data['pi0'] == 0): # if a bin is estimated to be all alternative, give the largest non-null weight |
|
|
574 |
one_index = gwas_data['pi0'] == 0 |
|
|
575 |
largest_pi0 = gwas_data.loc[~one_index]['pi0'].min() |
|
|
576 |
gwas_data.loc[one_index,'pi0'] = largest_pi0 |
|
|
577 |
|
|
|
578 |
#print("Re-weighting SNPs") |
|
|
579 |
weights = (1-gwas_data['pi0'])/(gwas_data['pi0']) |
|
|
580 |
|
|
|
581 |
## avoiding exploding p-values |
|
|
582 |
#weights = np.maximum(1, weights.values) |
|
|
583 |
mean_weight = weights.mean() |
|
|
584 |
weights = weights/mean_weight #normalize weights to have mean 1 |
|
|
585 |
|
|
|
586 |
## avoiding exploding p-values |
|
|
587 |
#weights = np.maximum(1, weights.values) |
|
|
588 |
|
|
|
589 |
gwas_data['weights'] = weights |
|
|
590 |
gwas_data['P_weighted'] = gwas_data['P']/weights #reweight SNPs |
|
|
591 |
|
|
|
592 |
index = gwas_data['P_weighted'] > 1 |
|
|
593 |
#gwas_data.loc[index, 'P_weighted'] = 1 |
|
|
594 |
gwas_data.loc[index, 'P_weighted'] = gwas_data['P'][index] ## using original p-value when above 1 |
|
|
595 |
gwas_data.loc[gwas_data['P_weighted'].isnull(), 'P_weighted'] = 1 |
|
|
596 |
return gwas_data['P_weighted'].values |