a b/shepherd/task_heads/gp_aligner.py
1
2
import pytorch_lightning as pl
3
from pytorch_lightning.loggers import WandbLogger
4
import wandb
5
6
from torch import nn
7
import torch
8
import torch.nn.functional as F
9
from torch.nn import TransformerEncoderLayer
10
11
import numpy as np
12
from scipy.stats import rankdata
13
14
from allennlp.modules.attention import CosineAttention, BilinearAttention, AdditiveAttention, DotProductAttention
15
16
17
from utils.loss_utils import MultisimilarityCriterion, _construct_labels, unique, _construct_disease_labels
18
from utils.train_utils import masked_mean, masked_softmax, weighted_sum, plot_degree_vs_attention, mean_reciprocal_rank, top_k_acc
19
20
class GPAligner(pl.LightningModule):
21
22
    def __init__(self, hparams, embed_dim):
23
        super().__init__()
24
        self.hyperparameters = hparams
25
        print('GPAligner embedding dimension: ', embed_dim)
26
27
        # attention for collapsing set of phenotype embeddings
28
        self.attn_vector = nn.Parameter(torch.zeros((1, embed_dim), dtype=torch.float), requires_grad=True)   
29
        nn.init.xavier_uniform_(self.attn_vector)
30
        
31
        if self.hyperparameters['attention_type'] == 'bilinear':
32
            self.attention = BilinearAttention(embed_dim, embed_dim)
33
        elif self.hyperparameters['attention_type'] == 'additive':
34
            self.attention = AdditiveAttention(embed_dim, embed_dim)
35
        elif self.hyperparameters['attention_type'] == 'dotpdt':
36
            self.attention = DotProductAttention()
37
        
38
        if self.hyperparameters['decoder_type'] == "dotpdt": 
39
            self.decoder = DotProductAttention(normalize=False)
40
        elif self.hyperparameters['decoder_type'] == "bilinear": 
41
            self.decoder = BilinearAttention(embed_dim, embed_dim, activation=torch.tanh, normalize=False)
42
        else:
43
            raise NotImplementedError
44
45
        # projection layers
46
        self.phen_project = nn.Linear(embed_dim, embed_dim) 
47
        self.gene_project = nn.Linear(embed_dim, embed_dim)
48
        self.phen_project2 = nn.Linear(embed_dim, embed_dim)
49
        self.gene_project2 = nn.Linear(embed_dim, embed_dim)
50
51
        # optional disease projection layer
52
        if self.hyperparameters['use_diseases']:
53
            self.disease_project = nn.Linear(embed_dim, embed_dim)
54
            self.disease_project2 = nn.Linear(embed_dim, embed_dim)
55
56
        self.leaky_relu = nn.LeakyReLU(hparams['leaky_relu'])
57
58
        self.loss = MultisimilarityCriterion(hparams['pos_weight'], hparams['neg_weight'], 
59
                                hparams['margin'], hparams['thresh'], 
60
                                embed_dim, hparams['only_hard_distractors']) 
61
62
        if 'n_transformer_layers' in hparams and hparams['n_transformer_layers'] > 0:
63
            encoder_layer = TransformerEncoderLayer(d_model=embed_dim, nhead=hparams['n_transformer_heads'])
64
            self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=hparams['n_transformer_layers'])
65
66
67
    def forward(self, phenotype_embeddings, candidate_gene_embeddings, disease_embeddings=None, phenotype_mask=None, gene_mask=None, disease_mask=None): 
68
        assert phenotype_mask != None
69
        assert gene_mask != None
70
        if self.hyperparameters['use_diseases']: assert disease_mask != None
71
72
        if 'n_transformer_layers' in self.hyperparameters and self.hyperparameters['n_transformer_layers'] > 0:
73
            phenotype_embeddings = self.transformer_encoder(phenotype_embeddings.transpose(0, 1), src_key_padding_mask=~phenotype_mask).transpose(0, 1)
74
75
        # attention weighted average of phenotype embeddings
76
        batched_attn = self.attn_vector.repeat(phenotype_embeddings.shape[0],1)
77
        attn_weights = self.attention(batched_attn, phenotype_embeddings, phenotype_mask)
78
        phenotype_embedding = weighted_sum(phenotype_embeddings, attn_weights)
79
80
        # project embeddings
81
        phenotype_embedding = self.phen_project2(self.leaky_relu(self.phen_project(phenotype_embedding)))
82
        candidate_gene_embeddings = self.gene_project2(self.leaky_relu(self.gene_project(candidate_gene_embeddings)))
83
        
84
85
        if self.hyperparameters['use_diseases']: 
86
            disease_embeddings = self.disease_project2(self.leaky_relu(self.disease_project(disease_embeddings)))
87
        else:
88
            disease_embeddings = None
89
            disease_mask = None
90
91
        return phenotype_embedding, candidate_gene_embeddings, disease_embeddings, gene_mask, phenotype_mask, disease_mask, attn_weights 
92
93
94
    def _calc_similarity(self, phenotype_embeddings, candidate_gene_embeddings, disease_embeddings, batch_cand_gene_nid,  batch_corr_gene_nid, batch_disease_nid, one_hot_labels, gene_mask, phenotype_mask, disease_mask, use_candidate_list, cand_gene_to_phenotypes_spl, alpha): 
95
        # Normalize Embeddings (within each individual patient)
96
        phenotype_embeddings = F.normalize(phenotype_embeddings, p=2, dim=1) 
97
        batch_sz = phenotype_embeddings.shape[0]
98
        if disease_embeddings != None: disease_embeddings = F.normalize(disease_embeddings.squeeze(), p=2, dim=1) 
99
        if candidate_gene_embeddings != None:
100
            batch_sz, n_cand_genes, embed_dim = candidate_gene_embeddings.shape
101
            candidate_gene_embeddings = F.normalize(candidate_gene_embeddings.view(batch_sz*n_cand_genes,-1), p=2, dim=1).view(batch_sz, n_cand_genes, embed_dim)
102
103
        # Only use each patient's candidate genes/diseases
104
        if self.hyperparameters['only_hard_distractors'] or use_candidate_list:
105
            if disease_embeddings == None: # only use genes
106
                mask = gene_mask
107
                one_hot_labels = one_hot_labels
108
                raw_sims = self.decoder(phenotype_embeddings, candidate_gene_embeddings)
109
                if cand_gene_to_phenotypes_spl != None:
110
                    sims = alpha * raw_sims + (1 - alpha) * cand_gene_to_phenotypes_spl
111
                else: sims = raw_sims
112
            
113
            elif candidate_gene_embeddings == None: # only use diseases
114
                raise NotImplementedError
115
            
116
            else:
117
                raise NotImplementedError
118
        
119
        # Otherwise, use entire batch as candidate genes/diseases
120
        else:
121
            if disease_embeddings == None: #only use genes
122
                candidate_gene_idx, candidate_gene_embeddings, one_hot_labels = _construct_labels(candidate_gene_embeddings, batch_cand_gene_nid, batch_corr_gene_nid, gene_mask)
123
                raw_sims = self.decoder(phenotype_embeddings, candidate_gene_embeddings.unsqueeze(0).repeat(batch_sz,1,1))
124
                if cand_gene_to_phenotypes_spl != None:
125
                    sims = alpha * raw_sims + (1 - alpha) * cand_gene_to_phenotypes_spl
126
                else: sims = raw_sims
127
                mask = None
128
                
129
            elif candidate_gene_embeddings == None: #only use diseases
130
                candidate_embeddings, one_hot_labels = _construct_disease_labels(disease_embeddings, batch_disease_nid)
131
                raw_sims = self.decoder(phenotype_embeddings, candidate_embeddings.unsqueeze(0).repeat(batch_sz,1,1))
132
                if batch_disease_nid.shape[1] > 1:
133
                    raw_sims = raw_sims[batch_disease_nid[:,0].squeeze() != 0] # remove rows where the patient doesn't have 
134
                    one_hot_labels = one_hot_labels[batch_disease_nid[:,0].squeeze() != 0]
135
                else:
136
                    raw_sims = raw_sims[batch_disease_nid.squeeze() != 0] # remove rows where the patient doesn't have 
137
                    one_hot_labels = one_hot_labels[batch_disease_nid.squeeze() != 0]
138
                sims = raw_sims
139
                mask = None
140
141
            else: # use genes + diseases
142
                raise NotImplementedError
143
144
        return sims, raw_sims, mask, one_hot_labels
145
146
147
    def _rank_genes(self, phen_gene_sims, gene_mask, one_hot_labels):
148
        phen_gene_sims = phen_gene_sims * gene_mask
149
        padded_phen_gene_sims = phen_gene_sims + (~gene_mask * -100000) # we want to rank the padded values last
150
        gene_ranks = torch.tensor(np.apply_along_axis(lambda row: rankdata(row * -1, method='average'), axis=1, arr=padded_phen_gene_sims.detach().cpu().numpy()))
151
        if one_hot_labels is None: correct_gene_ranks = None
152
        else: 
153
            gene_ranks = gene_ranks.to(one_hot_labels.device)
154
            correct_gene_ranks = gene_ranks[one_hot_labels == 1]
155
        return correct_gene_ranks, padded_phen_gene_sims
156
157
    def calc_loss(self, sims, mask, one_hot_labels):
158
        return self.loss(sims, mask, one_hot_labels)
159
160
161
162