#pytorch lightning
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
# torch
from torch import nn
import torch
import torch.nn.functional as F
import numpy as np
from scipy.stats import rankdata
import pandas as pd
from pathlib import Path
import time
import wandb
import sys
sys.path.insert(0, '..') # add project_config to path
from node_embedder_model import NodeEmbeder
from task_heads.gp_aligner import GPAligner
import project_config
# import utils
from utils.pretrain_utils import get_edges, calc_metrics
from utils.loss_utils import MultisimilarityCriterion
from utils.train_utils import mean_reciprocal_rank, top_k_acc, average_rank
from utils.train_utils import fit_umap, mrr_vs_percent_overlap, plot_gene_rank_vs_x_intrain, plot_gene_rank_vs_hops, plot_degree_vs_attention, plot_nhops_to_gene_vs_attention, plot_gene_rank_vs_fraction_phenotype, plot_gene_rank_vs_numtrain, plot_gene_rank_vs_trainset
from utils.train_utils import weighted_sum
class CombinedGPAligner(pl.LightningModule):
def __init__(self, edge_attr_dict, all_data, n_nodes=None, node_ckpt=None, hparams=None, node_hparams=None, spl_pca=[], spl_gate=[]):
super().__init__()
print('Initializing Model')
self.save_hyperparameters('hparams', ignore=["spl_pca", "spl_gate"]) # spl_pca and spl_gate never get used
#print('Saved combined model hyperparameters: ', self.hparams)
self.all_data = all_data
self.all_train_nodes = {}
self.train_patient_nodes = {}
self.train_sparse_nodes = {}
self.train_target_batch = {}
self.train_corr_gene_nid = {}
print(f"Loading Node Embedder from {node_ckpt}")
# NOTE: loads in saved hyperparameters
self.node_model = NodeEmbeder.load_from_checkpoint(checkpoint_path=node_ckpt,
all_data=all_data,
edge_attr_dict=edge_attr_dict,
num_nodes=n_nodes)
self.patient_model = self.get_patient_model()
print('End Patient Model Initialization')
def get_patient_model(self):
# NOTE: this will only work with GATv2Conv
model = GPAligner(self.hparams.hparams, embed_dim=self.node_model.hparams.hp_dict['output']*self.node_model.hparams.hp_dict['n_heads'])
return model
def forward(self, batch, step_type):
# Node Embedder
t0 = time.time()
print(len(batch.adjs))
outputs, gat_attn = self.node_model.forward(batch.n_id, batch.adjs)
pad_outputs = torch.cat([torch.zeros(1, outputs.size(1), device=outputs.device), outputs])
t1 = time.time()
# get masks
phenotype_mask = (batch.batch_pheno_nid != 0)
gene_mask = (batch.batch_cand_gene_nid != 0)
# index into outputs using phenotype & gene batch node idx
batch_sz, max_n_phen = batch.batch_pheno_nid.shape
phenotype_embeddings = torch.index_select(pad_outputs, 0, batch.batch_pheno_nid.view(-1)).view(batch_sz, max_n_phen, -1)
batch_sz, max_n_cand_genes = batch.batch_cand_gene_nid.shape
cand_gene_embeddings = torch.index_select(pad_outputs, 0, batch.batch_cand_gene_nid.view(-1)).view(batch_sz, max_n_cand_genes, -1)
if self.hparams.hparams['augment_genes']:
print("Augmenting genes...", self.hparams.hparams['aug_gene_w'])
_, max_n_sim_cand_genes, k_sim_genes = batch.batch_sim_gene_nid.shape
sim_gene_embeddings = torch.index_select(pad_outputs, 0, batch.batch_sim_gene_nid.view(-1)).view(batch_sz, max_n_sim_cand_genes, self.hparams.hparams['n_sim_genes'], -1)
agg_sim_gene_embedding = weighted_sum(sim_gene_embeddings, batch.batch_sim_gene_sims)
if self.hparams.hparams['aug_gene_by_deg']:
print("Augmenting gene by degree...")
aug_gene_w = self.hparams.hparams['aug_gene_w'] * torch.exp(-self.hparams.hparams['aug_gene_w'] * batch.batch_cand_gene_degs) + (1 - self.hparams.hparams['aug_gene_w'] - 0.1)
aug_gene_w = (aug_gene_w * (torch.sum(batch.batch_sim_gene_sims, dim = -1) > 0)).unsqueeze(-1)
else:
aug_gene_w = (self.hparams.hparams['aug_gene_w'] * (torch.sum(batch.batch_sim_gene_sims, dim = -1) > 0)).unsqueeze(-1)
cand_gene_embeddings = torch.mul(1 - aug_gene_w, cand_gene_embeddings) + torch.mul(aug_gene_w, agg_sim_gene_embedding)
# Patient Embedder with or without disease information
if self.hparams.hparams['use_diseases']:
disease_mask = (batch.batch_disease_nid != 0)
batch_sz, max_n_dx = batch.batch_disease_nid.shape
disease_embeddings = torch.index_select(pad_outputs, 0, batch.batch_disease_nid.view(-1)).view(batch_sz, max_n_dx, -1)
t2 = time.time()
phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights = self.patient_model.forward(phenotype_embeddings, cand_gene_embeddings, disease_embeddings, phenotype_mask, gene_mask, disease_mask)
t3 = time.time()
else:
t2 = time.time()
phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights = self.patient_model.forward(phenotype_embeddings, cand_gene_embeddings, phenotype_mask=phenotype_mask, gene_mask=gene_mask)
t3 = time.time()
if self.hparams.hparams['time']:
print(f'It takes {t1-t0:0.4f}s for the node model, {t2-t1:0.4f}s for indexing into the output, and {t3-t2:0.4f}s for the patient model forward.')
return outputs, gat_attn, phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights
def _step(self, batch, step_type):
t0 = time.time()
if step_type != 'test':
batch = get_edges(batch, self.all_data, step_type)
t1 = time.time()
# Forward pass
node_embeddings, gat_attn, phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights = self.forward(batch, step_type)
t2 = time.time()
# Calculate similarities between patient phenotypes & candidate genes/diseases
alpha = self.hparams.hparams['alpha']
use_candidate_list = True if step_type != 'train' else False
cand_gene_to_phenotypes_spl = batch.batch_cand_gene_to_phenotypes_spl if use_candidate_list else batch.batch_concat_cand_gene_to_phenotypes_spl
disease_nid = batch.batch_disease_nid if self.hparams.hparams['use_diseases'] else None
# calculate similarity between phen & genes for all genes in manual candidate list
phen_gene_sims, raw_phen_gene_sims, phen_gene_mask, phen_gene_one_hot_labels = self.patient_model._calc_similarity(phenotype_embedding, candidate_gene_embeddings, None, batch.batch_cand_gene_nid, batch.batch_corr_gene_nid, disease_nid, batch.one_hot_labels, gene_mask, phenotype_mask, disease_mask, True, batch.batch_cand_gene_to_phenotypes_spl, alpha)
# calculate similarity for loss function
if self.hparams.hparams['loss'] == 'gene_multisimilarity' and use_candidate_list: # in this case, the similarities are the same
sims = phen_gene_sims
mask = phen_gene_mask
one_hot_labels = phen_gene_one_hot_labels
else:
if self.hparams.hparams['loss'] == 'disease_multisimilarity': candidate_gene_embeddings = None
elif self.hparams.hparams['loss'] == 'gene_multisimilarity': disease_embeddings = None
sims, raw_sims, mask, one_hot_labels = self.patient_model._calc_similarity(phenotype_embedding, candidate_gene_embeddings, disease_embeddings, batch.batch_cand_gene_nid, batch.batch_corr_gene_nid, disease_nid, batch.one_hot_labels, gene_mask, phenotype_mask, disease_mask, use_candidate_list, cand_gene_to_phenotypes_spl, alpha)
## Rank genes
correct_gene_ranks, phen_gene_sims = self.patient_model._rank_genes(phen_gene_sims, phen_gene_mask, phen_gene_one_hot_labels)
t3 = time.time()
## Calculate patient embedding loss
loss = self.patient_model.calc_loss(sims, mask, one_hot_labels)
t4 = time.time()
## Calculate node embedding loss
if step_type == 'test':
node_embedder_loss = 0
roc_score, ap_score, acc, f1 = 0,0,0,0
else:
# Get link predictions
batch, raw_pred, pred = self.node_model.get_predictions(batch, node_embeddings)
link_labels = self.node_model.get_link_labels(batch.all_edge_types)
node_embedder_loss = self.node_model.calc_loss(pred, link_labels)
# Calculate metrics
metric_pred = torch.sigmoid(raw_pred)
roc_score, ap_score, acc, f1 = calc_metrics(metric_pred.cpu().detach().numpy(), link_labels.cpu().detach().numpy())
t5 = time.time()
## calc time
if self.hparams.hparams['time']:
print(f'It takes {t1-t0:0.4f}s to get edges, {t2-t1:0.4f}s for the forward pass, {t3-t2:0.4f}s to rank genes, {t4-t3:0.4f}s to calc patient loss, and {t5-t4:0.4f}s to calc the node loss.')
## Plot gradients
if self.hparams.hparams['plot_gradients']:
for k, v in self.patient_model.state_dict().items():
self.logger.experiment.log({f'gradients/{step_type}.gradients.%s' % k: wandb.Histogram(v.detach().cpu())})
return node_embedder_loss, loss, correct_gene_ranks, roc_score, ap_score, acc, f1, node_embeddings, gat_attn, phenotype_embedding, candidate_gene_embeddings, attn_weights, phen_gene_sims, raw_phen_gene_sims, gene_mask, phenotype_mask
def training_step(self, batch, batch_idx):
print('training step')
node_embedder_loss, patient_loss, correct_gene_ranks, roc_score, ap_score, acc, f1, node_embeddings, gat_attn, phenotype_embedding, candidate_gene_embeddings, attn_weights, phen_gene_sims, raw_phen_gene_sims, gene_mask, phenotype_mask = self._step(batch, 'train')
loss = (self.hparams.hparams['lambda'] * node_embedder_loss) + ((1 - self.hparams.hparams['lambda']) * patient_loss)
self.log('train_loss/patient.train_overall_loss', loss, prog_bar=True, on_epoch=True)
self.log('train_loss/patient.train_patient_loss', patient_loss, prog_bar=True, on_epoch=True)
self.log('train_loss/patient.train_node_embedder_loss', node_embedder_loss, prog_bar=True, on_epoch=True)
batch_sz, n_candidates, embed_dim = candidate_gene_embeddings.shape
candidate_gene_embeddings_flattened = candidate_gene_embeddings.view(batch_sz*n_candidates, embed_dim)
one_hot_labels_flattened = batch.one_hot_labels.view(batch_sz*n_candidates)
return {'loss': loss,
'train/train_correct_gene_ranks': correct_gene_ranks,
"train/node.train_roc": roc_score,
"train/node.train_ap": ap_score,
"train/node.train_acc": acc,
"train/node.train_f1": f1,
'train/one_hot_labels': batch.one_hot_labels.detach().cpu(),
'train/attention_weights': attn_weights.detach().cpu() if attn_weights != None else None,
'train/phen_gene_sims': phen_gene_sims.detach().cpu(),
'train/phenotype_names_degrees': batch.phenotype_names,
}
def validation_step(self, batch, batch_idx):
node_embedder_loss, patient_loss, correct_gene_ranks, roc_score, ap_score, acc, f1, node_embeddings, gat_attn, phenotype_embedding, candidate_gene_embeddings, attn_weights, phen_gene_sims, raw_phen_gene_sims, gene_mask, phenotype_mask = self._step(batch, 'val')
loss = (self.hparams.hparams['lambda'] * node_embedder_loss) + ((1 - self.hparams.hparams['lambda']) * patient_loss)
self.log('val_loss/patient.val_overall_loss', loss, prog_bar=True, on_epoch=True)
self.log('val_loss/patient.val_patient_loss', patient_loss, prog_bar=True)
self.log('val_loss/patient.val_node_embedder_loss', node_embedder_loss, prog_bar=True)
return {'loss/val_loss': loss,
'val/val_correct_gene_ranks': correct_gene_ranks,
"val/node.val_roc": roc_score,
"val/node.val_ap": ap_score,
"val/node.val_acc": acc,
"val/node.val_f1": f1,
'val/one_hot_labels': batch.one_hot_labels.detach().cpu(),
'val/attention_weights': attn_weights.detach().cpu() if attn_weights != None else None,
'val/phen_gene_sims': phen_gene_sims.detach().cpu(),
'val/phenotype_names_degrees': batch.phenotype_names,
}
def write_results_to_file(self, batch_info, phen_gene_sims, gene_mask, phenotype_mask, attn_weights, correct_gene_ranks, gat_attn, node_embeddings, phenotype_embeddings, save=True, loop_type='predict'):
# NOTE: only saves a single batch - to run at inference time, make sure batch size includes all patients
# Save GAT attention weights
#NOTE: assumes 3 layers to model
attn_dfs = []
layer = 0
for edge_attn in gat_attn:
edge_index, attn = edge_attn
edge_index = edge_index.cpu()
attn = attn.cpu()
gat_attn_df = pd.DataFrame({'source': edge_index[0,:], 'target': edge_index[1,:]})
for head in range(attn.shape[1]):
gat_attn_df[f'attn_{head}'] = attn[:,head]
attn_dfs.append(gat_attn_df)
layer += 1
# Save scores
all_sims, all_genes, all_patient_ids, all_labels = [], [], [], []
for patient_id, sims, genes, g_mask in zip(batch_info["patient_ids"], phen_gene_sims, batch_info["cand_gene_names"], gene_mask):
nonpadded_sims = sims[g_mask].tolist()
all_sims.extend(nonpadded_sims)
all_genes.extend(genes)
all_patient_ids.extend([patient_id] * len(genes))
results_df = pd.DataFrame({'patient_id': all_patient_ids, 'genes': all_genes, 'similarities': all_sims})
# Save phenotype information
if attn_weights is None:
phen_df = None
else:
all_patient_ids, all_phens, all_attn_weights, all_degrees = [], [], [], []
for patient_id, attn_w, phen_names, p_mask in zip(batch_info["patient_ids"], attn_weights, batch_info["phenotype_names"], phenotype_mask):
p_names, degrees = zip(*phen_names)
all_patient_ids.extend([patient_id] * len(phen_names))
all_degrees.extend(degrees)
all_phens.extend(p_names)
all_attn_weights.extend(attn_w[p_mask].tolist())
phen_df = pd.DataFrame({'patient_id': all_patient_ids, 'phenotypes': all_phens, 'degrees': all_degrees, 'attention':all_attn_weights})
print(phen_df.head())
return results_df, phen_df, attn_dfs, phenotype_embeddings.cpu(), None
def test_step(self, batch, batch_idx):
node_embedder_loss, patient_loss, correct_gene_ranks, roc_score, ap_score, acc, f1, node_embeddings, gat_attn, phenotype_embedding, candidate_gene_embeddings, attn_weights, phen_gene_sims, raw_phen_gene_sims, gene_mask, phenotype_mask = self._step(batch, 'test')
return {'test/test_correct_gene_ranks': correct_gene_ranks,
'test/node.embed': node_embeddings.detach().cpu(),
'test/patient.phenotype_embed': phenotype_embedding.detach().cpu(),
'test/one_hot_labels': batch.one_hot_labels.detach().cpu(), #one_hot_labels_flattened.detach().cpu(),
'test/attention_weights': attn_weights.detach().cpu() if attn_weights != None else None,
'test/phen_gene_sims': phen_gene_sims.detach().cpu(),
'test/phenotype_names_degrees': batch.phenotype_names, # type = list
'test/gene_mask': gene_mask.detach().cpu(),
'test/phenotype_mask': phenotype_mask.detach().cpu(),
"test/patient_ids": batch.patient_ids, # type = list
"test/cand_gene_names": batch.cand_gene_names, # type = list
'test/gat_attn': gat_attn, # type = list
"test/n_id": batch.n_id[:batch.batch_size].detach().cpu(),
}
def inference(self, batch, batch_idx):
outputs, gat_attn = self.node_model.predict(self.all_data)
pad_outputs = torch.cat([torch.zeros(1, outputs.size(1), device=outputs.device), outputs])
t1 = time.time()
# get masks
phenotype_mask = (batch.batch_pheno_nid != 0)
gene_mask = (batch.batch_cand_gene_nid != 0)
# index into outputs using phenotype & gene batch node idx
batch_sz, max_n_phen = batch.batch_pheno_nid.shape
phenotype_embeddings = torch.index_select(pad_outputs, 0, batch.batch_pheno_nid.view(-1)).view(batch_sz, max_n_phen, -1)
batch_sz, max_n_cand_genes = batch.batch_cand_gene_nid.shape
cand_gene_embeddings = torch.index_select(pad_outputs, 0, batch.batch_cand_gene_nid.view(-1)).view(batch_sz, max_n_cand_genes, -1)
if self.hparams.hparams['augment_genes']:
print("Augmenting genes at inference...", self.hparams.hparams['aug_gene_w'])
_, max_n_sim_cand_genes, k_sim_genes = batch.batch_sim_gene_nid.shape
sim_gene_embeddings = torch.index_select(pad_outputs, 0, batch.batch_sim_gene_nid.view(-1)).view(batch_sz, max_n_sim_cand_genes, self.hparams.hparams['n_sim_genes'], -1)
agg_sim_gene_embedding = weighted_sum(sim_gene_embeddings, batch.batch_sim_gene_sims)
aug_gene_w = (self.hparams.hparams['aug_gene_w'] * (torch.sum(batch.batch_sim_gene_sims, dim = -1) > 0)).unsqueeze(-1)
cand_gene_embeddings = torch.mul(1 - aug_gene_w, cand_gene_embeddings) + torch.mul(aug_gene_w, agg_sim_gene_embedding)
# Patient Embedder with or without disease information
if self.hparams.hparams['use_diseases']:
disease_mask = (batch.batch_disease_nid != 0)
batch_sz, max_n_dx = batch.batch_disease_nid.shape
disease_embeddings = torch.index_select(pad_outputs, 0, batch.batch_disease_nid.view(-1)).view(batch_sz, max_n_dx, -1)
t2 = time.time()
phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights = self.patient_model.forward(phenotype_embeddings, cand_gene_embeddings, disease_embeddings, phenotype_mask, gene_mask, disease_mask)
t3 = time.time()
else:
t2 = time.time()
phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights = self.patient_model.forward(phenotype_embeddings, cand_gene_embeddings, phenotype_mask=phenotype_mask, gene_mask=gene_mask)
t3 = time.time()
if self.hparams.hparams['time']:
print(f'It takes {t1-t0:0.4f}s for the node model, {t2-t1:0.4f}s for indexing into the output, and {t3-t2:0.4f}s for the patient model forward.')
return outputs, gat_attn, phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights
def predict_step(self, batch, batch_idx):
node_embeddings, gat_attn, phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights = self.inference(batch, batch_idx)
# Calculate similarities between patient phenotypes & candidate genes/diseases
alpha = self.hparams.hparams['alpha']
use_candidate_list = True
disease_nid = batch.batch_disease_nid if self.hparams.hparams['use_diseases'] else None
# calculate similarity between phen & genes for all genes in manual candidate list
phen_gene_sims, raw_phen_gene_sims, phen_gene_mask, phen_gene_one_hot_labels = self.patient_model._calc_similarity(phenotype_embedding, candidate_gene_embeddings, None, batch.batch_cand_gene_nid, batch.batch_corr_gene_nid, disease_nid, batch.one_hot_labels, gene_mask, phenotype_mask, disease_mask, True, batch.batch_cand_gene_to_phenotypes_spl, alpha)
# Rank genes
correct_gene_ranks, phen_gene_sims = self.patient_model._rank_genes(phen_gene_sims, phen_gene_mask, phen_gene_one_hot_labels)
results_df, phen_df, attn_dfs, phenotype_embeddings, disease_embeddings = self.write_results_to_file(batch, phen_gene_sims, gene_mask, phenotype_mask, attn_weights, correct_gene_ranks, gat_attn, node_embeddings, phenotype_embedding, save=True)
return results_df, phen_df, *attn_dfs, phenotype_embeddings, disease_embeddings
def _epoch_end(self, outputs, loop_type):
correct_gene_ranks = torch.cat([x[f'{loop_type}/{loop_type}_correct_gene_ranks'] for x in outputs], dim=0)
if loop_type == "test":
batch_info = {"n_id": torch.cat([x[f'{loop_type}/n_id'] for x in outputs], dim=0),
"patient_ids": [pat for x in outputs for pat in x[f'{loop_type}/patient_ids']],
"phenotype_names": [pat for x in outputs for pat in x[f'{loop_type}/phenotype_names_degrees']],
"cand_gene_names": [pat for x in outputs for pat in x[f'{loop_type}/cand_gene_names']],
"one_hot_labels": [pat for x in outputs for pat in x[f'{loop_type}/one_hot_labels']],
}
phen_gene_sims = [pat for x in outputs for pat in x[f'{loop_type}/phen_gene_sims']]
gene_mask = [pat for x in outputs for pat in x[f'{loop_type}/gene_mask']]
phenotype_mask = [pat for x in outputs for pat in x[f'{loop_type}/phenotype_mask']]
attn_weights = [pat for x in outputs for pat in x[f'{loop_type}/attention_weights']]
gat_attn = [pat for x in outputs for pat in x[f'{loop_type}/gat_attn']]
node_embeddings = torch.cat([x[f'{loop_type}/node.embed'] for x in outputs], dim=0)
phenotype_embedding = torch.cat([x[f'{loop_type}/patient.phenotype_embed'] for x in outputs], dim=0)
results_df, phen_df, attn_dfs, phenotype_embeddings, disease_embeddings = self.write_results_to_file(batch_info, phen_gene_sims, gene_mask, phenotype_mask, attn_weights, correct_gene_ranks, gat_attn, node_embeddings, phenotype_embedding, loop_type=loop_type)
print("Writing results for test...")
output_base = "/home/ml499/public_repos/SHEPHERD/shepherd/results/gp"
results_df.to_csv(str(output_base) + '_scores.csv', index=False)
print(results_df)
phen_df.to_csv(str(output_base) + '_phenotype_attention.csv', sep = ',', index=False)
print(phen_df)
# Plot embeddings
if loop_type != "train" and len(self.train_patient_nodes) > 0 and self.hparams.hparams['plot_intrain']:
correct_gene_nid = torch.cat([x[f'{loop_type}/corr_gene_nid_orig'] for x in outputs], dim=0)
assert correct_gene_ranks.shape[0] == correct_gene_nid.shape[0]
# Rank of gene vs. number of train patients with causal gene
gene_rank_corr_gene_fig, gene_rank_corr_gene_counts = plot_gene_rank_vs_numtrain(correct_gene_ranks, correct_gene_nid, self.train_corr_gene_nid)
gene_rank_cand_gene_fig, gene_rank_cand_gene_counts = plot_gene_rank_vs_numtrain(correct_gene_ranks, correct_gene_nid, self.train_patient_nodes)
gene_rank_sparse_fig, gene_rank_sparse_counts = plot_gene_rank_vs_numtrain(correct_gene_ranks, correct_gene_nid, self.train_sparse_nodes)
gene_rank_target_fig, gene_rank_target_counts = plot_gene_rank_vs_numtrain(correct_gene_ranks, correct_gene_nid, self.train_target_batch)
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_num_train_corr_genes': gene_rank_corr_gene_fig})
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_num_train_cand_genes': gene_rank_cand_gene_fig})
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_num_train_sparse': gene_rank_sparse_fig})
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_num_train_target': gene_rank_target_fig})
gene_nid_trainset = torch.stack([torch.tensor(gene_rank_corr_gene_counts),
torch.tensor(gene_rank_cand_gene_counts),
torch.tensor(gene_rank_sparse_counts),
torch.tensor(gene_rank_target_counts)], dim=1)
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_trainset': plot_gene_rank_vs_trainset(correct_gene_ranks, correct_gene_nid, gene_nid_trainset)})
if self.hparams.hparams['plot_PG_embed']:
self.logger.experiment.log({f'{loop_type}/patient_embed': fit_umap(patient_emb, patient_label)})
# plot % overlap with train patients
if loop_type != 'train' and self.hparams.hparams['mrr_vs_percent_overlap']:
max_percent_overlap_train = torch.cat([torch.tensor(x[f'val/max_percent_phen_overlap_train']) for x in outputs], dim=0)
self.logger.experiment.log({f'{loop_type}/mrr_vs_percent_overlap': mrr_vs_percent_overlap(correct_gene_ranks.detach().cpu(), max_percent_overlap_train.detach().cpu())})
if self.hparams.hparams['plot_frac_rank']:
# Rank of gene vs. fraction of phenotypes to disease
frac_p_with_direct_edge_to_dx = [pat[0][0] for x in outputs for pat in x[f'{loop_type}/frac_p_with_direct_edge_to_dx']] # NOTE: Currently ony select first disease.
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_frac_p_with_direct_edge_to_dx': plot_gene_rank_vs_fraction_phenotype(correct_gene_ranks.cpu(), frac_p_with_direct_edge_to_dx)})
# Rank of gene vs. fraction of phenotypes to gene
frac_p_with_direct_edge_to_g = [pat[0][0] for x in outputs for pat in x[f'{loop_type}/frac_p_with_direct_edge_to_g']] # NOTE: Currently ony select first gene.
self.logger.experiment.log({f'{loop_type}/frac_p_with_direct_edge_to_g': plot_gene_rank_vs_fraction_phenotype(correct_gene_ranks.cpu(), frac_p_with_direct_edge_to_g)})
if self.hparams.hparams['plot_nhops_rank']:
# Rank of gene vs. hops from disease
nhops_g_d = [pat[0] for x in outputs for pat in x[f'{loop_type}/n_hops_g_d']] # NOTE Currently ony select first disease.
fig_mean, fig_min = plot_gene_rank_vs_hops(correct_gene_ranks.cpu(), nhops_g_d)
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_mean_n_hops_g_d': fig_mean})
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_min_n_hops_g_d': fig_min})
# Rank of gene vs. mean/min hops from phenotypes
nhops_g_p = [pat[0] for x in outputs for pat in x[f'{loop_type}/n_hops_g_p']] # NOTE Currently ony select first gene.
fig_mean, fig_min = plot_gene_rank_vs_hops(correct_gene_ranks.cpu(), nhops_g_p)
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_mean_n_hops_g_p': fig_mean})
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_min_n_hops_g_p': fig_min})
# # Rank of gene vs. distance between phenotypes
nhops_p_p = [torch.tensor(pat) for x in outputs for pat in x[f'{loop_type}/n_hops_p_p']]
fig_mean, fig_min = plot_gene_rank_vs_hops(correct_gene_ranks.cpu(), nhops_p_p)
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_mean_n_hops_p_p': fig_mean})
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_min_n_hops_p_p': fig_min})
if self.hparams.hparams['plot_attn_nhops']:
# plot phenotype attention vs n_hops to gene and degree
attn_weights = [torch.split(x[f'{loop_type}/attention_weights'],1) for x in outputs]
attn_weights = [w[w > 0] for batch_w in attn_weights for w in batch_w]
phenotype_names = [pat for x in outputs for pat in x[f'{loop_type}/phenotype_names_degrees']]
attn_weights_cpu_reshaped = torch.cat(attn_weights, dim=0)
self.logger.experiment.log({f"{loop_type}_attn/attention weights": wandb.Histogram(attn_weights_cpu_reshaped[attn_weights_cpu_reshaped != 0])})
self.logger.experiment.log({f"{loop_type}_attn/single patient attention weights": wandb.Histogram(attn_weights[0])})
self.logger.experiment.log({f"{loop_type}_attn/n hops to gene vs attention weights" : plot_nhops_to_gene_vs_attention(attn_weights, phenotype_names, nhops_g_p)})
self.logger.experiment.log({f"{loop_type}_attn/single patient n hops to gene vs attention weights" : plot_nhops_to_gene_vs_attention(attn_weights, phenotype_names, nhops_g_p, single_patient=True)})
self.logger.experiment.log({f"{loop_type}_attn/degree vs attention weights" : plot_degree_vs_attention(attn_weights, phenotype_names)})
self.logger.experiment.log({f"{loop_type}_attn/single patient degree vs attention weights" : plot_degree_vs_attention(attn_weights, phenotype_names, single_patient=True)})
data = [[p_name[0], w.item(), p_name[1], n_hops_to_g] for w, p_name, n_hops_to_g in zip(attn_weights[0], phenotype_names[0], nhops_g_p[0])]
self.logger.experiment.log({f"{loop_type}_attn/phenotypes": wandb.Table(data=data, columns=["HPO Code", "Attention Weight", "Degree", "Num Hops to Gene" ])})
if self.hparams.hparams['plot_phen_gene_sims']:
all_phen_gene_sims, all_raw_phen_gene_sims, all_pg_spl, all_correct_sims, all_incorrect_sims = [], [], [], [], []
for x in outputs:
phen_gene_sims = x[f'{loop_type}/phen_gene_sims']
one_hot_labels = x[f'{loop_type}/one_hot_labels']
correct_phen_squeuegene_sims = all_correct_sims.append(phen_gene_sims[one_hot_labels == 1])
incorrect_phen_gene_sims = all_incorrect_sims.append(phen_gene_sims[one_hot_labels != 1])
phen_gene_sims_reshaped = all_phen_gene_sims.append(phen_gene_sims.view(-1))
phen_gene_sims_reshaped = torch.cat(all_phen_gene_sims)
correct_phen_gene_sims = torch.cat(all_correct_sims)
incorrect_phen_gene_sims = torch.cat(all_incorrect_sims)
if len(all_pg_spl) > 0: pg_spl_reshaped = torch.cat(all_pg_spl)
else: pg_spl_reshaped = []
self.logger.experiment.log({f"{loop_type}_pg_similarities/phenotype-gene similarities": wandb.Histogram(phen_gene_sims_reshaped[phen_gene_sims_reshaped != -100000])})
self.logger.experiment.log({f"{loop_type}_pg_similarities/phenotype-correct gene similarities": wandb.Histogram(correct_phen_gene_sims[correct_phen_gene_sims != -100000])})
self.logger.experiment.log({f"{loop_type}_pg_similarities/phenotype-incorrect gene similarities": wandb.Histogram(incorrect_phen_gene_sims[incorrect_phen_gene_sims != -100000])})
if len(pg_spl_reshaped) > 0: self.logger.experiment.log({f"{loop_type}_pg_similarities/pg spl": wandb.Histogram(pg_spl_reshaped[pg_spl_reshaped != 0])})
phen_gene_sims_patient = outputs[0][f'{loop_type}/phen_gene_sims'][0,:]
one_hot_labels_patient = outputs[0][f'{loop_type}/one_hot_labels'][0,:]
correct_phen_gene_sims_patient = phen_gene_sims_patient[one_hot_labels_patient == 1]
assert len(correct_phen_gene_sims_patient) == 1
incorrect_phen_gene_sims_patient = phen_gene_sims_patient[one_hot_labels_patient != 1]
self.logger.experiment.log({f"{loop_type}_pg_similarities/single patient phenotype-gene similarities": wandb.Histogram(phen_gene_sims_patient[phen_gene_sims_patient != -100000])})
self.logger.experiment.log({f"{loop_type}_pg_similarities/single patient phenotype-correct gene similarities": wandb.Histogram(correct_phen_gene_sims_patient[correct_phen_gene_sims_patient != -100000])})
self.logger.experiment.log({f"{loop_type}_pg_similarities/single patient phenotype-incorrect gene similarities": wandb.Histogram(incorrect_phen_gene_sims_patient[incorrect_phen_gene_sims_patient != -100000])})
if len(pg_spl_reshaped) > 0: self.logger.experiment.log({f"{loop_type}_pg_similarities/single patient pg spl": wandb.Histogram(pg_spl_reshaped[pg_spl_reshaped != 0])})
# top k accuracy
top_1_acc = top_k_acc(correct_gene_ranks, k=1)
top_3_acc = top_k_acc(correct_gene_ranks, k=3)
top_5_acc = top_k_acc(correct_gene_ranks, k=5)
top_10_acc = top_k_acc(correct_gene_ranks, k=10)
#mean reciprocal rank
mrr = mean_reciprocal_rank(correct_gene_ranks)
avg_rank = average_rank(correct_gene_ranks)
self.log(f'{loop_type}/gp_{loop_type}_epoch_top1_acc', top_1_acc, prog_bar=False)
self.log(f'{loop_type}/gp_{loop_type}_epoch_top3_acc', top_3_acc, prog_bar=False)
self.log(f'{loop_type}/gp_{loop_type}_epoch_top5_acc', top_5_acc, prog_bar=False)
self.log(f'{loop_type}/gp_{loop_type}_epoch_top10_acc', top_10_acc, prog_bar=False)
self.log(f'{loop_type}/gp_{loop_type}_epoch_mrr', mrr, prog_bar=False)
self.log(f'{loop_type}/gp_{loop_type}_epoch_avg_rank', avg_rank, prog_bar=False)
if loop_type == 'val':
self.log(f'curr_epoch', self.current_epoch, prog_bar=False)
def training_epoch_end(self, outputs):
if self.hparams.hparams['plot_intrain']:
all_train_nodes, counts = torch.unique(torch.cat([x['train/n_id'] for x in outputs], dim=0), return_counts=True)
curr_all_train_nodes = {n.item(): c.item() if n not in self.all_train_nodes else c.item() + self.all_train_nodes[n].item() for n, c in zip(all_train_nodes, counts)}
self.all_train_nodes.update(curr_all_train_nodes)
train_sparse_nodes, counts = torch.unique(torch.cat([x['train/sparse_idx'] for x in outputs], dim=0), return_counts=True)
curr_train_sparse_nodes = {n.item(): c.item() if n not in self.train_sparse_nodes else c.item() + self.train_sparse_nodes[n].item() for n, c in zip(train_sparse_nodes, counts)}
self.train_sparse_nodes.update(curr_train_sparse_nodes)
train_target_batch, counts = torch.unique(torch.cat([x['train/target_batch'] for x in outputs], dim=0), return_counts=True)
curr_train_target_batch = {n.item(): c.item() if n not in self.train_target_batch else c.item() + self.train_target_batch[n].item() for n, c in zip(train_target_batch, counts)}
self.train_target_batch.update(curr_train_target_batch)
train_patient_nodes, counts = torch.unique(torch.cat([x['train/cand_gene_nid_orig'] for x in outputs], dim=0), return_counts=True)
self.train_patient_nodes = {n.item(): c.item() for n, c in zip(train_patient_nodes, counts)}
train_corr_gene_nids, counts = torch.unique(torch.cat([x['train/corr_gene_nid_orig'] for x in outputs], dim=0), return_counts=True)
self.train_corr_gene_nid = {g.item(): c.item() for g, c in zip(train_corr_gene_nids, counts)}
self._epoch_end(outputs, 'train')
def validation_epoch_end(self, outputs):
self._epoch_end(outputs, 'val')
def test_epoch_end(self, outputs):
self._epoch_end(outputs, 'test')
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.hparams['lr'])
return optimizer