Diff of /kgwas/kgwas_data.py [000000] .. [8790ab]

Switch to unified view

a b/kgwas/kgwas_data.py
1
from torch_geometric.data import HeteroData
2
import torch_geometric.transforms as T
3
4
from sklearn.model_selection import train_test_split
5
import pandas as pd
6
import torch
7
import numpy as np
8
import pickle
9
import os
10
import tarfile
11
import urllib.request
12
import shutil
13
from tqdm import tqdm
14
import subprocess
15
16
from .utils import ldsc_regression_weights, load_dict
17
from .params import scdrs_traits
18
19
class KGWAS_Data:
20
    def __init__(self, data_path='./data/'):
21
        self.data_path = data_path
22
        
23
        # Ensure the data path exists
24
        if not os.path.exists(data_path):
25
            os.makedirs(data_path)
26
        
27
        # Check if relevant data exists in the data_path
28
        required_files = [
29
            'cell_kg/network/node_idx2id.pkl',
30
            'cell_kg/network/edge_index.pkl',
31
            'cell_kg/network/node_id2idx.pkl',
32
            'cell_kg/node_emb/variant_emb/enformer_feat.pkl',
33
            'cell_kg/node_emb/gene_emb/esm_feat.pkl',
34
            'ld_score/filter_genotyped_ldscores.csv',
35
            'ld_score/ldscores_from_data.csv',
36
            'ld_score/ukb_white_ld_10MB_no_hla.pkl',
37
            'ld_score/ukb_white_ld_10MB.pkl',
38
            'misc_data/ukb_white_with_cm.bim',
39
        ]
40
        missing_files = [f for f in required_files if not os.path.exists(os.path.join(data_path, f))]
41
        
42
        if missing_files:
43
            print("Relevant data not found in the data_path. Downloading and extracting data...")
44
            url = "https://dataverse.harvard.edu/api/access/datafile/10731230"
45
            file_name = 'kgwas_core_data'
46
            self._download_and_extract_data(url, file_name)
47
        else:
48
            print("All required data files are present.")
49
50
    def download_all_data(self):
51
        url = "https://dataverse.harvard.edu/api/access/datafile/XXXX"
52
        file_name = 'kgwas_data'
53
        self._download_and_extract_data(url, file_name)
54
55
    def _merge_with_rsync(self, src, dst):
56
        """Merge directories using rsync."""
57
        try:
58
            subprocess.run(
59
                ["rsync", "-a", "--ignore-existing", src + "/", dst + "/"],
60
                check=True,
61
                stdout=subprocess.PIPE,
62
                stderr=subprocess.PIPE,
63
            )
64
        except subprocess.CalledProcessError as e:
65
            print(f"Error during rsync: {e.stderr.decode()}")
66
67
    def _download_and_extract_data(self, url, file_name):
68
        """Download, extract, and merge directories using rsync."""
69
        tar_file_path = os.path.join(self.data_path, f"{file_name}.tar.gz")
70
71
        # Download the file
72
        print(f"Downloading {file_name}.tar.gz...")
73
        self._download_with_progress(url, tar_file_path)
74
        print("Download complete.")
75
76
        # Extract the tar.gz file
77
        print("Extracting files...")
78
        with tarfile.open(tar_file_path, 'r:gz') as tar:
79
            tar.extractall(self.data_path)
80
        print("Extraction complete.")
81
82
        # Clean up the tar.gz file
83
        os.remove(tar_file_path)
84
85
        # Merge extracted contents into the data_path directory
86
        extracted_dir = os.path.join(self.data_path, file_name)
87
        if os.path.exists(extracted_dir):
88
            print(f"Merging extracted directory '{extracted_dir}' into '{self.data_path}'...")
89
            self._merge_with_rsync(extracted_dir, self.data_path)
90
91
            # Remove the now-empty extracted directory
92
            shutil.rmtree(extracted_dir)
93
94
    def _download_with_progress(self, url, file_path):
95
        """Download a file with a progress bar."""
96
        request = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
97
        response = urllib.request.urlopen(request)
98
        total_size = int(response.getheader('Content-Length').strip())
99
        block_size = 1024  # 1 KB
100
101
        with open(file_path, 'wb') as file, tqdm(
102
            total=total_size, unit='B', unit_scale=True, desc="Downloading"
103
        ) as pbar:
104
            while True:
105
                buffer = response.read(block_size)
106
                if not buffer:
107
                    break
108
                file.write(buffer)
109
                pbar.update(len(buffer))
110
111
112
    def load_kg(self, snp_init_emb = 'enformer', 
113
                    go_init_emb = 'random',
114
                    gene_init_emb = 'esm', 
115
                    sample_edges = False, 
116
                    sample_ratio = 1):
117
118
        data_path = self.data_path
119
        
120
        ## Load KG
121
122
        print('--loading KG---')
123
        idx2id = load_dict(os.path.join(data_path, 'cell_kg/network/node_idx2id.pkl'))
124
        edge_index_all = load_dict(os.path.join(data_path, 'cell_kg/network/edge_index.pkl'))
125
        id2idx = load_dict(os.path.join(data_path, 'cell_kg/network/node_id2idx.pkl'))
126
        self.id2idx = id2idx
127
        self.idx2id = idx2id
128
        
129
        data = HeteroData()
130
131
        ## Load initialized embeddings
132
        
133
        if snp_init_emb == 'random':
134
            print('--using random SNP embedding--')
135
136
            data['SNP'].x = torch.rand((len(idx2id['SNP']), 128), requires_grad = False)
137
            snp_init_dim_size = 128
138
        elif snp_init_emb == 'kg':
139
            print('--using KG SNP embedding--')
140
141
            id2idx_kg = load_dict(os.path.join(data_path,  'cell_kg/node_emb/transe_emb/transe_emb_id2idx_kg.pkl'))
142
            kg_emb = load_dict(os.path.join(data_path,  'cell_kg/node_emb/transe_emb/transe_emb_inverse_triplets.pkl'))
143
            node_map = idx2id['SNP']
144
            data['SNP'].x = torch.vstack([torch.tensor(kg_emb[id2idx_kg[node_map[i]]]) if node_map[i] in id2idx_kg \
145
                                              else torch.rand(50, requires_grad = False) for i in range(len(node_map))])
146
            snp_init_dim_size = 50
147
148
        elif snp_init_emb == 'cadd':
149
            print('--using CADD SNP embedding--')
150
151
            df_variant = pd.read_csv(os.path.join(data_path, 'cell_kg/node_emb/variant_emb/cadd_feat.csv'))
152
            df_variant = df_variant.set_index('Unnamed: 0')
153
            variant_feat = df_variant.values
154
            node_map = idx2id['SNP']
155
            rs2idx_feat = dict(zip(df_variant.index.values, range(len(df_variant.index.values)))) 
156
            data['SNP'].x = torch.vstack([torch.tensor(variant_feat[rs2idx_feat[node_map[i]]]) if node_map[i] in rs2idx_feat \
157
                                                  else torch.rand(64, requires_grad = False) for i in range(len(node_map))]).float()
158
            snp_init_dim_size = 64
159
160
161
        elif snp_init_emb == 'baselineLD': 
162
            print('--using baselineLD SNP embedding--')
163
            node_map = idx2id['SNP']
164
            rs2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/variant_emb/baselineld_feat.pkl'))
165
            data['SNP'].x = torch.vstack([torch.tensor(rs2idx_feat[node_map[i]]) if node_map[i] in rs2idx_feat \
166
                                                  else torch.rand(70, requires_grad = False) for i in range(len(node_map))]).float()
167
            snp_init_dim_size = 70
168
169
        elif snp_init_emb == 'SLDSC': 
170
            print('--using SLDSC SNP embedding--')
171
            node_map = idx2id['SNP']
172
            rs2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/variant_emb/sldsc_feat.pkl'))
173
            data['SNP'].x = torch.vstack([torch.tensor(rs2idx_feat[node_map[i]]) if node_map[i] in rs2idx_feat \
174
                                                  else torch.rand(165, requires_grad = False) for i in range(len(node_map))]).float()
175
            snp_init_dim_size = 165 
176
        
177
        elif snp_init_emb == 'enformer':
178
            print('--using enformer SNP embedding--')
179
            node_map = idx2id['SNP']
180
            rs2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/variant_emb/enformer_feat.pkl'))
181
            data['SNP'].x = torch.vstack([torch.tensor(rs2idx_feat[node_map[i]]) if node_map[i] in rs2idx_feat \
182
                                                  else torch.rand(20, requires_grad = False) for i in range(len(node_map))]).float()
183
            snp_init_dim_size = 20 
184
        
185
        
186
        if go_init_emb == 'random': 
187
            print('--using random go embedding--')
188
189
            for rel in ['CellularComponent', 'BiologicalProcess', 'MolecularFunction']:
190
                data[rel].x = torch.rand((len(idx2id[rel]), 128), requires_grad = False)
191
            go_init_dim_size = 128
192
        elif go_init_emb == 'kg':
193
            print('--using KG go embedding--')
194
195
            id2idx_kg = load_dict(os.path.join(data_path, 'cell_kg/node_emb/transe_emb/transe_emb_id2idx_kg.pkl'))
196
            kg_emb = load_dict(os.path.join(data_path, 'cell_kg/node_emb/transe_emb/transe_emb_inverse_triplets.pkl'))
197
198
            for rel in ['CellularComponent', 'BiologicalProcess', 'MolecularFunction']:
199
                node_map = idx2id[rel]
200
                data[rel].x = torch.vstack([torch.tensor(kg_emb[id2idx_kg[node_map[i]]]) if node_map[i] in id2idx_kg \
201
                                              else torch.rand(50, requires_grad = False) for i in range(len(node_map))])
202
            go_init_dim_size = 50
203
204
        elif go_init_emb == 'biogpt':
205
            print('--using biogpt go embedding--')
206
207
            go2idx_feat = load_dict(os.path.join(data_path,  'cell_kg/node_emb/program_emb/biogpt_feat.pkl'))
208
            for rel in ['CellularComponent', 'BiologicalProcess', 'MolecularFunction']:
209
                node_map = idx2id[rel]
210
                data[rel].x = torch.vstack([torch.tensor(go2idx_feat[node_map[i]]) if node_map[i] in go2idx_feat \
211
                                                  else torch.rand(1600, requires_grad = False) for i in range(len(node_map))]).float()
212
            go_init_dim_size = 1600
213
214
215
        if gene_init_emb == 'random':   
216
            print('--using random gene embedding--')
217
218
            data['Gene'].x = torch.rand((len(idx2id['Gene']), 128), requires_grad = False)
219
            gene_init_dim_size = 128
220
        elif gene_init_emb == 'kg':
221
            print('--using KG gene embedding--')
222
            id2idx_kg = load_dict(os.path.join(data_path, 'cell_kg/node_emb/transe_emb/transe_emb_id2idx_kg.pkl'))
223
            kg_emb = load_dict(os.path.join(data_path, 'cell_kg/node_emb/transe_emb/transe_emb_inverse_triplets.pkl'))
224
            node_map = idx2id['Gene']
225
            data['Gene'].x = torch.vstack([torch.tensor(kg_emb[id2idx_kg[node_map[i]]]) if node_map[i] in id2idx_kg \
226
                                          else torch.rand(50, requires_grad = False) for i in range(len(node_map))])
227
            gene_init_dim_size = 50
228
229
        elif gene_init_emb == 'esm':
230
            print('--using ESM gene embedding--')
231
232
            gene2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/gene_emb/esm_feat.pkl'))
233
            node_map = idx2id['Gene']
234
            data['Gene'].x = torch.vstack([torch.tensor(gene2idx_feat[node_map[i]]) if node_map[i] in gene2idx_feat \
235
                                              else torch.rand(5120, requires_grad = False) for i in range(len(node_map))]).float()
236
            gene_init_dim_size = 5120
237
        elif gene_init_emb == 'pops':
238
            print('--using PoPs expression+PPI+pathways gene embedding--')
239
240
            gene2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/gene_emb/pops_feat.pkl'))
241
            node_map = idx2id['Gene']
242
            data['Gene'].x = torch.vstack([torch.tensor(gene2idx_feat[node_map[i]]) if node_map[i] in gene2idx_feat \
243
                                              else torch.rand(57742, requires_grad = False) for i in range(len(node_map))]).float()
244
            gene_init_dim_size = 57742
245
        elif gene_init_emb == 'pops_expression':
246
            print('--using PoPs expression only gene embedding--')
247
248
            gene2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/gene_emb/pops_expression_feat.pkl'))
249
            node_map = idx2id['Gene']
250
            data['Gene'].x = torch.vstack([torch.tensor(gene2idx_feat[node_map[i]]) if node_map[i] in gene2idx_feat \
251
                                              else torch.rand(40546, requires_grad = False) for i in range(len(node_map))]).float()
252
            gene_init_dim_size = 40546    
253
        
254
        
255
        self.gene_init_dim_size = gene_init_dim_size
256
        self.go_init_dim_size = go_init_dim_size
257
        self.snp_init_dim_size = snp_init_dim_size
258
        
259
        for i,j in edge_index_all.items():
260
            
261
            if sample_edges:
262
                edge_index = torch.tensor(j)
263
                num_edges = edge_index.size(1)
264
                num_samples = int(num_edges * sample_ratio)
265
                indices = torch.randperm(num_edges)[:num_samples]
266
                sampled_edge_index = edge_index[:, indices]
267
                print(i, ' sampling ratio ', sample_ratio, ' from ', edge_index.shape[1], ' to ', sampled_edge_index.shape[1])
268
                data[i].edge_index = sampled_edge_index
269
            else:
270
                data[i].edge_index = torch.tensor(j)
271
        data = T.ToUndirected()(data)
272
        data = T.AddSelfLoops()(data)
273
        self.data = data
274
275
    def load_simulation_gwas(self, simulation_type, seed):
276
        data_path = self.data_path
277
        print('Using simulation data....')
278
        small_cohort = 5000
279
        num_causal_hits = 20000
280
        heritability = 0.3
281
        self.sample_size = small_cohort
282
        if simulation_type == 'causal_link':
283
            lr_uni = pd.read_csv(os.path.join(data_path, 'simulation_gwas/causal_link_simulation/' + str(num_causal_hits) + '_' + str(seed) + '_' + str(heritability) + '_graph_funct_v2_ggi.fastGWA'), sep = '\t')
284
        elif simulation_type == 'causal':
285
            lr_uni = pd.read_csv(os.path.join(data_path, 'simulation_gwas/causal_simulation/' + str(num_causal_hits) + '_' + str(seed) + '_' + str(heritability) + '_' + str(small_cohort) + '_graph_funct_v2.fastGWA'), sep = '\t')
286
        elif simulation_type == 'null':
287
            lr_uni = pd.read_csv(os.path.join(data_path, 'simulation_gwas/null_simulation/' + str(num_causal_hits) + '_' + str(seed) + '_' + str(heritability) + '_' + str(small_cohort) + '.fastGWA'), sep = '\t')
288
           
289
        if ('SNP' in lr_uni.columns.values) and ('ID' in lr_uni.columns.values):
290
            self.lr_uni = lr_uni.rename(columns = {'CHR': '#CHROM'})
291
        else:
292
            self.lr_uni = lr_uni.rename(columns = {'CHR': '#CHROM', 'SNP': 'ID'})
293
        self.seed = seed
294
        self.pheno = 'simulation'
295
    
296
    def load_external_gwas(self, path = None, seed = 42, example_file = False):
297
        if example_file:
298
            print('Loading example GWAS file...')
299
            url = "https://dataverse.harvard.edu/api/access/datafile/10730346"
300
            example_file_path = os.path.join(self.data_path, 'biochemistry_Creatinine_fastgwa_full_10000_1.fastGWA')
301
302
            # Check if the example file is already downloaded
303
            if not os.path.exists(example_file_path):
304
                print('Example file not found locally. Downloading...')
305
                self._download_with_progress(url, example_file_path)
306
                print('Example file downloaded successfully.')
307
            else:
308
                print('Example file already exists locally.')
309
310
            path = example_file_path
311
312
        if path is None:
313
            raise ValueError("A valid path must be provided or example_file must be set to True.")
314
315
        print(f'Loading GWAS file from {path}...')
316
            
317
        lr_uni = pd.read_csv(path, sep=None, engine='python')
318
        if 'CHR' not in lr_uni.columns.values:
319
            raise ValueError('CHR chromosome not in the file!')
320
        if 'SNP' not in lr_uni.columns.values:
321
            raise ValueError('SNP column not in the file!')
322
        if 'P' not in lr_uni.columns.values:
323
            raise ValueError('P column not in the file!')  
324
        if 'N' not in lr_uni.columns.values:
325
            raise ValueError('N column number of sample size not in the file!')  
326
        lr_uni = lr_uni.rename(columns = {'CHR': '#CHROM', 'SNP': 'ID'})
327
        
328
        ## filtering to the current KG variant set
329
        old_variant_set_len = len(lr_uni)
330
        lr_uni = lr_uni[lr_uni.ID.isin(list(self.idx2id['SNP'].values()))]
331
        print('Number of SNPs in the KG:', len(self.idx2id['SNP']))
332
        print('Number of SNPs in the GWAS:', old_variant_set_len)
333
        print('Number of SNPs in the KG variant set:', len(lr_uni))
334
335
        self.lr_uni = lr_uni
336
        self.sample_size = lr_uni.N.values[0]
337
        self.pheno = 'EXTERNAL'
338
        self.seed = seed
339
        
340
        
341
    def load_full_gwas(self, pheno, seed=42):
342
        data_path = self.data_path
343
        if pheno in scdrs_traits:
344
            print('Using scdrs traits...')
345
            self.pheno = pheno
346
            lr_uni = pd.read_csv(os.path.join(data_path, 'scDRS_Data/sumstats_ukb_snps.csv'))
347
            lr_uni = lr_uni[['CHR', 'SNP', 'POS', 'A1', 'A2', 'N', 'AF1', pheno]]
348
            lr_uni = lr_uni[lr_uni[pheno].notnull()].reset_index(drop = True)
349
            lr_uni = lr_uni.rename(columns = {'CHR': '#CHROM', 'SNP': 'ID', pheno: 'chi'})
350
            print('number of SNPs:', len(lr_uni))
351
            self.lr_uni = lr_uni
352
            self.seed = seed
353
            
354
            trait2size = pickle.load(open(os.path.join(data_path, 'scDRS_data/trait2size.pkl'), 'rb'))
355
            self.sample_size = trait2size[pheno]
356
            
357
        else:
358
            ## load GWAS files
359
            self.pheno = pheno
360
            lr_uni = pd.read_csv(os.path.join(data_path, 'full_gwas/' + str(self.pheno) + '_with_rel_fastgwa.fastGWA'), sep = '\t')
361
            lr_uni = lr_uni.rename(columns = {'CHR': '#CHROM', 'SNP': 'ID'})
362
363
            self.lr_uni = lr_uni
364
            self.seed = seed
365
            self.sample_size = 387113
366
    
367
    def load_gwas_subsample(self, pheno, sample_size, seed):
368
        data_path = self.data_path
369
        if pheno in ['body_BALDING1', 'cancer_BREAST', 'disease_ALLERGY_ECZEMA_DIAGNOSED', 'disease_HYPOTHYROIDISM_SELF_REP', 'other_MORNINGPERSON', 'pigment_SUNBURN']:
370
            binary = True
371
        else:
372
            binary = False
373
        ## load GWAS files
374
        self.sample_size = sample_size
375
        self.pheno = pheno
376
        if (sample_size > 3000):
377
            lr_uni = pd.read_csv(os.path.join(data_path, 'subsample_gwas/' + str(self.pheno) + \
378
                     '_fastgwa_full_'+ str(sample_size) + '_' + str(seed) + '.fastGWA'), sep = '\t')
379
            lr_uni = lr_uni.rename(columns = {'CHR': '#CHROM', 'SNP': 'ID'})
380
        else:
381
            ## use PLINK if sample size <3000
382
            if binary:
383
                lr_uni = pd.read_csv(os.path.join(data_path, 'subsample_gwas/' + str(self.pheno) + \
384
                         '_plink_'+ str(sample_size) + '_' + str(seed) + '.PHENO1.glm.logistic.hybrid'), sep = '\t')
385
            else:
386
                lr_uni = pd.read_csv(os.path.join(data_path, 'subsample_gwas/' + + str(self.pheno) + \
387
                         '_plink_'+ str(sample_size) + '_' + str(seed) + '.PHENO1.glm.linear'), sep = '\t')
388
        self.lr_uni = lr_uni
389
        self.seed = seed
390
391
    def process_gwas_file(self, label = 'chi'):
392
        data_path = self.data_path
393
        lr_uni = self.lr_uni
394
        ## LD scores
395
396
        ld_scores = pd.read_csv(os.path.join(data_path, 'ld_score/filter_genotyped_ldscores.csv'))
397
        w_ld_scores = pd.read_csv(os.path.join(data_path, 'ld_score/ldscores_from_data.csv'))
398
399
        m = 15000000
400
        if 'N' not in lr_uni.columns.values:
401
            n = self.sample_size
402
        else:
403
            n = np.mean(lr_uni.N)
404
        h_g_2 = 0.5
405
        rs_id_2_ld_scores = dict(ld_scores.values)
406
407
        rs_id_2_ld_scores = dict(ld_scores.values)
408
        rs_id_2_w_ld = dict(w_ld_scores.values)
409
410
        ## use min ld score for snps with no ld score
411
        min_ld = min(rs_id_2_ld_scores.values())
412
        lr_uni['ld_score'] = lr_uni.ID.apply(lambda x: rs_id_2_ld_scores[x] if x in rs_id_2_ld_scores else min_ld)
413
        rs_id_2_ld_scores = dict(lr_uni[['ID', 'ld_score']].values)
414
415
        min_ld = min(rs_id_2_w_ld.values())
416
        ## the data LD is without the query SNP itself. so here add 1 
417
        lr_uni['w_ld_score'] = 1 + lr_uni.ID.apply(lambda x: rs_id_2_w_ld[x] if x in rs_id_2_w_ld else min_ld)
418
        rs_id_2_w_ld = dict(lr_uni[['ID', 'w_ld_score']].values)
419
420
        print('Using ldsc weight...')
421
        ld = np.array([rs_id_2_ld_scores[rs_id] for rs_id in lr_uni.ID.values])
422
        w_ld = np.array([rs_id_2_w_ld[rs_id] for rs_id in lr_uni.ID.values])
423
424
        ldsc_weight = ldsc_regression_weights(ld, w_ld, n, m, h_g_2)
425
        ldsc_weight = ldsc_weight/np.mean(ldsc_weight)
426
        print('ldsc_weight mean: ', np.mean(ldsc_weight))
427
        self.rs_id_to_ldsc_weight = dict(zip(lr_uni.ID.values, ldsc_weight))
428
429
        ## chi-square label
430
        if label == 'chi':
431
            if 'chi' in lr_uni.columns.values:
432
                print('chi pre-computed...')
433
                lr_uni['y'] = lr_uni['chi'].values            
434
            else:    
435
                if self.pheno in (['body_BALDING1', 'cancer_BREAST', 'disease_ALLERGY_ECZEMA_DIAGNOSED', 'disease_HYPOTHYROIDISM_SELF_REP', 'other_MORNINGPERSON', 'pigment_SUNBURN']) and (self.sample_size <= 3000):
436
                    lr_uni['y'] = lr_uni['Z_STAT'].values**2
437
                    lr_uni['y'] = lr_uni.y.fillna(0)   
438
                else:
439
                    if ('BETA' in lr_uni.columns.values) and ('SE' in lr_uni.columns.values):
440
                        lr_uni['y'] = (lr_uni['BETA']/lr_uni['SE']).values**2
441
                        lr_uni['y'] = lr_uni.y.fillna(0)   
442
                    else:
443
                        from scipy.stats import chi2
444
                        ## convert from p-values
445
                        lr_uni['y'] = chi2.ppf(1 - lr_uni['P'].values, 1)
446
                        lr_uni['y'] = lr_uni.y.fillna(0)
447
448
449
        elif label == 'residual-w-ld':
450
            lr_uni['y'] = (lr_uni['BETA']/lr_uni['SE']).values**2
451
            lr_uni['y'] = lr_uni.y.fillna(0)   
452
            lr_uni['ld_weight'] = lr_uni.ID.apply(lambda x: self.rs_id_to_ldsc_weight[x])
453
            import statsmodels.api as sm
454
455
            X = lr_uni.w_ld_score.values
456
            y = lr_uni.y.values
457
            weights = lr_uni.ld_weight.values
458
            X = sm.add_constant(X)
459
            model = sm.WLS(y, X, weights=weights)
460
            results = model.fit()
461
            y_pred = results.params[0] + results.params[1] * lr_uni.w_ld_score.values
462
            lr_uni['y'] = y - y_pred 
463
        elif label == 'residual-ld':
464
            lr_uni['y'] = (lr_uni['BETA']/lr_uni['SE']).values**2
465
            lr_uni['y'] = lr_uni.y.fillna(0)   
466
            lr_uni['ld_weight'] = lr_uni.ID.apply(lambda x: self.rs_id_to_ldsc_weight[x])
467
            import statsmodels.api as sm
468
469
            X = lr_uni.ld_score.values
470
            y = lr_uni.y.values
471
            weights = lr_uni.ld_weight.values
472
            X = sm.add_constant(X)
473
            model = sm.WLS(y, X, weights=weights)
474
            results = model.fit()
475
            y_pred = results.params[0] + results.params[1] * lr_uni.w_ld_score.values
476
            lr_uni['y'] = y - y_pred         
477
        elif label == 'residual-ld-ols':
478
            lr_uni['y'] = (lr_uni['BETA']/lr_uni['SE']).values**2
479
            lr_uni['y'] = lr_uni.y.fillna(0)   
480
            import statsmodels.api as sm
481
482
            X = lr_uni.ld_score.values
483
            y = lr_uni.y.values
484
            X = sm.add_constant(X)
485
            model = sm.OLS(y, X)
486
            results = model.fit()
487
            y_pred = results.params[0] + results.params[1] * lr_uni.w_ld_score.values
488
            lr_uni['y'] = y - y_pred 
489
        elif label == 'residual-ld-ols-abs':
490
            lr_uni['y'] = (lr_uni['BETA']/lr_uni['SE']).values**2
491
            lr_uni['y'] = lr_uni.y.fillna(0)   
492
            import statsmodels.api as sm
493
494
            X = lr_uni.ld_score.values
495
            y = lr_uni.y.values
496
            X = sm.add_constant(X)
497
            model = sm.OLS(y, X)
498
            results = model.fit()
499
            y_pred = results.params[0] + results.params[1] * lr_uni.w_ld_score.values
500
            lr_uni['y'] = np.abs(y - y_pred)
501
        elif label == 'residual-w-ld-ols':
502
            lr_uni['y'] = (lr_uni['BETA']/lr_uni['SE']).values**2
503
            lr_uni['y'] = lr_uni.y.fillna(0)   
504
            import statsmodels.api as sm
505
506
            X = lr_uni.w_ld_score.values
507
            y = lr_uni.y.values
508
            X = sm.add_constant(X)
509
            model = sm.OLS(y, X)
510
            results = model.fit()
511
            y_pred = results.params[0] + results.params[1] * lr_uni.w_ld_score.values
512
            lr_uni['y'] = y - y_pred     
513
            
514
        id2y = dict(lr_uni[['ID', 'y']].values)
515
        all_ids = lr_uni.ID.values
516
        self.all_ids = np.array([self.id2idx['SNP'][i] for i in all_ids])
517
        self.y = lr_uni.y.values
518
        #idx2y = dict(zip(self.all_ids, y))
519
520
        self.lr_uni = lr_uni
521
522
    def prepare_split(self, test_set_fraction_data = 0.05):
523
524
        ## split SNPs to train/test/valid
525
        train_val_ids, test_ids, y_train_val, y_test = train_test_split(self.all_ids, self.y, test_size=test_set_fraction_data, random_state=self.seed)
526
        train_ids, val_ids, y_train, y_val = train_test_split(train_val_ids, y_train_val, test_size=0.05, random_state=self.seed)
527
528
        self.train_input_nodes = ('SNP', train_ids)
529
        self.val_input_nodes = ('SNP', val_ids)
530
        self.test_input_nodes = ('SNP', test_ids)
531
532
        y_snp = torch.zeros(self.data['SNP'].x.shape[0]) - 1
533
        y_snp[train_ids] = torch.tensor(y_train).float()
534
        y_snp[val_ids] = torch.tensor(y_val).float()
535
        y_snp[test_ids] = torch.tensor(y_test).float()
536
537
        self.data['SNP'].y = y_snp
538
        for i in self.data.node_types:
539
            self.data[i].n_id = torch.arange(self.data[i].x.shape[0])
540
541
        self.data.train_mask = train_ids
542
        self.data.val_mask = val_ids
543
        self.data.test_mask = test_ids
544
        self.data.all_mask = self.all_ids
545
        #data = data.to(args.device)
546
547
    def get_pheno_list(self):
548
        return {"large_cohort": scdrs_traits, 
549
        "21_indep_traits": ['body_BALDING1',
550
           'disease_ALLERGY_ECZEMA_DIAGNOSED',
551
           'disease_HYPOTHYROIDISM_SELF_REP', 'pigment_SUNBURN', 
552
            '21001', '50', '30080', '30070', '30010', '30000', 
553
            'biochemistry_AlkalinePhosphatase',
554
           'biochemistry_AspartateAminotransferase',
555
           'biochemistry_Cholesterol', 'biochemistry_Creatinine',
556
           'biochemistry_IGF1', 'biochemistry_Phosphate',
557
           'biochemistry_Testosterone_Male', 'biochemistry_TotalBilirubin',
558
           'biochemistry_TotalProtein', 'biochemistry_VitaminD',
559
           'bmd_HEEL_TSCOREz']}