--- a +++ b/kgwas/data.py @@ -0,0 +1,426 @@ +import os, sys +from scipy.sparse import csr_matrix +from scipy.sparse.csgraph import connected_components +from sklearn.preprocessing import OneHotEncoder +import pandas as pd +import numpy as np +import pickle +from tqdm import tqdm +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +import torch +from .params import main_data_path, cohort_data_path, kinship_path, withdraw_path, fam_path +from .utils import get_fields, get_row_last_values, remove_kinships, save_dict, load_dict, print_sys + + +class ukbb_cohort: + + def __init__(self, main_data_path, cohort_data_path, withdraw_path, keep_relatives = False): + self.keep_relatives = keep_relatives + self.cohort_data_path = cohort_data_path + self.main_data_path = main_data_path + + if keep_relatives: + cohort_path = os.path.join(cohort_data_path, 'cohort_with_relatives.pkl') + else: + cohort_path = os.path.join(cohort_data_path, 'cohort_no_relatives.pkl') + + if not os.path.exists(cohort_path): + print_sys('construct from scratch...') + ''' + exclusions: + the original uk biobank nature paper supplementary S 3.4: + 22006: Genetic ethnic grouping -> to retain only white british ancestry + + https://www.frontiersin.org/articles/10.3389/fgene.2022.866042/full + 22018: genetic relatedness exclusions + 22019: sex chromosome aneuploidy + 31 <-> 22001: mismatch between self-reported sex and genetically determined sex + 22010: recommended genomic analysis exclusions, signs of insufficient data quality + + (optional) further remove relatives based on KING relative scores, choose the first one in relative group + remove the list of eids who do not want to be in the study anymore + ''' + + all_field_ids = [22006, 22018, 22019, 22001, 22010, 31] + df_main = get_fields(all_field_ids, main_data_path) + cur_size = len(df_main) + print_sys('Total sample size: ' + str(cur_size)) + df_main = df_main[df_main['22006-0.0'] == 1] + print_sys('Keeping only white british ancestry (ID: 22006), cutting from ' + str(cur_size) + ' to ' + str(len(df_main))) + cur_size = len(df_main) + + df_main = df_main[df_main['22018-0.0'].isnull()] + print_sys('Removing genetics related samples (ID: 22018), cutting from ' + str(cur_size) + ' to ' + str(len(df_main))) + cur_size = len(df_main) + + df_main = df_main[df_main['22019-0.0'].isnull()] + print_sys('Removing sex chromosome aneuploidy (ID: 22019), cutting from ' + str(cur_size) + ' to ' + str(len(df_main))) + cur_size = len(df_main) + + df_main = df_main[df_main['31-0.0'] == df_main['22001-0.0']] + print_sys('Removing samples with mismatched self-reported sex and genetic determined sex (ID: 31 <-> 22001), cutting from ' + str(cur_size) + ' to ' + str(len(df_main))) + cur_size = len(df_main) + + df_main = df_main[df_main['22010-0.0'].isnull()] + print_sys('Removing samples with genomic data quality (ID: 22010), cutting from ' + str(cur_size) + ' to ' + str(len(df_main))) + cur_size = len(df_main) + + save_dict(os.path.join(cohort_data_path, 'cohort_with_relatives.pkl'), df_main.eid.values) + + kinship_mask = remove_kinships(df_main.eid) + df_main = df_main[kinship_mask] + save_dict(os.path.join(cohort_data_path, 'cohort_no_relatives.pkl'), df_main.eid.values) + else: + print_sys('Found local copy...') + + self.cohort = load_dict(cohort_path) + print_sys('There are ' + str(len(self.cohort)) + ' samples!') + + if keep_relatives: + self.no_rel_eid = load_dict(os.path.join(cohort_data_path, 'cohort_no_relatives.pkl')) + + if os.path.exists(withdraw_path): + ## todo: when there is a withdraw file, implement this... + pass + + def get_covariates(self, to_plink = False, plink_num_pca = 15, return_full = False, plink_filter = False): + ''' + covariates: + + 31: sex + 21003: Age when attended assessment centre + 22009: pca + 54: assessment center + batch from params.fam_path file + ''' + covar_path = os.path.join(self.cohort_data_path, 'covariates_all.pkl') + if os.path.exists(covar_path): + print_sys('Found local copy...') + self.covar = load_dict(covar_path) + else: + print_sys('construct co-variates from scratch...') + df_covar = get_fields([31, 54, 21003, 22009], self.main_data_path) + column_name_map = {'22009-0.' + str(i): 'pca ' + str(i) for i in range(1, 41)} + column_name_map['31-0.0'] = 'sex' + column_name_map['21003-0.0'] = 'age' + column_name_map['54-0.0'] = 'assessment_center' + self.covar = df_covar.rename(columns = column_name_map) + + enc = OneHotEncoder(handle_unknown='ignore') + enc.fit(self.covar['assessment_center'].unique().reshape(-1,1)) + center_array = enc.transform(self.covar['assessment_center'].values.reshape(-1,1)).toarray() + center_one_hot = pd.DataFrame(center_array).astype('int').rename(columns = dict(zip(range(22), ['center_' + str(i) for i in range(22)]))) + + self.covar = self.covar.drop(['21003-1.0', '21003-2.0', '21003-3.0', 'assessment_center', '54-1.0', '54-2.0', '54-3.0'], axis = 1) + self.covar = self.covar.join(center_one_hot) + + df_fam = pd.read_csv(fam_path) + enc = OneHotEncoder(handle_unknown='ignore') + enc.fit(df_fam.trait.unique().reshape(-1,1)) + batch_one_hot = enc.transform(df_fam['trait'].values.reshape(-1,1)).toarray() + batch_num = batch_one_hot.shape[1] + id2batch = dict(zip(df_fam.fid.values, batch_one_hot.astype(int))) + df_batch = pd.DataFrame(np.stack(self.covar['eid'].apply(lambda x: id2batch[x] if x in id2batch else np.zeros(batch_one_hot.shape[1]).astype(int)).values)).rename(columns = dict(zip(range(batch_num), ['batch_' + str(i) for i in range(batch_num)]))) + self.covar = self.covar.join(df_batch) + + save_dict(covar_path, self.covar) + print_sys('Done! Saving...') + + if not to_plink: + if return_full: + return self.covar.reset_index(drop = True) + else: + return self.covar[self.covar.eid.isin(self.cohort)].reset_index(drop = True) + else: + plink_path = os.path.join(self.cohort_data_path, 'covar_pca' + str(plink_num_pca) + '_all_real_value') + if plink_filter: + plink_path += '_null_removed' + plink_path += '.txt' + + if not os.path.exists(plink_path): + pca_columns = [i for i in self.covar.columns.values if (i[:3]=='pca') and int(i.split()[-1]) <= plink_num_pca] + #center_one_hot_columns = ['center_' + str(i) for i in range(22)] + #batch_columns = ['batch_' + str(i) for i in range(batch_num)] + #self.covar[['eid', 'eid', 'age', 'sex'] + pca_columns + center_one_hot_columns + batch_columns].to_csv(plink_path, header=None, index=None, sep=' ') + center = np.argmax(self.covar.loc[:, self.covar.columns.str.contains('center')].values, axis = 1) + batch = np.argmax(self.covar.loc[:, self.covar.columns.str.contains('batch')].values, axis = 1) + self.covar = self.covar.iloc[:, :43] + self.covar['assessment_center'] = center + self.covar['batch'] = batch + if plink_filter: + self.covar = self.covar[self.covar.eid.isin(self.cohort)].reset_index(drop = True) + self.covar[['eid', 'eid', 'age', 'sex', 'assessment_center', 'batch'] + pca_columns].to_csv(plink_path, header=None, index=None, sep=' ') + self.covar_plink = pd.read_csv(plink_path, header = None, sep = ' ') + return self.covar_plink + + def get_external_traits(self, trait_name, to_plink = False, to_str = True, random_seed = 42, sep_cohort = False, randomize = False, use_sample_size = False, sample_size = -1, randomize_seed = 42): + ''' + example: + standing heights: 50 + ''' + if trait_name in ['body_BALDING1', 'cancer_BREAST', 'disease_ALLERGY_ECZEMA_DIAGNOSED', 'disease_HYPOTHYROIDISM_SELF_REP', 'other_MORNINGPERSON', 'pigment_SUNBURN']: + trait_type = 'binary' + else: + trait_type = 'continuous' + + pheno_path = os.path.join(self.cohort_data_path, str(trait_name) + '_pheno.pkl') + if os.path.exists(pheno_path): + print_sys('Found local copy...') + self.pheno = load_dict(pheno_path) + else: + print_sys('construct phenotype from scratch...') + + self.pheno = pd.read_csv(os.path.join(data_path, 'full_gwas', trait_name+'_'+trait_type+'.csv')) + self.pheno['eid'] = self.pheno.eid.astype('int') + self.pheno = self.pheno[self.pheno['pheno'].notnull()] + if trait_type == 'binary': + self.pheno['pheno'] += 1 + self.pheno['pheno'] = self.pheno['pheno'].astype(int) + save_dict(pheno_path, self.pheno) + print_sys('Done! Saving...') + + + + # filtering to cohorts incl. with/without relatives + self.pheno = self.pheno[self.pheno.eid.isin(self.cohort)].reset_index(drop = True) + + if to_str: + self.pheno['eid'] = self.pheno['eid'].astype('str') + if not to_plink: + return self.pheno + else: + plink_path = os.path.join(self.cohort_data_path, str(trait_name) + '_plink') + if self.keep_relatives: + plink_path = plink_path + '_with_relatives' + else: + plink_path = plink_path + '_no_relatives' + + if use_sample_size: + plink_path = plink_path + '_' + str(sample_size) + '_' + str(random_seed) + + if sep_cohort: + plink_path += '_sep_cohort' + + if randomize: + plink_path += '_randomize' + str(randomize_seed) + + plink_path = plink_path + '.txt' + + if randomize: + self.pheno['pheno'] = self.pheno['pheno'].sample(frac = 1, random_state = randomize_seed).values + + if use_sample_size: + from sklearn.model_selection import train_test_split + print('random_seed:', random_seed) + pheno_shuffle = self.pheno.sample(frac = 1, random_state = random_seed) + all_ids, y = pheno_shuffle.eid.values, pheno_shuffle['pheno'].values + train_val_ids, test_ids, y_train_val, y_test = all_ids[:sample_size], all_ids[sample_size:], y[:sample_size], y[sample_size:] + if sep_cohort: + self.pheno = self.pheno[self.pheno.eid.isin(test_ids)] + else: + self.pheno = self.pheno[self.pheno.eid.isin(train_val_ids)] + + if not os.path.exists(plink_path): + print_sys('Saving...') + self.pheno[['eid', 'eid', self.pheno.columns.values[-1]]].to_csv(plink_path, header=None, index=None, sep=' ') + else: + print_sys('Already existed! Loading...') + + self.pheno_plink = pd.read_csv(plink_path, header = None, sep = ' ') + return self.pheno_plink + + + + + def get_phenotype(self, field_id, aggregate = 'last_value', to_plink = False, to_str = True, normalize = 'None', frac = 1, random_seed = 42, fastgwa_match = False, icd10 = False, icd10_level = 2, sep_cohort = False, randomize = False, use_sample_size = False, sample_size = -1, randomize_seed = 42): + ''' + example: + standing heights: 50 + ''' + pheno_path = os.path.join(self.cohort_data_path, str(field_id) + '_pheno.pkl') + if os.path.exists(pheno_path): + print_sys('Found local copy...') + self.pheno = load_dict(pheno_path) + else: + print_sys('construct phenotype from scratch...') + if icd10: + ## field_id is icd10 level + icd10_df = self.get_icd10(to_plink = True, level = icd10_level, get_all = True) + self.pheno = icd10_df[['FID', field_id]].rename(columns = {'FID': 'eid'}) + self.pheno['eid'] = self.pheno.eid.astype('int') + else: + ## from raw data field id + self.pheno = get_fields([field_id], self.main_data_path) + save_dict(pheno_path, self.pheno) + print_sys('Done! Saving...') + + if len(self.pheno.columns.values) > 2: + print_sys('There are multiple index for this phenotype... aggregate...') + if aggregate == 'last_value': + print_sys('Getting the latest measure...') + tmp = pd.DataFrame() + tmp['eid'] = self.pheno.loc[:, 'eid'] + tmp[str(field_id)] = get_row_last_values(self.pheno.iloc[:, 1:]) + self.pheno = tmp + print_sys('There are ' + str(len(self.pheno[self.pheno[str(field_id)].isnull()])) + ' samples with NaN values. Removing them ...') + self.pheno = self.pheno[self.pheno[str(field_id)].notnull()] + + if fastgwa_match: + # get the number of without relatives: + if not self.keep_relatives: + raise ValueError('If you turned fastgwa_match = True, then keep_relatives = True!') + self.rel_ratio = len(self.pheno[self.pheno.eid.isin(self.no_rel_eid)])/len(self.pheno[self.pheno.eid.isin(self.cohort)]) + + # filtering to cohorts incl. with/without relatives + self.pheno = self.pheno[self.pheno.eid.isin(self.cohort)].reset_index(drop = True) + + if normalize != 'None': + y = self.pheno[str(field_id)].values + if normalize == 'log': + y = np.log(y) + elif normalize == 'std': + y = (y - np.mean(y))/np.std(y) + elif normalize == 'quantile_normalization': + from sklearn.preprocessing import quantile_transform + y = quantile_transform(y.reshape(-1,1), output_distribution = 'normal', random_state = 42).reshape(-1) + self.pheno[str(field_id)] = y + + if to_str: + self.pheno['eid'] = self.pheno['eid'].astype('str') + if not to_plink: + return self.pheno + else: + plink_path = os.path.join(self.cohort_data_path, str(field_id) + '_plink') + if self.keep_relatives: + plink_path = plink_path + '_with_relatives' + else: + plink_path = plink_path + '_no_relatives' + + if normalize != 'None': + plink_path = plink_path + '_' + str(normalize) + if use_sample_size: + plink_path = plink_path + '_' + str(sample_size) + '_' + str(random_seed) + else: + if frac != 1: + plink_path = plink_path + '_' + str(frac) + '_' + str(random_seed) + + if fastgwa_match: + plink_path += '_match' + + if sep_cohort: + plink_path += '_sep_cohort' + + if randomize: + plink_path += '_randomize' + str(randomize_seed) + + + plink_path = plink_path + '.txt' + + if randomize: + self.pheno[str(field_id)] = self.pheno[str(field_id)].sample(frac = 1, random_state = randomize_seed).values + + + if use_sample_size: + from sklearn.model_selection import train_test_split + + if icd10: + df_cases = self.pheno[self.pheno[str(field_id)] == 2] + df_cases_shuffle = df_cases.sample(frac = 1, random_state = random_seed) + all_ids, y = df_cases_shuffle.eid.values, df_cases_shuffle[str(field_id)].values + train_val_ids, test_ids, y_train_val, y_test = all_ids[:sample_size], all_ids[sample_size:], y[:sample_size], y[sample_size:] + train_val_ids = np.concatenate((train_val_ids, self.pheno[self.pheno[str(field_id)] == 1].eid.values)) + self.pheno = self.pheno[self.pheno.eid.isin(train_val_ids)] + if sep_cohort: + raise NotImplementedError + else: + print('random_seed', random_seed) + pheno_shuffle = self.pheno.sample(frac = 1, random_state = random_seed) + all_ids, y = pheno_shuffle.eid.values, pheno_shuffle[str(field_id)].values + train_val_ids, test_ids, y_train_val, y_test = all_ids[:sample_size], all_ids[sample_size:], y[:sample_size], y[sample_size:] + if fastgwa_match: + raise ValueError('Not used anymore...') + if sep_cohort: + self.pheno = self.pheno[self.pheno.eid.isin(test_ids)] + else: + self.pheno = self.pheno[self.pheno.eid.isin(train_val_ids)] + else: + if frac!=1: + from sklearn.model_selection import train_test_split + all_ids, y = self.pheno.eid.values, self.pheno[str(field_id)].values + train_val_ids, test_ids, y_train_val, y_test = train_test_split(all_ids, y, test_size=frac, random_state=random_seed) + if fastgwa_match: + train_val_ids, test_ids, y_train_val, y_test = train_test_split(train_val_ids, y_train_val, test_size=1-self.rel_ratio, random_state=42) + if sep_cohort: + self.pheno = self.pheno[self.pheno.eid.isin(test_ids)] + else: + self.pheno = self.pheno[self.pheno.eid.isin(train_val_ids)] + + if not os.path.exists(plink_path): + self.pheno[['eid', 'eid', self.pheno.columns.values[-1]]].to_csv(plink_path, header=None, index=None, sep=' ') + else: + print_sys('Already existed! Loading...') + + self.pheno_plink = pd.read_csv(plink_path, header = None, sep = ' ') + return self.pheno_plink + + + def get_icd10(self, to_plink = False, level = 2, get_all = False): + ''' + icd10: 41270 + ''' + pheno_path = os.path.join(self.cohort_data_path, 'icd10.pkl') + level_str = 'level' + str(level) + if os.path.exists(pheno_path): + print_sys('Found local copy...') + self.icd10 = load_dict(pheno_path) + else: + print_sys('construct from scratch...') + icd10_raw_concat = get_fields([41270], self.main_data_path) + icd10_columns = icd10_raw_concat.columns.values[1:] + icd10_tuple = icd10_raw_concat.apply(lambda x: (x.eid, x[icd10_columns][x[icd10_columns].notnull()].values), axis = 1) + icd10 = pd.DataFrame(list(icd10_tuple.values)).rename(columns = {0: 'eid', 1: 'level3'}) + icd10['level2'] = icd10['level3'].apply(lambda x: np.unique([i[:3] for i in x])) + save_dict(pheno_path, icd10) + print_sys('Done! Saving...') + self.icd10 = icd10 + if get_all: + self.pheno = self.icd10.reset_index(drop = True) + else: + self.pheno = self.icd10[self.icd10.eid.isin(self.cohort)].reset_index(drop = True) + if not to_plink: + return self.pheno + else: + if self.keep_relatives or get_all: + plink_path = os.path.join(self.cohort_data_path, 'icd10_plink_with_relatives_' + level_str + '.txt') + else: + plink_path = os.path.join(self.cohort_data_path, 'icd10_plink_no_relatives_' + level_str + '.txt') + + if os.path.exists(plink_path): + print_sys("Found local copy...") + self.icd10_plink = pd.read_csv(plink_path, sep=' ') + else: + print_sys('transforming to plink files... takes around 1 min...') + unique_icd10 = np.unique([item for sublist in self.pheno[level_str].values for item in sublist]) + icd10_2_idx = dict(zip(unique_icd10, range(len(unique_icd10)))) + idx_2_icd10 = dict(zip(range(len(unique_icd10)), unique_icd10)) + + self.pheno[level_str + '_idx'] = self.pheno[level_str].apply(lambda x: [icd10_2_idx[i] for i in x]) + + tmp = np.zeros((len(self.pheno), len(unique_icd10)), dtype=np.int8) + for idx, i in enumerate(self.pheno[level_str + '_idx'].values): + tmp[idx, i] = 1 + + icd10_plink = pd.DataFrame(tmp).rename(columns = idx_2_icd10) + icd102sample_size = dict(icd10_plink.sum(axis = 0)) + icd_100 = [i for i,j in icd102sample_size.items() if j > 100] + icd10_plink = icd10_plink + 1 + icd10_plink['IID'] = self.pheno.eid.values + icd10_plink['FID'] = self.pheno.eid.values + icd10_plink = icd10_plink.loc[:, ['FID', 'IID'] + icd_100] + print_sys('Only using ICD10 codes with at least 100 cases...') + print_sys('There are ' + str(len(icd_100)) + ' ICD10 codes with at least 100 cases.') + icd10_plink.to_csv(plink_path, index=None, sep=' ') + self.icd10_plink = icd10_plink + + return self.icd10_plink \ No newline at end of file