Switch to side-by-side view

--- a
+++ b/shepherd/task_heads/patient_nca.py
@@ -0,0 +1,87 @@
+
+import pytorch_lightning as pl
+from pytorch_lightning.loggers import WandbLogger
+
+import wandb
+
+from torch import nn
+import torch
+import torch.nn.functional as F
+from torch.nn import TransformerEncoderLayer
+
+import numpy as np
+from scipy.stats import rankdata
+
+from allennlp.modules.attention import CosineAttention, BilinearAttention, AdditiveAttention, DotProductAttention
+
+
+from utils.loss_utils import NCALoss
+from utils.train_utils import mean_reciprocal_rank, top_k_acc, masked_mean, masked_softmax, weighted_sum
+
+
+class PatientNCA(pl.LightningModule):
+
+    def __init__(self, hparams, embed_dim):
+        super().__init__()
+        self.hyperparameters = hparams
+
+        # attention for collapsing set of phenotype embeddings
+        self.attn_vector = nn.Parameter(torch.zeros((1, embed_dim), dtype=torch.float), requires_grad=True)   
+        nn.init.xavier_uniform_(self.attn_vector)
+        
+        if self.hyperparameters['attention_type'] == 'bilinear':
+            self.attention = BilinearAttention(embed_dim, embed_dim)
+        elif self.hyperparameters['attention_type'] == 'additive':
+            self.attention = AdditiveAttention(embed_dim, embed_dim)
+        elif self.hyperparameters['attention_type'] == 'dotpdt':
+            self.attention = DotProductAttention()
+
+        # projection layers
+        self.phen_project = nn.Linear(embed_dim, embed_dim)
+        self.phen_project2 = nn.Linear(embed_dim, embed_dim)
+        if self.hyperparameters['loss'] == 'patient_disease_NCA':
+            self.disease_project = nn.Linear(embed_dim, embed_dim)
+            self.disease_project2 = nn.Linear(embed_dim, embed_dim)
+
+        self.leaky_relu = nn.LeakyReLU(hparams['leaky_relu'])
+
+        self.loss = NCALoss(softmax_scale=self.hyperparameters['softmax_scale'], only_hard_distractors=self.hyperparameters['only_hard_distractors']) 
+        
+        if 'n_transformer_layers' in hparams and hparams['n_transformer_layers'] > 0:
+            encoder_layer = TransformerEncoderLayer(d_model=embed_dim, nhead=hparams['n_transformer_heads'])
+            self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=hparams['n_transformer_layers'])
+
+
+    def forward(self, phenotype_embeddings, disease_embeddings, phenotype_mask=None, disease_mask=None): 
+        assert phenotype_mask != None  
+        if self.hyperparameters['loss'] == 'patient_disease_NCA':  assert disease_mask != None
+        
+        if 'n_transformer_layers' in self.hyperparameters and self.hyperparameters['n_transformer_layers'] > 0:
+            phenotype_embeddings = self.transformer_encoder(phenotype_embeddings.transpose(0, 1), src_key_padding_mask=~phenotype_mask).transpose(0, 1)
+
+
+        # attention weighted average of phenotype embeddings
+        batched_attn = self.attn_vector.repeat(phenotype_embeddings.shape[0],1)
+        attn_weights = self.attention(batched_attn, phenotype_embeddings, phenotype_mask)
+        phenotype_embedding = weighted_sum(phenotype_embeddings, attn_weights)
+        
+        # project embeddings
+        phenotype_embedding = self.phen_project2(self.leaky_relu(self.phen_project(phenotype_embedding)))
+        if self.hyperparameters['loss'] == 'patient_disease_NCA': disease_embeddings = self.disease_project2(self.leaky_relu(self.disease_project(disease_embeddings)))
+        else: disease_embeddings = None
+
+        return phenotype_embedding, disease_embeddings, phenotype_mask, disease_mask, attn_weights
+
+    def calc_loss(self, batch, phenotype_embedding, disease_embeddings, disease_mask, labels, use_candidate_list):
+        if self.hyperparameters['loss'] == 'patient_disease_NCA':
+            loss, softmax, labels, candidate_disease_idx, candidate_disease_embeddings = self.loss(phenotype_embedding, disease_embeddings, batch.batch_disease_nid, batch.batch_cand_disease_nid, disease_mask, labels, use_candidate_list=use_candidate_list)
+        elif self.hyperparameters['loss'] == 'patient_patient_NCA':
+            loss, softmax, labels, candidate_disease_idx, candidate_disease_embeddings = self.loss(phenotype_embedding, None, None, None, None, labels, use_candidate_list=False)
+        else:
+            raise NotImplementedError
+        return loss, softmax, labels, candidate_disease_idx, candidate_disease_embeddings
+
+
+    def configure_optimizers(self):
+        optimizer = torch.optim.Adam(self.parameters(), lr=self.hyperparameters['lr'])
+        return optimizer
\ No newline at end of file