Switch to unified view

a b/shepherd/gene_prioritization_model.py
1
#pytorch lightning
2
import pytorch_lightning as pl
3
from pytorch_lightning.loggers import WandbLogger
4
5
# torch
6
from torch import nn
7
import torch
8
import torch.nn.functional as F
9
import numpy as np
10
from scipy.stats import rankdata
11
import pandas as pd
12
from pathlib import Path
13
import time
14
import wandb
15
import sys
16
17
sys.path.insert(0, '..') # add project_config to path
18
19
from node_embedder_model import NodeEmbeder
20
from task_heads.gp_aligner import GPAligner
21
22
import project_config
23
24
# import utils
25
from utils.pretrain_utils import get_edges, calc_metrics
26
from utils.loss_utils import MultisimilarityCriterion
27
from utils.train_utils import mean_reciprocal_rank, top_k_acc, average_rank
28
from utils.train_utils import fit_umap, mrr_vs_percent_overlap, plot_gene_rank_vs_x_intrain, plot_gene_rank_vs_hops, plot_degree_vs_attention, plot_nhops_to_gene_vs_attention, plot_gene_rank_vs_fraction_phenotype, plot_gene_rank_vs_numtrain, plot_gene_rank_vs_trainset
29
from utils.train_utils import weighted_sum
30
31
class CombinedGPAligner(pl.LightningModule):
32
33
    def __init__(self, edge_attr_dict, all_data, n_nodes=None, node_ckpt=None, hparams=None, node_hparams=None,  spl_pca=[], spl_gate=[]):
34
        super().__init__()
35
        print('Initializing Model')
36
37
        self.save_hyperparameters('hparams', ignore=["spl_pca", "spl_gate"]) # spl_pca and spl_gate never get used
38
39
        #print('Saved combined model hyperparameters: ', self.hparams)
40
41
        self.all_data = all_data
42
43
        self.all_train_nodes = {}
44
        self.train_patient_nodes = {}
45
        self.train_sparse_nodes = {}
46
        self.train_target_batch = {}
47
        self.train_corr_gene_nid = {}
48
49
        print(f"Loading Node Embedder from {node_ckpt}")
50
51
        # NOTE: loads in saved hyperparameters
52
        self.node_model = NodeEmbeder.load_from_checkpoint(checkpoint_path=node_ckpt, 
53
                                                           all_data=all_data,
54
                                                           edge_attr_dict=edge_attr_dict, 
55
                                                           num_nodes=n_nodes)
56
        
57
        self.patient_model = self.get_patient_model()
58
        print('End Patient Model Initialization')
59
        
60
61
    def get_patient_model(self):
62
        # NOTE: this will only work with GATv2Conv
63
        model = GPAligner(self.hparams.hparams, embed_dim=self.node_model.hparams.hp_dict['output']*self.node_model.hparams.hp_dict['n_heads'])
64
        return model
65
66
67
    def forward(self, batch, step_type):
68
        # Node Embedder
69
        t0 = time.time()
70
        print(len(batch.adjs))
71
        outputs, gat_attn = self.node_model.forward(batch.n_id, batch.adjs)
72
        pad_outputs = torch.cat([torch.zeros(1, outputs.size(1), device=outputs.device), outputs])
73
        t1 = time.time()
74
75
        # get masks
76
        phenotype_mask = (batch.batch_pheno_nid != 0)
77
        gene_mask = (batch.batch_cand_gene_nid != 0)
78
79
        # index into outputs using phenotype & gene batch node idx
80
        batch_sz, max_n_phen = batch.batch_pheno_nid.shape
81
        phenotype_embeddings = torch.index_select(pad_outputs, 0, batch.batch_pheno_nid.view(-1)).view(batch_sz, max_n_phen, -1)
82
        batch_sz, max_n_cand_genes = batch.batch_cand_gene_nid.shape
83
        cand_gene_embeddings = torch.index_select(pad_outputs, 0, batch.batch_cand_gene_nid.view(-1)).view(batch_sz, max_n_cand_genes, -1)
84
85
        if self.hparams.hparams['augment_genes']:            
86
            print("Augmenting genes...", self.hparams.hparams['aug_gene_w'])
87
            _, max_n_sim_cand_genes, k_sim_genes = batch.batch_sim_gene_nid.shape
88
            sim_gene_embeddings = torch.index_select(pad_outputs, 0, batch.batch_sim_gene_nid.view(-1)).view(batch_sz, max_n_sim_cand_genes, self.hparams.hparams['n_sim_genes'], -1)
89
            agg_sim_gene_embedding = weighted_sum(sim_gene_embeddings, batch.batch_sim_gene_sims)
90
            if self.hparams.hparams['aug_gene_by_deg']:
91
                print("Augmenting gene by degree...")
92
                aug_gene_w = self.hparams.hparams['aug_gene_w'] * torch.exp(-self.hparams.hparams['aug_gene_w'] * batch.batch_cand_gene_degs) + (1 - self.hparams.hparams['aug_gene_w'] - 0.1)
93
                aug_gene_w = (aug_gene_w * (torch.sum(batch.batch_sim_gene_sims, dim = -1) > 0)).unsqueeze(-1)
94
            else:
95
                aug_gene_w = (self.hparams.hparams['aug_gene_w'] * (torch.sum(batch.batch_sim_gene_sims, dim = -1) > 0)).unsqueeze(-1)
96
            cand_gene_embeddings = torch.mul(1 - aug_gene_w, cand_gene_embeddings) + torch.mul(aug_gene_w, agg_sim_gene_embedding)
97
98
        # Patient Embedder with or without disease information
99
        if self.hparams.hparams['use_diseases']: 
100
            disease_mask = (batch.batch_disease_nid != 0)
101
            batch_sz, max_n_dx = batch.batch_disease_nid.shape
102
            disease_embeddings = torch.index_select(pad_outputs, 0, batch.batch_disease_nid.view(-1)).view(batch_sz, max_n_dx, -1)
103
            t2 = time.time()
104
            phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights = self.patient_model.forward(phenotype_embeddings, cand_gene_embeddings, disease_embeddings, phenotype_mask, gene_mask, disease_mask)
105
            t3 = time.time()
106
        else:
107
            t2 = time.time()
108
            phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights = self.patient_model.forward(phenotype_embeddings, cand_gene_embeddings, phenotype_mask=phenotype_mask, gene_mask=gene_mask)
109
            t3 = time.time()
110
111
        if self.hparams.hparams['time']:
112
            print(f'It takes {t1-t0:0.4f}s for the node model, {t2-t1:0.4f}s for indexing into the output, and {t3-t2:0.4f}s for the patient model forward.')
113
        
114
        return outputs, gat_attn, phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights
115
116
    def _step(self, batch, step_type):
117
        t0 = time.time()
118
        if step_type != 'test':
119
            batch = get_edges(batch, self.all_data, step_type)
120
        t1 = time.time()
121
122
        # Forward pass
123
        node_embeddings, gat_attn, phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights = self.forward(batch, step_type)
124
        t2 = time.time()
125
126
        # Calculate similarities between patient phenotypes & candidate genes/diseases
127
        alpha = self.hparams.hparams['alpha']
128
        use_candidate_list = True if step_type != 'train' else False
129
        cand_gene_to_phenotypes_spl = batch.batch_cand_gene_to_phenotypes_spl if use_candidate_list else batch.batch_concat_cand_gene_to_phenotypes_spl
130
        disease_nid = batch.batch_disease_nid if self.hparams.hparams['use_diseases'] else None
131
132
        # calculate similarity between phen & genes for all genes in manual candidate list
133
        phen_gene_sims, raw_phen_gene_sims, phen_gene_mask, phen_gene_one_hot_labels = self.patient_model._calc_similarity(phenotype_embedding, candidate_gene_embeddings, None, batch.batch_cand_gene_nid, batch.batch_corr_gene_nid, disease_nid, batch.one_hot_labels, gene_mask, phenotype_mask, disease_mask, True, batch.batch_cand_gene_to_phenotypes_spl, alpha) 
134
135
        # calculate similarity for loss function  
136
        if self.hparams.hparams['loss'] == 'gene_multisimilarity' and use_candidate_list: # in this case, the similarities are the same
137
            sims = phen_gene_sims
138
            mask = phen_gene_mask
139
            one_hot_labels = phen_gene_one_hot_labels
140
        else:
141
            if self.hparams.hparams['loss'] == 'disease_multisimilarity': candidate_gene_embeddings = None
142
            elif self.hparams.hparams['loss'] == 'gene_multisimilarity': disease_embeddings = None
143
            sims, raw_sims, mask, one_hot_labels = self.patient_model._calc_similarity(phenotype_embedding, candidate_gene_embeddings, disease_embeddings, batch.batch_cand_gene_nid, batch.batch_corr_gene_nid, disease_nid, batch.one_hot_labels, gene_mask, phenotype_mask, disease_mask, use_candidate_list, cand_gene_to_phenotypes_spl, alpha)
144
145
146
        ## Rank genes
147
        correct_gene_ranks, phen_gene_sims = self.patient_model._rank_genes(phen_gene_sims, phen_gene_mask, phen_gene_one_hot_labels)
148
        t3 = time.time()
149
150
        ## Calculate patient embedding loss
151
        loss = self.patient_model.calc_loss(sims, mask, one_hot_labels)
152
        t4 = time.time()
153
154
        ## Calculate node embedding loss
155
        if step_type == 'test':
156
            node_embedder_loss = 0
157
            roc_score, ap_score, acc, f1 = 0,0,0,0
158
        else:
159
            # Get link predictions
160
            batch, raw_pred, pred = self.node_model.get_predictions(batch, node_embeddings)
161
            link_labels = self.node_model.get_link_labels(batch.all_edge_types)
162
            node_embedder_loss = self.node_model.calc_loss(pred, link_labels)
163
164
            # Calculate metrics
165
            metric_pred = torch.sigmoid(raw_pred)
166
            roc_score, ap_score, acc, f1 = calc_metrics(metric_pred.cpu().detach().numpy(), link_labels.cpu().detach().numpy())
167
        t5 = time.time()
168
169
        ## calc time
170
        if self.hparams.hparams['time']:
171
            print(f'It takes {t1-t0:0.4f}s to get edges, {t2-t1:0.4f}s for the forward pass,  {t3-t2:0.4f}s to rank genes, {t4-t3:0.4f}s to calc patient loss, and {t5-t4:0.4f}s to calc the node loss.')
172
173
        ## Plot gradients
174
        if self.hparams.hparams['plot_gradients']:
175
            for k, v in self.patient_model.state_dict().items():
176
                self.logger.experiment.log({f'gradients/{step_type}.gradients.%s' % k: wandb.Histogram(v.detach().cpu())})
177
178
        return node_embedder_loss, loss, correct_gene_ranks, roc_score, ap_score, acc, f1, node_embeddings, gat_attn, phenotype_embedding, candidate_gene_embeddings, attn_weights, phen_gene_sims, raw_phen_gene_sims, gene_mask, phenotype_mask
179
180
    def training_step(self, batch, batch_idx):
181
        print('training step')
182
        node_embedder_loss, patient_loss, correct_gene_ranks, roc_score, ap_score, acc, f1, node_embeddings, gat_attn, phenotype_embedding, candidate_gene_embeddings, attn_weights, phen_gene_sims, raw_phen_gene_sims, gene_mask, phenotype_mask = self._step(batch, 'train')
183
184
        loss = (self.hparams.hparams['lambda'] * node_embedder_loss) + ((1 - self.hparams.hparams['lambda']) *  patient_loss)
185
        self.log('train_loss/patient.train_overall_loss', loss, prog_bar=True, on_epoch=True)
186
        self.log('train_loss/patient.train_patient_loss', patient_loss, prog_bar=True, on_epoch=True)
187
        self.log('train_loss/patient.train_node_embedder_loss', node_embedder_loss, prog_bar=True, on_epoch=True)
188
189
        batch_sz, n_candidates, embed_dim = candidate_gene_embeddings.shape
190
        candidate_gene_embeddings_flattened = candidate_gene_embeddings.view(batch_sz*n_candidates, embed_dim)
191
        one_hot_labels_flattened = batch.one_hot_labels.view(batch_sz*n_candidates)
192
193
        return {'loss': loss, 
194
                'train/train_correct_gene_ranks': correct_gene_ranks, 
195
                "train/node.train_roc": roc_score, 
196
                "train/node.train_ap": ap_score, 
197
                "train/node.train_acc": acc, 
198
                "train/node.train_f1": f1, 
199
                'train/one_hot_labels': batch.one_hot_labels.detach().cpu(),
200
                'train/attention_weights': attn_weights.detach().cpu() if attn_weights != None else None,
201
                'train/phen_gene_sims': phen_gene_sims.detach().cpu(),
202
                'train/phenotype_names_degrees': batch.phenotype_names,
203
                }
204
205
    def validation_step(self, batch, batch_idx):
206
        node_embedder_loss, patient_loss, correct_gene_ranks, roc_score, ap_score, acc, f1, node_embeddings, gat_attn, phenotype_embedding, candidate_gene_embeddings, attn_weights, phen_gene_sims, raw_phen_gene_sims, gene_mask, phenotype_mask = self._step(batch, 'val')
207
        loss = (self.hparams.hparams['lambda'] * node_embedder_loss) + ((1 - self.hparams.hparams['lambda']) * patient_loss)
208
        self.log('val_loss/patient.val_overall_loss', loss, prog_bar=True, on_epoch=True)
209
        self.log('val_loss/patient.val_patient_loss', patient_loss, prog_bar=True)
210
        self.log('val_loss/patient.val_node_embedder_loss', node_embedder_loss, prog_bar=True)
211
212
        return {'loss/val_loss': loss, 
213
                'val/val_correct_gene_ranks': correct_gene_ranks, 
214
                "val/node.val_roc": roc_score, 
215
                "val/node.val_ap": ap_score, 
216
                "val/node.val_acc": acc, 
217
                "val/node.val_f1": f1, 
218
                'val/one_hot_labels': batch.one_hot_labels.detach().cpu(),
219
                'val/attention_weights': attn_weights.detach().cpu() if attn_weights != None else None,
220
                'val/phen_gene_sims': phen_gene_sims.detach().cpu(),
221
                'val/phenotype_names_degrees': batch.phenotype_names,
222
                }
223
    
224
    def write_results_to_file(self, batch_info, phen_gene_sims, gene_mask, phenotype_mask, attn_weights, correct_gene_ranks, gat_attn, node_embeddings, phenotype_embeddings, save=True, loop_type='predict'):
225
        # NOTE: only saves a single batch - to run at inference time, make sure batch size includes all patients
226
227
        
228
        # Save GAT attention weights
229
        #NOTE: assumes 3 layers to model
230
        attn_dfs = []
231
        layer = 0
232
        for edge_attn in gat_attn:
233
            edge_index, attn = edge_attn
234
            edge_index = edge_index.cpu()
235
            attn = attn.cpu()
236
            gat_attn_df = pd.DataFrame({'source': edge_index[0,:], 'target': edge_index[1,:]})
237
            for head in range(attn.shape[1]):
238
                gat_attn_df[f'attn_{head}'] =  attn[:,head]
239
            attn_dfs.append(gat_attn_df)
240
            layer += 1
241
        
242
        
243
        # Save scores
244
        all_sims, all_genes, all_patient_ids, all_labels = [], [], [], []
245
        for patient_id, sims, genes, g_mask in zip(batch_info["patient_ids"], phen_gene_sims, batch_info["cand_gene_names"], gene_mask):
246
            nonpadded_sims = sims[g_mask].tolist()
247
            all_sims.extend(nonpadded_sims)
248
            all_genes.extend(genes)
249
            all_patient_ids.extend([patient_id] * len(genes))
250
        results_df = pd.DataFrame({'patient_id': all_patient_ids, 'genes': all_genes, 'similarities': all_sims})
251
252
        # Save phenotype information
253
        if attn_weights is None:
254
            phen_df = None
255
        else:
256
            all_patient_ids, all_phens, all_attn_weights, all_degrees = [], [], [], []
257
            for patient_id, attn_w, phen_names, p_mask in zip(batch_info["patient_ids"], attn_weights, batch_info["phenotype_names"], phenotype_mask):
258
                p_names, degrees = zip(*phen_names)
259
                all_patient_ids.extend([patient_id] * len(phen_names))
260
                all_degrees.extend(degrees)
261
                all_phens.extend(p_names)
262
                all_attn_weights.extend(attn_w[p_mask].tolist())
263
            phen_df = pd.DataFrame({'patient_id': all_patient_ids, 'phenotypes': all_phens, 'degrees': all_degrees, 'attention':all_attn_weights})
264
            print(phen_df.head())
265
266
        return results_df, phen_df, attn_dfs, phenotype_embeddings.cpu(), None
267
  
268
    def test_step(self, batch, batch_idx):
269
        node_embedder_loss, patient_loss, correct_gene_ranks, roc_score, ap_score, acc, f1, node_embeddings, gat_attn, phenotype_embedding, candidate_gene_embeddings, attn_weights, phen_gene_sims, raw_phen_gene_sims, gene_mask, phenotype_mask = self._step(batch, 'test')
270
        
271
        return {'test/test_correct_gene_ranks': correct_gene_ranks, 
272
                'test/node.embed': node_embeddings.detach().cpu(), 
273
                'test/patient.phenotype_embed': phenotype_embedding.detach().cpu(),
274
                'test/one_hot_labels': batch.one_hot_labels.detach().cpu(), #one_hot_labels_flattened.detach().cpu(),
275
                'test/attention_weights': attn_weights.detach().cpu() if attn_weights != None else None,
276
                'test/phen_gene_sims': phen_gene_sims.detach().cpu(),
277
                'test/phenotype_names_degrees': batch.phenotype_names, # type = list
278
                'test/gene_mask': gene_mask.detach().cpu(),
279
                'test/phenotype_mask': phenotype_mask.detach().cpu(),
280
                "test/patient_ids": batch.patient_ids, # type = list
281
                "test/cand_gene_names": batch.cand_gene_names, # type = list
282
283
                'test/gat_attn': gat_attn, # type = list
284
                "test/n_id": batch.n_id[:batch.batch_size].detach().cpu(),
285
                }
286
287
288
    def inference(self, batch, batch_idx):
289
        outputs, gat_attn = self.node_model.predict(self.all_data)
290
        pad_outputs = torch.cat([torch.zeros(1, outputs.size(1), device=outputs.device), outputs])
291
        t1 = time.time()
292
293
        # get masks
294
        phenotype_mask = (batch.batch_pheno_nid != 0)
295
        gene_mask = (batch.batch_cand_gene_nid != 0)
296
                
297
        # index into outputs using phenotype & gene batch node idx
298
        batch_sz, max_n_phen = batch.batch_pheno_nid.shape
299
        phenotype_embeddings = torch.index_select(pad_outputs, 0, batch.batch_pheno_nid.view(-1)).view(batch_sz, max_n_phen, -1)
300
        batch_sz, max_n_cand_genes = batch.batch_cand_gene_nid.shape
301
        cand_gene_embeddings = torch.index_select(pad_outputs, 0, batch.batch_cand_gene_nid.view(-1)).view(batch_sz, max_n_cand_genes, -1)
302
303
        if self.hparams.hparams['augment_genes']:            
304
            print("Augmenting genes at inference...", self.hparams.hparams['aug_gene_w'])
305
            _, max_n_sim_cand_genes, k_sim_genes = batch.batch_sim_gene_nid.shape
306
            sim_gene_embeddings = torch.index_select(pad_outputs, 0, batch.batch_sim_gene_nid.view(-1)).view(batch_sz, max_n_sim_cand_genes, self.hparams.hparams['n_sim_genes'], -1)
307
            agg_sim_gene_embedding = weighted_sum(sim_gene_embeddings, batch.batch_sim_gene_sims)        
308
            aug_gene_w = (self.hparams.hparams['aug_gene_w'] * (torch.sum(batch.batch_sim_gene_sims, dim = -1) > 0)).unsqueeze(-1)
309
            cand_gene_embeddings = torch.mul(1 - aug_gene_w, cand_gene_embeddings) + torch.mul(aug_gene_w, agg_sim_gene_embedding)
310
311
        # Patient Embedder with or without disease information
312
        if self.hparams.hparams['use_diseases']: 
313
            disease_mask = (batch.batch_disease_nid != 0)
314
            batch_sz, max_n_dx = batch.batch_disease_nid.shape
315
            disease_embeddings = torch.index_select(pad_outputs, 0, batch.batch_disease_nid.view(-1)).view(batch_sz, max_n_dx, -1)
316
            t2 = time.time()
317
            phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights = self.patient_model.forward(phenotype_embeddings, cand_gene_embeddings, disease_embeddings, phenotype_mask, gene_mask, disease_mask)
318
            t3 = time.time()
319
        else:
320
            t2 = time.time()
321
            phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights = self.patient_model.forward(phenotype_embeddings, cand_gene_embeddings, phenotype_mask=phenotype_mask, gene_mask=gene_mask)
322
            t3 = time.time()
323
324
        if self.hparams.hparams['time']:
325
            print(f'It takes {t1-t0:0.4f}s for the node model, {t2-t1:0.4f}s for indexing into the output, and {t3-t2:0.4f}s for the patient model forward.')
326
        
327
        return outputs, gat_attn, phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights
328
329
330
    def predict_step(self, batch, batch_idx):
331
        node_embeddings, gat_attn, phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights = self.inference(batch, batch_idx)
332
        
333
        # Calculate similarities between patient phenotypes & candidate genes/diseases
334
        alpha = self.hparams.hparams['alpha']
335
        use_candidate_list = True
336
        disease_nid = batch.batch_disease_nid if self.hparams.hparams['use_diseases'] else None
337
338
        # calculate similarity between phen & genes for all genes in manual candidate list
339
        phen_gene_sims, raw_phen_gene_sims, phen_gene_mask, phen_gene_one_hot_labels = self.patient_model._calc_similarity(phenotype_embedding, candidate_gene_embeddings, None, batch.batch_cand_gene_nid, batch.batch_corr_gene_nid, disease_nid, batch.one_hot_labels, gene_mask, phenotype_mask, disease_mask, True, batch.batch_cand_gene_to_phenotypes_spl, alpha) 
340
341
        # Rank genes
342
        correct_gene_ranks, phen_gene_sims = self.patient_model._rank_genes(phen_gene_sims, phen_gene_mask, phen_gene_one_hot_labels)
343
344
        results_df, phen_df, attn_dfs, phenotype_embeddings, disease_embeddings = self.write_results_to_file(batch, phen_gene_sims, gene_mask, phenotype_mask, attn_weights, correct_gene_ranks, gat_attn, node_embeddings, phenotype_embedding, save=True)
345
        return results_df, phen_df, *attn_dfs, phenotype_embeddings, disease_embeddings
346
347
    def _epoch_end(self, outputs, loop_type):
348
349
        correct_gene_ranks = torch.cat([x[f'{loop_type}/{loop_type}_correct_gene_ranks'] for x in outputs], dim=0)
350
351
        if loop_type == "test":
352
            
353
            batch_info = {"n_id": torch.cat([x[f'{loop_type}/n_id'] for x in outputs], dim=0),
354
                          "patient_ids": [pat for x in outputs for pat in x[f'{loop_type}/patient_ids']],
355
                          "phenotype_names": [pat for x in outputs for pat in x[f'{loop_type}/phenotype_names_degrees']],
356
                          "cand_gene_names": [pat for x in outputs for pat in x[f'{loop_type}/cand_gene_names']],
357
                          "one_hot_labels": [pat for x in outputs for pat in x[f'{loop_type}/one_hot_labels']],
358
                          }
359
360
            phen_gene_sims = [pat for x in outputs for pat in x[f'{loop_type}/phen_gene_sims']] 
361
            gene_mask = [pat for x in outputs for pat in x[f'{loop_type}/gene_mask']] 
362
            phenotype_mask = [pat for x in outputs for pat in x[f'{loop_type}/phenotype_mask']] 
363
            attn_weights = [pat for x in outputs for pat in x[f'{loop_type}/attention_weights']] 
364
            gat_attn = [pat for x in outputs for pat in x[f'{loop_type}/gat_attn']] 
365
            node_embeddings = torch.cat([x[f'{loop_type}/node.embed'] for x in outputs], dim=0)
366
            phenotype_embedding = torch.cat([x[f'{loop_type}/patient.phenotype_embed'] for x in outputs], dim=0)
367
            
368
            results_df, phen_df, attn_dfs, phenotype_embeddings, disease_embeddings = self.write_results_to_file(batch_info, phen_gene_sims, gene_mask, phenotype_mask, attn_weights, correct_gene_ranks, gat_attn, node_embeddings, phenotype_embedding, loop_type=loop_type)
369
            
370
            print("Writing results for test...")
371
            output_base = "/home/ml499/public_repos/SHEPHERD/shepherd/results/gp"
372
            results_df.to_csv(str(output_base) + '_scores.csv', index=False)
373
            print(results_df)
374
375
            phen_df.to_csv(str(output_base) + '_phenotype_attention.csv', sep = ',', index=False)
376
            print(phen_df)
377
378
379
        # Plot embeddings
380
        if loop_type != "train" and len(self.train_patient_nodes) > 0 and self.hparams.hparams['plot_intrain']:
381
382
            correct_gene_nid = torch.cat([x[f'{loop_type}/corr_gene_nid_orig'] for x in outputs], dim=0)
383
            assert correct_gene_ranks.shape[0] == correct_gene_nid.shape[0]
384
385
            # Rank of gene vs. number of train patients with causal gene
386
            gene_rank_corr_gene_fig, gene_rank_corr_gene_counts = plot_gene_rank_vs_numtrain(correct_gene_ranks, correct_gene_nid, self.train_corr_gene_nid)
387
            gene_rank_cand_gene_fig, gene_rank_cand_gene_counts = plot_gene_rank_vs_numtrain(correct_gene_ranks, correct_gene_nid, self.train_patient_nodes)
388
            gene_rank_sparse_fig, gene_rank_sparse_counts = plot_gene_rank_vs_numtrain(correct_gene_ranks, correct_gene_nid, self.train_sparse_nodes)
389
            gene_rank_target_fig, gene_rank_target_counts = plot_gene_rank_vs_numtrain(correct_gene_ranks, correct_gene_nid, self.train_target_batch)
390
            self.logger.experiment.log({f'{loop_type}/gene_rank_vs_num_train_corr_genes': gene_rank_corr_gene_fig})
391
            self.logger.experiment.log({f'{loop_type}/gene_rank_vs_num_train_cand_genes': gene_rank_cand_gene_fig})
392
            self.logger.experiment.log({f'{loop_type}/gene_rank_vs_num_train_sparse': gene_rank_sparse_fig})
393
            self.logger.experiment.log({f'{loop_type}/gene_rank_vs_num_train_target': gene_rank_target_fig})
394
            
395
            gene_nid_trainset = torch.stack([torch.tensor(gene_rank_corr_gene_counts),
396
                                             torch.tensor(gene_rank_cand_gene_counts),
397
                                             torch.tensor(gene_rank_sparse_counts),
398
                                             torch.tensor(gene_rank_target_counts)], dim=1)
399
            self.logger.experiment.log({f'{loop_type}/gene_rank_vs_trainset': plot_gene_rank_vs_trainset(correct_gene_ranks, correct_gene_nid, gene_nid_trainset)})
400
            
401
        if self.hparams.hparams['plot_PG_embed']:
402
            self.logger.experiment.log({f'{loop_type}/patient_embed': fit_umap(patient_emb, patient_label)})
403
404
        # plot % overlap with train patients
405
        if loop_type != 'train' and self.hparams.hparams['mrr_vs_percent_overlap']:
406
            max_percent_overlap_train = torch.cat([torch.tensor(x[f'val/max_percent_phen_overlap_train']) for x in outputs], dim=0)
407
            self.logger.experiment.log({f'{loop_type}/mrr_vs_percent_overlap': mrr_vs_percent_overlap(correct_gene_ranks.detach().cpu(), max_percent_overlap_train.detach().cpu())})
408
        
409
        if self.hparams.hparams['plot_frac_rank']:
410
411
            # Rank of gene vs. fraction of phenotypes to disease 
412
            frac_p_with_direct_edge_to_dx = [pat[0][0] for x in outputs for pat in x[f'{loop_type}/frac_p_with_direct_edge_to_dx']] # NOTE: Currently ony select first disease. 
413
            self.logger.experiment.log({f'{loop_type}/gene_rank_vs_frac_p_with_direct_edge_to_dx': plot_gene_rank_vs_fraction_phenotype(correct_gene_ranks.cpu(), frac_p_with_direct_edge_to_dx)})
414
            
415
            # Rank of gene vs. fraction of phenotypes to gene 
416
            frac_p_with_direct_edge_to_g = [pat[0][0] for x in outputs for pat in x[f'{loop_type}/frac_p_with_direct_edge_to_g']] # NOTE: Currently ony select first gene. 
417
            self.logger.experiment.log({f'{loop_type}/frac_p_with_direct_edge_to_g': plot_gene_rank_vs_fraction_phenotype(correct_gene_ranks.cpu(), frac_p_with_direct_edge_to_g)})
418
        
419
        if self.hparams.hparams['plot_nhops_rank']:
420
421
            # Rank of gene vs. hops from disease
422
            nhops_g_d = [pat[0] for x in outputs for pat in x[f'{loop_type}/n_hops_g_d']] # NOTE Currently ony select first disease.
423
            fig_mean, fig_min = plot_gene_rank_vs_hops(correct_gene_ranks.cpu(), nhops_g_d)
424
            self.logger.experiment.log({f'{loop_type}/gene_rank_vs_mean_n_hops_g_d': fig_mean})
425
            self.logger.experiment.log({f'{loop_type}/gene_rank_vs_min_n_hops_g_d': fig_min})
426
427
            # Rank of gene vs. mean/min hops from phenotypes
428
            nhops_g_p = [pat[0] for x in outputs for pat in x[f'{loop_type}/n_hops_g_p']] # NOTE Currently ony select first gene. 
429
            fig_mean, fig_min = plot_gene_rank_vs_hops(correct_gene_ranks.cpu(), nhops_g_p)
430
            self.logger.experiment.log({f'{loop_type}/gene_rank_vs_mean_n_hops_g_p': fig_mean})
431
            self.logger.experiment.log({f'{loop_type}/gene_rank_vs_min_n_hops_g_p': fig_min})
432
433
            # # Rank of gene vs. distance between phenotypes
434
            nhops_p_p = [torch.tensor(pat) for x in outputs for pat in x[f'{loop_type}/n_hops_p_p']] 
435
            fig_mean, fig_min = plot_gene_rank_vs_hops(correct_gene_ranks.cpu(), nhops_p_p)
436
            self.logger.experiment.log({f'{loop_type}/gene_rank_vs_mean_n_hops_p_p': fig_mean})
437
            self.logger.experiment.log({f'{loop_type}/gene_rank_vs_min_n_hops_p_p': fig_min})
438
439
        if self.hparams.hparams['plot_attn_nhops']:
440
441
            # plot phenotype attention vs n_hops to gene and degree
442
            attn_weights = [torch.split(x[f'{loop_type}/attention_weights'],1) for x in outputs]
443
            attn_weights = [w[w > 0] for batch_w in attn_weights for w in batch_w]
444
            phenotype_names = [pat for x in outputs for pat in x[f'{loop_type}/phenotype_names_degrees']]
445
            attn_weights_cpu_reshaped = torch.cat(attn_weights, dim=0)
446
            self.logger.experiment.log({f"{loop_type}_attn/attention weights": wandb.Histogram(attn_weights_cpu_reshaped[attn_weights_cpu_reshaped != 0])})
447
            self.logger.experiment.log({f"{loop_type}_attn/single patient attention weights": wandb.Histogram(attn_weights[0])})
448
            self.logger.experiment.log({f"{loop_type}_attn/n hops to gene vs attention weights" : plot_nhops_to_gene_vs_attention(attn_weights, phenotype_names, nhops_g_p)})
449
            self.logger.experiment.log({f"{loop_type}_attn/single patient n hops to gene vs attention weights" : plot_nhops_to_gene_vs_attention(attn_weights, phenotype_names, nhops_g_p, single_patient=True)})
450
            self.logger.experiment.log({f"{loop_type}_attn/degree vs attention weights" : plot_degree_vs_attention(attn_weights, phenotype_names)})
451
            self.logger.experiment.log({f"{loop_type}_attn/single patient degree vs attention weights" : plot_degree_vs_attention(attn_weights, phenotype_names, single_patient=True)})
452
            data = [[p_name[0], w.item(), p_name[1], n_hops_to_g] for w, p_name, n_hops_to_g in zip(attn_weights[0], phenotype_names[0], nhops_g_p[0])]
453
            self.logger.experiment.log({f"{loop_type}_attn/phenotypes": wandb.Table(data=data, columns=["HPO Code", "Attention Weight", "Degree", "Num Hops to Gene" ])}) 
454
        
455
        if self.hparams.hparams['plot_phen_gene_sims']:
456
457
            all_phen_gene_sims, all_raw_phen_gene_sims, all_pg_spl, all_correct_sims, all_incorrect_sims = [], [], [], [], []
458
            for x in outputs:
459
                phen_gene_sims = x[f'{loop_type}/phen_gene_sims']
460
                one_hot_labels = x[f'{loop_type}/one_hot_labels']
461
                correct_phen_squeuegene_sims = all_correct_sims.append(phen_gene_sims[one_hot_labels == 1])
462
                incorrect_phen_gene_sims = all_incorrect_sims.append(phen_gene_sims[one_hot_labels != 1])
463
                phen_gene_sims_reshaped = all_phen_gene_sims.append(phen_gene_sims.view(-1))
464
                
465
            phen_gene_sims_reshaped = torch.cat(all_phen_gene_sims)
466
            correct_phen_gene_sims = torch.cat(all_correct_sims)
467
            incorrect_phen_gene_sims = torch.cat(all_incorrect_sims)
468
469
            if len(all_pg_spl) > 0: pg_spl_reshaped = torch.cat(all_pg_spl)
470
            else: pg_spl_reshaped = []
471
472
            self.logger.experiment.log({f"{loop_type}_pg_similarities/phenotype-gene similarities": wandb.Histogram(phen_gene_sims_reshaped[phen_gene_sims_reshaped != -100000])})
473
            self.logger.experiment.log({f"{loop_type}_pg_similarities/phenotype-correct gene similarities": wandb.Histogram(correct_phen_gene_sims[correct_phen_gene_sims != -100000])})
474
            self.logger.experiment.log({f"{loop_type}_pg_similarities/phenotype-incorrect gene similarities": wandb.Histogram(incorrect_phen_gene_sims[incorrect_phen_gene_sims != -100000])})
475
476
            if len(pg_spl_reshaped) > 0: self.logger.experiment.log({f"{loop_type}_pg_similarities/pg spl": wandb.Histogram(pg_spl_reshaped[pg_spl_reshaped != 0])})
477
478
            phen_gene_sims_patient = outputs[0][f'{loop_type}/phen_gene_sims'][0,:]
479
            one_hot_labels_patient = outputs[0][f'{loop_type}/one_hot_labels'][0,:]
480
            correct_phen_gene_sims_patient = phen_gene_sims_patient[one_hot_labels_patient == 1]
481
482
            assert len(correct_phen_gene_sims_patient) == 1
483
            incorrect_phen_gene_sims_patient = phen_gene_sims_patient[one_hot_labels_patient != 1]
484
485
            self.logger.experiment.log({f"{loop_type}_pg_similarities/single patient phenotype-gene similarities": wandb.Histogram(phen_gene_sims_patient[phen_gene_sims_patient != -100000])})
486
            self.logger.experiment.log({f"{loop_type}_pg_similarities/single patient phenotype-correct gene similarities": wandb.Histogram(correct_phen_gene_sims_patient[correct_phen_gene_sims_patient != -100000])})
487
            self.logger.experiment.log({f"{loop_type}_pg_similarities/single patient phenotype-incorrect gene similarities": wandb.Histogram(incorrect_phen_gene_sims_patient[incorrect_phen_gene_sims_patient != -100000])})
488
489
            if len(pg_spl_reshaped) > 0: self.logger.experiment.log({f"{loop_type}_pg_similarities/single patient pg spl": wandb.Histogram(pg_spl_reshaped[pg_spl_reshaped != 0])})
490
491
492
        # top k accuracy
493
        top_1_acc = top_k_acc(correct_gene_ranks, k=1)
494
        top_3_acc = top_k_acc(correct_gene_ranks, k=3)
495
        top_5_acc = top_k_acc(correct_gene_ranks, k=5)
496
        top_10_acc = top_k_acc(correct_gene_ranks, k=10)
497
498
        #mean reciprocal rank
499
        mrr = mean_reciprocal_rank(correct_gene_ranks)
500
        avg_rank = average_rank(correct_gene_ranks)
501
502
        self.log(f'{loop_type}/gp_{loop_type}_epoch_top1_acc', top_1_acc, prog_bar=False)
503
        self.log(f'{loop_type}/gp_{loop_type}_epoch_top3_acc', top_3_acc, prog_bar=False)
504
        self.log(f'{loop_type}/gp_{loop_type}_epoch_top5_acc', top_5_acc, prog_bar=False)
505
        self.log(f'{loop_type}/gp_{loop_type}_epoch_top10_acc', top_10_acc, prog_bar=False)
506
        self.log(f'{loop_type}/gp_{loop_type}_epoch_mrr', mrr, prog_bar=False)
507
        self.log(f'{loop_type}/gp_{loop_type}_epoch_avg_rank', avg_rank, prog_bar=False)
508
509
        if loop_type == 'val':
510
            self.log(f'curr_epoch', self.current_epoch, prog_bar=False)
511
512
    def training_epoch_end(self, outputs):
513
514
        if self.hparams.hparams['plot_intrain']:
515
            all_train_nodes, counts = torch.unique(torch.cat([x['train/n_id'] for x in outputs], dim=0), return_counts=True)
516
            curr_all_train_nodes = {n.item(): c.item() if n not in self.all_train_nodes else c.item() + self.all_train_nodes[n].item() for n, c in zip(all_train_nodes, counts)}
517
            self.all_train_nodes.update(curr_all_train_nodes)
518
519
            train_sparse_nodes, counts = torch.unique(torch.cat([x['train/sparse_idx'] for x in outputs], dim=0), return_counts=True)
520
            curr_train_sparse_nodes = {n.item(): c.item() if n not in self.train_sparse_nodes else c.item() + self.train_sparse_nodes[n].item() for n, c in zip(train_sparse_nodes, counts)}
521
            self.train_sparse_nodes.update(curr_train_sparse_nodes)
522
523
            train_target_batch, counts = torch.unique(torch.cat([x['train/target_batch'] for x in outputs], dim=0), return_counts=True)
524
            curr_train_target_batch = {n.item(): c.item() if n not in self.train_target_batch else c.item() + self.train_target_batch[n].item() for n, c in zip(train_target_batch, counts)}
525
            self.train_target_batch.update(curr_train_target_batch)
526
            
527
            train_patient_nodes, counts = torch.unique(torch.cat([x['train/cand_gene_nid_orig'] for x in outputs], dim=0), return_counts=True)
528
            self.train_patient_nodes = {n.item(): c.item() for n, c in zip(train_patient_nodes, counts)}
529
            
530
            train_corr_gene_nids, counts = torch.unique(torch.cat([x['train/corr_gene_nid_orig'] for x in outputs], dim=0), return_counts=True)
531
            self.train_corr_gene_nid = {g.item(): c.item() for g, c in zip(train_corr_gene_nids, counts)}
532
533
        self._epoch_end(outputs, 'train')
534
535
    def validation_epoch_end(self, outputs):
536
        self._epoch_end(outputs, 'val')
537
538
    def test_epoch_end(self, outputs):
539
        self._epoch_end(outputs, 'test')
540
    
541
    def configure_optimizers(self):
542
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.hparams['lr'])
543
        return optimizer