[bdbb47]: / shepherd / task_heads / patient_nca.py

Download this file

87 lines (61 with data), 4.4 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb
from torch import nn
import torch
import torch.nn.functional as F
from torch.nn import TransformerEncoderLayer
import numpy as np
from scipy.stats import rankdata
from allennlp.modules.attention import CosineAttention, BilinearAttention, AdditiveAttention, DotProductAttention
from utils.loss_utils import NCALoss
from utils.train_utils import mean_reciprocal_rank, top_k_acc, masked_mean, masked_softmax, weighted_sum
class PatientNCA(pl.LightningModule):
def __init__(self, hparams, embed_dim):
super().__init__()
self.hyperparameters = hparams
# attention for collapsing set of phenotype embeddings
self.attn_vector = nn.Parameter(torch.zeros((1, embed_dim), dtype=torch.float), requires_grad=True)
nn.init.xavier_uniform_(self.attn_vector)
if self.hyperparameters['attention_type'] == 'bilinear':
self.attention = BilinearAttention(embed_dim, embed_dim)
elif self.hyperparameters['attention_type'] == 'additive':
self.attention = AdditiveAttention(embed_dim, embed_dim)
elif self.hyperparameters['attention_type'] == 'dotpdt':
self.attention = DotProductAttention()
# projection layers
self.phen_project = nn.Linear(embed_dim, embed_dim)
self.phen_project2 = nn.Linear(embed_dim, embed_dim)
if self.hyperparameters['loss'] == 'patient_disease_NCA':
self.disease_project = nn.Linear(embed_dim, embed_dim)
self.disease_project2 = nn.Linear(embed_dim, embed_dim)
self.leaky_relu = nn.LeakyReLU(hparams['leaky_relu'])
self.loss = NCALoss(softmax_scale=self.hyperparameters['softmax_scale'], only_hard_distractors=self.hyperparameters['only_hard_distractors'])
if 'n_transformer_layers' in hparams and hparams['n_transformer_layers'] > 0:
encoder_layer = TransformerEncoderLayer(d_model=embed_dim, nhead=hparams['n_transformer_heads'])
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=hparams['n_transformer_layers'])
def forward(self, phenotype_embeddings, disease_embeddings, phenotype_mask=None, disease_mask=None):
assert phenotype_mask != None
if self.hyperparameters['loss'] == 'patient_disease_NCA': assert disease_mask != None
if 'n_transformer_layers' in self.hyperparameters and self.hyperparameters['n_transformer_layers'] > 0:
phenotype_embeddings = self.transformer_encoder(phenotype_embeddings.transpose(0, 1), src_key_padding_mask=~phenotype_mask).transpose(0, 1)
# attention weighted average of phenotype embeddings
batched_attn = self.attn_vector.repeat(phenotype_embeddings.shape[0],1)
attn_weights = self.attention(batched_attn, phenotype_embeddings, phenotype_mask)
phenotype_embedding = weighted_sum(phenotype_embeddings, attn_weights)
# project embeddings
phenotype_embedding = self.phen_project2(self.leaky_relu(self.phen_project(phenotype_embedding)))
if self.hyperparameters['loss'] == 'patient_disease_NCA': disease_embeddings = self.disease_project2(self.leaky_relu(self.disease_project(disease_embeddings)))
else: disease_embeddings = None
return phenotype_embedding, disease_embeddings, phenotype_mask, disease_mask, attn_weights
def calc_loss(self, batch, phenotype_embedding, disease_embeddings, disease_mask, labels, use_candidate_list):
if self.hyperparameters['loss'] == 'patient_disease_NCA':
loss, softmax, labels, candidate_disease_idx, candidate_disease_embeddings = self.loss(phenotype_embedding, disease_embeddings, batch.batch_disease_nid, batch.batch_cand_disease_nid, disease_mask, labels, use_candidate_list=use_candidate_list)
elif self.hyperparameters['loss'] == 'patient_patient_NCA':
loss, softmax, labels, candidate_disease_idx, candidate_disease_embeddings = self.loss(phenotype_embedding, None, None, None, None, labels, use_candidate_list=False)
else:
raise NotImplementedError
return loss, softmax, labels, candidate_disease_idx, candidate_disease_embeddings
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hyperparameters['lr'])
return optimizer