--- a
+++ b/kgwas/kgwas.py
@@ -0,0 +1,273 @@
+from copy import deepcopy
+from tqdm import tqdm
+import os
+os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
+import pandas as pd
+import numpy as np
+import pickle
+import subprocess
+
+import torch
+import torch.nn.functional as F
+import torch.optim as optim
+
+import torch.multiprocessing
+torch.multiprocessing.set_sharing_strategy('file_system')
+from torch_geometric.loader import NeighborLoader
+from .utils import print_sys, compute_metrics, save_dict, \
+                        load_dict, load_pretrained, save_model, \
+                        evaluate_minibatch_clean, process_data, \
+                        get_network_weight, generate_viz
+from .eval_utils import storey_ribshirani_integrate, get_clumps_gold_label, get_meta_clumps, \
+                        get_mega_clump_query, get_curve, find_closest_x
+from .model import HeteroGNN
+
+class KGWAS:
+    def __init__(self,
+                data,
+                weight_bias_track = False,
+                device = 'cuda',
+                proj_name = 'KGWAS',
+                exp_name = 'KGWAS',
+                seed = 42):
+        torch.manual_seed(seed)
+        torch.cuda.manual_seed(seed)
+        np.random.seed(seed)
+        self.seed = seed
+        torch.backends.cudnn.enabled = False
+        use_cuda = torch.cuda.is_available()
+        self.device = device if use_cuda else "cpu"
+
+        self.data = data
+        self.data_path = data.data_path
+        if weight_bias_track:
+            import wandb
+            wandb.init(project=proj_name, name=exp_name)  
+            self.wandb = wandb
+        else:
+            self.wandb = False
+        self.exp_name = exp_name
+        
+        
+    def initialize_model(self, gnn_num_layers = 2, gnn_hidden_dim = 128, gnn_backbone = 'GAT', gnn_aggr = 'sum', gat_num_head = 1, no_relu = False):
+
+        self.config = {
+            'gnn_num_layers': gnn_num_layers,
+            'gnn_hidden_dim': gnn_hidden_dim,
+            'gnn_backbone': gnn_backbone,
+            'gnn_aggr': gnn_aggr,
+            'gat_num_head': gat_num_head
+        }
+
+        self.gnn_num_layers = gnn_num_layers
+        self.model = HeteroGNN(self.data.data, gnn_hidden_dim, 1, 
+                               gnn_num_layers, gnn_backbone, gnn_aggr,
+                               self.data.snp_init_dim_size,
+                               self.data.gene_init_dim_size,
+                               self.data.go_init_dim_size,
+                               gat_num_head,
+                               no_relu = no_relu,
+                              ).to(self.device)
+    
+    
+    def load_pretrained(self, path):
+        with open(os.path.join(path, 'config.pkl'), 'rb') as f:
+            config = pickle.load(f)
+            
+        self.initialize_model(**config)
+        self.config = config
+        
+        self.model = load_pretrained(path, self.model)
+        self.best_model = self.model
+        self.kgwas_res = pd.read_csv(os.path.join(path, 'pred.csv'), sep = None, engine = 'python')
+        self.save_name = path.split('/')[-1]
+
+    def train(self, batch_size = 512, num_workers = 0, lr = 1e-4, 
+                    weight_decay = 5e-4, epoch = 10, save_best_model = True, 
+                    save_name = None, data_to_cuda = False):
+        total_epoch = epoch
+        if save_name is None:
+            save_name = self.exp_name
+        self.save_name = save_name
+        print_sys('Creating data loader...')
+        kwargs = {'batch_size': batch_size, 'num_workers': num_workers, 'drop_last': True}
+        eval_kwargs = {'batch_size': 512, 'num_workers': num_workers, 'drop_last': False}
+        
+        if data_to_cuda:
+            self.data.data = self.data.data.to(self.device)
+        
+        self.train_loader = NeighborLoader(self.data.data, num_neighbors=[-1] * self.gnn_num_layers, 
+                                      sampler = None,
+                                      input_nodes=self.data.train_input_nodes, **kwargs)
+        self.val_loader = NeighborLoader(self.data.data, num_neighbors=[-1] * self.gnn_num_layers,
+                                    input_nodes=self.data.val_input_nodes, **kwargs)
+        self.test_loader = NeighborLoader(self.data.data, num_neighbors=[-1] * self.gnn_num_layers,
+                                    input_nodes=self.data.test_input_nodes, **eval_kwargs)
+        
+        X_infer = self.data.lr_uni.ID.values
+        #print_sys('# of to-infer SNPs: ' + str(len(X_infer)))
+        infer_idx = np.array([self.data.id2idx['SNP'][i] for i in X_infer])
+        infer_input_nodes = ('SNP', infer_idx)
+
+        self.infer_loader = NeighborLoader(self.data.data, num_neighbors=[-1] * self.gnn_num_layers,
+                                    input_nodes=infer_input_nodes, **eval_kwargs)
+        
+        ## model training
+        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay = weight_decay)
+
+        loss_fct = F.mse_loss
+        earlystop_validation_metric = 'pearsonr'
+        binary_output = False
+        earlystop_direction = 'ascend'
+        min_val = -1000
+
+        self.best_model = deepcopy(self.model).to(self.device)
+        print_sys('Start Training...')
+        for epoch in range(total_epoch):
+            self.model.train()
+
+            for step, batch in enumerate(tqdm(self.train_loader, desc=f"Training Progress Epoch {epoch+1}/{total_epoch}", total=len(self.train_loader))):
+                optimizer.zero_grad()
+                if data_to_cuda:
+                    pass
+                    #batch = batch.to(self.device, 'edge_index')
+                else:
+                    batch = batch.to(self.device)
+                bs_batch = batch['SNP'].batch_size
+
+                out = self.model(batch.x_dict, batch.edge_index_dict, bs_batch)
+                pred = out.reshape(-1)
+
+                y_batch = batch['SNP'].y[:bs_batch]
+                rs_id = [self.data.idx2id['SNP'][i.item()] for i in batch['SNP']['n_id'][:bs_batch]]
+                ld_weight = torch.tensor([self.data.rs_id_to_ldsc_weight[i] for i in rs_id]).to(self.device)
+
+                loss = torch.mean(ld_weight * (pred - y_batch)**2)
+
+                if self.wandb:
+                    self.wandb.log({'training_loss': loss.item()})
+
+                loss.backward()
+                optimizer.step()
+
+                if (step % 500 == 0) and (step >= 500):
+                    log = "Epoch {} Step {} Train Loss: {:.4f}" 
+                    print_sys(log.format(epoch + 1, step + 1, loss.item()))
+
+            val_res = evaluate_minibatch_clean(self.val_loader, self.model, self.device)
+            val_metrics = compute_metrics(val_res, binary_output, -1, -1, loss_fct)
+
+
+            log = "Epoch {}: Validation MSE: {:.4f} " \
+                      "Validation Pearson: {:.4f}. "
+            print_sys(log.format(epoch + 1, val_metrics['mse'], 
+                         val_metrics['pearsonr']))
+
+            if self.wandb:
+                for i,j in val_metrics.items():
+                    self.wandb.log({'val_' + i: j})
+                
+            if val_metrics[earlystop_validation_metric] > min_val:
+                min_val = val_metrics[earlystop_validation_metric]
+                self.best_model = deepcopy(self.model)
+                best_epoch = epoch
+
+
+        if save_best_model:
+            save_model_path = self.data_path + '/model/'
+            print_sys('Saving models to ' + os.path.join(save_model_path, save_name))
+            save_model(self.best_model, self.config, os.path.join(save_model_path, save_name))
+
+
+        test_res = evaluate_minibatch_clean(self.test_loader, self.best_model, self.device)
+        test_metric = compute_metrics(test_res, binary_output, -1, -1, loss_fct)
+        if self.wandb:
+            for i,j in test_metric.items():
+                self.wandb.log({'test_' + i: j})    
+
+
+        infer_res = evaluate_minibatch_clean(self.infer_loader, self.best_model, self.device)
+
+        self.data.lr_uni['pred'] = infer_res['pred']
+        lr_uni_to_save = deepcopy(self.data.lr_uni)
+
+        self.data.lr_uni['abs_pred'] = np.abs(self.data.lr_uni['pred'])
+
+        self.data.lr_uni['SR_P_val'] = storey_ribshirani_integrate(self.data.lr_uni, column = 'abs_pred', num_bins = 500)
+        self.data.lr_uni['SR'] = -(np.log10(self.data.lr_uni['SR_P_val'].astype(float).values))
+        lr_uni_to_save['P_weighted'] = self.data.lr_uni['SR_P_val']
+
+        ## calibration
+        scale_factor = find_closest_x(lr_uni_to_save)
+        lr_uni_to_save['KGWAS_P'] = scale_factor * lr_uni_to_save['P_weighted']
+        lr_uni_to_save['KGWAS_P'] = lr_uni_to_save['KGWAS_P'].clip(lower=0, upper=1)
+
+        if not os.path.exists(self.data_path + '/model_pred/'):
+            os.makedirs(self.data_path + '/model_pred/')
+            os.makedirs(self.data_path + '/model_pred/new_experiments/')
+        lr_uni_to_save.to_csv(self.data_path + '/model_pred/new_experiments/' + save_name + '_pred.csv', index = False, sep = '\t')
+        print('KGWAS prediction and p-values saved to ' + self.data_path + '/model_pred/new_experiments/' + save_name + '_pred.csv')
+        if save_best_model:
+            lr_uni_to_save.to_csv(self.data_path + '/model/' + save_name + '/pred.csv', index = False, sep = '\t')
+        self.kgwas_res = lr_uni_to_save
+
+    def run_magma(self, path_to_magma, bfile):
+        if 'N' in self.kgwas_res.columns:
+            n_value = self.kgwas_res['N'].values[0]
+        else:
+            n_value = input("Please provide the sample size for the GWAS analysis.")
+        
+        url = "https://dataverse.harvard.edu/api/access/datafile/10731670"
+        annot_file_path = os.path.join(self.data_path, 'gene_annotation.genes.annot')
+
+        # Check if the example file is already downloaded
+        if not os.path.exists(annot_file_path):
+            print('Annotation file not found locally. Downloading...')
+            self.data._download_with_progress(url, annot_file_path)
+            print('Annotation file downloaded successfully.')
+        else:
+            print('Annotation file already exists locally.')
+
+        gene_annot = annot_file_path
+
+        magma_path = self.data_path + '/model_pred/new_experiments/' + self.save_name + '_magma_format.csv'
+        self.kgwas_res[['ID', 'KGWAS_P']].rename(columns = {'ID': 'SNP', 'KGWAS_P': 'P'}).to_csv(magma_path, index = False, sep = '\t')
+
+        # Construct the MAGMA command
+        command = [
+            path_to_magma,
+            "--bfile", bfile,
+            "--gene-annot", gene_annot,
+            "--pval", magma_path, f"N={n_value}",
+            "--out", self.data_path + '/model_pred/new_experiments/' + self.save_name + '_magma_out'
+        ]
+        
+        try:
+            # Run the command with real-time output
+            process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
+            print("Running MAGMA...")
+
+            # Stream stdout line by line
+            for line in process.stdout:
+                print(line, end="")  # Print each line as it's received
+
+            # Wait for the process to complete and capture stderr
+            stderr = process.communicate()[1]
+
+            if process.returncode == 0:
+                print("MAGMA command executed successfully.")
+            else:
+                print("MAGMA encountered an error.")
+                print("Error message:", stderr)
+        except FileNotFoundError:
+            print("MAGMA executable not found. Ensure it is in the specified path.")
+        except Exception as e:
+            print(f"An unexpected error occurred: {e}")
+
+
+    def get_disease_critical_network(self, variant_threshold = 5e-8, 
+                magma_path = None, magma_threshold = 0.05, program_threshold = 0.05,
+                K_neighbors = 3, num_cpus = 1):
+        df_network_weight = get_network_weight(self, self.data)
+        df_variant_interpretation, disease_critical_network = generate_viz(self, df_network_weight, self.data_path, variant_threshold, magma_path, magma_threshold, program_threshold, K_neighbors, num_cpus)
+        return df_network_weight, df_variant_interpretation, disease_critical_network
\ No newline at end of file