--- a
+++ b/shepherd/predict.py
@@ -0,0 +1,174 @@
+# General
+import numpy as np
+import pickle
+import random
+import argparse
+import os
+import sys
+from pathlib import Path
+from datetime import datetime
+import time
+from collections import Counter
+import pandas as pd
+
+sys.path.insert(0, '..') # add project_config to path
+
+# Pytorch
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+# Pytorch lightning
+import pytorch_lightning as pl
+from pytorch_lightning.loggers import WandbLogger
+
+# W&B
+import wandb
+
+
+# Own code
+import project_config
+from shepherd.dataset import PatientDataset
+from shepherd.samplers import PatientNeighborSampler
+
+
+import preprocess
+from hparams import get_predict_hparams
+from train import get_model, load_patient_datasets, get_dataloaders
+
+import os
+os.environ['CUDA_LAUNCH_BLOCKING'] = "1" 
+
+
+'''
+Example Command:
+
+python predict.py \
+--run_type causal_gene_discovery \
+--patient_data test_predict \
+--edgelist KG_edgelist_mask.txt \
+--node_map KG_node_map.txt \
+--saved_node_embeddings_path checkpoints/pretrain.ckpt \
+--best_ckpt checkpoints/causal_gene_discovery.ckpt 
+
+python predict.py \
+--run_type patients_like_me \
+--patient_data test_predict \
+--edgelist KG_edgelist_mask.txt \
+--node_map KG_node_map.txt \
+--saved_node_embeddings_path checkpoints/pretrain.ckpt \
+--best_ckpt checkpoints/patients_like_me.ckpt 
+
+python predict.py \
+--run_type disease_characterization \
+--patient_data test_predict \
+--edgelist KG_edgelist_mask.txt \
+--node_map KG_node_map.txt \
+--saved_node_embeddings_path checkpoints/pretrain.ckpt \
+--best_ckpt checkpoints/disease_characterization.ckpt 
+'''
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Predict using SHEPHERD")
+    parser.add_argument("--edgelist", type=str, default=None, help="File with edge list")
+    parser.add_argument("--node_map", type=str, default=None, help="File with node list")
+    parser.add_argument('--patient_data', default="disease_simulated", type=str)
+    parser.add_argument('--run_type', choices=["causal_gene_discovery", "disease_characterization", "patients_like_me"], type=str)
+    parser.add_argument('--saved_node_embeddings_path', type=str, default=None, help='Path to pretrained model checkpoint')
+    parser.add_argument('--best_ckpt', type=str, default=None, help='Name of the best performing checkpoint')
+    args = parser.parse_args()
+    return args
+
+
+@torch.no_grad()
+def predict(args):
+    
+    # Hyperparameters
+    hparams = get_predict_hparams(args)
+
+    # Seed
+    pl.seed_everything(hparams['seed'])
+
+    # Read KG
+    all_data, edge_attr_dict, nodes = preprocess.preprocess_graph(args)
+    n_nodes = len(nodes["node_idx"].unique())
+    gene_phen_dis_node_idx = torch.LongTensor(nodes.loc[nodes['node_type'].isin(['gene/protein', 'effect/phenotype', 'disease']), 'node_idx'].values)
+
+
+    # Get dataset
+    print('Loading SPL...')
+    spl = np.load(project_config.PROJECT_DIR / 'patients' / hparams['spl'])  
+    if (project_config.PROJECT_DIR / 'patients' / hparams['spl_index']).exists():
+        with open(str(project_config.PROJECT_DIR / 'patients' / hparams['spl_index']), "rb") as input_file:
+            spl_indexing_dict = pickle.load(input_file)
+    else: spl_indexing_dict = None 
+    print('Loaded SPL information')
+    
+    dataset = PatientDataset(project_config.PROJECT_DIR / 'patients' / hparams['test_data'], time=hparams['time'])
+    print(f'There are {len(dataset)} patients in the test dataset')
+    hparams.update({'inference_batch_size': len(dataset)})
+    print('batch size: ', hparams['inference_batch_size'])
+    # Get dataloader
+    nid_to_spl_dict = {nid: idx for idx, nid in enumerate(nodes[nodes["node_type"] == "gene/protein"]["node_idx"].tolist())}
+
+
+    dataloader = PatientNeighborSampler('predict', all_data.edge_index, all_data.edge_index[:,all_data.test_mask], 
+                                        sizes = [-1,10,5], patient_dataset=dataset, batch_size = hparams['inference_batch_size'], sparse_sample = 0, 
+                                        all_edge_attributes=all_data.edge_attr, n_nodes = n_nodes, relevant_node_idx=gene_phen_dis_node_idx,
+                                        n_cand_diseases=hparams['test_n_cand_diseases'],  use_diseases=hparams['use_diseases'], 
+                                        nid_to_spl_dict=nid_to_spl_dict, gp_spl=spl, spl_indexing_dict=spl_indexing_dict,
+                                        shuffle = False, num_workers=hparams['num_workers'],
+                                        hparams=hparams, pin_memory=hparams['pin_memory']) # 'test'
+    
+    # Create Weights & Biases Logger
+    run_name = 'test'
+    wandb_logger = WandbLogger(name=run_name, project='rare_disease_dx_combined', entity='rare_disease_dx', save_dir=hparams['wandb_save_dir'],
+                    id="_".join(run_name.split(":")), resume="allow") 
+
+    # Get patient model 
+    model = get_model(args, hparams, None, all_data, edge_attr_dict,  n_nodes,load_from_checkpoint=True)
+
+    trainer = pl.Trainer(gpus=hparams['n_gpus'])
+    
+    t1 = time.time()
+    results = trainer.predict(model, dataloaders=dataloader)
+    t2 = time.time()
+    print(f"Predicting took {t2 - t1:0.4f} seconds", len(dataset), "patients")
+
+    scores_dfs, attn_dfs, gat_attn_df_1, gat_attn_df_2, gat_attn_df_3, phenotype_embeddings, disease_embeddings = zip(*results)
+    
+    print('---- RESULTS ----')
+    if not os.path.exists(project_config.PROJECT_DIR / 'results'):
+        os.mkdir(project_config.PROJECT_DIR / 'results')
+    output_base = project_config.PROJECT_DIR / 'results' /  (str(args.best_ckpt).replace('/', '.').split('.ckpt')[0])     
+
+    # Save scores
+    scores_df = pd.concat(scores_dfs).reset_index(drop=True)
+    scores_df.to_csv(str(output_base) + '_scores.csv', index=False)
+    print(scores_df)
+
+    # Save patient phenotype attention
+    attn_df = pd.concat(attn_dfs).reset_index(drop=True)
+    attn_df.to_csv(str(output_base) + '_phenotype_attn.csv', index=False)
+    print(attn_df)
+
+    # Save patient phenotype embeddings
+    if type(phenotype_embeddings) == tuple: phenotype_embeddings = phenotype_embeddings[0]
+    torch.save(phenotype_embeddings, str(output_base) + '_phenotype_embeddings.pth')
+    print("Phenotype embeddings", phenotype_embeddings)
+
+    # Save disease embeddings
+    if args.run_type == "disease_characterization":
+        if type(disease_embeddings) == tuple: disease_embeddings = disease_embeddings[0]
+        torch.save(disease_embeddings, str(output_base) + '_disease_embeddings.pth')
+        print("Disease embeddings", disease_embeddings)
+
+
+if __name__ == "__main__":
+    
+    # Get hyperparameters
+    args = parse_args()
+
+    # perform prediction
+    predict(args)