--- a +++ b/kgwas/kgwas_data.py @@ -0,0 +1,559 @@ +from torch_geometric.data import HeteroData +import torch_geometric.transforms as T + +from sklearn.model_selection import train_test_split +import pandas as pd +import torch +import numpy as np +import pickle +import os +import tarfile +import urllib.request +import shutil +from tqdm import tqdm +import subprocess + +from .utils import ldsc_regression_weights, load_dict +from .params import scdrs_traits + +class KGWAS_Data: + def __init__(self, data_path='./data/'): + self.data_path = data_path + + # Ensure the data path exists + if not os.path.exists(data_path): + os.makedirs(data_path) + + # Check if relevant data exists in the data_path + required_files = [ + 'cell_kg/network/node_idx2id.pkl', + 'cell_kg/network/edge_index.pkl', + 'cell_kg/network/node_id2idx.pkl', + 'cell_kg/node_emb/variant_emb/enformer_feat.pkl', + 'cell_kg/node_emb/gene_emb/esm_feat.pkl', + 'ld_score/filter_genotyped_ldscores.csv', + 'ld_score/ldscores_from_data.csv', + 'ld_score/ukb_white_ld_10MB_no_hla.pkl', + 'ld_score/ukb_white_ld_10MB.pkl', + 'misc_data/ukb_white_with_cm.bim', + ] + missing_files = [f for f in required_files if not os.path.exists(os.path.join(data_path, f))] + + if missing_files: + print("Relevant data not found in the data_path. Downloading and extracting data...") + url = "https://dataverse.harvard.edu/api/access/datafile/10731230" + file_name = 'kgwas_core_data' + self._download_and_extract_data(url, file_name) + else: + print("All required data files are present.") + + def download_all_data(self): + url = "https://dataverse.harvard.edu/api/access/datafile/XXXX" + file_name = 'kgwas_data' + self._download_and_extract_data(url, file_name) + + def _merge_with_rsync(self, src, dst): + """Merge directories using rsync.""" + try: + subprocess.run( + ["rsync", "-a", "--ignore-existing", src + "/", dst + "/"], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + except subprocess.CalledProcessError as e: + print(f"Error during rsync: {e.stderr.decode()}") + + def _download_and_extract_data(self, url, file_name): + """Download, extract, and merge directories using rsync.""" + tar_file_path = os.path.join(self.data_path, f"{file_name}.tar.gz") + + # Download the file + print(f"Downloading {file_name}.tar.gz...") + self._download_with_progress(url, tar_file_path) + print("Download complete.") + + # Extract the tar.gz file + print("Extracting files...") + with tarfile.open(tar_file_path, 'r:gz') as tar: + tar.extractall(self.data_path) + print("Extraction complete.") + + # Clean up the tar.gz file + os.remove(tar_file_path) + + # Merge extracted contents into the data_path directory + extracted_dir = os.path.join(self.data_path, file_name) + if os.path.exists(extracted_dir): + print(f"Merging extracted directory '{extracted_dir}' into '{self.data_path}'...") + self._merge_with_rsync(extracted_dir, self.data_path) + + # Remove the now-empty extracted directory + shutil.rmtree(extracted_dir) + + def _download_with_progress(self, url, file_path): + """Download a file with a progress bar.""" + request = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'}) + response = urllib.request.urlopen(request) + total_size = int(response.getheader('Content-Length').strip()) + block_size = 1024 # 1 KB + + with open(file_path, 'wb') as file, tqdm( + total=total_size, unit='B', unit_scale=True, desc="Downloading" + ) as pbar: + while True: + buffer = response.read(block_size) + if not buffer: + break + file.write(buffer) + pbar.update(len(buffer)) + + + def load_kg(self, snp_init_emb = 'enformer', + go_init_emb = 'random', + gene_init_emb = 'esm', + sample_edges = False, + sample_ratio = 1): + + data_path = self.data_path + + ## Load KG + + print('--loading KG---') + idx2id = load_dict(os.path.join(data_path, 'cell_kg/network/node_idx2id.pkl')) + edge_index_all = load_dict(os.path.join(data_path, 'cell_kg/network/edge_index.pkl')) + id2idx = load_dict(os.path.join(data_path, 'cell_kg/network/node_id2idx.pkl')) + self.id2idx = id2idx + self.idx2id = idx2id + + data = HeteroData() + + ## Load initialized embeddings + + if snp_init_emb == 'random': + print('--using random SNP embedding--') + + data['SNP'].x = torch.rand((len(idx2id['SNP']), 128), requires_grad = False) + snp_init_dim_size = 128 + elif snp_init_emb == 'kg': + print('--using KG SNP embedding--') + + id2idx_kg = load_dict(os.path.join(data_path, 'cell_kg/node_emb/transe_emb/transe_emb_id2idx_kg.pkl')) + kg_emb = load_dict(os.path.join(data_path, 'cell_kg/node_emb/transe_emb/transe_emb_inverse_triplets.pkl')) + node_map = idx2id['SNP'] + data['SNP'].x = torch.vstack([torch.tensor(kg_emb[id2idx_kg[node_map[i]]]) if node_map[i] in id2idx_kg \ + else torch.rand(50, requires_grad = False) for i in range(len(node_map))]) + snp_init_dim_size = 50 + + elif snp_init_emb == 'cadd': + print('--using CADD SNP embedding--') + + df_variant = pd.read_csv(os.path.join(data_path, 'cell_kg/node_emb/variant_emb/cadd_feat.csv')) + df_variant = df_variant.set_index('Unnamed: 0') + variant_feat = df_variant.values + node_map = idx2id['SNP'] + rs2idx_feat = dict(zip(df_variant.index.values, range(len(df_variant.index.values)))) + data['SNP'].x = torch.vstack([torch.tensor(variant_feat[rs2idx_feat[node_map[i]]]) if node_map[i] in rs2idx_feat \ + else torch.rand(64, requires_grad = False) for i in range(len(node_map))]).float() + snp_init_dim_size = 64 + + + elif snp_init_emb == 'baselineLD': + print('--using baselineLD SNP embedding--') + node_map = idx2id['SNP'] + rs2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/variant_emb/baselineld_feat.pkl')) + data['SNP'].x = torch.vstack([torch.tensor(rs2idx_feat[node_map[i]]) if node_map[i] in rs2idx_feat \ + else torch.rand(70, requires_grad = False) for i in range(len(node_map))]).float() + snp_init_dim_size = 70 + + elif snp_init_emb == 'SLDSC': + print('--using SLDSC SNP embedding--') + node_map = idx2id['SNP'] + rs2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/variant_emb/sldsc_feat.pkl')) + data['SNP'].x = torch.vstack([torch.tensor(rs2idx_feat[node_map[i]]) if node_map[i] in rs2idx_feat \ + else torch.rand(165, requires_grad = False) for i in range(len(node_map))]).float() + snp_init_dim_size = 165 + + elif snp_init_emb == 'enformer': + print('--using enformer SNP embedding--') + node_map = idx2id['SNP'] + rs2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/variant_emb/enformer_feat.pkl')) + data['SNP'].x = torch.vstack([torch.tensor(rs2idx_feat[node_map[i]]) if node_map[i] in rs2idx_feat \ + else torch.rand(20, requires_grad = False) for i in range(len(node_map))]).float() + snp_init_dim_size = 20 + + + if go_init_emb == 'random': + print('--using random go embedding--') + + for rel in ['CellularComponent', 'BiologicalProcess', 'MolecularFunction']: + data[rel].x = torch.rand((len(idx2id[rel]), 128), requires_grad = False) + go_init_dim_size = 128 + elif go_init_emb == 'kg': + print('--using KG go embedding--') + + id2idx_kg = load_dict(os.path.join(data_path, 'cell_kg/node_emb/transe_emb/transe_emb_id2idx_kg.pkl')) + kg_emb = load_dict(os.path.join(data_path, 'cell_kg/node_emb/transe_emb/transe_emb_inverse_triplets.pkl')) + + for rel in ['CellularComponent', 'BiologicalProcess', 'MolecularFunction']: + node_map = idx2id[rel] + data[rel].x = torch.vstack([torch.tensor(kg_emb[id2idx_kg[node_map[i]]]) if node_map[i] in id2idx_kg \ + else torch.rand(50, requires_grad = False) for i in range(len(node_map))]) + go_init_dim_size = 50 + + elif go_init_emb == 'biogpt': + print('--using biogpt go embedding--') + + go2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/program_emb/biogpt_feat.pkl')) + for rel in ['CellularComponent', 'BiologicalProcess', 'MolecularFunction']: + node_map = idx2id[rel] + data[rel].x = torch.vstack([torch.tensor(go2idx_feat[node_map[i]]) if node_map[i] in go2idx_feat \ + else torch.rand(1600, requires_grad = False) for i in range(len(node_map))]).float() + go_init_dim_size = 1600 + + + if gene_init_emb == 'random': + print('--using random gene embedding--') + + data['Gene'].x = torch.rand((len(idx2id['Gene']), 128), requires_grad = False) + gene_init_dim_size = 128 + elif gene_init_emb == 'kg': + print('--using KG gene embedding--') + id2idx_kg = load_dict(os.path.join(data_path, 'cell_kg/node_emb/transe_emb/transe_emb_id2idx_kg.pkl')) + kg_emb = load_dict(os.path.join(data_path, 'cell_kg/node_emb/transe_emb/transe_emb_inverse_triplets.pkl')) + node_map = idx2id['Gene'] + data['Gene'].x = torch.vstack([torch.tensor(kg_emb[id2idx_kg[node_map[i]]]) if node_map[i] in id2idx_kg \ + else torch.rand(50, requires_grad = False) for i in range(len(node_map))]) + gene_init_dim_size = 50 + + elif gene_init_emb == 'esm': + print('--using ESM gene embedding--') + + gene2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/gene_emb/esm_feat.pkl')) + node_map = idx2id['Gene'] + data['Gene'].x = torch.vstack([torch.tensor(gene2idx_feat[node_map[i]]) if node_map[i] in gene2idx_feat \ + else torch.rand(5120, requires_grad = False) for i in range(len(node_map))]).float() + gene_init_dim_size = 5120 + elif gene_init_emb == 'pops': + print('--using PoPs expression+PPI+pathways gene embedding--') + + gene2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/gene_emb/pops_feat.pkl')) + node_map = idx2id['Gene'] + data['Gene'].x = torch.vstack([torch.tensor(gene2idx_feat[node_map[i]]) if node_map[i] in gene2idx_feat \ + else torch.rand(57742, requires_grad = False) for i in range(len(node_map))]).float() + gene_init_dim_size = 57742 + elif gene_init_emb == 'pops_expression': + print('--using PoPs expression only gene embedding--') + + gene2idx_feat = load_dict(os.path.join(data_path, 'cell_kg/node_emb/gene_emb/pops_expression_feat.pkl')) + node_map = idx2id['Gene'] + data['Gene'].x = torch.vstack([torch.tensor(gene2idx_feat[node_map[i]]) if node_map[i] in gene2idx_feat \ + else torch.rand(40546, requires_grad = False) for i in range(len(node_map))]).float() + gene_init_dim_size = 40546 + + + self.gene_init_dim_size = gene_init_dim_size + self.go_init_dim_size = go_init_dim_size + self.snp_init_dim_size = snp_init_dim_size + + for i,j in edge_index_all.items(): + + if sample_edges: + edge_index = torch.tensor(j) + num_edges = edge_index.size(1) + num_samples = int(num_edges * sample_ratio) + indices = torch.randperm(num_edges)[:num_samples] + sampled_edge_index = edge_index[:, indices] + print(i, ' sampling ratio ', sample_ratio, ' from ', edge_index.shape[1], ' to ', sampled_edge_index.shape[1]) + data[i].edge_index = sampled_edge_index + else: + data[i].edge_index = torch.tensor(j) + data = T.ToUndirected()(data) + data = T.AddSelfLoops()(data) + self.data = data + + def load_simulation_gwas(self, simulation_type, seed): + data_path = self.data_path + print('Using simulation data....') + small_cohort = 5000 + num_causal_hits = 20000 + heritability = 0.3 + self.sample_size = small_cohort + if simulation_type == 'causal_link': + lr_uni = pd.read_csv(os.path.join(data_path, 'simulation_gwas/causal_link_simulation/' + str(num_causal_hits) + '_' + str(seed) + '_' + str(heritability) + '_graph_funct_v2_ggi.fastGWA'), sep = '\t') + elif simulation_type == 'causal': + lr_uni = pd.read_csv(os.path.join(data_path, 'simulation_gwas/causal_simulation/' + str(num_causal_hits) + '_' + str(seed) + '_' + str(heritability) + '_' + str(small_cohort) + '_graph_funct_v2.fastGWA'), sep = '\t') + elif simulation_type == 'null': + lr_uni = pd.read_csv(os.path.join(data_path, 'simulation_gwas/null_simulation/' + str(num_causal_hits) + '_' + str(seed) + '_' + str(heritability) + '_' + str(small_cohort) + '.fastGWA'), sep = '\t') + + if ('SNP' in lr_uni.columns.values) and ('ID' in lr_uni.columns.values): + self.lr_uni = lr_uni.rename(columns = {'CHR': '#CHROM'}) + else: + self.lr_uni = lr_uni.rename(columns = {'CHR': '#CHROM', 'SNP': 'ID'}) + self.seed = seed + self.pheno = 'simulation' + + def load_external_gwas(self, path = None, seed = 42, example_file = False): + if example_file: + print('Loading example GWAS file...') + url = "https://dataverse.harvard.edu/api/access/datafile/10730346" + example_file_path = os.path.join(self.data_path, 'biochemistry_Creatinine_fastgwa_full_10000_1.fastGWA') + + # Check if the example file is already downloaded + if not os.path.exists(example_file_path): + print('Example file not found locally. Downloading...') + self._download_with_progress(url, example_file_path) + print('Example file downloaded successfully.') + else: + print('Example file already exists locally.') + + path = example_file_path + + if path is None: + raise ValueError("A valid path must be provided or example_file must be set to True.") + + print(f'Loading GWAS file from {path}...') + + lr_uni = pd.read_csv(path, sep=None, engine='python') + if 'CHR' not in lr_uni.columns.values: + raise ValueError('CHR chromosome not in the file!') + if 'SNP' not in lr_uni.columns.values: + raise ValueError('SNP column not in the file!') + if 'P' not in lr_uni.columns.values: + raise ValueError('P column not in the file!') + if 'N' not in lr_uni.columns.values: + raise ValueError('N column number of sample size not in the file!') + lr_uni = lr_uni.rename(columns = {'CHR': '#CHROM', 'SNP': 'ID'}) + + ## filtering to the current KG variant set + old_variant_set_len = len(lr_uni) + lr_uni = lr_uni[lr_uni.ID.isin(list(self.idx2id['SNP'].values()))] + print('Number of SNPs in the KG:', len(self.idx2id['SNP'])) + print('Number of SNPs in the GWAS:', old_variant_set_len) + print('Number of SNPs in the KG variant set:', len(lr_uni)) + + self.lr_uni = lr_uni + self.sample_size = lr_uni.N.values[0] + self.pheno = 'EXTERNAL' + self.seed = seed + + + def load_full_gwas(self, pheno, seed=42): + data_path = self.data_path + if pheno in scdrs_traits: + print('Using scdrs traits...') + self.pheno = pheno + lr_uni = pd.read_csv(os.path.join(data_path, 'scDRS_Data/sumstats_ukb_snps.csv')) + lr_uni = lr_uni[['CHR', 'SNP', 'POS', 'A1', 'A2', 'N', 'AF1', pheno]] + lr_uni = lr_uni[lr_uni[pheno].notnull()].reset_index(drop = True) + lr_uni = lr_uni.rename(columns = {'CHR': '#CHROM', 'SNP': 'ID', pheno: 'chi'}) + print('number of SNPs:', len(lr_uni)) + self.lr_uni = lr_uni + self.seed = seed + + trait2size = pickle.load(open(os.path.join(data_path, 'scDRS_data/trait2size.pkl'), 'rb')) + self.sample_size = trait2size[pheno] + + else: + ## load GWAS files + self.pheno = pheno + lr_uni = pd.read_csv(os.path.join(data_path, 'full_gwas/' + str(self.pheno) + '_with_rel_fastgwa.fastGWA'), sep = '\t') + lr_uni = lr_uni.rename(columns = {'CHR': '#CHROM', 'SNP': 'ID'}) + + self.lr_uni = lr_uni + self.seed = seed + self.sample_size = 387113 + + def load_gwas_subsample(self, pheno, sample_size, seed): + data_path = self.data_path + if pheno in ['body_BALDING1', 'cancer_BREAST', 'disease_ALLERGY_ECZEMA_DIAGNOSED', 'disease_HYPOTHYROIDISM_SELF_REP', 'other_MORNINGPERSON', 'pigment_SUNBURN']: + binary = True + else: + binary = False + ## load GWAS files + self.sample_size = sample_size + self.pheno = pheno + if (sample_size > 3000): + lr_uni = pd.read_csv(os.path.join(data_path, 'subsample_gwas/' + str(self.pheno) + \ + '_fastgwa_full_'+ str(sample_size) + '_' + str(seed) + '.fastGWA'), sep = '\t') + lr_uni = lr_uni.rename(columns = {'CHR': '#CHROM', 'SNP': 'ID'}) + else: + ## use PLINK if sample size <3000 + if binary: + lr_uni = pd.read_csv(os.path.join(data_path, 'subsample_gwas/' + str(self.pheno) + \ + '_plink_'+ str(sample_size) + '_' + str(seed) + '.PHENO1.glm.logistic.hybrid'), sep = '\t') + else: + lr_uni = pd.read_csv(os.path.join(data_path, 'subsample_gwas/' + + str(self.pheno) + \ + '_plink_'+ str(sample_size) + '_' + str(seed) + '.PHENO1.glm.linear'), sep = '\t') + self.lr_uni = lr_uni + self.seed = seed + + def process_gwas_file(self, label = 'chi'): + data_path = self.data_path + lr_uni = self.lr_uni + ## LD scores + + ld_scores = pd.read_csv(os.path.join(data_path, 'ld_score/filter_genotyped_ldscores.csv')) + w_ld_scores = pd.read_csv(os.path.join(data_path, 'ld_score/ldscores_from_data.csv')) + + m = 15000000 + if 'N' not in lr_uni.columns.values: + n = self.sample_size + else: + n = np.mean(lr_uni.N) + h_g_2 = 0.5 + rs_id_2_ld_scores = dict(ld_scores.values) + + rs_id_2_ld_scores = dict(ld_scores.values) + rs_id_2_w_ld = dict(w_ld_scores.values) + + ## use min ld score for snps with no ld score + min_ld = min(rs_id_2_ld_scores.values()) + lr_uni['ld_score'] = lr_uni.ID.apply(lambda x: rs_id_2_ld_scores[x] if x in rs_id_2_ld_scores else min_ld) + rs_id_2_ld_scores = dict(lr_uni[['ID', 'ld_score']].values) + + min_ld = min(rs_id_2_w_ld.values()) + ## the data LD is without the query SNP itself. so here add 1 + lr_uni['w_ld_score'] = 1 + lr_uni.ID.apply(lambda x: rs_id_2_w_ld[x] if x in rs_id_2_w_ld else min_ld) + rs_id_2_w_ld = dict(lr_uni[['ID', 'w_ld_score']].values) + + print('Using ldsc weight...') + ld = np.array([rs_id_2_ld_scores[rs_id] for rs_id in lr_uni.ID.values]) + w_ld = np.array([rs_id_2_w_ld[rs_id] for rs_id in lr_uni.ID.values]) + + ldsc_weight = ldsc_regression_weights(ld, w_ld, n, m, h_g_2) + ldsc_weight = ldsc_weight/np.mean(ldsc_weight) + print('ldsc_weight mean: ', np.mean(ldsc_weight)) + self.rs_id_to_ldsc_weight = dict(zip(lr_uni.ID.values, ldsc_weight)) + + ## chi-square label + if label == 'chi': + if 'chi' in lr_uni.columns.values: + print('chi pre-computed...') + lr_uni['y'] = lr_uni['chi'].values + else: + if self.pheno in (['body_BALDING1', 'cancer_BREAST', 'disease_ALLERGY_ECZEMA_DIAGNOSED', 'disease_HYPOTHYROIDISM_SELF_REP', 'other_MORNINGPERSON', 'pigment_SUNBURN']) and (self.sample_size <= 3000): + lr_uni['y'] = lr_uni['Z_STAT'].values**2 + lr_uni['y'] = lr_uni.y.fillna(0) + else: + if ('BETA' in lr_uni.columns.values) and ('SE' in lr_uni.columns.values): + lr_uni['y'] = (lr_uni['BETA']/lr_uni['SE']).values**2 + lr_uni['y'] = lr_uni.y.fillna(0) + else: + from scipy.stats import chi2 + ## convert from p-values + lr_uni['y'] = chi2.ppf(1 - lr_uni['P'].values, 1) + lr_uni['y'] = lr_uni.y.fillna(0) + + + elif label == 'residual-w-ld': + lr_uni['y'] = (lr_uni['BETA']/lr_uni['SE']).values**2 + lr_uni['y'] = lr_uni.y.fillna(0) + lr_uni['ld_weight'] = lr_uni.ID.apply(lambda x: self.rs_id_to_ldsc_weight[x]) + import statsmodels.api as sm + + X = lr_uni.w_ld_score.values + y = lr_uni.y.values + weights = lr_uni.ld_weight.values + X = sm.add_constant(X) + model = sm.WLS(y, X, weights=weights) + results = model.fit() + y_pred = results.params[0] + results.params[1] * lr_uni.w_ld_score.values + lr_uni['y'] = y - y_pred + elif label == 'residual-ld': + lr_uni['y'] = (lr_uni['BETA']/lr_uni['SE']).values**2 + lr_uni['y'] = lr_uni.y.fillna(0) + lr_uni['ld_weight'] = lr_uni.ID.apply(lambda x: self.rs_id_to_ldsc_weight[x]) + import statsmodels.api as sm + + X = lr_uni.ld_score.values + y = lr_uni.y.values + weights = lr_uni.ld_weight.values + X = sm.add_constant(X) + model = sm.WLS(y, X, weights=weights) + results = model.fit() + y_pred = results.params[0] + results.params[1] * lr_uni.w_ld_score.values + lr_uni['y'] = y - y_pred + elif label == 'residual-ld-ols': + lr_uni['y'] = (lr_uni['BETA']/lr_uni['SE']).values**2 + lr_uni['y'] = lr_uni.y.fillna(0) + import statsmodels.api as sm + + X = lr_uni.ld_score.values + y = lr_uni.y.values + X = sm.add_constant(X) + model = sm.OLS(y, X) + results = model.fit() + y_pred = results.params[0] + results.params[1] * lr_uni.w_ld_score.values + lr_uni['y'] = y - y_pred + elif label == 'residual-ld-ols-abs': + lr_uni['y'] = (lr_uni['BETA']/lr_uni['SE']).values**2 + lr_uni['y'] = lr_uni.y.fillna(0) + import statsmodels.api as sm + + X = lr_uni.ld_score.values + y = lr_uni.y.values + X = sm.add_constant(X) + model = sm.OLS(y, X) + results = model.fit() + y_pred = results.params[0] + results.params[1] * lr_uni.w_ld_score.values + lr_uni['y'] = np.abs(y - y_pred) + elif label == 'residual-w-ld-ols': + lr_uni['y'] = (lr_uni['BETA']/lr_uni['SE']).values**2 + lr_uni['y'] = lr_uni.y.fillna(0) + import statsmodels.api as sm + + X = lr_uni.w_ld_score.values + y = lr_uni.y.values + X = sm.add_constant(X) + model = sm.OLS(y, X) + results = model.fit() + y_pred = results.params[0] + results.params[1] * lr_uni.w_ld_score.values + lr_uni['y'] = y - y_pred + + id2y = dict(lr_uni[['ID', 'y']].values) + all_ids = lr_uni.ID.values + self.all_ids = np.array([self.id2idx['SNP'][i] for i in all_ids]) + self.y = lr_uni.y.values + #idx2y = dict(zip(self.all_ids, y)) + + self.lr_uni = lr_uni + + def prepare_split(self, test_set_fraction_data = 0.05): + + ## split SNPs to train/test/valid + train_val_ids, test_ids, y_train_val, y_test = train_test_split(self.all_ids, self.y, test_size=test_set_fraction_data, random_state=self.seed) + train_ids, val_ids, y_train, y_val = train_test_split(train_val_ids, y_train_val, test_size=0.05, random_state=self.seed) + + self.train_input_nodes = ('SNP', train_ids) + self.val_input_nodes = ('SNP', val_ids) + self.test_input_nodes = ('SNP', test_ids) + + y_snp = torch.zeros(self.data['SNP'].x.shape[0]) - 1 + y_snp[train_ids] = torch.tensor(y_train).float() + y_snp[val_ids] = torch.tensor(y_val).float() + y_snp[test_ids] = torch.tensor(y_test).float() + + self.data['SNP'].y = y_snp + for i in self.data.node_types: + self.data[i].n_id = torch.arange(self.data[i].x.shape[0]) + + self.data.train_mask = train_ids + self.data.val_mask = val_ids + self.data.test_mask = test_ids + self.data.all_mask = self.all_ids + #data = data.to(args.device) + + def get_pheno_list(self): + return {"large_cohort": scdrs_traits, + "21_indep_traits": ['body_BALDING1', + 'disease_ALLERGY_ECZEMA_DIAGNOSED', + 'disease_HYPOTHYROIDISM_SELF_REP', 'pigment_SUNBURN', + '21001', '50', '30080', '30070', '30010', '30000', + 'biochemistry_AlkalinePhosphatase', + 'biochemistry_AspartateAminotransferase', + 'biochemistry_Cholesterol', 'biochemistry_Creatinine', + 'biochemistry_IGF1', 'biochemistry_Phosphate', + 'biochemistry_Testosterone_Male', 'biochemistry_TotalBilirubin', + 'biochemistry_TotalProtein', 'biochemistry_VitaminD', + 'bmd_HEEL_TSCOREz']} \ No newline at end of file