Switch to side-by-side view

--- a
+++ b/shepherd/task_heads/gp_aligner.py
@@ -0,0 +1,162 @@
+
+import pytorch_lightning as pl
+from pytorch_lightning.loggers import WandbLogger
+import wandb
+
+from torch import nn
+import torch
+import torch.nn.functional as F
+from torch.nn import TransformerEncoderLayer
+
+import numpy as np
+from scipy.stats import rankdata
+
+from allennlp.modules.attention import CosineAttention, BilinearAttention, AdditiveAttention, DotProductAttention
+
+
+from utils.loss_utils import MultisimilarityCriterion, _construct_labels, unique, _construct_disease_labels
+from utils.train_utils import masked_mean, masked_softmax, weighted_sum, plot_degree_vs_attention, mean_reciprocal_rank, top_k_acc
+
+class GPAligner(pl.LightningModule):
+
+    def __init__(self, hparams, embed_dim):
+        super().__init__()
+        self.hyperparameters = hparams
+        print('GPAligner embedding dimension: ', embed_dim)
+
+        # attention for collapsing set of phenotype embeddings
+        self.attn_vector = nn.Parameter(torch.zeros((1, embed_dim), dtype=torch.float), requires_grad=True)   
+        nn.init.xavier_uniform_(self.attn_vector)
+        
+        if self.hyperparameters['attention_type'] == 'bilinear':
+            self.attention = BilinearAttention(embed_dim, embed_dim)
+        elif self.hyperparameters['attention_type'] == 'additive':
+            self.attention = AdditiveAttention(embed_dim, embed_dim)
+        elif self.hyperparameters['attention_type'] == 'dotpdt':
+            self.attention = DotProductAttention()
+        
+        if self.hyperparameters['decoder_type'] == "dotpdt": 
+            self.decoder = DotProductAttention(normalize=False)
+        elif self.hyperparameters['decoder_type'] == "bilinear": 
+            self.decoder = BilinearAttention(embed_dim, embed_dim, activation=torch.tanh, normalize=False)
+        else:
+            raise NotImplementedError
+
+        # projection layers
+        self.phen_project = nn.Linear(embed_dim, embed_dim) 
+        self.gene_project = nn.Linear(embed_dim, embed_dim)
+        self.phen_project2 = nn.Linear(embed_dim, embed_dim)
+        self.gene_project2 = nn.Linear(embed_dim, embed_dim)
+
+        # optional disease projection layer
+        if self.hyperparameters['use_diseases']:
+            self.disease_project = nn.Linear(embed_dim, embed_dim)
+            self.disease_project2 = nn.Linear(embed_dim, embed_dim)
+
+        self.leaky_relu = nn.LeakyReLU(hparams['leaky_relu'])
+
+        self.loss = MultisimilarityCriterion(hparams['pos_weight'], hparams['neg_weight'], 
+                                hparams['margin'], hparams['thresh'], 
+                                embed_dim, hparams['only_hard_distractors']) 
+
+        if 'n_transformer_layers' in hparams and hparams['n_transformer_layers'] > 0:
+            encoder_layer = TransformerEncoderLayer(d_model=embed_dim, nhead=hparams['n_transformer_heads'])
+            self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=hparams['n_transformer_layers'])
+
+
+    def forward(self, phenotype_embeddings, candidate_gene_embeddings, disease_embeddings=None, phenotype_mask=None, gene_mask=None, disease_mask=None): 
+        assert phenotype_mask != None
+        assert gene_mask != None
+        if self.hyperparameters['use_diseases']: assert disease_mask != None
+
+        if 'n_transformer_layers' in self.hyperparameters and self.hyperparameters['n_transformer_layers'] > 0:
+            phenotype_embeddings = self.transformer_encoder(phenotype_embeddings.transpose(0, 1), src_key_padding_mask=~phenotype_mask).transpose(0, 1)
+
+        # attention weighted average of phenotype embeddings
+        batched_attn = self.attn_vector.repeat(phenotype_embeddings.shape[0],1)
+        attn_weights = self.attention(batched_attn, phenotype_embeddings, phenotype_mask)
+        phenotype_embedding = weighted_sum(phenotype_embeddings, attn_weights)
+
+        # project embeddings
+        phenotype_embedding = self.phen_project2(self.leaky_relu(self.phen_project(phenotype_embedding)))
+        candidate_gene_embeddings = self.gene_project2(self.leaky_relu(self.gene_project(candidate_gene_embeddings)))
+        
+
+        if self.hyperparameters['use_diseases']: 
+            disease_embeddings = self.disease_project2(self.leaky_relu(self.disease_project(disease_embeddings)))
+        else:
+            disease_embeddings = None
+            disease_mask = None
+
+        return phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights 
+
+
+    def _calc_similarity(self, phenotype_embeddings, candidate_gene_embeddings, disease_embeddings, batch_cand_gene_nid,  batch_corr_gene_nid, batch_disease_nid, one_hot_labels, gene_mask, phenotype_mask, disease_mask, use_candidate_list, cand_gene_to_phenotypes_spl, alpha): 
+        # Normalize Embeddings (within each individual patient)
+        phenotype_embeddings = F.normalize(phenotype_embeddings, p=2, dim=1) 
+        batch_sz = phenotype_embeddings.shape[0]
+        if disease_embeddings != None: disease_embeddings = F.normalize(disease_embeddings.squeeze(), p=2, dim=1) 
+        if candidate_gene_embeddings != None:
+            batch_sz, n_cand_genes, embed_dim = candidate_gene_embeddings.shape
+            candidate_gene_embeddings = F.normalize(candidate_gene_embeddings.view(batch_sz*n_cand_genes,-1), p=2, dim=1).view(batch_sz, n_cand_genes, embed_dim)
+
+        # Only use each patient's candidate genes/diseases
+        if self.hyperparameters['only_hard_distractors'] or use_candidate_list:
+            if disease_embeddings == None: # only use genes
+                mask = gene_mask
+                one_hot_labels = one_hot_labels
+                raw_sims = self.decoder(phenotype_embeddings, candidate_gene_embeddings)
+                if cand_gene_to_phenotypes_spl != None:
+                    sims = alpha * raw_sims + (1 - alpha) * cand_gene_to_phenotypes_spl
+                else: sims = raw_sims
+            
+            elif candidate_gene_embeddings == None: # only use diseases
+                raise NotImplementedError
+            
+            else:
+                raise NotImplementedError
+        
+        # Otherwise, use entire batch as candidate genes/diseases
+        else:
+            if disease_embeddings == None: #only use genes
+                candidate_gene_idx, candidate_gene_embeddings, one_hot_labels = _construct_labels(candidate_gene_embeddings, batch_cand_gene_nid, batch_corr_gene_nid, gene_mask)
+                raw_sims = self.decoder(phenotype_embeddings, candidate_gene_embeddings.unsqueeze(0).repeat(batch_sz,1,1))
+                if cand_gene_to_phenotypes_spl != None:
+                    sims = alpha * raw_sims + (1 - alpha) * cand_gene_to_phenotypes_spl
+                else: sims = raw_sims
+                mask = None
+                
+            elif candidate_gene_embeddings == None: #only use diseases
+                candidate_embeddings, one_hot_labels = _construct_disease_labels(disease_embeddings, batch_disease_nid)
+                raw_sims = self.decoder(phenotype_embeddings, candidate_embeddings.unsqueeze(0).repeat(batch_sz,1,1))
+                if batch_disease_nid.shape[1] > 1:
+                    raw_sims = raw_sims[batch_disease_nid[:,0].squeeze() != 0] # remove rows where the patient doesn't have 
+                    one_hot_labels = one_hot_labels[batch_disease_nid[:,0].squeeze() != 0]
+                else:
+                    raw_sims = raw_sims[batch_disease_nid.squeeze() != 0] # remove rows where the patient doesn't have 
+                    one_hot_labels = one_hot_labels[batch_disease_nid.squeeze() != 0]
+                sims = raw_sims
+                mask = None
+
+            else: # use genes + diseases
+                raise NotImplementedError
+
+        return sims, raw_sims, mask, one_hot_labels
+
+
+    def _rank_genes(self, phen_gene_sims, gene_mask, one_hot_labels):
+        phen_gene_sims = phen_gene_sims * gene_mask
+        padded_phen_gene_sims = phen_gene_sims + (~gene_mask * -100000) # we want to rank the padded values last
+        gene_ranks = torch.tensor(np.apply_along_axis(lambda row: rankdata(row * -1, method='average'), axis=1, arr=padded_phen_gene_sims.detach().cpu().numpy()))
+        if one_hot_labels is None: correct_gene_ranks = None
+        else: 
+            gene_ranks = gene_ranks.to(one_hot_labels.device)
+            correct_gene_ranks = gene_ranks[one_hot_labels == 1]
+        return correct_gene_ranks, padded_phen_gene_sims
+
+    def calc_loss(self, sims, mask, one_hot_labels):
+        return self.loss(sims, mask, one_hot_labels)
+
+
+
+