--- a
+++ b/shepherd/patient_nca_model.py
@@ -0,0 +1,480 @@
+#pytorch lightning
+import pytorch_lightning as pl
+from pytorch_lightning.loggers import WandbLogger
+
+# torch
+from torch import nn
+import torch
+import torch.nn.functional as F
+import numpy as np
+from scipy.stats import rankdata
+import pandas as pd
+
+import time
+import wandb
+import sys
+import umap
+from pathlib import Path
+
+# Our code
+from node_embedder_model import NodeEmbeder
+from task_heads.gp_aligner import GPAligner
+from shepherd.task_heads.patient_nca import PatientNCA
+from utils.pretrain_utils import get_edges, calc_metrics
+from utils.train_utils import mean_reciprocal_rank, top_k_acc, average_rank
+from utils.train_utils import fit_umap, plot_softmax, mrr_vs_percent_overlap, plot_gene_rank_vs_x_intrain, plot_gene_rank_vs_hops, plot_degree_vs_attention, plot_nhops_to_gene_vs_attention, plot_gene_rank_vs_fraction_phenotype, plot_gene_rank_vs_numtrain, plot_gene_rank_vs_trainset
+
+sys.path.insert(0, '..') # add project_config to path
+import project_config
+
+class CombinedPatientNCA(pl.LightningModule):
+
+    def __init__(self, edge_attr_dict, all_data, n_nodes=None, node_ckpt=None, hparams=None):
+        super().__init__()
+        self.save_hyperparameters('hparams') 
+
+        self.all_data = all_data
+
+        self.all_train_nodes = []
+        self.train_patient_nodes = []
+
+        print(f"Loading Node Embedder from {node_ckpt}")
+        
+        # NOTE: loads in saved hyperparameters
+        self.node_model = NodeEmbeder.load_from_checkpoint(checkpoint_path=node_ckpt,
+                                                           all_data=all_data,
+                                                           edge_attr_dict=edge_attr_dict, 
+                                                           num_nodes=n_nodes)
+    
+        # NOTE: this will only work with GATv2Conv
+        self.patient_model = PatientNCA(hparams, embed_dim=self.node_model.hparams.hp_dict['output']*self.node_model.hparams.hp_dict['n_heads'])
+
+
+    def forward(self, batch):
+        # Node Embedder
+        t0 = time.time()
+        outputs, gat_attn = self.node_model.forward(batch.n_id, batch.adjs)
+        pad_outputs = torch.cat([torch.zeros(1, outputs.size(1), device=outputs.device), outputs]) 
+        t1 = time.time()
+
+        # get masks
+        phenotype_mask = (batch.batch_pheno_nid != 0)
+        if self.hparams.hparams['loss'] == 'patient_disease_NCA': disease_mask = (batch.batch_cand_disease_nid != 0)
+        else: disease_mask = None
+
+        # index into outputs using phenotype & disease batch node idx
+        batch_sz, max_n_phen = batch.batch_pheno_nid.shape
+        phenotype_embeddings = torch.index_select(pad_outputs, 0, batch.batch_pheno_nid.view(-1)).view(batch_sz, max_n_phen, -1)
+        if self.hparams.hparams['loss'] == 'patient_disease_NCA': 
+            batch_sz, max_n_dx = batch.batch_cand_disease_nid.shape
+            disease_embeddings = torch.index_select(pad_outputs, 0, batch.batch_cand_disease_nid.view(-1)).view(batch_sz, max_n_dx, -1)
+        else: disease_embeddings = None
+
+        t2 = time.time()
+        phenotype_embedding, disease_embeddings, phenotype_mask, disease_mask, attn_weights = self.patient_model.forward(phenotype_embeddings, disease_embeddings, phenotype_mask, disease_mask)
+        t3 = time.time()
+
+        if self.hparams.hparams['time']:
+            print(f'It takes {t1-t0:0.4f}s for the node model, {t2-t1:0.4f}s for indexing into the output, and {t3-t2:0.4f}s for the patient model forward.')
+        
+        return outputs, gat_attn, phenotype_embedding, disease_embeddings, phenotype_mask, disease_mask, attn_weights
+
+    def rank_diseases(self, disease_softmax, disease_mask, labels):
+        disease_mask = (disease_mask.sum(dim=1) > 0).unsqueeze(-1) # convert (batch, n_diseases) -> (batch, 1)
+        disease_softmax =  disease_softmax + (~disease_mask * -100000) # we want to rank the padded values last
+        disease_ranks = torch.tensor(np.apply_along_axis(lambda row: rankdata(row * -1, method='average'), axis=1, arr=disease_softmax.detach().cpu().numpy()))
+        if labels is None:
+            correct_disease_ranks = None
+        else:
+            disease_ranks = disease_ranks.to(labels.device)
+            correct_disease_ranks = [ranks[lab == 1] for ranks, lab in zip(disease_ranks, labels)]
+
+        return correct_disease_ranks
+
+    def rank_patients(self, patient_softmax, labels):
+        labels = labels * ~torch.eye(labels.shape[0], dtype=torch.bool).to(labels.device) # don't consider label positive for patients with themselves
+        patient_ranks = torch.tensor(np.apply_along_axis(lambda row: rankdata(row * -1, method='average'), axis=1, arr=patient_softmax.detach().cpu().numpy()))
+        if labels is None:
+            correct_patient_ranks = None
+        else:
+            patient_ranks = patient_ranks.to(labels.device)
+            correct_patient_ranks = [ranks[lab == 1] for ranks, lab in zip(patient_ranks, labels)]
+
+        return correct_patient_ranks, labels
+    
+    def _step(self, batch, step_type):
+        t0 = time.time()
+        if step_type != 'test':
+            batch = get_edges(batch, self.all_data, step_type)
+        t1 = time.time()
+
+        # forward pass
+        node_embeddings, gat_attn, phenotype_embedding, disease_embeddings, phenotype_mask, disease_mask, attn_weights = self.forward(batch)
+
+
+        # calculate patient embedding loss
+        use_candidate_list = self.hparams.hparams['only_hard_distractors'] #True if step_type != 'train' else False
+        if self.hparams.hparams['loss'] == 'patient_disease_NCA': labels = batch.disease_one_hot_labels
+        else: labels = batch.patient_labels
+        loss, softmax, labels, candidate_disease_idx, candidate_disease_embeddings = self.patient_model.calc_loss(batch, phenotype_embedding, disease_embeddings, disease_mask, labels, use_candidate_list)
+        if self.hparams.hparams['loss'] == 'patient_disease_NCA': correct_ranks = self.rank_diseases(softmax, disease_mask, labels)
+        else: correct_ranks, labels = self.rank_patients(softmax, labels)
+
+        # calculate node embedding loss
+        if step_type == 'test':
+            node_embedder_loss = 0
+            roc_score, ap_score, acc, f1 = 0,0,0,0
+        else:
+            # Get link predictions
+            batch, raw_pred, pred = self.node_model.get_predictions(batch, node_embeddings)
+            link_labels = self.node_model.get_link_labels(batch.all_edge_types)
+            node_embedder_loss = self.node_model.calc_loss(pred, link_labels)
+
+            # Calculate metrics
+            metric_pred = torch.sigmoid(raw_pred)
+            roc_score, ap_score, acc, f1 = calc_metrics(metric_pred.cpu().detach().numpy(), link_labels.cpu().detach().numpy())
+
+        # Plot gradients
+        if self.hparams.hparams['plot_gradients']:
+            for k, v in self.patient_model.state_dict().items():
+                self.logger.experiment.log({f'gradients/{step_type}.gradients.%s' % k: wandb.Histogram(v.detach().cpu())})
+
+        return correct_ranks, softmax, labels, node_embedder_loss, loss, roc_score, ap_score, acc, f1, gat_attn, node_embeddings, phenotype_embedding, disease_embeddings, phenotype_mask, disease_mask, attn_weights, candidate_disease_idx, candidate_disease_embeddings
+
+    def training_step(self, batch, batch_idx):
+        correct_ranks, softmax, labels, node_embedder_loss, patient_loss, roc_score, ap_score, acc, f1, gat_attn, node_embeddings, phenotype_embedding, disease_embeddings, phenotype_mask, disease_mask, attn_weights, cand_disease_idx, cand_disease_embeddings = self._step(batch, 'train')
+
+        loss = (self.hparams.hparams['lambda'] * node_embedder_loss) + ((1 - self.hparams.hparams['lambda']) *  patient_loss)
+        self.log('train_loss/overall_loss', loss, prog_bar=True, on_epoch=True)
+        self.log('train_loss/patient_loss', patient_loss, prog_bar=True, on_epoch=True)
+        self.log('train_loss/node_embedder_loss', node_embedder_loss, prog_bar=True, on_epoch=True)
+
+        batch_results = {'loss': loss, 
+                "train/node.roc": roc_score, 
+                "train/node.ap": ap_score, "train/node.acc": acc, "train/node.f1": f1, 
+                'train/node.embed': node_embeddings.detach().cpu(), 
+                'train/patient.phenotype_embed': phenotype_embedding.detach().cpu(), 
+                'train/attention_weights': attn_weights.detach().cpu(),
+                'train/phenotype_names_degrees': batch.phenotype_names,
+                'train/correct_ranks': correct_ranks,
+                'train/disease_names':  batch.disease_names,
+                'train/corr_gene_names': batch.corr_gene_names,
+                "train/softmax": softmax.detach().cpu(),         
+                }
+
+        if self.hparams.hparams['loss'] == 'patient_disease_NCA':
+            batch_sz, n_diseases, embed_dim = disease_embeddings.shape
+            batch_disease_nid_reshaped = batch.batch_disease_nid.view(-1)
+            batch_results.update({
+                'train/batch_disease_nid': batch_disease_nid_reshaped.detach().cpu(),
+                'train/cand_disease_names': batch.cand_disease_names,
+                'train/batch_cand_disease_nid': cand_disease_idx.detach().cpu(),
+                'train/patient.disease_embed': cand_disease_embeddings.detach().cpu()
+            })
+
+        return batch_results
+
+    def validation_step(self, batch, batch_idx):
+        correct_ranks, softmax, labels, node_embedder_loss, patient_loss, roc_score, ap_score, acc, f1, gat_attn, node_embeddings, phenotype_embedding, disease_embeddings, phenotype_mask, disease_mask, attn_weights, cand_disease_idx, cand_disease_embeddings = self._step(batch, 'val')
+        loss = (self.hparams.hparams['lambda'] * node_embedder_loss) + ((1 - self.hparams.hparams['lambda']) * patient_loss)
+        self.log('val_loss/overall_loss', loss, prog_bar=True, on_epoch=True)
+        self.log('val_loss/patient_loss', patient_loss, prog_bar=True)
+        self.log('val_loss/node_embedder_loss', node_embedder_loss, prog_bar=True)
+
+        batch_results = {"loss/val_loss": loss, 
+                         "val/node.roc": roc_score, 
+                         "val/node.ap": ap_score, "val/node.acc": acc, 
+                         "val/node.f1": f1, 
+                         'val/node.embed': node_embeddings.detach().cpu(), 
+                         'val/patient.phenotype_embed': phenotype_embedding.detach().cpu(), 
+                         'val/attention_weights': attn_weights.detach().cpu(),
+                         'val/phenotype_names_degrees': batch.phenotype_names,
+                         'val/correct_ranks': correct_ranks, 
+                         'val/disease_names':  batch.disease_names,
+                         'val/corr_gene_names': batch.corr_gene_names,
+                         "val/softmax": softmax.detach().cpu(),
+                        }
+
+        if self.hparams.hparams['loss'] == 'patient_disease_NCA':
+            batch_sz, n_diseases, embed_dim = disease_embeddings.shape
+            batch_disease_nid_reshaped = batch.batch_disease_nid.view(-1)
+            batch_results.update({'val/batch_disease_nid': batch_disease_nid_reshaped.detach().cpu(),
+                                  'val/cand_disease_names': batch.cand_disease_names,
+                                  'val/batch_cand_disease_nid': cand_disease_idx.detach().cpu(),
+                                  'val/patient.disease_embed': cand_disease_embeddings.detach().cpu()
+                                })
+        return batch_results 
+
+    def write_results_to_file(self, batch, softmax, correct_ranks, labels, phenotype_mask, disease_mask, attn_weights,  gat_attn, node_embeddings, phenotype_embeddings, disease_embeddings, save=True, loop_type='predict'):
+        
+        if save:
+            run_folder = Path(project_config.PROJECT_DIR) / 'checkpoints' / 'patient_NCA' / self.hparams.hparams['run_name'] / (Path(self.test_dataloader.dataloader.dataset.filepath).stem ) #.replce('/', '_')
+            run_folder.mkdir(parents=True, exist_ok=True)
+            print('run_folder', run_folder)
+        
+     
+        # Save scores
+        if self.hparams.hparams['loss'] == 'patient_disease_NCA':
+            cand_disease_names = [d for d_list in batch['cand_disease_names'] for d in d_list]
+
+            all_sims, all_diseases, all_patient_ids = [], [], []
+            for patient_id, sims in zip(batch['patient_ids'], softmax):  #batch['cand_disease_names'], disease_mask, 
+                sims = sims.tolist() 
+                all_sims.extend(sims)
+                all_diseases.extend(cand_disease_names)
+                all_patient_ids.extend([patient_id] * len(sims))
+            results_df = pd.DataFrame({'patient_id': all_patient_ids, 'diseases': all_diseases, 'similarities': all_sims})
+        else:
+            all_sims, all_cand_pats, all_patient_ids = [], [], []
+            for patient_id, sims in zip(batch['patient_ids'], softmax):
+                patient_mask = torch.Tensor([p_id != patient_id for p_id in batch['patient_ids']]).bool()
+                remaining_pats = [p_id for p_id in batch['patient_ids'] if p_id != patient_id]
+                all_sims.extend(sims[patient_mask].tolist())
+                all_cand_pats.extend(remaining_pats)
+                all_patient_ids.extend([patient_id] * len(remaining_pats))
+            results_df = pd.DataFrame({'patient_id': all_patient_ids, 'candidate_patients': all_cand_pats, 'similarities': all_sims})
+        print(results_df.head())
+        if save:
+            print('logging results to run dir: ', run_folder)
+            results_df.to_csv(Path(run_folder)  /'scores.csv', sep = ',', index=False)
+
+        # Save phenotype information
+        if attn_weights is None:
+            phen_df = None
+        else:
+            all_patient_ids, all_phens, all_attn_weights, all_degrees = [], [], [], []
+            for patient_id, attn_w, phen_names, p_mask in zip(batch['patient_ids'], attn_weights, batch['phenotype_names'], phenotype_mask):
+                p_names, degrees = zip(*phen_names)
+                all_patient_ids.extend([patient_id] * len(phen_names))
+                all_degrees.extend(degrees)
+                all_phens.extend(p_names)
+                all_attn_weights.extend(attn_w[p_mask].tolist())
+            phen_df = pd.DataFrame({'patient_id': all_patient_ids, 'phenotypes': all_phens, 'degrees': all_degrees, 'attention':all_attn_weights})
+            print(phen_df.head())
+            if save:
+                phen_df.to_csv(Path(run_folder) /'phenotype_attention.csv', sep = ',', index=False)
+        
+        # Save GAT attention weights
+        #NOTE: assumes 3 layers to model
+        attn_dfs = []
+        layer = 0
+        for edge_attn in gat_attn:
+            edge_index, attn = edge_attn
+            edge_index = edge_index.cpu()
+            attn = attn.cpu()
+            gat_attn_df = pd.DataFrame({'source': edge_index[0,:], 'target': edge_index[1,:]})
+            for head in range(attn.shape[1]):
+                gat_attn_df[f'attn_{head}'] =  attn[:,head]
+            attn_dfs.append(gat_attn_df)
+            print(f'gat_attn_df, layer={layer}', gat_attn_df.head())
+            if save:
+                gat_attn_df.to_csv(Path(run_folder)  / f'gat_attn_layer={layer}.csv', sep = ',', index=False) #wandb.run.dir
+            layer += 1
+
+        # Save embeddings
+        if save:
+            torch.save(batch["n_id"].cpu(), Path(run_folder) /'node_embeddings_idx.pth')
+            torch.save(node_embeddings.cpu(), Path(run_folder) /'node_embeddings.pth')
+            torch.save(phenotype_embeddings.cpu(), Path(run_folder) /'phenotype_embeddings.pth')
+            if self.hparams.hparams['loss'] == 'patient_disease_NCA': torch.save(disease_embeddings.cpu(), Path(run_folder) /'disease_embeddings.pth')        
+        if self.hparams.hparams['loss'] == 'patient_disease_NCA': disease_embeddings = disease_embeddings.cpu()
+
+        return results_df, phen_df, attn_dfs, phenotype_embeddings.cpu(), disease_embeddings
+
+    def test_step(self, batch, batch_idx):
+        correct_ranks, softmax, labels, node_embedder_loss, patient_loss, roc_score, ap_score, acc, f1, gat_attn, node_embeddings, phenotype_embedding, disease_embeddings, phenotype_mask, disease_mask, attn_weights, cand_disease_idx, cand_disease_embeddings = self._step(batch, 'test')
+        batch_results = {'test/correct_ranks': correct_ranks, 
+                         'test/node.embed': node_embeddings.detach().cpu(), 
+                         'test/patient.phenotype_embed': phenotype_embedding.detach().cpu(), 
+                         'test/attention_weights': attn_weights.detach().cpu(),
+                         'test/phenotype_names_degrees': batch.phenotype_names,
+                         'test/disease_names': batch.disease_names,
+                         'test/corr_gene_names': batch.corr_gene_names,
+                         'test/gat_attn': gat_attn, # type = list
+                         "test/n_id": batch.n_id[:batch.batch_size].detach().cpu(),
+                         "test/patient_ids": batch.patient_ids, # type = list
+                         "test/softmax": softmax.detach().cpu(),
+                         "test/labels": labels.detach().cpu(),
+                         'test/phenotype_mask': phenotype_mask.detach().cpu(),
+                         'test/disease_mask': phenotype_mask.detach().cpu(),
+                        }
+
+        if self.hparams.hparams['loss'] == 'patient_disease_NCA':
+            batch_sz, n_diseases, embed_dim = disease_embeddings.shape
+            batch_disease_nid_reshaped = batch.batch_disease_nid.view(-1)
+            batch_results.update({
+                'test/batch_disease_nid': batch_disease_nid_reshaped.detach().cpu(),
+                'test/cand_disease_names': batch.cand_disease_names,
+                'test/batch_cand_disease_nid': cand_disease_idx,
+                'test/patient.disease_embed': cand_disease_embeddings
+            })
+        else:
+            batch_results.update({
+                'test/patient.disease_embed': None, 
+                'test/batch_disease_nid': None,
+                'test/cand_disease_names': None
+
+            })
+        
+        return batch_results
+
+    
+    def inference(self, batch, batch_idx):
+        outputs, gat_attn = self.node_model.predict(self.all_data)
+
+        pad_outputs = torch.cat([torch.zeros(1, outputs.size(1), device=outputs.device), outputs]) 
+
+        # get masks
+        phenotype_mask = (batch.batch_pheno_nid != 0)
+        if self.hparams.hparams['loss'] == 'patient_disease_NCA': disease_mask = (batch.batch_cand_disease_nid != 0)
+        else: disease_mask = None
+
+        # index into outputs using phenotype & disease batch node idx
+        batch_sz, max_n_phen = batch.batch_pheno_nid.shape
+        phenotype_embeddings = torch.index_select(pad_outputs, 0, batch.batch_pheno_nid.view(-1)).view(batch_sz, max_n_phen, -1)
+        if self.hparams.hparams['loss'] == 'patient_disease_NCA': 
+            batch_sz, max_n_dx = batch.batch_cand_disease_nid.shape
+            disease_embeddings = torch.index_select(pad_outputs, 0, batch.batch_cand_disease_nid.view(-1)).view(batch_sz, max_n_dx, -1)
+        else: disease_embeddings = None
+
+        phenotype_embedding, disease_embeddings, phenotype_mask, disease_mask, attn_weights = self.patient_model.forward(phenotype_embeddings, disease_embeddings, phenotype_mask, disease_mask)
+
+        return outputs, gat_attn, phenotype_embedding, disease_embeddings, phenotype_mask, disease_mask, attn_weights
+
+
+    def predict_step(self, batch, batch_idx):
+        node_embeddings, gat_attn, phenotype_embedding, disease_embeddings, phenotype_mask, disease_mask, attn_weights = self.inference(batch, batch_idx)
+
+        # calculate patient embedding loss
+        use_candidate_list =  self.hparams.hparams['only_hard_distractors'] 
+        if self.hparams.hparams['loss'] == 'patient_disease_NCA': labels = batch.disease_one_hot_labels
+        else: labels = batch.patient_labels
+        loss, softmax, labels, candidate_disease_idx, candidate_disease_embeddings = self.patient_model.calc_loss(batch, phenotype_embedding, disease_embeddings, disease_mask, labels, use_candidate_list)
+        if labels.nelement() == 0:
+            correct_ranks = None
+        else:
+            if self.hparams.hparams['loss'] == 'patient_disease_NCA': correct_ranks = self.rank_diseases(softmax, disease_mask, labels)
+            else: correct_ranks, labels = self.rank_patients(softmax, labels)
+        
+
+        results_df, phen_df, attn_dfs, phenotype_embeddings, disease_embeddings = self.write_results_to_file(batch, softmax, correct_ranks, labels, phenotype_mask, disease_mask , attn_weights, gat_attn, node_embeddings, phenotype_embedding, disease_embeddings, save=True, loop_type='predict')
+        return results_df, phen_df, *attn_dfs, phenotype_embeddings, disease_embeddings
+
+    
+    def _epoch_end(self, outputs, loop_type):
+        correct_ranks = torch.cat([ranks  for x in outputs for ranks in x[f'{loop_type}/correct_ranks']], dim=0) #if len(ranks.shape) > 0 else ranks.unsqueeze(-1)
+        correct_ranks_with_pad = [ranks if len(ranks.unsqueeze(-1)) > 0 else torch.tensor([-1]) for x in outputs for ranks in x[f'{loop_type}/correct_ranks']]
+
+        if loop_type == "test":
+            
+            batch_info = {"n_id": torch.cat([x[f'{loop_type}/n_id'] for x in outputs], dim=0),
+                          "patient_ids": [pat for x in outputs for pat in x[f'{loop_type}/patient_ids'] ],
+                          "phenotype_names": [pat for x in outputs for pat in x[f'{loop_type}/phenotype_names_degrees']],
+                          "cand_disease_names": [pat for x in outputs for pat in x[f'{loop_type}/cand_disease_names']] if outputs[0][f'{loop_type}/cand_disease_names'] is not None else None,
+                          }
+
+            softmax = [pat for x in outputs for pat in x[f'{loop_type}/softmax']] 
+            labels = [pat for x in outputs for pat in x[f'{loop_type}/labels']] 
+            phenotype_mask = [pat for x in outputs for pat in x[f'{loop_type}/phenotype_mask']] 
+            disease_mask = [pat for x in outputs for pat in x[f'{loop_type}/disease_mask']] 
+            attn_weights = [pat for x in outputs for pat in x[f'{loop_type}/attention_weights']] 
+            gat_attn = [pat for x in outputs for pat in x[f'{loop_type}/gat_attn']] 
+            node_embeddings = torch.cat([x[f'{loop_type}/node.embed'] for x in outputs], dim=0)
+            phenotype_embedding = torch.cat([x[f'{loop_type}/patient.phenotype_embed'] for x in outputs], dim=0)
+            disease_embeddings = torch.cat([x[f'{loop_type}/patient.disease_embed'] for x in outputs], dim=0) if outputs[0][f'{loop_type}/patient.disease_embed'] is not None else None
+            if self.hparams.hparams['loss'] == 'patient_disease_NCA':
+                cand_disease_batch_nid = torch.cat([x[f'{loop_type}/batch_cand_disease_nid'] for x in outputs], dim=0)
+            else: cand_disease_batch_nid = None
+
+            results_df, phen_df, attn_dfs, phenotype_embeddings, disease_embeddings = self.write_results_to_file(batch_info, softmax, correct_ranks_with_pad, labels, phenotype_mask, disease_mask, attn_weights, gat_attn, node_embeddings, phenotype_embedding, disease_embeddings, save=True, loop_type='test')
+
+            print("Writing results for test...")
+            output_base = "/home/ml499/public_repos/SHEPHERD/shepherd/results/patients_like_me"
+            results_df.to_csv(str(output_base) + '_scores.csv', index=False)
+            print(results_df)
+
+
+        if self.hparams.hparams['plot_patient_embed']:
+            phenotype_embedding = torch.cat([x[f'{loop_type}/patient.phenotype_embed'] for x in outputs], dim=0)
+            correct_gene_names = ['None' if len(li) == 0 else '  |  '.join(li) for x in outputs for li in x[f'{loop_type}/corr_gene_names'] ] 
+            correct_disease_names = ['None' if len(li) == 0 else '  |  '.join(li) for x in outputs for li in x[f'{loop_type}/disease_names'] ] 
+
+            phenotype_names = [' | '.join([item[0] for item in li][0:6])  for x in outputs for li in x[f'{loop_type}/phenotype_names_degrees'] ] #only take first few for now because they don't all fit
+            patient_label = {
+                        "Phenotypes": phenotype_names ,
+                        "Node Type": correct_disease_names,
+                        "Correct Gene": correct_gene_names,
+                        "Correct Disease":  correct_disease_names
+            }
+            self.logger.experiment.log({f'{loop_type}/patient_embed': fit_umap(phenotype_embedding, patient_label)})
+        
+        if self.hparams.hparams['plot_disease_embed']:
+            # Plot embeddings of patient aggregated phenotype & diseases              
+            phenotype_embedding = torch.cat([x[f'{loop_type}/patient.phenotype_embed'] for x in outputs], dim=0)
+            disease_embeddings = torch.cat([x[f'{loop_type}/patient.disease_embed'] for x in outputs], dim=0) 
+            disease_batch_nid = torch.cat([x[f'{loop_type}/batch_disease_nid'] for x in outputs], dim=0)
+            cand_disease_batch_nid = torch.cat([x[f'{loop_type}/batch_cand_disease_nid'] for x in outputs], dim=0)
+            disease_mask = (disease_batch_nid != 0)
+            cand_disease_mask = (cand_disease_batch_nid != 0)
+    
+            phenotype_names = [' | '.join([item[0] for item in li][0:6])  for x in outputs for li in x[f'{loop_type}/phenotype_names_degrees'] ] #only take first few for now because they don't all fit
+            cand_disease_names = [item for x in outputs for li in x[f'{loop_type}/cand_disease_names'] for item in li] 
+            correct_disease_names = ['None' if len(li) == 0 else '  |  '.join(li) for x in outputs for li in x[f'{loop_type}/disease_names'] ] 
+
+            patient_emb = torch.cat([phenotype_embedding, disease_embeddings])
+
+            patient_label = {
+                        "Node Type": ["Patient Phenotype"] * phenotype_embedding.shape[0] + ['Disease'] * disease_embeddings.shape[0],
+                        "Name": phenotype_names + cand_disease_names,
+                        "Correct Disease": correct_disease_names + ['NA'] * disease_embeddings.shape[0]
+                        }
+            self.logger.experiment.log({f'{loop_type}/patient_embed': fit_umap(patient_emb, patient_label)})
+     
+        if 'plot_softmax' in self.hparams.hparams and self.hparams.hparams['plot_softmax']:
+            softmax = [pat for x in outputs for pat in x[f'{loop_type}/softmax']] 
+            softmax_diff = [s.max() - s.min() for s in softmax]
+            softmax_top2_diff = [torch.topk(s, 2).values.max() - torch.topk(s, 2).values.min() for s in softmax]
+            softmax_top5_diff = [torch.topk(s, 5).values.max() - torch.topk(s, 5).values.min() for s in softmax]
+            self.logger.experiment.log({f'{loop_type}/softmax_top2_diff': plot_softmax(softmax_top2_diff)})
+            self.logger.experiment.log({f'{loop_type}/softmax_top5_diff': plot_softmax(softmax_top5_diff)})
+            self.logger.experiment.log({f'{loop_type}/softmax_diff': plot_softmax(softmax_diff)})
+
+        if self.hparams.hparams['plot_attn_nhops']:
+            # plot phenotype attention vs n_hops to gene and degree
+            attn_weights = [torch.split(x[f'{loop_type}/attention_weights'],1) for x in outputs]
+            attn_weights = [w[w > 0] for batch_w in attn_weights for w in batch_w]
+            phenotype_names = [pat for x in outputs for pat in x[f'{loop_type}/phenotype_names_degrees']]
+            attn_weights_cpu_reshaped = torch.cat(attn_weights, dim=0)
+            self.logger.experiment.log({f"{loop_type}_attn/attention weights": wandb.Histogram(attn_weights_cpu_reshaped[attn_weights_cpu_reshaped != 0])})
+            self.logger.experiment.log({f"{loop_type}_attn/single patient attention weights": wandb.Histogram(attn_weights[0])})
+        
+        if loop_type == 'val':
+            self.log(f'patient.curr_epoch', self.current_epoch, prog_bar=False)
+
+        # top k accuracy
+        top_1_acc = top_k_acc(correct_ranks, k=1)
+        top_3_acc = top_k_acc(correct_ranks, k=3)
+        top_5_acc = top_k_acc(correct_ranks, k=5)
+        top_10_acc = top_k_acc(correct_ranks, k=10)
+
+        #mean reciprocal rank
+        mrr = mean_reciprocal_rank(correct_ranks)
+
+        self.log(f'{loop_type}/top1_acc', top_1_acc, prog_bar=False)
+        self.log(f'{loop_type}/top3_acc', top_3_acc, prog_bar=False)
+        self.log(f'{loop_type}/top5_acc', top_5_acc, prog_bar=False)
+        self.log(f'{loop_type}/top10_acc', top_10_acc, prog_bar=False)
+        self.log(f'{loop_type}/mrr', mrr, prog_bar=False)
+
+    def training_epoch_end(self, outputs):
+        self._epoch_end(outputs, 'train')
+
+    def validation_epoch_end(self, outputs):
+        self._epoch_end(outputs, 'val')
+
+    def test_epoch_end(self, outputs):
+        self._epoch_end(outputs, 'test')
+
+    def configure_optimizers(self):
+        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.hparams['lr'])
+        return optimizer