--- a
+++ b/gwas/classic_gwas_eval.py
@@ -0,0 +1,244 @@
+import argparse
+from copy import deepcopy
+from tqdm import tqdm
+import os
+import subprocess
+import sys
+sys.path.append('../')
+import pandas as pd
+import numpy as np
+from sklearn.model_selection import train_test_split
+from sklearn.metrics import precision_score
+
+from kgwas.data import ukbb_cohort
+from kgwas.utils import print_sys, evaluate, compute_metrics, save_dict, \
+                        load_dict, get_args, load_pretrained, save_model, \
+                        get_gwas_results
+from kgwas.params import main_data_path, cohort_data_path, kinship_path, withdraw_path, gwas_result_path
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--model', type=str, default='plink', choices = ['plink', 'fastgwa_full', 'fastgwa_match', 'gold_label', 'fastgwa_gold'])
+parser.add_argument('--threshold', type=float, default=5e-8)
+
+parser.add_argument('--data_split_seed', type=int, default=42)
+parser.add_argument('--wandb', action="store_true", default=False)
+parser.add_argument('--normalize_y', type=str, default='None', choices = ['quantile_normalization', 'std', 'log', 'None'])
+parser.add_argument('--covar_file', type=str, default='real_pca15', choices = ['full', 'old', 'real_value', 'real_pca15'])
+
+parser.add_argument('--test_set_fraction', type=float, default=0.7)
+parser.add_argument('--pheno', type=str, default='50') 
+parser.add_argument('--icd10', action="store_true", default=False)
+parser.add_argument('--external_cohort', action="store_true", default=False)
+parser.add_argument('--randomize', action="store_true", default=False)
+parser.add_argument('--use_sample_size', action="store_true", default=False)
+parser.add_argument('--sample_size', type=int, default=-1)
+parser.add_argument('--randomize_seed', type=int, default=42)
+parser.add_argument('--non_redundant_traits', action="store_true", default=False)
+
+parser.add_argument('--simulate', action="store_true", default=False)
+parser.add_argument('--num_causal_hits', type=int, default=1000)
+parser.add_argument('--heritability', type=float, default=0.1)
+parser.add_argument('--simulate_seed', type=int, default=42)
+parser.add_argument('--simulate_graph', action="store_true", default=False)
+parser.add_argument('--simulate_graph_func', action="store_true", default=False)
+parser.add_argument('--small_cohort', type=int, default=-1)
+parser.add_argument('--null_simulation', action="store_true", default=False)
+parser.add_argument('--network_plant', action="store_true", default=False)
+
+
+args = parser.parse_args()
+print(args)
+
+if args.icd10:
+    binary = True
+else:
+    if args.non_redundant_traits and (args.pheno in ['body_BALDING1', 'cancer_BREAST', 'disease_ALLERGY_ECZEMA_DIAGNOSED', 'disease_HYPOTHYROIDISM_SELF_REP', 'other_MORNINGPERSON', 'pigment_SUNBURN']):
+        print('Using binary traits...')
+        binary = True
+    else:
+        print('Using continuous traits...')
+        binary = False    
+    
+if args.simulate:
+    if args.simulate_graph:
+        name = str(args.num_causal_hits) + '_' + str(args.simulate_seed) + '_' + str(args.heritability) + '_graph_v2'
+        pheno_path = '/dfs/project/datasets/20220524-ukbiobank/data/genetics/simulate_data/result/causal_snp_' + str(args.num_causal_hits) + '_' + str(args.simulate_seed) + '_' + str(args.heritability) + '_graph_v2.phen'
+    elif args.network_plant:
+        if args.small_cohort != -1:
+            name = str(args.num_causal_hits) + '_' + str(args.simulate_seed) + '_' + str(args.heritability)+ '_' + str(args.small_cohort) + '_graph_funct_v2_ggi'
+            pheno_path = '/dfs/project/datasets/20220524-ukbiobank/data/genetics/simulate_data/result/causal_snp_' + str(args.num_causal_hits) + '_' + str(args.simulate_seed) + '_' + str(args.heritability) + '_' + str(args.small_cohort) + '_graph_funct_v2_ggi.phen'
+            if not os.path.exists(pheno_path):
+                print('creating pheno_path at: ' + pheno_path)
+                pd.read_csv('/dfs/project/datasets/20220524-ukbiobank/data/genetics/simulate_data/result/causal_snp_' + str(args.num_causal_hits) + '_' + str(args.simulate_seed) + '_' + str(args.heritability) + '_graph_funct_v2_ggi.phen', header = None).sample(n = args.small_cohort, random_state = 42).to_csv(pheno_path, index = False, header = None)
+                
+        else:
+            name = str(args.num_causal_hits) + '_' + str(args.simulate_seed) + '_' + str(args.heritability)+ '_graph_funct_v2_ggi'
+            pheno_path = '/dfs/project/datasets/20220524-ukbiobank/data/genetics/simulate_data/result/causal_snp_' + str(args.num_causal_hits) + '_' + str(args.simulate_seed) + '_' + str(args.heritability) + '_graph_funct_v2_ggi.phen'
+    elif args.simulate_graph_func:
+        if args.small_cohort != -1:
+            pheno_path = '/dfs/project/datasets/20220524-ukbiobank/data/genetics/simulate_data/result/causal_snp_' + str(args.num_causal_hits) + '_' + str(args.simulate_seed) + '_' + str(args.heritability) + '_' + str(args.small_cohort) + '_graph_funct_v2.phen'
+            if not os.path.exists(pheno_path):
+                print('creating pheno_path at: ' + pheno_path)
+                pd.read_csv('/dfs/project/datasets/20220524-ukbiobank/data/genetics/simulate_data/result/causal_snp_' + str(args.num_causal_hits) + '_' + str(args.simulate_seed) + '_' + str(args.heritability) + '_graph_funct_v2.phen', header = None).sample(n = args.small_cohort, random_state = 42).to_csv(pheno_path, index = False, header = None)
+
+            name = str(args.num_causal_hits) + '_' + str(args.simulate_seed) + '_' + str(args.heritability) + '_' + str(args.small_cohort)+ '_graph_funct_v2'                
+        else:
+            name = str(args.num_causal_hits) + '_' + str(args.simulate_seed) + '_' + str(args.heritability)+ '_graph_funct_v2'
+            pheno_path = '/dfs/project/datasets/20220524-ukbiobank/data/genetics/simulate_data/result/causal_snp_' + str(args.num_causal_hits) + '_' + str(args.simulate_seed) + '_' + str(args.heritability) + '_graph_funct_v2.phen'
+    else:
+        if args.small_cohort != -1:
+            if args.null_simulation:
+                pheno_path = '/dfs/project/datasets/20220524-ukbiobank/data/genetics/simulate_data/result/causal_snp_' + str(args.simulate_seed) + '_0_' + str(args.small_cohort) + '_null.phen'
+                name = str(args.simulate_seed) + '_0_' + str(args.small_cohort) + '_null'
+            else:
+                pheno_path = '/dfs/project/datasets/20220524-ukbiobank/data/genetics/simulate_data/result/causal_snp_' + str(args.num_causal_hits) + '_' + str(args.simulate_seed) + '_' + str(args.heritability) + '_' + str(args.small_cohort) + '.phen'
+                if not os.path.exists(pheno_path):
+                    print('creating pheno_path at: ' + pheno_path)
+                    pd.read_csv('/dfs/project/datasets/20220524-ukbiobank/data/genetics/simulate_data/result/causal_snp_' + str(args.num_causal_hits) + '_' + str(args.simulate_seed) + '_' + str(args.heritability) + '.phen', header = None).sample(n = args.small_cohort, random_state = 42).to_csv(pheno_path, index = False, header = None)
+                    
+                name = str(args.num_causal_hits) + '_' + str(args.simulate_seed) + '_' + str(args.heritability) + '_' + str(args.small_cohort)
+        else:
+            if args.null_simulation:
+                pheno_path = '/dfs/project/datasets/20220524-ukbiobank/data/genetics/simulate_data/result/causal_snp_' + str(args.simulate_seed) + '_0_null.phen'
+                name = str(args.simulate_seed) + '_0_null'
+            else:
+                pheno_path = '/dfs/project/datasets/20220524-ukbiobank/data/genetics/simulate_data/result/causal_snp_' + str(args.num_causal_hits) + '_' + str(args.simulate_seed) + '_' + str(args.heritability) + '.phen'
+                name = str(args.num_causal_hits) + '_' + str(args.simulate_seed) + '_' + str(args.heritability)
+        
+    
+    if (args.null_simulation) or ((args.small_cohort != -1) and (args.small_cohort <= 3000)):
+        gwas_result_path_file = os.path.join(gwas_result_path, name + '.PHENO1.glm.linear')
+        if os.path.exists(gwas_result_path_file):
+            print_sys('local file detected... not running again...')
+        else:
+            print_sys('Running GWAS on PLINK... Takes > 1 hour...')
+            cmd = "bash ./plink_python_interface_gwas.sh " + pheno_path + ' ' + name
+            p = subprocess.Popen(cmd, shell = True, stdout=subprocess.PIPE, bufsize=1)
+            for line in iter(p.stdout.readline, b''):
+                print_sys(line)
+            p.stdout.close()
+            p.wait()
+            
+    else:
+        gwas_result_path_file = os.path.join(gwas_result_path, name + '.fastGWA')
+        print(gwas_result_path_file)
+    
+        if os.path.exists(gwas_result_path_file):
+            print_sys('local file detected... not running again...')
+        else:
+            print_sys('Running GWAS on fastGWA... Takes > 1 hour...')
+            cmd = "bash ./fastgwa_python_interface.sh " + pheno_path + ' ' + name
+            p = subprocess.Popen(cmd, shell = True, stdout=subprocess.PIPE, bufsize=1)
+            for line in iter(p.stdout.readline, b''):
+                print_sys(line)
+            p.stdout.close()
+            p.wait()
+            
+else:    
+    if args.model not in ['gold_label', 'fastgwa_gold']:
+        if args.use_sample_size:
+            name = str(args.pheno) + '_' + args.model + '_' + str(args.sample_size) + '_' + str(args.data_split_seed)
+        else:
+            name = str(args.pheno) + '_' + args.model + '_' + str(args.test_set_fraction) + '_' + str(args.data_split_seed)
+        if args.threshold != 5e-8:
+            name += ('_' + str(args.threshold))
+    elif args.model == 'fastgwa_gold':
+        name = str(args.pheno) + '_with_rel_fastgwa'
+        args.test_set_fraction = 1
+    else:
+        name = str(args.pheno) + '_no_rel'
+        args.test_set_fraction = 1
+
+    if args.external_cohort:
+        name += '_sep_cohort'
+
+    if args.randomize:
+        name += '_randomize' + str(args.randomize_seed)
+
+    if args.wandb:
+        import wandb
+        wandb.init(project='KGWAS', name=name)
+        wandb.config.update(args)
+
+
+    if args.model in ['fastgwa_full', 'fastgwa_match', 'fastgwa_gold']:
+        keep_relatives = True
+    else:
+        keep_relatives = False
+
+    if args.model =='fastgwa_match':
+        fastgwa_match = True
+    else:
+        fastgwa_match = False
+
+    ukbb_data = ukbb_cohort(main_data_path, cohort_data_path, withdraw_path, keep_relatives=keep_relatives)
+    
+    if (not args.icd10) and (not args.non_redundant_traits):
+        args.pheno = int(args.pheno)
+    
+    if args.non_redundant_traits:
+        pheno = ukbb_data.get_external_traits(args.pheno, to_plink = True, random_seed = args.data_split_seed, sep_cohort = args.external_cohort, randomize = args.randomize, use_sample_size = args.use_sample_size, sample_size = args.sample_size, randomize_seed = args.randomize_seed)
+    else:
+        pheno = ukbb_data.get_phenotype(args.pheno, to_plink = True, normalize = args.normalize_y, frac = args.test_set_fraction, random_seed = args.data_split_seed, fastgwa_match = fastgwa_match, icd10 = args.icd10, sep_cohort = args.external_cohort, randomize = args.randomize, use_sample_size = args.use_sample_size, sample_size = args.sample_size, randomize_seed = args.randomize_seed)
+ 
+    if args.model == 'plink':
+        if args.use_sample_size:
+            pheno_path = os.path.join(cohort_data_path, str(args.pheno) + '_plink_no_relatives_' + str(args.sample_size) + '_' + str(args.data_split_seed))
+        else:
+            pheno_path = os.path.join(cohort_data_path, str(args.pheno) + '_plink_no_relatives_' + str(args.test_set_fraction) + '_' + str(args.data_split_seed))
+    elif args.model == 'fastgwa_full':
+        if args.use_sample_size:
+            pheno_path = os.path.join(cohort_data_path, str(args.pheno) + '_plink_with_relatives_' + str(args.sample_size) + '_' + str(args.data_split_seed))
+        else:
+            pheno_path = os.path.join(cohort_data_path, str(args.pheno) + '_plink_with_relatives_' + str(args.test_set_fraction) + '_' + str(args.data_split_seed))
+    elif args.model == 'fastgwa_match':
+        pheno_path = os.path.join(cohort_data_path, str(args.pheno) + '_plink_with_relatives_' + str(args.test_set_fraction) + '_' + str(args.data_split_seed) + '_match')
+    elif args.model == 'gold_label':
+        pheno_path = os.path.join(cohort_data_path, str(args.pheno) + '_plink_no_relatives')
+    elif args.model == 'fastgwa_gold':
+        pheno_path = os.path.join(cohort_data_path, str(args.pheno) + '_plink_with_relatives')
+
+    if args.external_cohort:
+        pheno_path += '_sep_cohort'
+
+    if args.randomize:
+        pheno_path += '_randomize' + str(args.randomize_seed)
+
+    pheno_path += '.txt'
+
+## feed to PLINK/FastGWA to run GWAS on K% of data
+
+    if args.model in ['plink', 'gold_label']:
+        if binary:
+            gwas_result_path_file = os.path.join(gwas_result_path, name + '.PHENO1.glm.logistic.hybrid')
+        else:
+            gwas_result_path_file = os.path.join(gwas_result_path, name + '.PHENO1.glm.linear')
+        if os.path.exists(gwas_result_path_file):
+            print_sys('local file detected... not running again...')
+        else:
+            print_sys('Running GWAS on PLINK... Takes > 1 hour...')
+            cmd = "bash ./plink_python_interface_gwas.sh " + pheno_path + ' ' + name
+            p = subprocess.Popen(cmd, shell = True, stdout=subprocess.PIPE, bufsize=1)
+            for line in iter(p.stdout.readline, b''):
+                print_sys(line)
+            p.stdout.close()
+            p.wait()
+
+    elif args.model in ['fastgwa_full', 'fastgwa_match', 'fastgwa_gold']:
+        gwas_result_path_file = os.path.join(gwas_result_path, name + '.fastGWA')
+        if os.path.exists(gwas_result_path_file):
+            print_sys('local file detected... not running again...')
+        else:
+            print_sys('Running GWAS on fastGWA... Takes > 1 hour...')
+            cmd = "bash ./fastgwa_python_interface.sh " + pheno_path + ' ' + name
+            p = subprocess.Popen(cmd, shell = True, stdout=subprocess.PIPE, bufsize=1)
+            for line in iter(p.stdout.readline, b''):
+                print_sys(line)
+            p.stdout.close()
+            p.wait()
+        
+        
+    df_gwas = pd.read_csv(gwas_result_path_file, sep = '\t')
+    if args.model in ['plink', 'gold_label']:
+        df_gwas['SNP'] = df_gwas['ID']
+    df_gwas.to_csv(gwas_result_path_file, index = False, sep = '\t')
\ No newline at end of file