|
a |
|
b/kgwas/kgwas_data.py |
|
|
1 |
from torch_geometric.data import HeteroData |
|
|
2 |
import torch_geometric.transforms as T |
|
|
3 |
|
|
|
4 |
from sklearn.model_selection import train_test_split |
|
|
5 |
import pandas as pd |
|
|
6 |
import torch |
|
|
7 |
import numpy as np |
|
|
8 |
import pickle |
|
|
9 |
import os |
|
|
10 |
import tarfile |
|
|
11 |
import urllib.request |
|
|
12 |
import shutil |
|
|
13 |
from tqdm import tqdm |
|
|
14 |
import subprocess |
|
|
15 |
|
|
|
16 |
from .utils import ldsc_regression_weights, load_dict |
|
|
17 |
from .params import scdrs_traits |
|
|
18 |
|
|
|
19 |
class KGWAS_Data: |
|
|
20 |
def __init__(self, data_path='./data/'): |
|
|
21 |
self.data_path = data_path |
|
|
22 |
|
|
|
23 |
# Ensure the data path exists |
|
|
24 |
if not os.path.exists(data_path): |
|
|
25 |
os.makedirs(data_path) |
|
|
26 |
|
|
|
27 |
# Check if relevant data exists in the data_path |
|
|
28 |
required_files = [ |
|
|
29 |
'cell_kg/network/node_idx2id.pkl', |
|
|
30 |
'cell_kg/network/edge_index.pkl', |
|
|
31 |
'cell_kg/network/node_id2idx.pkl', |
|
|
32 |
'cell_kg/node_emb/variant_emb/enformer_feat.pkl', |
|
|
33 |
'cell_kg/node_emb/gene_emb/esm_feat.pkl', |
|
|
34 |
'ld_score/filter_genotyped_ldscores.csv', |
|
|
35 |
'ld_score/ldscores_from_data.csv', |
|
|
36 |
'ld_score/ukb_white_ld_10MB_no_hla.pkl', |
|
|
37 |
'ld_score/ukb_white_ld_10MB.pkl', |
|
|
38 |
'misc_data/ukb_white_with_cm.bim', |
|
|
39 |
] |
|
|
40 |
missing_files = [f for f in required_files if not os.path.exists(os.path.join(data_path, f))] |
|
|
41 |
|
|
|
42 |
if missing_files: |
|
|
43 |
print("Relevant data not found in the data_path. Downloading and extracting data...") |
|
|
44 |
url = "https://dataverse.harvard.edu/api/access/datafile/10731230" |
|
|
45 |
file_name = 'kgwas_core_data' |
|
|
46 |
self._download_and_extract_data(url, file_name) |
|
|
47 |
else: |
|
|
48 |
print("All required data files are present.") |
|
|
49 |
|
|
|
50 |
def download_all_data(self): |
|
|
51 |
url = "https://dataverse.harvard.edu/api/access/datafile/XXXX" |
|
|
52 |
file_name = 'kgwas_data' |
|
|
53 |
self._download_and_extract_data(url, file_name) |
|
|
54 |
|
|
|
55 |
def _merge_with_rsync(self, src, dst): |
|
|
56 |
"""Merge directories using rsync.""" |
|
|
57 |
try: |
|
|
58 |
subprocess.run( |
|
|
59 |
["rsync", "-a", "--ignore-existing", src + "/", dst + "/"], |
|
|
60 |
check=True, |
|
|
61 |
stdout=subprocess.PIPE, |
|
|
62 |
stderr=subprocess.PIPE, |
|
|
63 |
) |
|
|
64 |
except subprocess.CalledProcessError as e: |
|
|
65 |
print(f"Error during rsync: {e.stderr.decode()}") |
|
|
66 |
|
|
|
67 |
def _download_and_extract_data(self, url, file_name): |
|
|
68 |
"""Download, extract, and merge directories using rsync.""" |
|
|
69 |
tar_file_path = os.path.join(self.data_path, f"{file_name}.tar.gz") |
|
|
70 |
|
|
|
71 |
# Download the file |
|
|
72 |
print(f"Downloading {file_name}.tar.gz...") |
|
|
73 |
self._download_with_progress(url, tar_file_path) |
|
|
74 |
print("Download complete.") |
|
|
75 |
|
|
|
76 |
# Extract the tar.gz file |
|
|
77 |
print("Extracting files...") |
|
|
78 |
with tarfile.open(tar_file_path, 'r:gz') as tar: |
|
|
79 |
tar.extractall(self.data_path) |
|
|
80 |
print("Extraction complete.") |
|
|
81 |
|
|
|
82 |
# Clean up the tar.gz file |
|
|
83 |
os.remove(tar_file_path) |
|
|
84 |
|
|
|
85 |
# Merge extracted contents into the data_path directory |
|
|
86 |
extracted_dir = os.path.join(self.data_path, file_name) |
|
|
87 |
if os.path.exists(extracted_dir): |
|
|
88 |
print(f"Merging extracted directory '{extracted_dir}' into '{self.data_path}'...") |
|
|
89 |
self._merge_with_rsync(extracted_dir, self.data_path) |
|
|
90 |
|
|
|
91 |
# Remove the now-empty extracted directory |
|
|
92 |
shutil.rmtree(extracted_dir) |
|
|
93 |
|
|
|
94 |
def _download_with_progress(self, url, file_path): |
|
|
95 |
"""Download a file with a progress bar.""" |
|
|
96 |
request = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'}) |
|
|
97 |
response = urllib.request.urlopen(request) |
|
|
98 |
total_size = int(response.getheader('Content-Length').strip()) |
|
|
99 |
block_size = 1024 # 1 KB |
|
|
100 |
|
|
|
101 |
with open(file_path, 'wb') as file, tqdm( |
|
|
102 |
total=total_size, unit='B', unit_scale=True, desc="Downloading" |
|
|
103 |
) as pbar: |
|
|
104 |
while True: |
|
|
105 |
buffer = response.read(block_size) |
|
|
106 |
if not buffer: |
|
|
107 |
break |
|
|
108 |
file.write(buffer) |
|
|
109 |
pbar.update(len(buffer)) |
|
|
110 |
|
|
|
111 |
|
|
|
112 |
def load_kg(self, snp_init_emb = 'enformer', |
|
|
113 |
go_init_emb = 'random', |
|
|
114 |
gene_init_emb = 'esm', |
|
|
115 |
sample_edges = False, |
|
|
116 |
sample_ratio = 1): |
|
|
117 |
|
|
|
118 |
data_path = self.data_path |
|
|
119 |
|
|
|
120 |
## Load KG |
|
|
121 |
|
|
|
122 |
print('--loading KG---') |
|
|
123 |
idx2id = load_dict(os.path.join(data_path, 'cell_kg/network/node_idx2id.pkl')) |
|
|
124 |
edge_index_all = load_dict(os.path.join(data_path, 'cell_kg/network/edge_index.pkl')) |
|
|
125 |
id2idx = load_dict(os.path.join(data_path, 'cell_kg/network/node_id2idx.pkl')) |
|
|
126 |
self.id2idx = id2idx |
|
|
127 |
self.idx2id = idx2id |
|
|
128 |
|
|
|
129 |
data = HeteroData() |
|
|
130 |
|
|
|
131 |
## Load initialized embeddings |
|
|
132 |
|
|
|
133 |
if snp_init_emb == 'random': |
|
|
134 |
print('--using random SNP embedding--') |
|
|
135 |
|
|
|
136 |
data['SNP'].x = torch.rand((len(idx2id['SNP']), 128), requires_grad = False) |
|
|
137 |
snp_init_dim_size = 128 |
|
|
138 |
elif snp_init_emb == 'kg': |
|
|
139 |
print('--using KG SNP embedding--') |
|
|
140 |
|
|
|
141 |
id2idx_kg = load_dict(os.path.join(data_path, 'cell_kg/node_emb/transe_emb/transe_emb_id2idx_kg.pkl')) |
|
|
142 |
kg_emb = load_dict(os.path.join(data_path, 'cell_kg/node_emb/transe_emb/transe_emb_inverse_triplets.pkl')) |
|
|
143 |
node_map = idx2id['SNP'] |
|
|
144 |
data['SNP'].x = torch.vstack([torch.tensor(kg_emb[id2idx_kg[node_map[i]]]) if node_map[i] in id2idx_kg \ |
|
|
145 |
else torch.rand(50, requires_grad = False) for i in range(len(node_map))]) |
|
|
146 |
snp_init_dim_size = 50 |
|
|
147 |
|
|
|
148 |
elif snp_init_emb == 'cadd': |
|
|
149 |
print('--using CADD SNP embedding--') |
|
|
150 |
|
|
|
151 |
df_variant = pd.read_csv(os.path.join(data_path, 'cell_kg/node_emb/variant_emb/cadd_feat.csv')) |
|
|
152 |
df_variant = df_variant.set_index('Unnamed: 0') |
|
|
153 |
variant_feat = df_variant.values |
|
|
154 |
node_map = idx2id['SNP'] |
|
|
155 |
rs2idx_feat = dict(zip(df_variant.index.values, range(len(df_variant.index.values)))) |
|
|
156 |
data['SNP'].x = torch.vstack([torch.tensor(variant_feat[rs2idx_feat[node_map[i]]]) if node_map[i] in rs2idx_feat \ |
|
|
157 |
else torch.rand(64, requires_grad = False) for i in range(len(node_map))]).float() |
|
|
158 |
snp_init_dim_size = 64 |
|
|
159 |
|
|
|
160 |
|
|
|
161 |
elif snp_init_emb == 'baselineLD': |
|
|
162 |
print('--using baselineLD SNP embedding--') |
|
|
163 |
node_map = idx2id['SNP'] |
|
|
164 |
rs2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/variant_emb/baselineld_feat.pkl')) |
|
|
165 |
data['SNP'].x = torch.vstack([torch.tensor(rs2idx_feat[node_map[i]]) if node_map[i] in rs2idx_feat \ |
|
|
166 |
else torch.rand(70, requires_grad = False) for i in range(len(node_map))]).float() |
|
|
167 |
snp_init_dim_size = 70 |
|
|
168 |
|
|
|
169 |
elif snp_init_emb == 'SLDSC': |
|
|
170 |
print('--using SLDSC SNP embedding--') |
|
|
171 |
node_map = idx2id['SNP'] |
|
|
172 |
rs2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/variant_emb/sldsc_feat.pkl')) |
|
|
173 |
data['SNP'].x = torch.vstack([torch.tensor(rs2idx_feat[node_map[i]]) if node_map[i] in rs2idx_feat \ |
|
|
174 |
else torch.rand(165, requires_grad = False) for i in range(len(node_map))]).float() |
|
|
175 |
snp_init_dim_size = 165 |
|
|
176 |
|
|
|
177 |
elif snp_init_emb == 'enformer': |
|
|
178 |
print('--using enformer SNP embedding--') |
|
|
179 |
node_map = idx2id['SNP'] |
|
|
180 |
rs2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/variant_emb/enformer_feat.pkl')) |
|
|
181 |
data['SNP'].x = torch.vstack([torch.tensor(rs2idx_feat[node_map[i]]) if node_map[i] in rs2idx_feat \ |
|
|
182 |
else torch.rand(20, requires_grad = False) for i in range(len(node_map))]).float() |
|
|
183 |
snp_init_dim_size = 20 |
|
|
184 |
|
|
|
185 |
|
|
|
186 |
if go_init_emb == 'random': |
|
|
187 |
print('--using random go embedding--') |
|
|
188 |
|
|
|
189 |
for rel in ['CellularComponent', 'BiologicalProcess', 'MolecularFunction']: |
|
|
190 |
data[rel].x = torch.rand((len(idx2id[rel]), 128), requires_grad = False) |
|
|
191 |
go_init_dim_size = 128 |
|
|
192 |
elif go_init_emb == 'kg': |
|
|
193 |
print('--using KG go embedding--') |
|
|
194 |
|
|
|
195 |
id2idx_kg = load_dict(os.path.join(data_path, 'cell_kg/node_emb/transe_emb/transe_emb_id2idx_kg.pkl')) |
|
|
196 |
kg_emb = load_dict(os.path.join(data_path, 'cell_kg/node_emb/transe_emb/transe_emb_inverse_triplets.pkl')) |
|
|
197 |
|
|
|
198 |
for rel in ['CellularComponent', 'BiologicalProcess', 'MolecularFunction']: |
|
|
199 |
node_map = idx2id[rel] |
|
|
200 |
data[rel].x = torch.vstack([torch.tensor(kg_emb[id2idx_kg[node_map[i]]]) if node_map[i] in id2idx_kg \ |
|
|
201 |
else torch.rand(50, requires_grad = False) for i in range(len(node_map))]) |
|
|
202 |
go_init_dim_size = 50 |
|
|
203 |
|
|
|
204 |
elif go_init_emb == 'biogpt': |
|
|
205 |
print('--using biogpt go embedding--') |
|
|
206 |
|
|
|
207 |
go2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/program_emb/biogpt_feat.pkl')) |
|
|
208 |
for rel in ['CellularComponent', 'BiologicalProcess', 'MolecularFunction']: |
|
|
209 |
node_map = idx2id[rel] |
|
|
210 |
data[rel].x = torch.vstack([torch.tensor(go2idx_feat[node_map[i]]) if node_map[i] in go2idx_feat \ |
|
|
211 |
else torch.rand(1600, requires_grad = False) for i in range(len(node_map))]).float() |
|
|
212 |
go_init_dim_size = 1600 |
|
|
213 |
|
|
|
214 |
|
|
|
215 |
if gene_init_emb == 'random': |
|
|
216 |
print('--using random gene embedding--') |
|
|
217 |
|
|
|
218 |
data['Gene'].x = torch.rand((len(idx2id['Gene']), 128), requires_grad = False) |
|
|
219 |
gene_init_dim_size = 128 |
|
|
220 |
elif gene_init_emb == 'kg': |
|
|
221 |
print('--using KG gene embedding--') |
|
|
222 |
id2idx_kg = load_dict(os.path.join(data_path, 'cell_kg/node_emb/transe_emb/transe_emb_id2idx_kg.pkl')) |
|
|
223 |
kg_emb = load_dict(os.path.join(data_path, 'cell_kg/node_emb/transe_emb/transe_emb_inverse_triplets.pkl')) |
|
|
224 |
node_map = idx2id['Gene'] |
|
|
225 |
data['Gene'].x = torch.vstack([torch.tensor(kg_emb[id2idx_kg[node_map[i]]]) if node_map[i] in id2idx_kg \ |
|
|
226 |
else torch.rand(50, requires_grad = False) for i in range(len(node_map))]) |
|
|
227 |
gene_init_dim_size = 50 |
|
|
228 |
|
|
|
229 |
elif gene_init_emb == 'esm': |
|
|
230 |
print('--using ESM gene embedding--') |
|
|
231 |
|
|
|
232 |
gene2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/gene_emb/esm_feat.pkl')) |
|
|
233 |
node_map = idx2id['Gene'] |
|
|
234 |
data['Gene'].x = torch.vstack([torch.tensor(gene2idx_feat[node_map[i]]) if node_map[i] in gene2idx_feat \ |
|
|
235 |
else torch.rand(5120, requires_grad = False) for i in range(len(node_map))]).float() |
|
|
236 |
gene_init_dim_size = 5120 |
|
|
237 |
elif gene_init_emb == 'pops': |
|
|
238 |
print('--using PoPs expression+PPI+pathways gene embedding--') |
|
|
239 |
|
|
|
240 |
gene2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/gene_emb/pops_feat.pkl')) |
|
|
241 |
node_map = idx2id['Gene'] |
|
|
242 |
data['Gene'].x = torch.vstack([torch.tensor(gene2idx_feat[node_map[i]]) if node_map[i] in gene2idx_feat \ |
|
|
243 |
else torch.rand(57742, requires_grad = False) for i in range(len(node_map))]).float() |
|
|
244 |
gene_init_dim_size = 57742 |
|
|
245 |
elif gene_init_emb == 'pops_expression': |
|
|
246 |
print('--using PoPs expression only gene embedding--') |
|
|
247 |
|
|
|
248 |
gene2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/gene_emb/pops_expression_feat.pkl')) |
|
|
249 |
node_map = idx2id['Gene'] |
|
|
250 |
data['Gene'].x = torch.vstack([torch.tensor(gene2idx_feat[node_map[i]]) if node_map[i] in gene2idx_feat \ |
|
|
251 |
else torch.rand(40546, requires_grad = False) for i in range(len(node_map))]).float() |
|
|
252 |
gene_init_dim_size = 40546 |
|
|
253 |
|
|
|
254 |
|
|
|
255 |
self.gene_init_dim_size = gene_init_dim_size |
|
|
256 |
self.go_init_dim_size = go_init_dim_size |
|
|
257 |
self.snp_init_dim_size = snp_init_dim_size |
|
|
258 |
|
|
|
259 |
for i,j in edge_index_all.items(): |
|
|
260 |
|
|
|
261 |
if sample_edges: |
|
|
262 |
edge_index = torch.tensor(j) |
|
|
263 |
num_edges = edge_index.size(1) |
|
|
264 |
num_samples = int(num_edges * sample_ratio) |
|
|
265 |
indices = torch.randperm(num_edges)[:num_samples] |
|
|
266 |
sampled_edge_index = edge_index[:, indices] |
|
|
267 |
print(i, ' sampling ratio ', sample_ratio, ' from ', edge_index.shape[1], ' to ', sampled_edge_index.shape[1]) |
|
|
268 |
data[i].edge_index = sampled_edge_index |
|
|
269 |
else: |
|
|
270 |
data[i].edge_index = torch.tensor(j) |
|
|
271 |
data = T.ToUndirected()(data) |
|
|
272 |
data = T.AddSelfLoops()(data) |
|
|
273 |
self.data = data |
|
|
274 |
|
|
|
275 |
def load_simulation_gwas(self, simulation_type, seed): |
|
|
276 |
data_path = self.data_path |
|
|
277 |
print('Using simulation data....') |
|
|
278 |
small_cohort = 5000 |
|
|
279 |
num_causal_hits = 20000 |
|
|
280 |
heritability = 0.3 |
|
|
281 |
self.sample_size = small_cohort |
|
|
282 |
if simulation_type == 'causal_link': |
|
|
283 |
lr_uni = pd.read_csv(os.path.join(data_path, 'simulation_gwas/causal_link_simulation/' + str(num_causal_hits) + '_' + str(seed) + '_' + str(heritability) + '_graph_funct_v2_ggi.fastGWA'), sep = '\t') |
|
|
284 |
elif simulation_type == 'causal': |
|
|
285 |
lr_uni = pd.read_csv(os.path.join(data_path, 'simulation_gwas/causal_simulation/' + str(num_causal_hits) + '_' + str(seed) + '_' + str(heritability) + '_' + str(small_cohort) + '_graph_funct_v2.fastGWA'), sep = '\t') |
|
|
286 |
elif simulation_type == 'null': |
|
|
287 |
lr_uni = pd.read_csv(os.path.join(data_path, 'simulation_gwas/null_simulation/' + str(num_causal_hits) + '_' + str(seed) + '_' + str(heritability) + '_' + str(small_cohort) + '.fastGWA'), sep = '\t') |
|
|
288 |
|
|
|
289 |
if ('SNP' in lr_uni.columns.values) and ('ID' in lr_uni.columns.values): |
|
|
290 |
self.lr_uni = lr_uni.rename(columns = {'CHR': '#CHROM'}) |
|
|
291 |
else: |
|
|
292 |
self.lr_uni = lr_uni.rename(columns = {'CHR': '#CHROM', 'SNP': 'ID'}) |
|
|
293 |
self.seed = seed |
|
|
294 |
self.pheno = 'simulation' |
|
|
295 |
|
|
|
296 |
def load_external_gwas(self, path = None, seed = 42, example_file = False): |
|
|
297 |
if example_file: |
|
|
298 |
print('Loading example GWAS file...') |
|
|
299 |
url = "https://dataverse.harvard.edu/api/access/datafile/10730346" |
|
|
300 |
example_file_path = os.path.join(self.data_path, 'biochemistry_Creatinine_fastgwa_full_10000_1.fastGWA') |
|
|
301 |
|
|
|
302 |
# Check if the example file is already downloaded |
|
|
303 |
if not os.path.exists(example_file_path): |
|
|
304 |
print('Example file not found locally. Downloading...') |
|
|
305 |
self._download_with_progress(url, example_file_path) |
|
|
306 |
print('Example file downloaded successfully.') |
|
|
307 |
else: |
|
|
308 |
print('Example file already exists locally.') |
|
|
309 |
|
|
|
310 |
path = example_file_path |
|
|
311 |
|
|
|
312 |
if path is None: |
|
|
313 |
raise ValueError("A valid path must be provided or example_file must be set to True.") |
|
|
314 |
|
|
|
315 |
print(f'Loading GWAS file from {path}...') |
|
|
316 |
|
|
|
317 |
lr_uni = pd.read_csv(path, sep=None, engine='python') |
|
|
318 |
if 'CHR' not in lr_uni.columns.values: |
|
|
319 |
raise ValueError('CHR chromosome not in the file!') |
|
|
320 |
if 'SNP' not in lr_uni.columns.values: |
|
|
321 |
raise ValueError('SNP column not in the file!') |
|
|
322 |
if 'P' not in lr_uni.columns.values: |
|
|
323 |
raise ValueError('P column not in the file!') |
|
|
324 |
if 'N' not in lr_uni.columns.values: |
|
|
325 |
raise ValueError('N column number of sample size not in the file!') |
|
|
326 |
lr_uni = lr_uni.rename(columns = {'CHR': '#CHROM', 'SNP': 'ID'}) |
|
|
327 |
|
|
|
328 |
## filtering to the current KG variant set |
|
|
329 |
old_variant_set_len = len(lr_uni) |
|
|
330 |
lr_uni = lr_uni[lr_uni.ID.isin(list(self.idx2id['SNP'].values()))] |
|
|
331 |
print('Number of SNPs in the KG:', len(self.idx2id['SNP'])) |
|
|
332 |
print('Number of SNPs in the GWAS:', old_variant_set_len) |
|
|
333 |
print('Number of SNPs in the KG variant set:', len(lr_uni)) |
|
|
334 |
|
|
|
335 |
self.lr_uni = lr_uni |
|
|
336 |
self.sample_size = lr_uni.N.values[0] |
|
|
337 |
self.pheno = 'EXTERNAL' |
|
|
338 |
self.seed = seed |
|
|
339 |
|
|
|
340 |
|
|
|
341 |
def load_full_gwas(self, pheno, seed=42): |
|
|
342 |
data_path = self.data_path |
|
|
343 |
if pheno in scdrs_traits: |
|
|
344 |
print('Using scdrs traits...') |
|
|
345 |
self.pheno = pheno |
|
|
346 |
lr_uni = pd.read_csv(os.path.join(data_path, 'scDRS_Data/sumstats_ukb_snps.csv')) |
|
|
347 |
lr_uni = lr_uni[['CHR', 'SNP', 'POS', 'A1', 'A2', 'N', 'AF1', pheno]] |
|
|
348 |
lr_uni = lr_uni[lr_uni[pheno].notnull()].reset_index(drop = True) |
|
|
349 |
lr_uni = lr_uni.rename(columns = {'CHR': '#CHROM', 'SNP': 'ID', pheno: 'chi'}) |
|
|
350 |
print('number of SNPs:', len(lr_uni)) |
|
|
351 |
self.lr_uni = lr_uni |
|
|
352 |
self.seed = seed |
|
|
353 |
|
|
|
354 |
trait2size = pickle.load(open(os.path.join(data_path, 'scDRS_data/trait2size.pkl'), 'rb')) |
|
|
355 |
self.sample_size = trait2size[pheno] |
|
|
356 |
|
|
|
357 |
else: |
|
|
358 |
## load GWAS files |
|
|
359 |
self.pheno = pheno |
|
|
360 |
lr_uni = pd.read_csv(os.path.join(data_path, 'full_gwas/' + str(self.pheno) + '_with_rel_fastgwa.fastGWA'), sep = '\t') |
|
|
361 |
lr_uni = lr_uni.rename(columns = {'CHR': '#CHROM', 'SNP': 'ID'}) |
|
|
362 |
|
|
|
363 |
self.lr_uni = lr_uni |
|
|
364 |
self.seed = seed |
|
|
365 |
self.sample_size = 387113 |
|
|
366 |
|
|
|
367 |
def load_gwas_subsample(self, pheno, sample_size, seed): |
|
|
368 |
data_path = self.data_path |
|
|
369 |
if pheno in ['body_BALDING1', 'cancer_BREAST', 'disease_ALLERGY_ECZEMA_DIAGNOSED', 'disease_HYPOTHYROIDISM_SELF_REP', 'other_MORNINGPERSON', 'pigment_SUNBURN']: |
|
|
370 |
binary = True |
|
|
371 |
else: |
|
|
372 |
binary = False |
|
|
373 |
## load GWAS files |
|
|
374 |
self.sample_size = sample_size |
|
|
375 |
self.pheno = pheno |
|
|
376 |
if (sample_size > 3000): |
|
|
377 |
lr_uni = pd.read_csv(os.path.join(data_path, 'subsample_gwas/' + str(self.pheno) + \ |
|
|
378 |
'_fastgwa_full_'+ str(sample_size) + '_' + str(seed) + '.fastGWA'), sep = '\t') |
|
|
379 |
lr_uni = lr_uni.rename(columns = {'CHR': '#CHROM', 'SNP': 'ID'}) |
|
|
380 |
else: |
|
|
381 |
## use PLINK if sample size <3000 |
|
|
382 |
if binary: |
|
|
383 |
lr_uni = pd.read_csv(os.path.join(data_path, 'subsample_gwas/' + str(self.pheno) + \ |
|
|
384 |
'_plink_'+ str(sample_size) + '_' + str(seed) + '.PHENO1.glm.logistic.hybrid'), sep = '\t') |
|
|
385 |
else: |
|
|
386 |
lr_uni = pd.read_csv(os.path.join(data_path, 'subsample_gwas/' + + str(self.pheno) + \ |
|
|
387 |
'_plink_'+ str(sample_size) + '_' + str(seed) + '.PHENO1.glm.linear'), sep = '\t') |
|
|
388 |
self.lr_uni = lr_uni |
|
|
389 |
self.seed = seed |
|
|
390 |
|
|
|
391 |
def process_gwas_file(self, label = 'chi'): |
|
|
392 |
data_path = self.data_path |
|
|
393 |
lr_uni = self.lr_uni |
|
|
394 |
## LD scores |
|
|
395 |
|
|
|
396 |
ld_scores = pd.read_csv(os.path.join(data_path, 'ld_score/filter_genotyped_ldscores.csv')) |
|
|
397 |
w_ld_scores = pd.read_csv(os.path.join(data_path, 'ld_score/ldscores_from_data.csv')) |
|
|
398 |
|
|
|
399 |
m = 15000000 |
|
|
400 |
if 'N' not in lr_uni.columns.values: |
|
|
401 |
n = self.sample_size |
|
|
402 |
else: |
|
|
403 |
n = np.mean(lr_uni.N) |
|
|
404 |
h_g_2 = 0.5 |
|
|
405 |
rs_id_2_ld_scores = dict(ld_scores.values) |
|
|
406 |
|
|
|
407 |
rs_id_2_ld_scores = dict(ld_scores.values) |
|
|
408 |
rs_id_2_w_ld = dict(w_ld_scores.values) |
|
|
409 |
|
|
|
410 |
## use min ld score for snps with no ld score |
|
|
411 |
min_ld = min(rs_id_2_ld_scores.values()) |
|
|
412 |
lr_uni['ld_score'] = lr_uni.ID.apply(lambda x: rs_id_2_ld_scores[x] if x in rs_id_2_ld_scores else min_ld) |
|
|
413 |
rs_id_2_ld_scores = dict(lr_uni[['ID', 'ld_score']].values) |
|
|
414 |
|
|
|
415 |
min_ld = min(rs_id_2_w_ld.values()) |
|
|
416 |
## the data LD is without the query SNP itself. so here add 1 |
|
|
417 |
lr_uni['w_ld_score'] = 1 + lr_uni.ID.apply(lambda x: rs_id_2_w_ld[x] if x in rs_id_2_w_ld else min_ld) |
|
|
418 |
rs_id_2_w_ld = dict(lr_uni[['ID', 'w_ld_score']].values) |
|
|
419 |
|
|
|
420 |
print('Using ldsc weight...') |
|
|
421 |
ld = np.array([rs_id_2_ld_scores[rs_id] for rs_id in lr_uni.ID.values]) |
|
|
422 |
w_ld = np.array([rs_id_2_w_ld[rs_id] for rs_id in lr_uni.ID.values]) |
|
|
423 |
|
|
|
424 |
ldsc_weight = ldsc_regression_weights(ld, w_ld, n, m, h_g_2) |
|
|
425 |
ldsc_weight = ldsc_weight/np.mean(ldsc_weight) |
|
|
426 |
print('ldsc_weight mean: ', np.mean(ldsc_weight)) |
|
|
427 |
self.rs_id_to_ldsc_weight = dict(zip(lr_uni.ID.values, ldsc_weight)) |
|
|
428 |
|
|
|
429 |
## chi-square label |
|
|
430 |
if label == 'chi': |
|
|
431 |
if 'chi' in lr_uni.columns.values: |
|
|
432 |
print('chi pre-computed...') |
|
|
433 |
lr_uni['y'] = lr_uni['chi'].values |
|
|
434 |
else: |
|
|
435 |
if self.pheno in (['body_BALDING1', 'cancer_BREAST', 'disease_ALLERGY_ECZEMA_DIAGNOSED', 'disease_HYPOTHYROIDISM_SELF_REP', 'other_MORNINGPERSON', 'pigment_SUNBURN']) and (self.sample_size <= 3000): |
|
|
436 |
lr_uni['y'] = lr_uni['Z_STAT'].values**2 |
|
|
437 |
lr_uni['y'] = lr_uni.y.fillna(0) |
|
|
438 |
else: |
|
|
439 |
if ('BETA' in lr_uni.columns.values) and ('SE' in lr_uni.columns.values): |
|
|
440 |
lr_uni['y'] = (lr_uni['BETA']/lr_uni['SE']).values**2 |
|
|
441 |
lr_uni['y'] = lr_uni.y.fillna(0) |
|
|
442 |
else: |
|
|
443 |
from scipy.stats import chi2 |
|
|
444 |
## convert from p-values |
|
|
445 |
lr_uni['y'] = chi2.ppf(1 - lr_uni['P'].values, 1) |
|
|
446 |
lr_uni['y'] = lr_uni.y.fillna(0) |
|
|
447 |
|
|
|
448 |
|
|
|
449 |
elif label == 'residual-w-ld': |
|
|
450 |
lr_uni['y'] = (lr_uni['BETA']/lr_uni['SE']).values**2 |
|
|
451 |
lr_uni['y'] = lr_uni.y.fillna(0) |
|
|
452 |
lr_uni['ld_weight'] = lr_uni.ID.apply(lambda x: self.rs_id_to_ldsc_weight[x]) |
|
|
453 |
import statsmodels.api as sm |
|
|
454 |
|
|
|
455 |
X = lr_uni.w_ld_score.values |
|
|
456 |
y = lr_uni.y.values |
|
|
457 |
weights = lr_uni.ld_weight.values |
|
|
458 |
X = sm.add_constant(X) |
|
|
459 |
model = sm.WLS(y, X, weights=weights) |
|
|
460 |
results = model.fit() |
|
|
461 |
y_pred = results.params[0] + results.params[1] * lr_uni.w_ld_score.values |
|
|
462 |
lr_uni['y'] = y - y_pred |
|
|
463 |
elif label == 'residual-ld': |
|
|
464 |
lr_uni['y'] = (lr_uni['BETA']/lr_uni['SE']).values**2 |
|
|
465 |
lr_uni['y'] = lr_uni.y.fillna(0) |
|
|
466 |
lr_uni['ld_weight'] = lr_uni.ID.apply(lambda x: self.rs_id_to_ldsc_weight[x]) |
|
|
467 |
import statsmodels.api as sm |
|
|
468 |
|
|
|
469 |
X = lr_uni.ld_score.values |
|
|
470 |
y = lr_uni.y.values |
|
|
471 |
weights = lr_uni.ld_weight.values |
|
|
472 |
X = sm.add_constant(X) |
|
|
473 |
model = sm.WLS(y, X, weights=weights) |
|
|
474 |
results = model.fit() |
|
|
475 |
y_pred = results.params[0] + results.params[1] * lr_uni.w_ld_score.values |
|
|
476 |
lr_uni['y'] = y - y_pred |
|
|
477 |
elif label == 'residual-ld-ols': |
|
|
478 |
lr_uni['y'] = (lr_uni['BETA']/lr_uni['SE']).values**2 |
|
|
479 |
lr_uni['y'] = lr_uni.y.fillna(0) |
|
|
480 |
import statsmodels.api as sm |
|
|
481 |
|
|
|
482 |
X = lr_uni.ld_score.values |
|
|
483 |
y = lr_uni.y.values |
|
|
484 |
X = sm.add_constant(X) |
|
|
485 |
model = sm.OLS(y, X) |
|
|
486 |
results = model.fit() |
|
|
487 |
y_pred = results.params[0] + results.params[1] * lr_uni.w_ld_score.values |
|
|
488 |
lr_uni['y'] = y - y_pred |
|
|
489 |
elif label == 'residual-ld-ols-abs': |
|
|
490 |
lr_uni['y'] = (lr_uni['BETA']/lr_uni['SE']).values**2 |
|
|
491 |
lr_uni['y'] = lr_uni.y.fillna(0) |
|
|
492 |
import statsmodels.api as sm |
|
|
493 |
|
|
|
494 |
X = lr_uni.ld_score.values |
|
|
495 |
y = lr_uni.y.values |
|
|
496 |
X = sm.add_constant(X) |
|
|
497 |
model = sm.OLS(y, X) |
|
|
498 |
results = model.fit() |
|
|
499 |
y_pred = results.params[0] + results.params[1] * lr_uni.w_ld_score.values |
|
|
500 |
lr_uni['y'] = np.abs(y - y_pred) |
|
|
501 |
elif label == 'residual-w-ld-ols': |
|
|
502 |
lr_uni['y'] = (lr_uni['BETA']/lr_uni['SE']).values**2 |
|
|
503 |
lr_uni['y'] = lr_uni.y.fillna(0) |
|
|
504 |
import statsmodels.api as sm |
|
|
505 |
|
|
|
506 |
X = lr_uni.w_ld_score.values |
|
|
507 |
y = lr_uni.y.values |
|
|
508 |
X = sm.add_constant(X) |
|
|
509 |
model = sm.OLS(y, X) |
|
|
510 |
results = model.fit() |
|
|
511 |
y_pred = results.params[0] + results.params[1] * lr_uni.w_ld_score.values |
|
|
512 |
lr_uni['y'] = y - y_pred |
|
|
513 |
|
|
|
514 |
id2y = dict(lr_uni[['ID', 'y']].values) |
|
|
515 |
all_ids = lr_uni.ID.values |
|
|
516 |
self.all_ids = np.array([self.id2idx['SNP'][i] for i in all_ids]) |
|
|
517 |
self.y = lr_uni.y.values |
|
|
518 |
#idx2y = dict(zip(self.all_ids, y)) |
|
|
519 |
|
|
|
520 |
self.lr_uni = lr_uni |
|
|
521 |
|
|
|
522 |
def prepare_split(self, test_set_fraction_data = 0.05): |
|
|
523 |
|
|
|
524 |
## split SNPs to train/test/valid |
|
|
525 |
train_val_ids, test_ids, y_train_val, y_test = train_test_split(self.all_ids, self.y, test_size=test_set_fraction_data, random_state=self.seed) |
|
|
526 |
train_ids, val_ids, y_train, y_val = train_test_split(train_val_ids, y_train_val, test_size=0.05, random_state=self.seed) |
|
|
527 |
|
|
|
528 |
self.train_input_nodes = ('SNP', train_ids) |
|
|
529 |
self.val_input_nodes = ('SNP', val_ids) |
|
|
530 |
self.test_input_nodes = ('SNP', test_ids) |
|
|
531 |
|
|
|
532 |
y_snp = torch.zeros(self.data['SNP'].x.shape[0]) - 1 |
|
|
533 |
y_snp[train_ids] = torch.tensor(y_train).float() |
|
|
534 |
y_snp[val_ids] = torch.tensor(y_val).float() |
|
|
535 |
y_snp[test_ids] = torch.tensor(y_test).float() |
|
|
536 |
|
|
|
537 |
self.data['SNP'].y = y_snp |
|
|
538 |
for i in self.data.node_types: |
|
|
539 |
self.data[i].n_id = torch.arange(self.data[i].x.shape[0]) |
|
|
540 |
|
|
|
541 |
self.data.train_mask = train_ids |
|
|
542 |
self.data.val_mask = val_ids |
|
|
543 |
self.data.test_mask = test_ids |
|
|
544 |
self.data.all_mask = self.all_ids |
|
|
545 |
#data = data.to(args.device) |
|
|
546 |
|
|
|
547 |
def get_pheno_list(self): |
|
|
548 |
return {"large_cohort": scdrs_traits, |
|
|
549 |
"21_indep_traits": ['body_BALDING1', |
|
|
550 |
'disease_ALLERGY_ECZEMA_DIAGNOSED', |
|
|
551 |
'disease_HYPOTHYROIDISM_SELF_REP', 'pigment_SUNBURN', |
|
|
552 |
'21001', '50', '30080', '30070', '30010', '30000', |
|
|
553 |
'biochemistry_AlkalinePhosphatase', |
|
|
554 |
'biochemistry_AspartateAminotransferase', |
|
|
555 |
'biochemistry_Cholesterol', 'biochemistry_Creatinine', |
|
|
556 |
'biochemistry_IGF1', 'biochemistry_Phosphate', |
|
|
557 |
'biochemistry_Testosterone_Male', 'biochemistry_TotalBilirubin', |
|
|
558 |
'biochemistry_TotalProtein', 'biochemistry_VitaminD', |
|
|
559 |
'bmd_HEEL_TSCOREz']} |