a b/shepherd/task_heads/patient_nca.py
1
2
import pytorch_lightning as pl
3
from pytorch_lightning.loggers import WandbLogger
4
5
import wandb
6
7
from torch import nn
8
import torch
9
import torch.nn.functional as F
10
from torch.nn import TransformerEncoderLayer
11
12
import numpy as np
13
from scipy.stats import rankdata
14
15
from allennlp.modules.attention import CosineAttention, BilinearAttention, AdditiveAttention, DotProductAttention
16
17
18
from utils.loss_utils import NCALoss
19
from utils.train_utils import mean_reciprocal_rank, top_k_acc, masked_mean, masked_softmax, weighted_sum
20
21
22
class PatientNCA(pl.LightningModule):
23
24
    def __init__(self, hparams, embed_dim):
25
        super().__init__()
26
        self.hyperparameters = hparams
27
28
        # attention for collapsing set of phenotype embeddings
29
        self.attn_vector = nn.Parameter(torch.zeros((1, embed_dim), dtype=torch.float), requires_grad=True)   
30
        nn.init.xavier_uniform_(self.attn_vector)
31
        
32
        if self.hyperparameters['attention_type'] == 'bilinear':
33
            self.attention = BilinearAttention(embed_dim, embed_dim)
34
        elif self.hyperparameters['attention_type'] == 'additive':
35
            self.attention = AdditiveAttention(embed_dim, embed_dim)
36
        elif self.hyperparameters['attention_type'] == 'dotpdt':
37
            self.attention = DotProductAttention()
38
39
        # projection layers
40
        self.phen_project = nn.Linear(embed_dim, embed_dim)
41
        self.phen_project2 = nn.Linear(embed_dim, embed_dim)
42
        if self.hyperparameters['loss'] == 'patient_disease_NCA':
43
            self.disease_project = nn.Linear(embed_dim, embed_dim)
44
            self.disease_project2 = nn.Linear(embed_dim, embed_dim)
45
46
        self.leaky_relu = nn.LeakyReLU(hparams['leaky_relu'])
47
48
        self.loss = NCALoss(softmax_scale=self.hyperparameters['softmax_scale'], only_hard_distractors=self.hyperparameters['only_hard_distractors']) 
49
        
50
        if 'n_transformer_layers' in hparams and hparams['n_transformer_layers'] > 0:
51
            encoder_layer = TransformerEncoderLayer(d_model=embed_dim, nhead=hparams['n_transformer_heads'])
52
            self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=hparams['n_transformer_layers'])
53
54
55
    def forward(self, phenotype_embeddings, disease_embeddings, phenotype_mask=None, disease_mask=None): 
56
        assert phenotype_mask != None  
57
        if self.hyperparameters['loss'] == 'patient_disease_NCA':  assert disease_mask != None
58
        
59
        if 'n_transformer_layers' in self.hyperparameters and self.hyperparameters['n_transformer_layers'] > 0:
60
            phenotype_embeddings = self.transformer_encoder(phenotype_embeddings.transpose(0, 1), src_key_padding_mask=~phenotype_mask).transpose(0, 1)
61
62
63
        # attention weighted average of phenotype embeddings
64
        batched_attn = self.attn_vector.repeat(phenotype_embeddings.shape[0],1)
65
        attn_weights = self.attention(batched_attn, phenotype_embeddings, phenotype_mask)
66
        phenotype_embedding = weighted_sum(phenotype_embeddings, attn_weights)
67
        
68
        # project embeddings
69
        phenotype_embedding = self.phen_project2(self.leaky_relu(self.phen_project(phenotype_embedding)))
70
        if self.hyperparameters['loss'] == 'patient_disease_NCA': disease_embeddings = self.disease_project2(self.leaky_relu(self.disease_project(disease_embeddings)))
71
        else: disease_embeddings = None
72
73
        return phenotype_embedding, disease_embeddings, phenotype_mask, disease_mask, attn_weights
74
75
    def calc_loss(self, batch, phenotype_embedding, disease_embeddings, disease_mask, labels, use_candidate_list):
76
        if self.hyperparameters['loss'] == 'patient_disease_NCA':
77
            loss, softmax, labels, candidate_disease_idx, candidate_disease_embeddings = self.loss(phenotype_embedding, disease_embeddings, batch.batch_disease_nid, batch.batch_cand_disease_nid, disease_mask, labels, use_candidate_list=use_candidate_list)
78
        elif self.hyperparameters['loss'] == 'patient_patient_NCA':
79
            loss, softmax, labels, candidate_disease_idx, candidate_disease_embeddings = self.loss(phenotype_embedding, None, None, None, None, labels, use_candidate_list=False)
80
        else:
81
            raise NotImplementedError
82
        return loss, softmax, labels, candidate_disease_idx, candidate_disease_embeddings
83
84
85
    def configure_optimizers(self):
86
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hyperparameters['lr'])
87
        return optimizer