Diff of /kgwas/data.py [000000] .. [8790ab]

Switch to unified view

a b/kgwas/data.py
1
import os, sys
2
from scipy.sparse import csr_matrix
3
from scipy.sparse.csgraph import connected_components
4
from sklearn.preprocessing import OneHotEncoder
5
import pandas as pd
6
import numpy as np
7
import pickle
8
from tqdm import tqdm
9
from torch.utils.data import Dataset, DataLoader
10
from torch.utils.data.sampler import WeightedRandomSampler
11
import torch
12
from .params import main_data_path, cohort_data_path, kinship_path, withdraw_path, fam_path
13
from .utils import get_fields, get_row_last_values, remove_kinships, save_dict, load_dict, print_sys
14
15
16
class ukbb_cohort:
17
    
18
    def __init__(self, main_data_path, cohort_data_path, withdraw_path, keep_relatives = False):
19
        self.keep_relatives = keep_relatives
20
        self.cohort_data_path = cohort_data_path
21
        self.main_data_path = main_data_path
22
        
23
        if keep_relatives:
24
            cohort_path = os.path.join(cohort_data_path, 'cohort_with_relatives.pkl')
25
        else:
26
            cohort_path = os.path.join(cohort_data_path, 'cohort_no_relatives.pkl')
27
        
28
        if not os.path.exists(cohort_path):
29
            print_sys('construct from scratch...')
30
            '''
31
            exclusions:
32
            the original uk biobank nature paper supplementary S 3.4:
33
            22006: Genetic ethnic grouping -> to retain only white british ancestry
34
35
            https://www.frontiersin.org/articles/10.3389/fgene.2022.866042/full
36
            22018: genetic relatedness exclusions
37
            22019: sex chromosome aneuploidy 
38
            31 <-> 22001: mismatch between self-reported sex and genetically determined sex 
39
            22010: recommended genomic analysis exclusions, signs of insufficient data quality
40
41
            (optional) further remove relatives based on KING relative scores, choose the first one in relative group
42
            remove the list of eids who do not want to be in the study anymore            
43
            '''
44
            
45
            all_field_ids = [22006, 22018, 22019, 22001, 22010, 31]
46
            df_main = get_fields(all_field_ids, main_data_path)
47
            cur_size = len(df_main)
48
            print_sys('Total sample size: ' + str(cur_size))
49
            df_main = df_main[df_main['22006-0.0'] == 1]
50
            print_sys('Keeping only white british ancestry (ID: 22006), cutting from ' + str(cur_size) + ' to ' + str(len(df_main)))
51
            cur_size = len(df_main)
52
53
            df_main = df_main[df_main['22018-0.0'].isnull()]
54
            print_sys('Removing genetics related samples (ID: 22018), cutting from ' + str(cur_size) + ' to ' + str(len(df_main)))
55
            cur_size = len(df_main)
56
57
            df_main = df_main[df_main['22019-0.0'].isnull()]
58
            print_sys('Removing sex chromosome aneuploidy (ID: 22019), cutting from ' + str(cur_size) + ' to ' + str(len(df_main)))
59
            cur_size = len(df_main)
60
61
            df_main = df_main[df_main['31-0.0'] == df_main['22001-0.0']]
62
            print_sys('Removing samples with mismatched self-reported sex and genetic determined sex (ID: 31 <-> 22001), cutting from ' + str(cur_size) + ' to ' + str(len(df_main)))
63
            cur_size = len(df_main)
64
65
            df_main = df_main[df_main['22010-0.0'].isnull()]
66
            print_sys('Removing samples with genomic data quality (ID: 22010), cutting from ' + str(cur_size) + ' to ' + str(len(df_main)))
67
            cur_size = len(df_main)
68
            
69
            save_dict(os.path.join(cohort_data_path, 'cohort_with_relatives.pkl'), df_main.eid.values)
70
            
71
            kinship_mask = remove_kinships(df_main.eid)
72
            df_main = df_main[kinship_mask]
73
            save_dict(os.path.join(cohort_data_path, 'cohort_no_relatives.pkl'), df_main.eid.values)
74
        else:
75
            print_sys('Found local copy...')
76
            
77
        self.cohort = load_dict(cohort_path)
78
        print_sys('There are ' + str(len(self.cohort)) + ' samples!')
79
        
80
        if keep_relatives:
81
            self.no_rel_eid = load_dict(os.path.join(cohort_data_path, 'cohort_no_relatives.pkl'))
82
            
83
        if os.path.exists(withdraw_path):
84
            ## todo: when there is a withdraw file, implement this...
85
            pass
86
        
87
    def get_covariates(self, to_plink = False, plink_num_pca = 15, return_full = False, plink_filter = False):
88
        '''
89
        covariates:
90
91
        31: sex
92
        21003: Age when attended assessment centre
93
        22009: pca
94
        54: assessment center
95
        batch from params.fam_path file
96
        '''
97
        covar_path = os.path.join(self.cohort_data_path, 'covariates_all.pkl')
98
        if os.path.exists(covar_path):
99
            print_sys('Found local copy...')
100
            self.covar = load_dict(covar_path)
101
        else:
102
            print_sys('construct co-variates from scratch...')
103
            df_covar = get_fields([31, 54, 21003, 22009], self.main_data_path)
104
            column_name_map = {'22009-0.' + str(i): 'pca ' + str(i) for i in range(1, 41)}
105
            column_name_map['31-0.0'] = 'sex'
106
            column_name_map['21003-0.0'] = 'age'
107
            column_name_map['54-0.0'] = 'assessment_center'
108
            self.covar = df_covar.rename(columns = column_name_map)
109
            
110
            enc = OneHotEncoder(handle_unknown='ignore')
111
            enc.fit(self.covar['assessment_center'].unique().reshape(-1,1))
112
            center_array = enc.transform(self.covar['assessment_center'].values.reshape(-1,1)).toarray()
113
            center_one_hot = pd.DataFrame(center_array).astype('int').rename(columns = dict(zip(range(22), ['center_' + str(i) for i in range(22)])))
114
            
115
            self.covar = self.covar.drop(['21003-1.0', '21003-2.0', '21003-3.0', 'assessment_center', '54-1.0', '54-2.0', '54-3.0'], axis = 1)
116
            self.covar = self.covar.join(center_one_hot)
117
            
118
            df_fam = pd.read_csv(fam_path)
119
            enc = OneHotEncoder(handle_unknown='ignore')
120
            enc.fit(df_fam.trait.unique().reshape(-1,1))
121
            batch_one_hot = enc.transform(df_fam['trait'].values.reshape(-1,1)).toarray()
122
            batch_num = batch_one_hot.shape[1]
123
            id2batch = dict(zip(df_fam.fid.values, batch_one_hot.astype(int)))
124
            df_batch = pd.DataFrame(np.stack(self.covar['eid'].apply(lambda x: id2batch[x] if x in id2batch else np.zeros(batch_one_hot.shape[1]).astype(int)).values)).rename(columns = dict(zip(range(batch_num), ['batch_' + str(i) for i in range(batch_num)])))
125
            self.covar = self.covar.join(df_batch)
126
            
127
            save_dict(covar_path, self.covar)
128
            print_sys('Done! Saving...')
129
            
130
        if not to_plink:
131
            if return_full:
132
                return self.covar.reset_index(drop = True)
133
            else:
134
                return self.covar[self.covar.eid.isin(self.cohort)].reset_index(drop = True)
135
        else:
136
            plink_path = os.path.join(self.cohort_data_path, 'covar_pca' + str(plink_num_pca) + '_all_real_value')
137
            if plink_filter:
138
                plink_path += '_null_removed'
139
            plink_path += '.txt'
140
            
141
            if not os.path.exists(plink_path):
142
                pca_columns = [i for i in self.covar.columns.values if (i[:3]=='pca') and int(i.split()[-1]) <= plink_num_pca]
143
                #center_one_hot_columns = ['center_' + str(i) for i in range(22)]
144
                #batch_columns = ['batch_' + str(i) for i in range(batch_num)]
145
                #self.covar[['eid', 'eid', 'age', 'sex'] + pca_columns + center_one_hot_columns + batch_columns].to_csv(plink_path, header=None, index=None, sep=' ')
146
                center = np.argmax(self.covar.loc[:, self.covar.columns.str.contains('center')].values, axis = 1)
147
                batch = np.argmax(self.covar.loc[:, self.covar.columns.str.contains('batch')].values, axis = 1)
148
                self.covar = self.covar.iloc[:, :43]
149
                self.covar['assessment_center'] = center
150
                self.covar['batch'] = batch
151
                if plink_filter:
152
                    self.covar = self.covar[self.covar.eid.isin(self.cohort)].reset_index(drop = True)
153
                self.covar[['eid', 'eid', 'age', 'sex', 'assessment_center', 'batch'] + pca_columns].to_csv(plink_path, header=None, index=None, sep=' ')
154
            self.covar_plink = pd.read_csv(plink_path, header = None, sep = ' ')
155
            return self.covar_plink
156
    
157
    def get_external_traits(self, trait_name, to_plink = False, to_str = True, random_seed = 42, sep_cohort = False, randomize = False, use_sample_size = False, sample_size = -1, randomize_seed = 42):
158
        '''
159
        example:
160
        standing heights: 50
161
        '''
162
        if trait_name in ['body_BALDING1', 'cancer_BREAST', 'disease_ALLERGY_ECZEMA_DIAGNOSED', 'disease_HYPOTHYROIDISM_SELF_REP', 'other_MORNINGPERSON', 'pigment_SUNBURN']:
163
            trait_type = 'binary'
164
        else:
165
            trait_type = 'continuous'
166
167
        pheno_path = os.path.join(self.cohort_data_path, str(trait_name) + '_pheno.pkl')
168
        if os.path.exists(pheno_path):
169
            print_sys('Found local copy...')
170
            self.pheno = load_dict(pheno_path)
171
        else:
172
            print_sys('construct phenotype from scratch...')
173
            
174
            self.pheno = pd.read_csv(os.path.join(data_path, 'full_gwas', trait_name+'_'+trait_type+'.csv'))
175
            self.pheno['eid'] = self.pheno.eid.astype('int')
176
            self.pheno = self.pheno[self.pheno['pheno'].notnull()]
177
            if trait_type == 'binary':
178
                self.pheno['pheno'] += 1
179
                self.pheno['pheno'] = self.pheno['pheno'].astype(int)
180
            save_dict(pheno_path, self.pheno)
181
            print_sys('Done! Saving...')
182
            
183
            
184
        
185
        # filtering to cohorts incl. with/without relatives            
186
        self.pheno = self.pheno[self.pheno.eid.isin(self.cohort)].reset_index(drop = True)
187
            
188
        if to_str:
189
            self.pheno['eid'] = self.pheno['eid'].astype('str')
190
        if not to_plink:
191
            return self.pheno
192
        else:
193
            plink_path = os.path.join(self.cohort_data_path, str(trait_name) + '_plink')
194
            if self.keep_relatives:
195
                plink_path = plink_path + '_with_relatives'
196
            else:
197
                plink_path = plink_path + '_no_relatives'
198
                
199
            if use_sample_size:
200
                plink_path = plink_path + '_' + str(sample_size) + '_' + str(random_seed)
201
            
202
            if sep_cohort:
203
                plink_path += '_sep_cohort'
204
             
205
            if randomize:
206
                plink_path += '_randomize' + str(randomize_seed)
207
            
208
            plink_path = plink_path + '.txt'
209
            
210
            if randomize:
211
                self.pheno['pheno'] = self.pheno['pheno'].sample(frac = 1, random_state = randomize_seed).values
212
            
213
            if use_sample_size:
214
                from sklearn.model_selection import train_test_split
215
                print('random_seed:', random_seed)
216
                pheno_shuffle = self.pheno.sample(frac = 1, random_state = random_seed)
217
                all_ids, y = pheno_shuffle.eid.values, pheno_shuffle['pheno'].values
218
                train_val_ids, test_ids, y_train_val, y_test = all_ids[:sample_size], all_ids[sample_size:], y[:sample_size], y[sample_size:]
219
                if sep_cohort:
220
                    self.pheno = self.pheno[self.pheno.eid.isin(test_ids)]
221
                else:
222
                    self.pheno = self.pheno[self.pheno.eid.isin(train_val_ids)]
223
                        
224
            if not os.path.exists(plink_path):
225
                print_sys('Saving...')
226
                self.pheno[['eid', 'eid', self.pheno.columns.values[-1]]].to_csv(plink_path, header=None, index=None, sep=' ')
227
            else:
228
                print_sys('Already existed! Loading...')
229
            
230
            self.pheno_plink = pd.read_csv(plink_path, header = None, sep = ' ')
231
            return self.pheno_plink
232
    
233
    
234
    
235
    
236
    def get_phenotype(self, field_id, aggregate = 'last_value', to_plink = False, to_str = True, normalize = 'None', frac = 1, random_seed = 42, fastgwa_match = False, icd10 = False, icd10_level = 2, sep_cohort = False, randomize = False, use_sample_size = False, sample_size = -1, randomize_seed = 42):
237
        '''
238
        example:
239
        standing heights: 50
240
        '''
241
        pheno_path = os.path.join(self.cohort_data_path, str(field_id) + '_pheno.pkl')
242
        if os.path.exists(pheno_path):
243
            print_sys('Found local copy...')
244
            self.pheno = load_dict(pheno_path)
245
        else:
246
            print_sys('construct phenotype from scratch...')
247
            if icd10:
248
                ## field_id is icd10 level
249
                icd10_df = self.get_icd10(to_plink = True, level = icd10_level, get_all = True)
250
                self.pheno = icd10_df[['FID', field_id]].rename(columns = {'FID': 'eid'})
251
                self.pheno['eid'] = self.pheno.eid.astype('int')
252
            else:
253
                ## from raw data field id
254
                self.pheno = get_fields([field_id], self.main_data_path)
255
            save_dict(pheno_path, self.pheno)
256
            print_sys('Done! Saving...')
257
        
258
        if len(self.pheno.columns.values) > 2:
259
            print_sys('There are multiple index for this phenotype... aggregate...')
260
            if aggregate == 'last_value':
261
                print_sys('Getting the latest measure...')
262
                tmp = pd.DataFrame()
263
                tmp['eid'] = self.pheno.loc[:, 'eid']
264
                tmp[str(field_id)] = get_row_last_values(self.pheno.iloc[:, 1:])
265
                self.pheno = tmp
266
                print_sys('There are ' + str(len(self.pheno[self.pheno[str(field_id)].isnull()])) + ' samples with NaN values. Removing them ...')
267
                self.pheno = self.pheno[self.pheno[str(field_id)].notnull()]
268
                
269
        if fastgwa_match:
270
            # get the number of without relatives:
271
            if not self.keep_relatives:
272
                raise ValueError('If you turned fastgwa_match = True, then keep_relatives = True!')
273
            self.rel_ratio = len(self.pheno[self.pheno.eid.isin(self.no_rel_eid)])/len(self.pheno[self.pheno.eid.isin(self.cohort)])
274
        
275
        # filtering to cohorts incl. with/without relatives            
276
        self.pheno = self.pheno[self.pheno.eid.isin(self.cohort)].reset_index(drop = True)
277
278
        if normalize != 'None':
279
            y = self.pheno[str(field_id)].values
280
            if normalize == 'log':
281
                y = np.log(y)
282
            elif normalize == 'std':
283
                y = (y - np.mean(y))/np.std(y)
284
            elif normalize == 'quantile_normalization':
285
                from sklearn.preprocessing import quantile_transform
286
                y = quantile_transform(y.reshape(-1,1), output_distribution = 'normal', random_state = 42).reshape(-1)
287
            self.pheno[str(field_id)] = y 
288
            
289
        if to_str:
290
            self.pheno['eid'] = self.pheno['eid'].astype('str')
291
        if not to_plink:
292
            return self.pheno
293
        else:
294
            plink_path = os.path.join(self.cohort_data_path, str(field_id) + '_plink')
295
            if self.keep_relatives:
296
                plink_path = plink_path + '_with_relatives'
297
            else:
298
                plink_path = plink_path + '_no_relatives'
299
                
300
            if normalize != 'None':
301
                plink_path = plink_path + '_' + str(normalize)
302
            if use_sample_size:
303
                plink_path = plink_path + '_' + str(sample_size) + '_' + str(random_seed)
304
            else:
305
                if frac != 1:
306
                    plink_path = plink_path + '_' + str(frac) + '_' + str(random_seed)
307
                
308
            if fastgwa_match:
309
                plink_path += '_match'
310
                
311
            if sep_cohort:
312
                plink_path += '_sep_cohort'
313
             
314
            if randomize:
315
                plink_path += '_randomize' + str(randomize_seed)
316
                
317
            
318
            plink_path = plink_path + '.txt'
319
            
320
            if randomize:
321
                self.pheno[str(field_id)] = self.pheno[str(field_id)].sample(frac = 1, random_state = randomize_seed).values
322
            
323
            
324
            if use_sample_size:
325
                from sklearn.model_selection import train_test_split
326
                
327
                if icd10:
328
                    df_cases = self.pheno[self.pheno[str(field_id)] == 2]
329
                    df_cases_shuffle = df_cases.sample(frac = 1, random_state = random_seed)
330
                    all_ids, y = df_cases_shuffle.eid.values, df_cases_shuffle[str(field_id)].values
331
                    train_val_ids, test_ids, y_train_val, y_test = all_ids[:sample_size], all_ids[sample_size:], y[:sample_size], y[sample_size:]
332
                    train_val_ids = np.concatenate((train_val_ids, self.pheno[self.pheno[str(field_id)] == 1].eid.values))
333
                    self.pheno = self.pheno[self.pheno.eid.isin(train_val_ids)]
334
                    if sep_cohort:
335
                        raise NotImplementedError
336
                else:
337
                    print('random_seed', random_seed)
338
                    pheno_shuffle = self.pheno.sample(frac = 1, random_state = random_seed)
339
                    all_ids, y = pheno_shuffle.eid.values, pheno_shuffle[str(field_id)].values
340
                    train_val_ids, test_ids, y_train_val, y_test = all_ids[:sample_size], all_ids[sample_size:], y[:sample_size], y[sample_size:]
341
                    if fastgwa_match:
342
                        raise ValueError('Not used anymore...')
343
                    if sep_cohort:
344
                        self.pheno = self.pheno[self.pheno.eid.isin(test_ids)]
345
                    else:
346
                        self.pheno = self.pheno[self.pheno.eid.isin(train_val_ids)]
347
            else:
348
                if frac!=1:
349
                    from sklearn.model_selection import train_test_split
350
                    all_ids, y = self.pheno.eid.values, self.pheno[str(field_id)].values                
351
                    train_val_ids, test_ids, y_train_val, y_test = train_test_split(all_ids, y, test_size=frac, random_state=random_seed)
352
                    if fastgwa_match:
353
                        train_val_ids, test_ids, y_train_val, y_test = train_test_split(train_val_ids, y_train_val, test_size=1-self.rel_ratio, random_state=42)
354
                    if sep_cohort:
355
                        self.pheno = self.pheno[self.pheno.eid.isin(test_ids)]
356
                    else:
357
                        self.pheno = self.pheno[self.pheno.eid.isin(train_val_ids)]
358
                        
359
            if not os.path.exists(plink_path):
360
                self.pheno[['eid', 'eid', self.pheno.columns.values[-1]]].to_csv(plink_path, header=None, index=None, sep=' ')
361
            else:
362
                print_sys('Already existed! Loading...')
363
            
364
            self.pheno_plink = pd.read_csv(plink_path, header = None, sep = ' ')
365
            return self.pheno_plink
366
            
367
        
368
    def get_icd10(self, to_plink = False, level = 2, get_all = False):
369
        '''
370
        icd10: 41270
371
        '''
372
        pheno_path = os.path.join(self.cohort_data_path, 'icd10.pkl')
373
        level_str = 'level' + str(level)
374
        if os.path.exists(pheno_path):
375
            print_sys('Found local copy...')
376
            self.icd10 = load_dict(pheno_path)
377
        else:
378
            print_sys('construct from scratch...')
379
            icd10_raw_concat = get_fields([41270], self.main_data_path)
380
            icd10_columns = icd10_raw_concat.columns.values[1:]
381
            icd10_tuple = icd10_raw_concat.apply(lambda x: (x.eid, x[icd10_columns][x[icd10_columns].notnull()].values), axis = 1)
382
            icd10 = pd.DataFrame(list(icd10_tuple.values)).rename(columns = {0: 'eid', 1: 'level3'})
383
            icd10['level2'] = icd10['level3'].apply(lambda x: np.unique([i[:3] for i in x]))
384
            save_dict(pheno_path, icd10)
385
            print_sys('Done! Saving...')
386
            self.icd10 = icd10
387
        if get_all:
388
            self.pheno = self.icd10.reset_index(drop = True)
389
        else:
390
            self.pheno = self.icd10[self.icd10.eid.isin(self.cohort)].reset_index(drop = True)
391
        if not to_plink:
392
            return self.pheno
393
        else:            
394
            if self.keep_relatives or get_all:
395
                plink_path = os.path.join(self.cohort_data_path, 'icd10_plink_with_relatives_' + level_str + '.txt')
396
            else:
397
                plink_path = os.path.join(self.cohort_data_path, 'icd10_plink_no_relatives_' + level_str + '.txt')
398
            
399
            if os.path.exists(plink_path):
400
                print_sys("Found local copy...")
401
                self.icd10_plink = pd.read_csv(plink_path, sep=' ')
402
            else:
403
                print_sys('transforming to plink files... takes around 1 min...')
404
                unique_icd10 = np.unique([item for sublist in self.pheno[level_str].values for item in sublist])
405
                icd10_2_idx = dict(zip(unique_icd10, range(len(unique_icd10))))
406
                idx_2_icd10 = dict(zip(range(len(unique_icd10)), unique_icd10))
407
408
                self.pheno[level_str + '_idx'] = self.pheno[level_str].apply(lambda x: [icd10_2_idx[i] for i in x])
409
410
                tmp = np.zeros((len(self.pheno), len(unique_icd10)), dtype=np.int8)
411
                for idx, i in enumerate(self.pheno[level_str + '_idx'].values):
412
                    tmp[idx, i] = 1
413
414
                icd10_plink = pd.DataFrame(tmp).rename(columns = idx_2_icd10)
415
                icd102sample_size = dict(icd10_plink.sum(axis = 0))
416
                icd_100 = [i for i,j in icd102sample_size.items() if j > 100]
417
                icd10_plink = icd10_plink + 1
418
                icd10_plink['IID'] = self.pheno.eid.values
419
                icd10_plink['FID'] = self.pheno.eid.values
420
                icd10_plink = icd10_plink.loc[:, ['FID', 'IID'] + icd_100]
421
                print_sys('Only using ICD10 codes with at least 100 cases...')
422
                print_sys('There are ' + str(len(icd_100)) + ' ICD10 codes with at least 100 cases.')
423
                icd10_plink.to_csv(plink_path, index=None, sep=' ')
424
                self.icd10_plink = icd10_plink
425
                
426
            return self.icd10_plink