|
a |
|
b/shepherd/gene_prioritization_model.py |
|
|
1 |
#pytorch lightning |
|
|
2 |
import pytorch_lightning as pl |
|
|
3 |
from pytorch_lightning.loggers import WandbLogger |
|
|
4 |
|
|
|
5 |
# torch |
|
|
6 |
from torch import nn |
|
|
7 |
import torch |
|
|
8 |
import torch.nn.functional as F |
|
|
9 |
import numpy as np |
|
|
10 |
from scipy.stats import rankdata |
|
|
11 |
import pandas as pd |
|
|
12 |
from pathlib import Path |
|
|
13 |
import time |
|
|
14 |
import wandb |
|
|
15 |
import sys |
|
|
16 |
|
|
|
17 |
sys.path.insert(0, '..') # add project_config to path |
|
|
18 |
|
|
|
19 |
from node_embedder_model import NodeEmbeder |
|
|
20 |
from task_heads.gp_aligner import GPAligner |
|
|
21 |
|
|
|
22 |
import project_config |
|
|
23 |
|
|
|
24 |
# import utils |
|
|
25 |
from utils.pretrain_utils import get_edges, calc_metrics |
|
|
26 |
from utils.loss_utils import MultisimilarityCriterion |
|
|
27 |
from utils.train_utils import mean_reciprocal_rank, top_k_acc, average_rank |
|
|
28 |
from utils.train_utils import fit_umap, mrr_vs_percent_overlap, plot_gene_rank_vs_x_intrain, plot_gene_rank_vs_hops, plot_degree_vs_attention, plot_nhops_to_gene_vs_attention, plot_gene_rank_vs_fraction_phenotype, plot_gene_rank_vs_numtrain, plot_gene_rank_vs_trainset |
|
|
29 |
from utils.train_utils import weighted_sum |
|
|
30 |
|
|
|
31 |
class CombinedGPAligner(pl.LightningModule): |
|
|
32 |
|
|
|
33 |
def __init__(self, edge_attr_dict, all_data, n_nodes=None, node_ckpt=None, hparams=None, node_hparams=None, spl_pca=[], spl_gate=[]): |
|
|
34 |
super().__init__() |
|
|
35 |
print('Initializing Model') |
|
|
36 |
|
|
|
37 |
self.save_hyperparameters('hparams', ignore=["spl_pca", "spl_gate"]) # spl_pca and spl_gate never get used |
|
|
38 |
|
|
|
39 |
#print('Saved combined model hyperparameters: ', self.hparams) |
|
|
40 |
|
|
|
41 |
self.all_data = all_data |
|
|
42 |
|
|
|
43 |
self.all_train_nodes = {} |
|
|
44 |
self.train_patient_nodes = {} |
|
|
45 |
self.train_sparse_nodes = {} |
|
|
46 |
self.train_target_batch = {} |
|
|
47 |
self.train_corr_gene_nid = {} |
|
|
48 |
|
|
|
49 |
print(f"Loading Node Embedder from {node_ckpt}") |
|
|
50 |
|
|
|
51 |
# NOTE: loads in saved hyperparameters |
|
|
52 |
self.node_model = NodeEmbeder.load_from_checkpoint(checkpoint_path=node_ckpt, |
|
|
53 |
all_data=all_data, |
|
|
54 |
edge_attr_dict=edge_attr_dict, |
|
|
55 |
num_nodes=n_nodes) |
|
|
56 |
|
|
|
57 |
self.patient_model = self.get_patient_model() |
|
|
58 |
print('End Patient Model Initialization') |
|
|
59 |
|
|
|
60 |
|
|
|
61 |
def get_patient_model(self): |
|
|
62 |
# NOTE: this will only work with GATv2Conv |
|
|
63 |
model = GPAligner(self.hparams.hparams, embed_dim=self.node_model.hparams.hp_dict['output']*self.node_model.hparams.hp_dict['n_heads']) |
|
|
64 |
return model |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
def forward(self, batch, step_type): |
|
|
68 |
# Node Embedder |
|
|
69 |
t0 = time.time() |
|
|
70 |
print(len(batch.adjs)) |
|
|
71 |
outputs, gat_attn = self.node_model.forward(batch.n_id, batch.adjs) |
|
|
72 |
pad_outputs = torch.cat([torch.zeros(1, outputs.size(1), device=outputs.device), outputs]) |
|
|
73 |
t1 = time.time() |
|
|
74 |
|
|
|
75 |
# get masks |
|
|
76 |
phenotype_mask = (batch.batch_pheno_nid != 0) |
|
|
77 |
gene_mask = (batch.batch_cand_gene_nid != 0) |
|
|
78 |
|
|
|
79 |
# index into outputs using phenotype & gene batch node idx |
|
|
80 |
batch_sz, max_n_phen = batch.batch_pheno_nid.shape |
|
|
81 |
phenotype_embeddings = torch.index_select(pad_outputs, 0, batch.batch_pheno_nid.view(-1)).view(batch_sz, max_n_phen, -1) |
|
|
82 |
batch_sz, max_n_cand_genes = batch.batch_cand_gene_nid.shape |
|
|
83 |
cand_gene_embeddings = torch.index_select(pad_outputs, 0, batch.batch_cand_gene_nid.view(-1)).view(batch_sz, max_n_cand_genes, -1) |
|
|
84 |
|
|
|
85 |
if self.hparams.hparams['augment_genes']: |
|
|
86 |
print("Augmenting genes...", self.hparams.hparams['aug_gene_w']) |
|
|
87 |
_, max_n_sim_cand_genes, k_sim_genes = batch.batch_sim_gene_nid.shape |
|
|
88 |
sim_gene_embeddings = torch.index_select(pad_outputs, 0, batch.batch_sim_gene_nid.view(-1)).view(batch_sz, max_n_sim_cand_genes, self.hparams.hparams['n_sim_genes'], -1) |
|
|
89 |
agg_sim_gene_embedding = weighted_sum(sim_gene_embeddings, batch.batch_sim_gene_sims) |
|
|
90 |
if self.hparams.hparams['aug_gene_by_deg']: |
|
|
91 |
print("Augmenting gene by degree...") |
|
|
92 |
aug_gene_w = self.hparams.hparams['aug_gene_w'] * torch.exp(-self.hparams.hparams['aug_gene_w'] * batch.batch_cand_gene_degs) + (1 - self.hparams.hparams['aug_gene_w'] - 0.1) |
|
|
93 |
aug_gene_w = (aug_gene_w * (torch.sum(batch.batch_sim_gene_sims, dim = -1) > 0)).unsqueeze(-1) |
|
|
94 |
else: |
|
|
95 |
aug_gene_w = (self.hparams.hparams['aug_gene_w'] * (torch.sum(batch.batch_sim_gene_sims, dim = -1) > 0)).unsqueeze(-1) |
|
|
96 |
cand_gene_embeddings = torch.mul(1 - aug_gene_w, cand_gene_embeddings) + torch.mul(aug_gene_w, agg_sim_gene_embedding) |
|
|
97 |
|
|
|
98 |
# Patient Embedder with or without disease information |
|
|
99 |
if self.hparams.hparams['use_diseases']: |
|
|
100 |
disease_mask = (batch.batch_disease_nid != 0) |
|
|
101 |
batch_sz, max_n_dx = batch.batch_disease_nid.shape |
|
|
102 |
disease_embeddings = torch.index_select(pad_outputs, 0, batch.batch_disease_nid.view(-1)).view(batch_sz, max_n_dx, -1) |
|
|
103 |
t2 = time.time() |
|
|
104 |
phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights = self.patient_model.forward(phenotype_embeddings, cand_gene_embeddings, disease_embeddings, phenotype_mask, gene_mask, disease_mask) |
|
|
105 |
t3 = time.time() |
|
|
106 |
else: |
|
|
107 |
t2 = time.time() |
|
|
108 |
phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights = self.patient_model.forward(phenotype_embeddings, cand_gene_embeddings, phenotype_mask=phenotype_mask, gene_mask=gene_mask) |
|
|
109 |
t3 = time.time() |
|
|
110 |
|
|
|
111 |
if self.hparams.hparams['time']: |
|
|
112 |
print(f'It takes {t1-t0:0.4f}s for the node model, {t2-t1:0.4f}s for indexing into the output, and {t3-t2:0.4f}s for the patient model forward.') |
|
|
113 |
|
|
|
114 |
return outputs, gat_attn, phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights |
|
|
115 |
|
|
|
116 |
def _step(self, batch, step_type): |
|
|
117 |
t0 = time.time() |
|
|
118 |
if step_type != 'test': |
|
|
119 |
batch = get_edges(batch, self.all_data, step_type) |
|
|
120 |
t1 = time.time() |
|
|
121 |
|
|
|
122 |
# Forward pass |
|
|
123 |
node_embeddings, gat_attn, phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights = self.forward(batch, step_type) |
|
|
124 |
t2 = time.time() |
|
|
125 |
|
|
|
126 |
# Calculate similarities between patient phenotypes & candidate genes/diseases |
|
|
127 |
alpha = self.hparams.hparams['alpha'] |
|
|
128 |
use_candidate_list = True if step_type != 'train' else False |
|
|
129 |
cand_gene_to_phenotypes_spl = batch.batch_cand_gene_to_phenotypes_spl if use_candidate_list else batch.batch_concat_cand_gene_to_phenotypes_spl |
|
|
130 |
disease_nid = batch.batch_disease_nid if self.hparams.hparams['use_diseases'] else None |
|
|
131 |
|
|
|
132 |
# calculate similarity between phen & genes for all genes in manual candidate list |
|
|
133 |
phen_gene_sims, raw_phen_gene_sims, phen_gene_mask, phen_gene_one_hot_labels = self.patient_model._calc_similarity(phenotype_embedding, candidate_gene_embeddings, None, batch.batch_cand_gene_nid, batch.batch_corr_gene_nid, disease_nid, batch.one_hot_labels, gene_mask, phenotype_mask, disease_mask, True, batch.batch_cand_gene_to_phenotypes_spl, alpha) |
|
|
134 |
|
|
|
135 |
# calculate similarity for loss function |
|
|
136 |
if self.hparams.hparams['loss'] == 'gene_multisimilarity' and use_candidate_list: # in this case, the similarities are the same |
|
|
137 |
sims = phen_gene_sims |
|
|
138 |
mask = phen_gene_mask |
|
|
139 |
one_hot_labels = phen_gene_one_hot_labels |
|
|
140 |
else: |
|
|
141 |
if self.hparams.hparams['loss'] == 'disease_multisimilarity': candidate_gene_embeddings = None |
|
|
142 |
elif self.hparams.hparams['loss'] == 'gene_multisimilarity': disease_embeddings = None |
|
|
143 |
sims, raw_sims, mask, one_hot_labels = self.patient_model._calc_similarity(phenotype_embedding, candidate_gene_embeddings, disease_embeddings, batch.batch_cand_gene_nid, batch.batch_corr_gene_nid, disease_nid, batch.one_hot_labels, gene_mask, phenotype_mask, disease_mask, use_candidate_list, cand_gene_to_phenotypes_spl, alpha) |
|
|
144 |
|
|
|
145 |
|
|
|
146 |
## Rank genes |
|
|
147 |
correct_gene_ranks, phen_gene_sims = self.patient_model._rank_genes(phen_gene_sims, phen_gene_mask, phen_gene_one_hot_labels) |
|
|
148 |
t3 = time.time() |
|
|
149 |
|
|
|
150 |
## Calculate patient embedding loss |
|
|
151 |
loss = self.patient_model.calc_loss(sims, mask, one_hot_labels) |
|
|
152 |
t4 = time.time() |
|
|
153 |
|
|
|
154 |
## Calculate node embedding loss |
|
|
155 |
if step_type == 'test': |
|
|
156 |
node_embedder_loss = 0 |
|
|
157 |
roc_score, ap_score, acc, f1 = 0,0,0,0 |
|
|
158 |
else: |
|
|
159 |
# Get link predictions |
|
|
160 |
batch, raw_pred, pred = self.node_model.get_predictions(batch, node_embeddings) |
|
|
161 |
link_labels = self.node_model.get_link_labels(batch.all_edge_types) |
|
|
162 |
node_embedder_loss = self.node_model.calc_loss(pred, link_labels) |
|
|
163 |
|
|
|
164 |
# Calculate metrics |
|
|
165 |
metric_pred = torch.sigmoid(raw_pred) |
|
|
166 |
roc_score, ap_score, acc, f1 = calc_metrics(metric_pred.cpu().detach().numpy(), link_labels.cpu().detach().numpy()) |
|
|
167 |
t5 = time.time() |
|
|
168 |
|
|
|
169 |
## calc time |
|
|
170 |
if self.hparams.hparams['time']: |
|
|
171 |
print(f'It takes {t1-t0:0.4f}s to get edges, {t2-t1:0.4f}s for the forward pass, {t3-t2:0.4f}s to rank genes, {t4-t3:0.4f}s to calc patient loss, and {t5-t4:0.4f}s to calc the node loss.') |
|
|
172 |
|
|
|
173 |
## Plot gradients |
|
|
174 |
if self.hparams.hparams['plot_gradients']: |
|
|
175 |
for k, v in self.patient_model.state_dict().items(): |
|
|
176 |
self.logger.experiment.log({f'gradients/{step_type}.gradients.%s' % k: wandb.Histogram(v.detach().cpu())}) |
|
|
177 |
|
|
|
178 |
return node_embedder_loss, loss, correct_gene_ranks, roc_score, ap_score, acc, f1, node_embeddings, gat_attn, phenotype_embedding, candidate_gene_embeddings, attn_weights, phen_gene_sims, raw_phen_gene_sims, gene_mask, phenotype_mask |
|
|
179 |
|
|
|
180 |
def training_step(self, batch, batch_idx): |
|
|
181 |
print('training step') |
|
|
182 |
node_embedder_loss, patient_loss, correct_gene_ranks, roc_score, ap_score, acc, f1, node_embeddings, gat_attn, phenotype_embedding, candidate_gene_embeddings, attn_weights, phen_gene_sims, raw_phen_gene_sims, gene_mask, phenotype_mask = self._step(batch, 'train') |
|
|
183 |
|
|
|
184 |
loss = (self.hparams.hparams['lambda'] * node_embedder_loss) + ((1 - self.hparams.hparams['lambda']) * patient_loss) |
|
|
185 |
self.log('train_loss/patient.train_overall_loss', loss, prog_bar=True, on_epoch=True) |
|
|
186 |
self.log('train_loss/patient.train_patient_loss', patient_loss, prog_bar=True, on_epoch=True) |
|
|
187 |
self.log('train_loss/patient.train_node_embedder_loss', node_embedder_loss, prog_bar=True, on_epoch=True) |
|
|
188 |
|
|
|
189 |
batch_sz, n_candidates, embed_dim = candidate_gene_embeddings.shape |
|
|
190 |
candidate_gene_embeddings_flattened = candidate_gene_embeddings.view(batch_sz*n_candidates, embed_dim) |
|
|
191 |
one_hot_labels_flattened = batch.one_hot_labels.view(batch_sz*n_candidates) |
|
|
192 |
|
|
|
193 |
return {'loss': loss, |
|
|
194 |
'train/train_correct_gene_ranks': correct_gene_ranks, |
|
|
195 |
"train/node.train_roc": roc_score, |
|
|
196 |
"train/node.train_ap": ap_score, |
|
|
197 |
"train/node.train_acc": acc, |
|
|
198 |
"train/node.train_f1": f1, |
|
|
199 |
'train/one_hot_labels': batch.one_hot_labels.detach().cpu(), |
|
|
200 |
'train/attention_weights': attn_weights.detach().cpu() if attn_weights != None else None, |
|
|
201 |
'train/phen_gene_sims': phen_gene_sims.detach().cpu(), |
|
|
202 |
'train/phenotype_names_degrees': batch.phenotype_names, |
|
|
203 |
} |
|
|
204 |
|
|
|
205 |
def validation_step(self, batch, batch_idx): |
|
|
206 |
node_embedder_loss, patient_loss, correct_gene_ranks, roc_score, ap_score, acc, f1, node_embeddings, gat_attn, phenotype_embedding, candidate_gene_embeddings, attn_weights, phen_gene_sims, raw_phen_gene_sims, gene_mask, phenotype_mask = self._step(batch, 'val') |
|
|
207 |
loss = (self.hparams.hparams['lambda'] * node_embedder_loss) + ((1 - self.hparams.hparams['lambda']) * patient_loss) |
|
|
208 |
self.log('val_loss/patient.val_overall_loss', loss, prog_bar=True, on_epoch=True) |
|
|
209 |
self.log('val_loss/patient.val_patient_loss', patient_loss, prog_bar=True) |
|
|
210 |
self.log('val_loss/patient.val_node_embedder_loss', node_embedder_loss, prog_bar=True) |
|
|
211 |
|
|
|
212 |
return {'loss/val_loss': loss, |
|
|
213 |
'val/val_correct_gene_ranks': correct_gene_ranks, |
|
|
214 |
"val/node.val_roc": roc_score, |
|
|
215 |
"val/node.val_ap": ap_score, |
|
|
216 |
"val/node.val_acc": acc, |
|
|
217 |
"val/node.val_f1": f1, |
|
|
218 |
'val/one_hot_labels': batch.one_hot_labels.detach().cpu(), |
|
|
219 |
'val/attention_weights': attn_weights.detach().cpu() if attn_weights != None else None, |
|
|
220 |
'val/phen_gene_sims': phen_gene_sims.detach().cpu(), |
|
|
221 |
'val/phenotype_names_degrees': batch.phenotype_names, |
|
|
222 |
} |
|
|
223 |
|
|
|
224 |
def write_results_to_file(self, batch_info, phen_gene_sims, gene_mask, phenotype_mask, attn_weights, correct_gene_ranks, gat_attn, node_embeddings, phenotype_embeddings, save=True, loop_type='predict'): |
|
|
225 |
# NOTE: only saves a single batch - to run at inference time, make sure batch size includes all patients |
|
|
226 |
|
|
|
227 |
|
|
|
228 |
# Save GAT attention weights |
|
|
229 |
#NOTE: assumes 3 layers to model |
|
|
230 |
attn_dfs = [] |
|
|
231 |
layer = 0 |
|
|
232 |
for edge_attn in gat_attn: |
|
|
233 |
edge_index, attn = edge_attn |
|
|
234 |
edge_index = edge_index.cpu() |
|
|
235 |
attn = attn.cpu() |
|
|
236 |
gat_attn_df = pd.DataFrame({'source': edge_index[0,:], 'target': edge_index[1,:]}) |
|
|
237 |
for head in range(attn.shape[1]): |
|
|
238 |
gat_attn_df[f'attn_{head}'] = attn[:,head] |
|
|
239 |
attn_dfs.append(gat_attn_df) |
|
|
240 |
layer += 1 |
|
|
241 |
|
|
|
242 |
|
|
|
243 |
# Save scores |
|
|
244 |
all_sims, all_genes, all_patient_ids, all_labels = [], [], [], [] |
|
|
245 |
for patient_id, sims, genes, g_mask in zip(batch_info["patient_ids"], phen_gene_sims, batch_info["cand_gene_names"], gene_mask): |
|
|
246 |
nonpadded_sims = sims[g_mask].tolist() |
|
|
247 |
all_sims.extend(nonpadded_sims) |
|
|
248 |
all_genes.extend(genes) |
|
|
249 |
all_patient_ids.extend([patient_id] * len(genes)) |
|
|
250 |
results_df = pd.DataFrame({'patient_id': all_patient_ids, 'genes': all_genes, 'similarities': all_sims}) |
|
|
251 |
|
|
|
252 |
# Save phenotype information |
|
|
253 |
if attn_weights is None: |
|
|
254 |
phen_df = None |
|
|
255 |
else: |
|
|
256 |
all_patient_ids, all_phens, all_attn_weights, all_degrees = [], [], [], [] |
|
|
257 |
for patient_id, attn_w, phen_names, p_mask in zip(batch_info["patient_ids"], attn_weights, batch_info["phenotype_names"], phenotype_mask): |
|
|
258 |
p_names, degrees = zip(*phen_names) |
|
|
259 |
all_patient_ids.extend([patient_id] * len(phen_names)) |
|
|
260 |
all_degrees.extend(degrees) |
|
|
261 |
all_phens.extend(p_names) |
|
|
262 |
all_attn_weights.extend(attn_w[p_mask].tolist()) |
|
|
263 |
phen_df = pd.DataFrame({'patient_id': all_patient_ids, 'phenotypes': all_phens, 'degrees': all_degrees, 'attention':all_attn_weights}) |
|
|
264 |
print(phen_df.head()) |
|
|
265 |
|
|
|
266 |
return results_df, phen_df, attn_dfs, phenotype_embeddings.cpu(), None |
|
|
267 |
|
|
|
268 |
def test_step(self, batch, batch_idx): |
|
|
269 |
node_embedder_loss, patient_loss, correct_gene_ranks, roc_score, ap_score, acc, f1, node_embeddings, gat_attn, phenotype_embedding, candidate_gene_embeddings, attn_weights, phen_gene_sims, raw_phen_gene_sims, gene_mask, phenotype_mask = self._step(batch, 'test') |
|
|
270 |
|
|
|
271 |
return {'test/test_correct_gene_ranks': correct_gene_ranks, |
|
|
272 |
'test/node.embed': node_embeddings.detach().cpu(), |
|
|
273 |
'test/patient.phenotype_embed': phenotype_embedding.detach().cpu(), |
|
|
274 |
'test/one_hot_labels': batch.one_hot_labels.detach().cpu(), #one_hot_labels_flattened.detach().cpu(), |
|
|
275 |
'test/attention_weights': attn_weights.detach().cpu() if attn_weights != None else None, |
|
|
276 |
'test/phen_gene_sims': phen_gene_sims.detach().cpu(), |
|
|
277 |
'test/phenotype_names_degrees': batch.phenotype_names, # type = list |
|
|
278 |
'test/gene_mask': gene_mask.detach().cpu(), |
|
|
279 |
'test/phenotype_mask': phenotype_mask.detach().cpu(), |
|
|
280 |
"test/patient_ids": batch.patient_ids, # type = list |
|
|
281 |
"test/cand_gene_names": batch.cand_gene_names, # type = list |
|
|
282 |
|
|
|
283 |
'test/gat_attn': gat_attn, # type = list |
|
|
284 |
"test/n_id": batch.n_id[:batch.batch_size].detach().cpu(), |
|
|
285 |
} |
|
|
286 |
|
|
|
287 |
|
|
|
288 |
def inference(self, batch, batch_idx): |
|
|
289 |
outputs, gat_attn = self.node_model.predict(self.all_data) |
|
|
290 |
pad_outputs = torch.cat([torch.zeros(1, outputs.size(1), device=outputs.device), outputs]) |
|
|
291 |
t1 = time.time() |
|
|
292 |
|
|
|
293 |
# get masks |
|
|
294 |
phenotype_mask = (batch.batch_pheno_nid != 0) |
|
|
295 |
gene_mask = (batch.batch_cand_gene_nid != 0) |
|
|
296 |
|
|
|
297 |
# index into outputs using phenotype & gene batch node idx |
|
|
298 |
batch_sz, max_n_phen = batch.batch_pheno_nid.shape |
|
|
299 |
phenotype_embeddings = torch.index_select(pad_outputs, 0, batch.batch_pheno_nid.view(-1)).view(batch_sz, max_n_phen, -1) |
|
|
300 |
batch_sz, max_n_cand_genes = batch.batch_cand_gene_nid.shape |
|
|
301 |
cand_gene_embeddings = torch.index_select(pad_outputs, 0, batch.batch_cand_gene_nid.view(-1)).view(batch_sz, max_n_cand_genes, -1) |
|
|
302 |
|
|
|
303 |
if self.hparams.hparams['augment_genes']: |
|
|
304 |
print("Augmenting genes at inference...", self.hparams.hparams['aug_gene_w']) |
|
|
305 |
_, max_n_sim_cand_genes, k_sim_genes = batch.batch_sim_gene_nid.shape |
|
|
306 |
sim_gene_embeddings = torch.index_select(pad_outputs, 0, batch.batch_sim_gene_nid.view(-1)).view(batch_sz, max_n_sim_cand_genes, self.hparams.hparams['n_sim_genes'], -1) |
|
|
307 |
agg_sim_gene_embedding = weighted_sum(sim_gene_embeddings, batch.batch_sim_gene_sims) |
|
|
308 |
aug_gene_w = (self.hparams.hparams['aug_gene_w'] * (torch.sum(batch.batch_sim_gene_sims, dim = -1) > 0)).unsqueeze(-1) |
|
|
309 |
cand_gene_embeddings = torch.mul(1 - aug_gene_w, cand_gene_embeddings) + torch.mul(aug_gene_w, agg_sim_gene_embedding) |
|
|
310 |
|
|
|
311 |
# Patient Embedder with or without disease information |
|
|
312 |
if self.hparams.hparams['use_diseases']: |
|
|
313 |
disease_mask = (batch.batch_disease_nid != 0) |
|
|
314 |
batch_sz, max_n_dx = batch.batch_disease_nid.shape |
|
|
315 |
disease_embeddings = torch.index_select(pad_outputs, 0, batch.batch_disease_nid.view(-1)).view(batch_sz, max_n_dx, -1) |
|
|
316 |
t2 = time.time() |
|
|
317 |
phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights = self.patient_model.forward(phenotype_embeddings, cand_gene_embeddings, disease_embeddings, phenotype_mask, gene_mask, disease_mask) |
|
|
318 |
t3 = time.time() |
|
|
319 |
else: |
|
|
320 |
t2 = time.time() |
|
|
321 |
phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights = self.patient_model.forward(phenotype_embeddings, cand_gene_embeddings, phenotype_mask=phenotype_mask, gene_mask=gene_mask) |
|
|
322 |
t3 = time.time() |
|
|
323 |
|
|
|
324 |
if self.hparams.hparams['time']: |
|
|
325 |
print(f'It takes {t1-t0:0.4f}s for the node model, {t2-t1:0.4f}s for indexing into the output, and {t3-t2:0.4f}s for the patient model forward.') |
|
|
326 |
|
|
|
327 |
return outputs, gat_attn, phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights |
|
|
328 |
|
|
|
329 |
|
|
|
330 |
def predict_step(self, batch, batch_idx): |
|
|
331 |
node_embeddings, gat_attn, phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights = self.inference(batch, batch_idx) |
|
|
332 |
|
|
|
333 |
# Calculate similarities between patient phenotypes & candidate genes/diseases |
|
|
334 |
alpha = self.hparams.hparams['alpha'] |
|
|
335 |
use_candidate_list = True |
|
|
336 |
disease_nid = batch.batch_disease_nid if self.hparams.hparams['use_diseases'] else None |
|
|
337 |
|
|
|
338 |
# calculate similarity between phen & genes for all genes in manual candidate list |
|
|
339 |
phen_gene_sims, raw_phen_gene_sims, phen_gene_mask, phen_gene_one_hot_labels = self.patient_model._calc_similarity(phenotype_embedding, candidate_gene_embeddings, None, batch.batch_cand_gene_nid, batch.batch_corr_gene_nid, disease_nid, batch.one_hot_labels, gene_mask, phenotype_mask, disease_mask, True, batch.batch_cand_gene_to_phenotypes_spl, alpha) |
|
|
340 |
|
|
|
341 |
# Rank genes |
|
|
342 |
correct_gene_ranks, phen_gene_sims = self.patient_model._rank_genes(phen_gene_sims, phen_gene_mask, phen_gene_one_hot_labels) |
|
|
343 |
|
|
|
344 |
results_df, phen_df, attn_dfs, phenotype_embeddings, disease_embeddings = self.write_results_to_file(batch, phen_gene_sims, gene_mask, phenotype_mask, attn_weights, correct_gene_ranks, gat_attn, node_embeddings, phenotype_embedding, save=True) |
|
|
345 |
return results_df, phen_df, *attn_dfs, phenotype_embeddings, disease_embeddings |
|
|
346 |
|
|
|
347 |
def _epoch_end(self, outputs, loop_type): |
|
|
348 |
|
|
|
349 |
correct_gene_ranks = torch.cat([x[f'{loop_type}/{loop_type}_correct_gene_ranks'] for x in outputs], dim=0) |
|
|
350 |
|
|
|
351 |
if loop_type == "test": |
|
|
352 |
|
|
|
353 |
batch_info = {"n_id": torch.cat([x[f'{loop_type}/n_id'] for x in outputs], dim=0), |
|
|
354 |
"patient_ids": [pat for x in outputs for pat in x[f'{loop_type}/patient_ids']], |
|
|
355 |
"phenotype_names": [pat for x in outputs for pat in x[f'{loop_type}/phenotype_names_degrees']], |
|
|
356 |
"cand_gene_names": [pat for x in outputs for pat in x[f'{loop_type}/cand_gene_names']], |
|
|
357 |
"one_hot_labels": [pat for x in outputs for pat in x[f'{loop_type}/one_hot_labels']], |
|
|
358 |
} |
|
|
359 |
|
|
|
360 |
phen_gene_sims = [pat for x in outputs for pat in x[f'{loop_type}/phen_gene_sims']] |
|
|
361 |
gene_mask = [pat for x in outputs for pat in x[f'{loop_type}/gene_mask']] |
|
|
362 |
phenotype_mask = [pat for x in outputs for pat in x[f'{loop_type}/phenotype_mask']] |
|
|
363 |
attn_weights = [pat for x in outputs for pat in x[f'{loop_type}/attention_weights']] |
|
|
364 |
gat_attn = [pat for x in outputs for pat in x[f'{loop_type}/gat_attn']] |
|
|
365 |
node_embeddings = torch.cat([x[f'{loop_type}/node.embed'] for x in outputs], dim=0) |
|
|
366 |
phenotype_embedding = torch.cat([x[f'{loop_type}/patient.phenotype_embed'] for x in outputs], dim=0) |
|
|
367 |
|
|
|
368 |
results_df, phen_df, attn_dfs, phenotype_embeddings, disease_embeddings = self.write_results_to_file(batch_info, phen_gene_sims, gene_mask, phenotype_mask, attn_weights, correct_gene_ranks, gat_attn, node_embeddings, phenotype_embedding, loop_type=loop_type) |
|
|
369 |
|
|
|
370 |
print("Writing results for test...") |
|
|
371 |
output_base = "/home/ml499/public_repos/SHEPHERD/shepherd/results/gp" |
|
|
372 |
results_df.to_csv(str(output_base) + '_scores.csv', index=False) |
|
|
373 |
print(results_df) |
|
|
374 |
|
|
|
375 |
phen_df.to_csv(str(output_base) + '_phenotype_attention.csv', sep = ',', index=False) |
|
|
376 |
print(phen_df) |
|
|
377 |
|
|
|
378 |
|
|
|
379 |
# Plot embeddings |
|
|
380 |
if loop_type != "train" and len(self.train_patient_nodes) > 0 and self.hparams.hparams['plot_intrain']: |
|
|
381 |
|
|
|
382 |
correct_gene_nid = torch.cat([x[f'{loop_type}/corr_gene_nid_orig'] for x in outputs], dim=0) |
|
|
383 |
assert correct_gene_ranks.shape[0] == correct_gene_nid.shape[0] |
|
|
384 |
|
|
|
385 |
# Rank of gene vs. number of train patients with causal gene |
|
|
386 |
gene_rank_corr_gene_fig, gene_rank_corr_gene_counts = plot_gene_rank_vs_numtrain(correct_gene_ranks, correct_gene_nid, self.train_corr_gene_nid) |
|
|
387 |
gene_rank_cand_gene_fig, gene_rank_cand_gene_counts = plot_gene_rank_vs_numtrain(correct_gene_ranks, correct_gene_nid, self.train_patient_nodes) |
|
|
388 |
gene_rank_sparse_fig, gene_rank_sparse_counts = plot_gene_rank_vs_numtrain(correct_gene_ranks, correct_gene_nid, self.train_sparse_nodes) |
|
|
389 |
gene_rank_target_fig, gene_rank_target_counts = plot_gene_rank_vs_numtrain(correct_gene_ranks, correct_gene_nid, self.train_target_batch) |
|
|
390 |
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_num_train_corr_genes': gene_rank_corr_gene_fig}) |
|
|
391 |
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_num_train_cand_genes': gene_rank_cand_gene_fig}) |
|
|
392 |
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_num_train_sparse': gene_rank_sparse_fig}) |
|
|
393 |
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_num_train_target': gene_rank_target_fig}) |
|
|
394 |
|
|
|
395 |
gene_nid_trainset = torch.stack([torch.tensor(gene_rank_corr_gene_counts), |
|
|
396 |
torch.tensor(gene_rank_cand_gene_counts), |
|
|
397 |
torch.tensor(gene_rank_sparse_counts), |
|
|
398 |
torch.tensor(gene_rank_target_counts)], dim=1) |
|
|
399 |
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_trainset': plot_gene_rank_vs_trainset(correct_gene_ranks, correct_gene_nid, gene_nid_trainset)}) |
|
|
400 |
|
|
|
401 |
if self.hparams.hparams['plot_PG_embed']: |
|
|
402 |
self.logger.experiment.log({f'{loop_type}/patient_embed': fit_umap(patient_emb, patient_label)}) |
|
|
403 |
|
|
|
404 |
# plot % overlap with train patients |
|
|
405 |
if loop_type != 'train' and self.hparams.hparams['mrr_vs_percent_overlap']: |
|
|
406 |
max_percent_overlap_train = torch.cat([torch.tensor(x[f'val/max_percent_phen_overlap_train']) for x in outputs], dim=0) |
|
|
407 |
self.logger.experiment.log({f'{loop_type}/mrr_vs_percent_overlap': mrr_vs_percent_overlap(correct_gene_ranks.detach().cpu(), max_percent_overlap_train.detach().cpu())}) |
|
|
408 |
|
|
|
409 |
if self.hparams.hparams['plot_frac_rank']: |
|
|
410 |
|
|
|
411 |
# Rank of gene vs. fraction of phenotypes to disease |
|
|
412 |
frac_p_with_direct_edge_to_dx = [pat[0][0] for x in outputs for pat in x[f'{loop_type}/frac_p_with_direct_edge_to_dx']] # NOTE: Currently ony select first disease. |
|
|
413 |
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_frac_p_with_direct_edge_to_dx': plot_gene_rank_vs_fraction_phenotype(correct_gene_ranks.cpu(), frac_p_with_direct_edge_to_dx)}) |
|
|
414 |
|
|
|
415 |
# Rank of gene vs. fraction of phenotypes to gene |
|
|
416 |
frac_p_with_direct_edge_to_g = [pat[0][0] for x in outputs for pat in x[f'{loop_type}/frac_p_with_direct_edge_to_g']] # NOTE: Currently ony select first gene. |
|
|
417 |
self.logger.experiment.log({f'{loop_type}/frac_p_with_direct_edge_to_g': plot_gene_rank_vs_fraction_phenotype(correct_gene_ranks.cpu(), frac_p_with_direct_edge_to_g)}) |
|
|
418 |
|
|
|
419 |
if self.hparams.hparams['plot_nhops_rank']: |
|
|
420 |
|
|
|
421 |
# Rank of gene vs. hops from disease |
|
|
422 |
nhops_g_d = [pat[0] for x in outputs for pat in x[f'{loop_type}/n_hops_g_d']] # NOTE Currently ony select first disease. |
|
|
423 |
fig_mean, fig_min = plot_gene_rank_vs_hops(correct_gene_ranks.cpu(), nhops_g_d) |
|
|
424 |
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_mean_n_hops_g_d': fig_mean}) |
|
|
425 |
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_min_n_hops_g_d': fig_min}) |
|
|
426 |
|
|
|
427 |
# Rank of gene vs. mean/min hops from phenotypes |
|
|
428 |
nhops_g_p = [pat[0] for x in outputs for pat in x[f'{loop_type}/n_hops_g_p']] # NOTE Currently ony select first gene. |
|
|
429 |
fig_mean, fig_min = plot_gene_rank_vs_hops(correct_gene_ranks.cpu(), nhops_g_p) |
|
|
430 |
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_mean_n_hops_g_p': fig_mean}) |
|
|
431 |
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_min_n_hops_g_p': fig_min}) |
|
|
432 |
|
|
|
433 |
# # Rank of gene vs. distance between phenotypes |
|
|
434 |
nhops_p_p = [torch.tensor(pat) for x in outputs for pat in x[f'{loop_type}/n_hops_p_p']] |
|
|
435 |
fig_mean, fig_min = plot_gene_rank_vs_hops(correct_gene_ranks.cpu(), nhops_p_p) |
|
|
436 |
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_mean_n_hops_p_p': fig_mean}) |
|
|
437 |
self.logger.experiment.log({f'{loop_type}/gene_rank_vs_min_n_hops_p_p': fig_min}) |
|
|
438 |
|
|
|
439 |
if self.hparams.hparams['plot_attn_nhops']: |
|
|
440 |
|
|
|
441 |
# plot phenotype attention vs n_hops to gene and degree |
|
|
442 |
attn_weights = [torch.split(x[f'{loop_type}/attention_weights'],1) for x in outputs] |
|
|
443 |
attn_weights = [w[w > 0] for batch_w in attn_weights for w in batch_w] |
|
|
444 |
phenotype_names = [pat for x in outputs for pat in x[f'{loop_type}/phenotype_names_degrees']] |
|
|
445 |
attn_weights_cpu_reshaped = torch.cat(attn_weights, dim=0) |
|
|
446 |
self.logger.experiment.log({f"{loop_type}_attn/attention weights": wandb.Histogram(attn_weights_cpu_reshaped[attn_weights_cpu_reshaped != 0])}) |
|
|
447 |
self.logger.experiment.log({f"{loop_type}_attn/single patient attention weights": wandb.Histogram(attn_weights[0])}) |
|
|
448 |
self.logger.experiment.log({f"{loop_type}_attn/n hops to gene vs attention weights" : plot_nhops_to_gene_vs_attention(attn_weights, phenotype_names, nhops_g_p)}) |
|
|
449 |
self.logger.experiment.log({f"{loop_type}_attn/single patient n hops to gene vs attention weights" : plot_nhops_to_gene_vs_attention(attn_weights, phenotype_names, nhops_g_p, single_patient=True)}) |
|
|
450 |
self.logger.experiment.log({f"{loop_type}_attn/degree vs attention weights" : plot_degree_vs_attention(attn_weights, phenotype_names)}) |
|
|
451 |
self.logger.experiment.log({f"{loop_type}_attn/single patient degree vs attention weights" : plot_degree_vs_attention(attn_weights, phenotype_names, single_patient=True)}) |
|
|
452 |
data = [[p_name[0], w.item(), p_name[1], n_hops_to_g] for w, p_name, n_hops_to_g in zip(attn_weights[0], phenotype_names[0], nhops_g_p[0])] |
|
|
453 |
self.logger.experiment.log({f"{loop_type}_attn/phenotypes": wandb.Table(data=data, columns=["HPO Code", "Attention Weight", "Degree", "Num Hops to Gene" ])}) |
|
|
454 |
|
|
|
455 |
if self.hparams.hparams['plot_phen_gene_sims']: |
|
|
456 |
|
|
|
457 |
all_phen_gene_sims, all_raw_phen_gene_sims, all_pg_spl, all_correct_sims, all_incorrect_sims = [], [], [], [], [] |
|
|
458 |
for x in outputs: |
|
|
459 |
phen_gene_sims = x[f'{loop_type}/phen_gene_sims'] |
|
|
460 |
one_hot_labels = x[f'{loop_type}/one_hot_labels'] |
|
|
461 |
correct_phen_squeuegene_sims = all_correct_sims.append(phen_gene_sims[one_hot_labels == 1]) |
|
|
462 |
incorrect_phen_gene_sims = all_incorrect_sims.append(phen_gene_sims[one_hot_labels != 1]) |
|
|
463 |
phen_gene_sims_reshaped = all_phen_gene_sims.append(phen_gene_sims.view(-1)) |
|
|
464 |
|
|
|
465 |
phen_gene_sims_reshaped = torch.cat(all_phen_gene_sims) |
|
|
466 |
correct_phen_gene_sims = torch.cat(all_correct_sims) |
|
|
467 |
incorrect_phen_gene_sims = torch.cat(all_incorrect_sims) |
|
|
468 |
|
|
|
469 |
if len(all_pg_spl) > 0: pg_spl_reshaped = torch.cat(all_pg_spl) |
|
|
470 |
else: pg_spl_reshaped = [] |
|
|
471 |
|
|
|
472 |
self.logger.experiment.log({f"{loop_type}_pg_similarities/phenotype-gene similarities": wandb.Histogram(phen_gene_sims_reshaped[phen_gene_sims_reshaped != -100000])}) |
|
|
473 |
self.logger.experiment.log({f"{loop_type}_pg_similarities/phenotype-correct gene similarities": wandb.Histogram(correct_phen_gene_sims[correct_phen_gene_sims != -100000])}) |
|
|
474 |
self.logger.experiment.log({f"{loop_type}_pg_similarities/phenotype-incorrect gene similarities": wandb.Histogram(incorrect_phen_gene_sims[incorrect_phen_gene_sims != -100000])}) |
|
|
475 |
|
|
|
476 |
if len(pg_spl_reshaped) > 0: self.logger.experiment.log({f"{loop_type}_pg_similarities/pg spl": wandb.Histogram(pg_spl_reshaped[pg_spl_reshaped != 0])}) |
|
|
477 |
|
|
|
478 |
phen_gene_sims_patient = outputs[0][f'{loop_type}/phen_gene_sims'][0,:] |
|
|
479 |
one_hot_labels_patient = outputs[0][f'{loop_type}/one_hot_labels'][0,:] |
|
|
480 |
correct_phen_gene_sims_patient = phen_gene_sims_patient[one_hot_labels_patient == 1] |
|
|
481 |
|
|
|
482 |
assert len(correct_phen_gene_sims_patient) == 1 |
|
|
483 |
incorrect_phen_gene_sims_patient = phen_gene_sims_patient[one_hot_labels_patient != 1] |
|
|
484 |
|
|
|
485 |
self.logger.experiment.log({f"{loop_type}_pg_similarities/single patient phenotype-gene similarities": wandb.Histogram(phen_gene_sims_patient[phen_gene_sims_patient != -100000])}) |
|
|
486 |
self.logger.experiment.log({f"{loop_type}_pg_similarities/single patient phenotype-correct gene similarities": wandb.Histogram(correct_phen_gene_sims_patient[correct_phen_gene_sims_patient != -100000])}) |
|
|
487 |
self.logger.experiment.log({f"{loop_type}_pg_similarities/single patient phenotype-incorrect gene similarities": wandb.Histogram(incorrect_phen_gene_sims_patient[incorrect_phen_gene_sims_patient != -100000])}) |
|
|
488 |
|
|
|
489 |
if len(pg_spl_reshaped) > 0: self.logger.experiment.log({f"{loop_type}_pg_similarities/single patient pg spl": wandb.Histogram(pg_spl_reshaped[pg_spl_reshaped != 0])}) |
|
|
490 |
|
|
|
491 |
|
|
|
492 |
# top k accuracy |
|
|
493 |
top_1_acc = top_k_acc(correct_gene_ranks, k=1) |
|
|
494 |
top_3_acc = top_k_acc(correct_gene_ranks, k=3) |
|
|
495 |
top_5_acc = top_k_acc(correct_gene_ranks, k=5) |
|
|
496 |
top_10_acc = top_k_acc(correct_gene_ranks, k=10) |
|
|
497 |
|
|
|
498 |
#mean reciprocal rank |
|
|
499 |
mrr = mean_reciprocal_rank(correct_gene_ranks) |
|
|
500 |
avg_rank = average_rank(correct_gene_ranks) |
|
|
501 |
|
|
|
502 |
self.log(f'{loop_type}/gp_{loop_type}_epoch_top1_acc', top_1_acc, prog_bar=False) |
|
|
503 |
self.log(f'{loop_type}/gp_{loop_type}_epoch_top3_acc', top_3_acc, prog_bar=False) |
|
|
504 |
self.log(f'{loop_type}/gp_{loop_type}_epoch_top5_acc', top_5_acc, prog_bar=False) |
|
|
505 |
self.log(f'{loop_type}/gp_{loop_type}_epoch_top10_acc', top_10_acc, prog_bar=False) |
|
|
506 |
self.log(f'{loop_type}/gp_{loop_type}_epoch_mrr', mrr, prog_bar=False) |
|
|
507 |
self.log(f'{loop_type}/gp_{loop_type}_epoch_avg_rank', avg_rank, prog_bar=False) |
|
|
508 |
|
|
|
509 |
if loop_type == 'val': |
|
|
510 |
self.log(f'curr_epoch', self.current_epoch, prog_bar=False) |
|
|
511 |
|
|
|
512 |
def training_epoch_end(self, outputs): |
|
|
513 |
|
|
|
514 |
if self.hparams.hparams['plot_intrain']: |
|
|
515 |
all_train_nodes, counts = torch.unique(torch.cat([x['train/n_id'] for x in outputs], dim=0), return_counts=True) |
|
|
516 |
curr_all_train_nodes = {n.item(): c.item() if n not in self.all_train_nodes else c.item() + self.all_train_nodes[n].item() for n, c in zip(all_train_nodes, counts)} |
|
|
517 |
self.all_train_nodes.update(curr_all_train_nodes) |
|
|
518 |
|
|
|
519 |
train_sparse_nodes, counts = torch.unique(torch.cat([x['train/sparse_idx'] for x in outputs], dim=0), return_counts=True) |
|
|
520 |
curr_train_sparse_nodes = {n.item(): c.item() if n not in self.train_sparse_nodes else c.item() + self.train_sparse_nodes[n].item() for n, c in zip(train_sparse_nodes, counts)} |
|
|
521 |
self.train_sparse_nodes.update(curr_train_sparse_nodes) |
|
|
522 |
|
|
|
523 |
train_target_batch, counts = torch.unique(torch.cat([x['train/target_batch'] for x in outputs], dim=0), return_counts=True) |
|
|
524 |
curr_train_target_batch = {n.item(): c.item() if n not in self.train_target_batch else c.item() + self.train_target_batch[n].item() for n, c in zip(train_target_batch, counts)} |
|
|
525 |
self.train_target_batch.update(curr_train_target_batch) |
|
|
526 |
|
|
|
527 |
train_patient_nodes, counts = torch.unique(torch.cat([x['train/cand_gene_nid_orig'] for x in outputs], dim=0), return_counts=True) |
|
|
528 |
self.train_patient_nodes = {n.item(): c.item() for n, c in zip(train_patient_nodes, counts)} |
|
|
529 |
|
|
|
530 |
train_corr_gene_nids, counts = torch.unique(torch.cat([x['train/corr_gene_nid_orig'] for x in outputs], dim=0), return_counts=True) |
|
|
531 |
self.train_corr_gene_nid = {g.item(): c.item() for g, c in zip(train_corr_gene_nids, counts)} |
|
|
532 |
|
|
|
533 |
self._epoch_end(outputs, 'train') |
|
|
534 |
|
|
|
535 |
def validation_epoch_end(self, outputs): |
|
|
536 |
self._epoch_end(outputs, 'val') |
|
|
537 |
|
|
|
538 |
def test_epoch_end(self, outputs): |
|
|
539 |
self._epoch_end(outputs, 'test') |
|
|
540 |
|
|
|
541 |
def configure_optimizers(self): |
|
|
542 |
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.hparams['lr']) |
|
|
543 |
return optimizer |