|
a |
|
b/shepherd/task_heads/gp_aligner.py |
|
|
1 |
|
|
|
2 |
import pytorch_lightning as pl |
|
|
3 |
from pytorch_lightning.loggers import WandbLogger |
|
|
4 |
import wandb |
|
|
5 |
|
|
|
6 |
from torch import nn |
|
|
7 |
import torch |
|
|
8 |
import torch.nn.functional as F |
|
|
9 |
from torch.nn import TransformerEncoderLayer |
|
|
10 |
|
|
|
11 |
import numpy as np |
|
|
12 |
from scipy.stats import rankdata |
|
|
13 |
|
|
|
14 |
from allennlp.modules.attention import CosineAttention, BilinearAttention, AdditiveAttention, DotProductAttention |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
from utils.loss_utils import MultisimilarityCriterion, _construct_labels, unique, _construct_disease_labels |
|
|
18 |
from utils.train_utils import masked_mean, masked_softmax, weighted_sum, plot_degree_vs_attention, mean_reciprocal_rank, top_k_acc |
|
|
19 |
|
|
|
20 |
class GPAligner(pl.LightningModule): |
|
|
21 |
|
|
|
22 |
def __init__(self, hparams, embed_dim): |
|
|
23 |
super().__init__() |
|
|
24 |
self.hyperparameters = hparams |
|
|
25 |
print('GPAligner embedding dimension: ', embed_dim) |
|
|
26 |
|
|
|
27 |
# attention for collapsing set of phenotype embeddings |
|
|
28 |
self.attn_vector = nn.Parameter(torch.zeros((1, embed_dim), dtype=torch.float), requires_grad=True) |
|
|
29 |
nn.init.xavier_uniform_(self.attn_vector) |
|
|
30 |
|
|
|
31 |
if self.hyperparameters['attention_type'] == 'bilinear': |
|
|
32 |
self.attention = BilinearAttention(embed_dim, embed_dim) |
|
|
33 |
elif self.hyperparameters['attention_type'] == 'additive': |
|
|
34 |
self.attention = AdditiveAttention(embed_dim, embed_dim) |
|
|
35 |
elif self.hyperparameters['attention_type'] == 'dotpdt': |
|
|
36 |
self.attention = DotProductAttention() |
|
|
37 |
|
|
|
38 |
if self.hyperparameters['decoder_type'] == "dotpdt": |
|
|
39 |
self.decoder = DotProductAttention(normalize=False) |
|
|
40 |
elif self.hyperparameters['decoder_type'] == "bilinear": |
|
|
41 |
self.decoder = BilinearAttention(embed_dim, embed_dim, activation=torch.tanh, normalize=False) |
|
|
42 |
else: |
|
|
43 |
raise NotImplementedError |
|
|
44 |
|
|
|
45 |
# projection layers |
|
|
46 |
self.phen_project = nn.Linear(embed_dim, embed_dim) |
|
|
47 |
self.gene_project = nn.Linear(embed_dim, embed_dim) |
|
|
48 |
self.phen_project2 = nn.Linear(embed_dim, embed_dim) |
|
|
49 |
self.gene_project2 = nn.Linear(embed_dim, embed_dim) |
|
|
50 |
|
|
|
51 |
# optional disease projection layer |
|
|
52 |
if self.hyperparameters['use_diseases']: |
|
|
53 |
self.disease_project = nn.Linear(embed_dim, embed_dim) |
|
|
54 |
self.disease_project2 = nn.Linear(embed_dim, embed_dim) |
|
|
55 |
|
|
|
56 |
self.leaky_relu = nn.LeakyReLU(hparams['leaky_relu']) |
|
|
57 |
|
|
|
58 |
self.loss = MultisimilarityCriterion(hparams['pos_weight'], hparams['neg_weight'], |
|
|
59 |
hparams['margin'], hparams['thresh'], |
|
|
60 |
embed_dim, hparams['only_hard_distractors']) |
|
|
61 |
|
|
|
62 |
if 'n_transformer_layers' in hparams and hparams['n_transformer_layers'] > 0: |
|
|
63 |
encoder_layer = TransformerEncoderLayer(d_model=embed_dim, nhead=hparams['n_transformer_heads']) |
|
|
64 |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=hparams['n_transformer_layers']) |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
def forward(self, phenotype_embeddings, candidate_gene_embeddings, disease_embeddings=None, phenotype_mask=None, gene_mask=None, disease_mask=None): |
|
|
68 |
assert phenotype_mask != None |
|
|
69 |
assert gene_mask != None |
|
|
70 |
if self.hyperparameters['use_diseases']: assert disease_mask != None |
|
|
71 |
|
|
|
72 |
if 'n_transformer_layers' in self.hyperparameters and self.hyperparameters['n_transformer_layers'] > 0: |
|
|
73 |
phenotype_embeddings = self.transformer_encoder(phenotype_embeddings.transpose(0, 1), src_key_padding_mask=~phenotype_mask).transpose(0, 1) |
|
|
74 |
|
|
|
75 |
# attention weighted average of phenotype embeddings |
|
|
76 |
batched_attn = self.attn_vector.repeat(phenotype_embeddings.shape[0],1) |
|
|
77 |
attn_weights = self.attention(batched_attn, phenotype_embeddings, phenotype_mask) |
|
|
78 |
phenotype_embedding = weighted_sum(phenotype_embeddings, attn_weights) |
|
|
79 |
|
|
|
80 |
# project embeddings |
|
|
81 |
phenotype_embedding = self.phen_project2(self.leaky_relu(self.phen_project(phenotype_embedding))) |
|
|
82 |
candidate_gene_embeddings = self.gene_project2(self.leaky_relu(self.gene_project(candidate_gene_embeddings))) |
|
|
83 |
|
|
|
84 |
|
|
|
85 |
if self.hyperparameters['use_diseases']: |
|
|
86 |
disease_embeddings = self.disease_project2(self.leaky_relu(self.disease_project(disease_embeddings))) |
|
|
87 |
else: |
|
|
88 |
disease_embeddings = None |
|
|
89 |
disease_mask = None |
|
|
90 |
|
|
|
91 |
return phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights |
|
|
92 |
|
|
|
93 |
|
|
|
94 |
def _calc_similarity(self, phenotype_embeddings, candidate_gene_embeddings, disease_embeddings, batch_cand_gene_nid, batch_corr_gene_nid, batch_disease_nid, one_hot_labels, gene_mask, phenotype_mask, disease_mask, use_candidate_list, cand_gene_to_phenotypes_spl, alpha): |
|
|
95 |
# Normalize Embeddings (within each individual patient) |
|
|
96 |
phenotype_embeddings = F.normalize(phenotype_embeddings, p=2, dim=1) |
|
|
97 |
batch_sz = phenotype_embeddings.shape[0] |
|
|
98 |
if disease_embeddings != None: disease_embeddings = F.normalize(disease_embeddings.squeeze(), p=2, dim=1) |
|
|
99 |
if candidate_gene_embeddings != None: |
|
|
100 |
batch_sz, n_cand_genes, embed_dim = candidate_gene_embeddings.shape |
|
|
101 |
candidate_gene_embeddings = F.normalize(candidate_gene_embeddings.view(batch_sz*n_cand_genes,-1), p=2, dim=1).view(batch_sz, n_cand_genes, embed_dim) |
|
|
102 |
|
|
|
103 |
# Only use each patient's candidate genes/diseases |
|
|
104 |
if self.hyperparameters['only_hard_distractors'] or use_candidate_list: |
|
|
105 |
if disease_embeddings == None: # only use genes |
|
|
106 |
mask = gene_mask |
|
|
107 |
one_hot_labels = one_hot_labels |
|
|
108 |
raw_sims = self.decoder(phenotype_embeddings, candidate_gene_embeddings) |
|
|
109 |
if cand_gene_to_phenotypes_spl != None: |
|
|
110 |
sims = alpha * raw_sims + (1 - alpha) * cand_gene_to_phenotypes_spl |
|
|
111 |
else: sims = raw_sims |
|
|
112 |
|
|
|
113 |
elif candidate_gene_embeddings == None: # only use diseases |
|
|
114 |
raise NotImplementedError |
|
|
115 |
|
|
|
116 |
else: |
|
|
117 |
raise NotImplementedError |
|
|
118 |
|
|
|
119 |
# Otherwise, use entire batch as candidate genes/diseases |
|
|
120 |
else: |
|
|
121 |
if disease_embeddings == None: #only use genes |
|
|
122 |
candidate_gene_idx, candidate_gene_embeddings, one_hot_labels = _construct_labels(candidate_gene_embeddings, batch_cand_gene_nid, batch_corr_gene_nid, gene_mask) |
|
|
123 |
raw_sims = self.decoder(phenotype_embeddings, candidate_gene_embeddings.unsqueeze(0).repeat(batch_sz,1,1)) |
|
|
124 |
if cand_gene_to_phenotypes_spl != None: |
|
|
125 |
sims = alpha * raw_sims + (1 - alpha) * cand_gene_to_phenotypes_spl |
|
|
126 |
else: sims = raw_sims |
|
|
127 |
mask = None |
|
|
128 |
|
|
|
129 |
elif candidate_gene_embeddings == None: #only use diseases |
|
|
130 |
candidate_embeddings, one_hot_labels = _construct_disease_labels(disease_embeddings, batch_disease_nid) |
|
|
131 |
raw_sims = self.decoder(phenotype_embeddings, candidate_embeddings.unsqueeze(0).repeat(batch_sz,1,1)) |
|
|
132 |
if batch_disease_nid.shape[1] > 1: |
|
|
133 |
raw_sims = raw_sims[batch_disease_nid[:,0].squeeze() != 0] # remove rows where the patient doesn't have |
|
|
134 |
one_hot_labels = one_hot_labels[batch_disease_nid[:,0].squeeze() != 0] |
|
|
135 |
else: |
|
|
136 |
raw_sims = raw_sims[batch_disease_nid.squeeze() != 0] # remove rows where the patient doesn't have |
|
|
137 |
one_hot_labels = one_hot_labels[batch_disease_nid.squeeze() != 0] |
|
|
138 |
sims = raw_sims |
|
|
139 |
mask = None |
|
|
140 |
|
|
|
141 |
else: # use genes + diseases |
|
|
142 |
raise NotImplementedError |
|
|
143 |
|
|
|
144 |
return sims, raw_sims, mask, one_hot_labels |
|
|
145 |
|
|
|
146 |
|
|
|
147 |
def _rank_genes(self, phen_gene_sims, gene_mask, one_hot_labels): |
|
|
148 |
phen_gene_sims = phen_gene_sims * gene_mask |
|
|
149 |
padded_phen_gene_sims = phen_gene_sims + (~gene_mask * -100000) # we want to rank the padded values last |
|
|
150 |
gene_ranks = torch.tensor(np.apply_along_axis(lambda row: rankdata(row * -1, method='average'), axis=1, arr=padded_phen_gene_sims.detach().cpu().numpy())) |
|
|
151 |
if one_hot_labels is None: correct_gene_ranks = None |
|
|
152 |
else: |
|
|
153 |
gene_ranks = gene_ranks.to(one_hot_labels.device) |
|
|
154 |
correct_gene_ranks = gene_ranks[one_hot_labels == 1] |
|
|
155 |
return correct_gene_ranks, padded_phen_gene_sims |
|
|
156 |
|
|
|
157 |
def calc_loss(self, sims, mask, one_hot_labels): |
|
|
158 |
return self.loss(sims, mask, one_hot_labels) |
|
|
159 |
|
|
|
160 |
|
|
|
161 |
|
|
|
162 |
|