--- a
+++ b/kgwas/data.py
@@ -0,0 +1,426 @@
+import os, sys
+from scipy.sparse import csr_matrix
+from scipy.sparse.csgraph import connected_components
+from sklearn.preprocessing import OneHotEncoder
+import pandas as pd
+import numpy as np
+import pickle
+from tqdm import tqdm
+from torch.utils.data import Dataset, DataLoader
+from torch.utils.data.sampler import WeightedRandomSampler
+import torch
+from .params import main_data_path, cohort_data_path, kinship_path, withdraw_path, fam_path
+from .utils import get_fields, get_row_last_values, remove_kinships, save_dict, load_dict, print_sys
+
+
+class ukbb_cohort:
+    
+    def __init__(self, main_data_path, cohort_data_path, withdraw_path, keep_relatives = False):
+        self.keep_relatives = keep_relatives
+        self.cohort_data_path = cohort_data_path
+        self.main_data_path = main_data_path
+        
+        if keep_relatives:
+            cohort_path = os.path.join(cohort_data_path, 'cohort_with_relatives.pkl')
+        else:
+            cohort_path = os.path.join(cohort_data_path, 'cohort_no_relatives.pkl')
+        
+        if not os.path.exists(cohort_path):
+            print_sys('construct from scratch...')
+            '''
+            exclusions:
+            the original uk biobank nature paper supplementary S 3.4:
+            22006: Genetic ethnic grouping -> to retain only white british ancestry
+
+            https://www.frontiersin.org/articles/10.3389/fgene.2022.866042/full
+            22018: genetic relatedness exclusions
+            22019: sex chromosome aneuploidy 
+            31 <-> 22001: mismatch between self-reported sex and genetically determined sex 
+            22010: recommended genomic analysis exclusions, signs of insufficient data quality
+
+            (optional) further remove relatives based on KING relative scores, choose the first one in relative group
+            remove the list of eids who do not want to be in the study anymore            
+            '''
+            
+            all_field_ids = [22006, 22018, 22019, 22001, 22010, 31]
+            df_main = get_fields(all_field_ids, main_data_path)
+            cur_size = len(df_main)
+            print_sys('Total sample size: ' + str(cur_size))
+            df_main = df_main[df_main['22006-0.0'] == 1]
+            print_sys('Keeping only white british ancestry (ID: 22006), cutting from ' + str(cur_size) + ' to ' + str(len(df_main)))
+            cur_size = len(df_main)
+
+            df_main = df_main[df_main['22018-0.0'].isnull()]
+            print_sys('Removing genetics related samples (ID: 22018), cutting from ' + str(cur_size) + ' to ' + str(len(df_main)))
+            cur_size = len(df_main)
+
+            df_main = df_main[df_main['22019-0.0'].isnull()]
+            print_sys('Removing sex chromosome aneuploidy (ID: 22019), cutting from ' + str(cur_size) + ' to ' + str(len(df_main)))
+            cur_size = len(df_main)
+
+            df_main = df_main[df_main['31-0.0'] == df_main['22001-0.0']]
+            print_sys('Removing samples with mismatched self-reported sex and genetic determined sex (ID: 31 <-> 22001), cutting from ' + str(cur_size) + ' to ' + str(len(df_main)))
+            cur_size = len(df_main)
+
+            df_main = df_main[df_main['22010-0.0'].isnull()]
+            print_sys('Removing samples with genomic data quality (ID: 22010), cutting from ' + str(cur_size) + ' to ' + str(len(df_main)))
+            cur_size = len(df_main)
+            
+            save_dict(os.path.join(cohort_data_path, 'cohort_with_relatives.pkl'), df_main.eid.values)
+            
+            kinship_mask = remove_kinships(df_main.eid)
+            df_main = df_main[kinship_mask]
+            save_dict(os.path.join(cohort_data_path, 'cohort_no_relatives.pkl'), df_main.eid.values)
+        else:
+            print_sys('Found local copy...')
+            
+        self.cohort = load_dict(cohort_path)
+        print_sys('There are ' + str(len(self.cohort)) + ' samples!')
+        
+        if keep_relatives:
+            self.no_rel_eid = load_dict(os.path.join(cohort_data_path, 'cohort_no_relatives.pkl'))
+            
+        if os.path.exists(withdraw_path):
+            ## todo: when there is a withdraw file, implement this...
+            pass
+        
+    def get_covariates(self, to_plink = False, plink_num_pca = 15, return_full = False, plink_filter = False):
+        '''
+        covariates:
+
+        31: sex
+        21003: Age when attended assessment centre
+        22009: pca
+        54: assessment center
+        batch from params.fam_path file
+        '''
+        covar_path = os.path.join(self.cohort_data_path, 'covariates_all.pkl')
+        if os.path.exists(covar_path):
+            print_sys('Found local copy...')
+            self.covar = load_dict(covar_path)
+        else:
+            print_sys('construct co-variates from scratch...')
+            df_covar = get_fields([31, 54, 21003, 22009], self.main_data_path)
+            column_name_map = {'22009-0.' + str(i): 'pca ' + str(i) for i in range(1, 41)}
+            column_name_map['31-0.0'] = 'sex'
+            column_name_map['21003-0.0'] = 'age'
+            column_name_map['54-0.0'] = 'assessment_center'
+            self.covar = df_covar.rename(columns = column_name_map)
+            
+            enc = OneHotEncoder(handle_unknown='ignore')
+            enc.fit(self.covar['assessment_center'].unique().reshape(-1,1))
+            center_array = enc.transform(self.covar['assessment_center'].values.reshape(-1,1)).toarray()
+            center_one_hot = pd.DataFrame(center_array).astype('int').rename(columns = dict(zip(range(22), ['center_' + str(i) for i in range(22)])))
+            
+            self.covar = self.covar.drop(['21003-1.0', '21003-2.0', '21003-3.0', 'assessment_center', '54-1.0', '54-2.0', '54-3.0'], axis = 1)
+            self.covar = self.covar.join(center_one_hot)
+            
+            df_fam = pd.read_csv(fam_path)
+            enc = OneHotEncoder(handle_unknown='ignore')
+            enc.fit(df_fam.trait.unique().reshape(-1,1))
+            batch_one_hot = enc.transform(df_fam['trait'].values.reshape(-1,1)).toarray()
+            batch_num = batch_one_hot.shape[1]
+            id2batch = dict(zip(df_fam.fid.values, batch_one_hot.astype(int)))
+            df_batch = pd.DataFrame(np.stack(self.covar['eid'].apply(lambda x: id2batch[x] if x in id2batch else np.zeros(batch_one_hot.shape[1]).astype(int)).values)).rename(columns = dict(zip(range(batch_num), ['batch_' + str(i) for i in range(batch_num)])))
+            self.covar = self.covar.join(df_batch)
+            
+            save_dict(covar_path, self.covar)
+            print_sys('Done! Saving...')
+            
+        if not to_plink:
+            if return_full:
+                return self.covar.reset_index(drop = True)
+            else:
+                return self.covar[self.covar.eid.isin(self.cohort)].reset_index(drop = True)
+        else:
+            plink_path = os.path.join(self.cohort_data_path, 'covar_pca' + str(plink_num_pca) + '_all_real_value')
+            if plink_filter:
+                plink_path += '_null_removed'
+            plink_path += '.txt'
+            
+            if not os.path.exists(plink_path):
+                pca_columns = [i for i in self.covar.columns.values if (i[:3]=='pca') and int(i.split()[-1]) <= plink_num_pca]
+                #center_one_hot_columns = ['center_' + str(i) for i in range(22)]
+                #batch_columns = ['batch_' + str(i) for i in range(batch_num)]
+                #self.covar[['eid', 'eid', 'age', 'sex'] + pca_columns + center_one_hot_columns + batch_columns].to_csv(plink_path, header=None, index=None, sep=' ')
+                center = np.argmax(self.covar.loc[:, self.covar.columns.str.contains('center')].values, axis = 1)
+                batch = np.argmax(self.covar.loc[:, self.covar.columns.str.contains('batch')].values, axis = 1)
+                self.covar = self.covar.iloc[:, :43]
+                self.covar['assessment_center'] = center
+                self.covar['batch'] = batch
+                if plink_filter:
+                    self.covar = self.covar[self.covar.eid.isin(self.cohort)].reset_index(drop = True)
+                self.covar[['eid', 'eid', 'age', 'sex', 'assessment_center', 'batch'] + pca_columns].to_csv(plink_path, header=None, index=None, sep=' ')
+            self.covar_plink = pd.read_csv(plink_path, header = None, sep = ' ')
+            return self.covar_plink
+    
+    def get_external_traits(self, trait_name, to_plink = False, to_str = True, random_seed = 42, sep_cohort = False, randomize = False, use_sample_size = False, sample_size = -1, randomize_seed = 42):
+        '''
+        example:
+        standing heights: 50
+        '''
+        if trait_name in ['body_BALDING1', 'cancer_BREAST', 'disease_ALLERGY_ECZEMA_DIAGNOSED', 'disease_HYPOTHYROIDISM_SELF_REP', 'other_MORNINGPERSON', 'pigment_SUNBURN']:
+            trait_type = 'binary'
+        else:
+            trait_type = 'continuous'
+
+        pheno_path = os.path.join(self.cohort_data_path, str(trait_name) + '_pheno.pkl')
+        if os.path.exists(pheno_path):
+            print_sys('Found local copy...')
+            self.pheno = load_dict(pheno_path)
+        else:
+            print_sys('construct phenotype from scratch...')
+            
+            self.pheno = pd.read_csv(os.path.join(data_path, 'full_gwas', trait_name+'_'+trait_type+'.csv'))
+            self.pheno['eid'] = self.pheno.eid.astype('int')
+            self.pheno = self.pheno[self.pheno['pheno'].notnull()]
+            if trait_type == 'binary':
+                self.pheno['pheno'] += 1
+                self.pheno['pheno'] = self.pheno['pheno'].astype(int)
+            save_dict(pheno_path, self.pheno)
+            print_sys('Done! Saving...')
+            
+            
+        
+        # filtering to cohorts incl. with/without relatives            
+        self.pheno = self.pheno[self.pheno.eid.isin(self.cohort)].reset_index(drop = True)
+            
+        if to_str:
+            self.pheno['eid'] = self.pheno['eid'].astype('str')
+        if not to_plink:
+            return self.pheno
+        else:
+            plink_path = os.path.join(self.cohort_data_path, str(trait_name) + '_plink')
+            if self.keep_relatives:
+                plink_path = plink_path + '_with_relatives'
+            else:
+                plink_path = plink_path + '_no_relatives'
+                
+            if use_sample_size:
+                plink_path = plink_path + '_' + str(sample_size) + '_' + str(random_seed)
+            
+            if sep_cohort:
+                plink_path += '_sep_cohort'
+             
+            if randomize:
+                plink_path += '_randomize' + str(randomize_seed)
+            
+            plink_path = plink_path + '.txt'
+            
+            if randomize:
+                self.pheno['pheno'] = self.pheno['pheno'].sample(frac = 1, random_state = randomize_seed).values
+            
+            if use_sample_size:
+                from sklearn.model_selection import train_test_split
+                print('random_seed:', random_seed)
+                pheno_shuffle = self.pheno.sample(frac = 1, random_state = random_seed)
+                all_ids, y = pheno_shuffle.eid.values, pheno_shuffle['pheno'].values
+                train_val_ids, test_ids, y_train_val, y_test = all_ids[:sample_size], all_ids[sample_size:], y[:sample_size], y[sample_size:]
+                if sep_cohort:
+                    self.pheno = self.pheno[self.pheno.eid.isin(test_ids)]
+                else:
+                    self.pheno = self.pheno[self.pheno.eid.isin(train_val_ids)]
+                        
+            if not os.path.exists(plink_path):
+                print_sys('Saving...')
+                self.pheno[['eid', 'eid', self.pheno.columns.values[-1]]].to_csv(plink_path, header=None, index=None, sep=' ')
+            else:
+                print_sys('Already existed! Loading...')
+            
+            self.pheno_plink = pd.read_csv(plink_path, header = None, sep = ' ')
+            return self.pheno_plink
+    
+    
+    
+    
+    def get_phenotype(self, field_id, aggregate = 'last_value', to_plink = False, to_str = True, normalize = 'None', frac = 1, random_seed = 42, fastgwa_match = False, icd10 = False, icd10_level = 2, sep_cohort = False, randomize = False, use_sample_size = False, sample_size = -1, randomize_seed = 42):
+        '''
+        example:
+        standing heights: 50
+        '''
+        pheno_path = os.path.join(self.cohort_data_path, str(field_id) + '_pheno.pkl')
+        if os.path.exists(pheno_path):
+            print_sys('Found local copy...')
+            self.pheno = load_dict(pheno_path)
+        else:
+            print_sys('construct phenotype from scratch...')
+            if icd10:
+                ## field_id is icd10 level
+                icd10_df = self.get_icd10(to_plink = True, level = icd10_level, get_all = True)
+                self.pheno = icd10_df[['FID', field_id]].rename(columns = {'FID': 'eid'})
+                self.pheno['eid'] = self.pheno.eid.astype('int')
+            else:
+                ## from raw data field id
+                self.pheno = get_fields([field_id], self.main_data_path)
+            save_dict(pheno_path, self.pheno)
+            print_sys('Done! Saving...')
+        
+        if len(self.pheno.columns.values) > 2:
+            print_sys('There are multiple index for this phenotype... aggregate...')
+            if aggregate == 'last_value':
+                print_sys('Getting the latest measure...')
+                tmp = pd.DataFrame()
+                tmp['eid'] = self.pheno.loc[:, 'eid']
+                tmp[str(field_id)] = get_row_last_values(self.pheno.iloc[:, 1:])
+                self.pheno = tmp
+                print_sys('There are ' + str(len(self.pheno[self.pheno[str(field_id)].isnull()])) + ' samples with NaN values. Removing them ...')
+                self.pheno = self.pheno[self.pheno[str(field_id)].notnull()]
+                
+        if fastgwa_match:
+            # get the number of without relatives:
+            if not self.keep_relatives:
+                raise ValueError('If you turned fastgwa_match = True, then keep_relatives = True!')
+            self.rel_ratio = len(self.pheno[self.pheno.eid.isin(self.no_rel_eid)])/len(self.pheno[self.pheno.eid.isin(self.cohort)])
+        
+        # filtering to cohorts incl. with/without relatives            
+        self.pheno = self.pheno[self.pheno.eid.isin(self.cohort)].reset_index(drop = True)
+
+        if normalize != 'None':
+            y = self.pheno[str(field_id)].values
+            if normalize == 'log':
+                y = np.log(y)
+            elif normalize == 'std':
+                y = (y - np.mean(y))/np.std(y)
+            elif normalize == 'quantile_normalization':
+                from sklearn.preprocessing import quantile_transform
+                y = quantile_transform(y.reshape(-1,1), output_distribution = 'normal', random_state = 42).reshape(-1)
+            self.pheno[str(field_id)] = y 
+            
+        if to_str:
+            self.pheno['eid'] = self.pheno['eid'].astype('str')
+        if not to_plink:
+            return self.pheno
+        else:
+            plink_path = os.path.join(self.cohort_data_path, str(field_id) + '_plink')
+            if self.keep_relatives:
+                plink_path = plink_path + '_with_relatives'
+            else:
+                plink_path = plink_path + '_no_relatives'
+                
+            if normalize != 'None':
+                plink_path = plink_path + '_' + str(normalize)
+            if use_sample_size:
+                plink_path = plink_path + '_' + str(sample_size) + '_' + str(random_seed)
+            else:
+                if frac != 1:
+                    plink_path = plink_path + '_' + str(frac) + '_' + str(random_seed)
+                
+            if fastgwa_match:
+                plink_path += '_match'
+                
+            if sep_cohort:
+                plink_path += '_sep_cohort'
+             
+            if randomize:
+                plink_path += '_randomize' + str(randomize_seed)
+                
+            
+            plink_path = plink_path + '.txt'
+            
+            if randomize:
+                self.pheno[str(field_id)] = self.pheno[str(field_id)].sample(frac = 1, random_state = randomize_seed).values
+            
+            
+            if use_sample_size:
+                from sklearn.model_selection import train_test_split
+                
+                if icd10:
+                    df_cases = self.pheno[self.pheno[str(field_id)] == 2]
+                    df_cases_shuffle = df_cases.sample(frac = 1, random_state = random_seed)
+                    all_ids, y = df_cases_shuffle.eid.values, df_cases_shuffle[str(field_id)].values
+                    train_val_ids, test_ids, y_train_val, y_test = all_ids[:sample_size], all_ids[sample_size:], y[:sample_size], y[sample_size:]
+                    train_val_ids = np.concatenate((train_val_ids, self.pheno[self.pheno[str(field_id)] == 1].eid.values))
+                    self.pheno = self.pheno[self.pheno.eid.isin(train_val_ids)]
+                    if sep_cohort:
+                        raise NotImplementedError
+                else:
+                    print('random_seed', random_seed)
+                    pheno_shuffle = self.pheno.sample(frac = 1, random_state = random_seed)
+                    all_ids, y = pheno_shuffle.eid.values, pheno_shuffle[str(field_id)].values
+                    train_val_ids, test_ids, y_train_val, y_test = all_ids[:sample_size], all_ids[sample_size:], y[:sample_size], y[sample_size:]
+                    if fastgwa_match:
+                        raise ValueError('Not used anymore...')
+                    if sep_cohort:
+                        self.pheno = self.pheno[self.pheno.eid.isin(test_ids)]
+                    else:
+                        self.pheno = self.pheno[self.pheno.eid.isin(train_val_ids)]
+            else:
+                if frac!=1:
+                    from sklearn.model_selection import train_test_split
+                    all_ids, y = self.pheno.eid.values, self.pheno[str(field_id)].values                
+                    train_val_ids, test_ids, y_train_val, y_test = train_test_split(all_ids, y, test_size=frac, random_state=random_seed)
+                    if fastgwa_match:
+                        train_val_ids, test_ids, y_train_val, y_test = train_test_split(train_val_ids, y_train_val, test_size=1-self.rel_ratio, random_state=42)
+                    if sep_cohort:
+                        self.pheno = self.pheno[self.pheno.eid.isin(test_ids)]
+                    else:
+                        self.pheno = self.pheno[self.pheno.eid.isin(train_val_ids)]
+                        
+            if not os.path.exists(plink_path):
+                self.pheno[['eid', 'eid', self.pheno.columns.values[-1]]].to_csv(plink_path, header=None, index=None, sep=' ')
+            else:
+                print_sys('Already existed! Loading...')
+            
+            self.pheno_plink = pd.read_csv(plink_path, header = None, sep = ' ')
+            return self.pheno_plink
+            
+        
+    def get_icd10(self, to_plink = False, level = 2, get_all = False):
+        '''
+        icd10: 41270
+        '''
+        pheno_path = os.path.join(self.cohort_data_path, 'icd10.pkl')
+        level_str = 'level' + str(level)
+        if os.path.exists(pheno_path):
+            print_sys('Found local copy...')
+            self.icd10 = load_dict(pheno_path)
+        else:
+            print_sys('construct from scratch...')
+            icd10_raw_concat = get_fields([41270], self.main_data_path)
+            icd10_columns = icd10_raw_concat.columns.values[1:]
+            icd10_tuple = icd10_raw_concat.apply(lambda x: (x.eid, x[icd10_columns][x[icd10_columns].notnull()].values), axis = 1)
+            icd10 = pd.DataFrame(list(icd10_tuple.values)).rename(columns = {0: 'eid', 1: 'level3'})
+            icd10['level2'] = icd10['level3'].apply(lambda x: np.unique([i[:3] for i in x]))
+            save_dict(pheno_path, icd10)
+            print_sys('Done! Saving...')
+            self.icd10 = icd10
+        if get_all:
+            self.pheno = self.icd10.reset_index(drop = True)
+        else:
+            self.pheno = self.icd10[self.icd10.eid.isin(self.cohort)].reset_index(drop = True)
+        if not to_plink:
+            return self.pheno
+        else:            
+            if self.keep_relatives or get_all:
+                plink_path = os.path.join(self.cohort_data_path, 'icd10_plink_with_relatives_' + level_str + '.txt')
+            else:
+                plink_path = os.path.join(self.cohort_data_path, 'icd10_plink_no_relatives_' + level_str + '.txt')
+            
+            if os.path.exists(plink_path):
+                print_sys("Found local copy...")
+                self.icd10_plink = pd.read_csv(plink_path, sep=' ')
+            else:
+                print_sys('transforming to plink files... takes around 1 min...')
+                unique_icd10 = np.unique([item for sublist in self.pheno[level_str].values for item in sublist])
+                icd10_2_idx = dict(zip(unique_icd10, range(len(unique_icd10))))
+                idx_2_icd10 = dict(zip(range(len(unique_icd10)), unique_icd10))
+
+                self.pheno[level_str + '_idx'] = self.pheno[level_str].apply(lambda x: [icd10_2_idx[i] for i in x])
+
+                tmp = np.zeros((len(self.pheno), len(unique_icd10)), dtype=np.int8)
+                for idx, i in enumerate(self.pheno[level_str + '_idx'].values):
+                    tmp[idx, i] = 1
+
+                icd10_plink = pd.DataFrame(tmp).rename(columns = idx_2_icd10)
+                icd102sample_size = dict(icd10_plink.sum(axis = 0))
+                icd_100 = [i for i,j in icd102sample_size.items() if j > 100]
+                icd10_plink = icd10_plink + 1
+                icd10_plink['IID'] = self.pheno.eid.values
+                icd10_plink['FID'] = self.pheno.eid.values
+                icd10_plink = icd10_plink.loc[:, ['FID', 'IID'] + icd_100]
+                print_sys('Only using ICD10 codes with at least 100 cases...')
+                print_sys('There are ' + str(len(icd_100)) + ' ICD10 codes with at least 100 cases.')
+                icd10_plink.to_csv(plink_path, index=None, sep=' ')
+                self.icd10_plink = icd10_plink
+                
+            return self.icd10_plink
\ No newline at end of file