Switch to unified view

a b/shepherd/utils/train_utils.py
1
import torch, torch.nn as nn, torch.nn.functional as F, numpy as np
2
from torch.nn.parameter import Parameter
3
from allennlp.modules.attention.attention import Attention
4
from allennlp.nn import Activation
5
6
import umap
7
import pandas as pd
8
9
# Matplotlib
10
from matplotlib import pyplot as plt
11
from matplotlib.backends.backend_pdf import PdfPages
12
import plotly.express as px
13
14
####################################
15
# Evaluation utils
16
17
def mean_reciprocal_rank(correct_gene_ranks):
18
    return torch.mean(1/correct_gene_ranks)
19
20
def average_rank(correct_gene_ranks):
21
    return torch.mean(correct_gene_ranks)
22
23
def top_k_acc(correct_gene_ranks, k):
24
    return torch.sum(correct_gene_ranks <= k) / len(correct_gene_ranks)
25
26
27
###########################################
28
29
# below functions from AllenNLP
30
31
def masked_mean(
32
    vector: torch.Tensor, mask: torch.BoolTensor, dim: int, keepdim: bool = False
33
    ) -> torch.Tensor:
34
    """
35
    To calculate mean along certain dimensions on masked values
36
    # Parameters
37
    vector : `torch.Tensor`
38
        The vector to calculate mean.
39
    mask : `torch.BoolTensor`
40
        The mask of the vector. It must be broadcastable with vector.
41
    dim : `int`
42
        The dimension to calculate mean
43
    keepdim : `bool`
44
        Whether to keep dimension
45
    # Returns
46
    `torch.Tensor`
47
        A `torch.Tensor` of including the mean values.
48
    """
49
    replaced_vector = vector.masked_fill(~mask, 0.0)
50
51
    value_sum = torch.sum(replaced_vector, dim=dim, keepdim=keepdim)
52
    value_count = torch.sum(mask, dim=dim, keepdim=keepdim)
53
    return value_sum / value_count.float().clamp(min=tiny_value_of_dtype(torch.float))
54
55
56
def tiny_value_of_dtype(dtype: torch.dtype):
57
    """
58
    Returns a moderately tiny value for a given PyTorch data type that is used to avoid numerical
59
    issues such as division by zero.
60
    This is different from `info_value_of_dtype(dtype).tiny` because it causes some NaN bugs.
61
    Only supports floating point dtypes.
62
    """
63
    if not dtype.is_floating_point:
64
        raise TypeError("Only supports floating point dtypes.")
65
    if dtype == torch.float or dtype == torch.double:
66
        return 1e-13
67
    elif dtype == torch.half:
68
        return 1e-4
69
    else:
70
        raise TypeError("Does not support dtype " + str(dtype))
71
72
73
def masked_softmax(
74
    vector: torch.Tensor,
75
    mask: torch.BoolTensor,
76
    dim: int = -1,
77
    memory_efficient: bool = False,
78
    ) -> torch.Tensor:
79
    """
80
    `torch.nn.functional.softmax(vector)` does not work if some elements of `vector` should be
81
    masked.  This performs a softmax on just the non-masked portions of `vector`.  Passing
82
    `None` in for the mask is also acceptable; you'll just get a regular softmax.
83
    `vector` can have an arbitrary number of dimensions; the only requirement is that `mask` is
84
    broadcastable to `vector's` shape.  If `mask` has fewer dimensions than `vector`, we will
85
    unsqueeze on dimension 1 until they match.  If you need a different unsqueezing of your mask,
86
    do it yourself before passing the mask into this function.
87
    If `memory_efficient` is set to true, we will simply use a very large negative number for those
88
    masked positions so that the probabilities of those positions would be approximately 0.
89
    This is not accurate in math, but works for most cases and consumes less memory.
90
    In the case that the input vector is completely masked and `memory_efficient` is false, this function
91
    returns an array of `0.0`. This behavior may cause `NaN` if this is used as the last layer of
92
    a model that uses categorical cross-entropy loss. Instead, if `memory_efficient` is true, this function
93
    will treat every element as equal, and do softmax over equal numbers.
94
    """
95
    if mask is None:
96
        result = torch.nn.functional.softmax(vector, dim=dim)
97
    else:
98
        while mask.dim() < vector.dim():
99
            mask = mask.unsqueeze(1)
100
        if not memory_efficient:
101
            # To limit numerical errors from large vector elements outside the mask, we zero these out.
102
            result = torch.nn.functional.softmax(vector * mask, dim=dim)
103
            result = result * mask
104
            result = result / (
105
                result.sum(dim=dim, keepdim=True) + tiny_value_of_dtype(result.dtype)
106
            )
107
        else:
108
            masked_vector = vector.masked_fill(~mask, min_value_of_dtype(vector.dtype))
109
            result = torch.nn.functional.softmax(masked_vector, dim=dim)
110
    return result
111
112
113
def weighted_sum(matrix: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
114
    """
115
    Takes a matrix of vectors and a set of weights over the rows in the matrix (which we call an
116
    "attention" vector), and returns a weighted sum of the rows in the matrix.  This is the typical
117
    computation performed after an attention mechanism.
118
    Note that while we call this a "matrix" of vectors and an attention "vector", we also handle
119
    higher-order tensors.  We always sum over the second-to-last dimension of the "matrix", and we
120
    assume that all dimensions in the "matrix" prior to the last dimension are matched in the
121
    "vector".  Non-matched dimensions in the "vector" must be `directly after the batch dimension`.
122
    For example, say I have a "matrix" with dimensions `(batch_size, num_queries, num_words,
123
    embedding_dim)`.  The attention "vector" then must have at least those dimensions, and could
124
    have more. Both:
125
        - `(batch_size, num_queries, num_words)` (distribution over words for each query)
126
        - `(batch_size, num_documents, num_queries, num_words)` (distribution over words in a
127
          query for each document)
128
    are valid input "vectors", producing tensors of shape:
129
    `(batch_size, num_queries, embedding_dim)` and
130
    `(batch_size, num_documents, num_queries, embedding_dim)` respectively.
131
    """
132
    # We'll special-case a few settings here, where there are efficient (but poorly-named)
133
    # operations in pytorch that already do the computation we need.
134
    if attention.dim() == 2 and matrix.dim() == 3:
135
        return attention.unsqueeze(1).bmm(matrix).squeeze(1)
136
    if attention.dim() == 3 and matrix.dim() == 3:
137
        return attention.bmm(matrix)
138
    if matrix.dim() - 1 < attention.dim():
139
        expanded_size = list(matrix.size())
140
        for i in range(attention.dim() - matrix.dim() + 1):
141
            matrix = matrix.unsqueeze(1)
142
            expanded_size.insert(i + 1, attention.size(i + 1))
143
        matrix = matrix.expand(*expanded_size)
144
    intermediate = attention.unsqueeze(-1).expand_as(matrix) * matrix
145
    return intermediate.sum(dim=-2)
146
147
148
149
###########################################
150
# wandb plots
151
152
def mrr_vs_percent_overlap(correct_gene_rank, percent_overlap_train):
153
    df = pd.DataFrame({"Percent of Phenotypes Found in Single Train Patient": percent_overlap_train.squeeze(), "Rank of Correct Gene": correct_gene_rank})
154
    fig = px.scatter(df, x = "Percent of Phenotypes Found in Single Train Patient", y = "Rank of Correct Gene")
155
    return fig
156
157
def plot_softmax(softmax):
158
    softmax = [s.detach().item() for s in softmax]
159
    df = pd.DataFrame({'softmax':softmax})
160
    fig = px.histogram(df, x="softmax")
161
    return fig
162
163
def fit_umap(embed, labels={}, n_neighbors=15, min_dist=0.1, n_components=2, metric='euclidean', random_state=3):
164
    embed = embed.detach().cpu()
165
    mapping = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, metric=metric, random_state=random_state).fit(embed)
166
    embedding = mapping.transform(embed)
167
168
    data = {"x": embedding[:, 0], "y": embedding[:, 1]}
169
    if len(labels) > 0: data.update(labels)
170
    df = pd.DataFrame(data)
171
    if len(labels) > 0: fig = px.scatter(df, x = "x", y = "y", color = "Node Type", hover_data=list(labels.keys()))
172
    else: fig = px.scatter(df, x = "x", y = "y")
173
    return fig
174
175
def plot_degree_vs_attention(attn_weights, phenotype_names, single_patient=False):
176
    if single_patient:
177
        phenotype_names = phenotype_names[0]
178
        attn_weights = attn_weights[0]
179
        data = [(w.item(), p_name[1]) for w, p_name in zip(attn_weights, phenotype_names)]
180
    else:
181
        data = [(attn_weights[i], phenotype_names[i]) for i in range(len(phenotype_names))]
182
        data = [(w.item(), p_name[1]) for attn_w, phen_name in data for w, p_name in zip(attn_w, phen_name)] 
183
    attn_weights, degrees = zip(*data)
184
    df = pd.DataFrame({"Node Degree": degrees, "Attention Weight": attn_weights})
185
    fig = px.scatter(df, x = "Node Degree", y = "Attention Weight")
186
    return fig
187
188
def plot_nhops_to_gene_vs_attention(attn_weights, phenotype_names, nhops_g_p, single_patient=False):
189
    if single_patient:
190
        attn_weights = attn_weights[0]
191
        nhops_g_p = nhops_g_p[0]
192
        phenotype_names = phenotype_names[0]
193
        data = [(w.item(), hops) for w, hops in zip(attn_weights, nhops_g_p)]
194
    else:
195
        data = [(attn_weights[i], nhops_g_p[i]) for i in range(len(phenotype_names))]
196
        data = [(w.item(), hop) for attn_w, nhops in data for w, hop in zip(attn_w, nhops)] 
197
    attn_weights, n_hops_g_p = zip(*data)
198
    df = pd.DataFrame({"Number of Hops to Gene in KG": n_hops_g_p, "Attention Weight": attn_weights})
199
    fig = px.scatter(df, x = "Number of Hops to Gene in KG", y = "Attention Weight")
200
    return fig
201
202
def plot_gene_rank_vs_x_intrain(corr_gene_ranks, in_train):
203
    if sum(in_train == 1) == 0: 
204
        values_in_train = -1
205
        err_in_train = 0
206
    else: 
207
        values_in_train = torch.mean(corr_gene_ranks[in_train == 1].float())
208
        err_in_train = torch.std(corr_gene_ranks[in_train == 1].float())
209
    if sum(in_train == 0) == 0: 
210
        values_not_in_train = -1
211
        err_not_in_train = 0
212
    else: 
213
        values_not_in_train = torch.mean(corr_gene_ranks[in_train == 0].float())
214
        err_not_in_train = torch.std(corr_gene_ranks[in_train == 0].float())
215
    df = pd.DataFrame({"Average Rank of Correct Gene": [values_in_train, values_not_in_train], "In Train or Not": ["True", "False"], "Error Bars": [err_in_train, err_not_in_train]})
216
    fig = px.bar(df, x = "In Train or Not", y = "Average Rank of Correct Gene", error_y = "Error Bars")
217
    return fig
218
219
def plot_gene_rank_vs_numtrain(corr_gene_ranks, correct_gene_nid, train_corr_gene_nid):
220
    gene_counts = [train_corr_gene_nid[g] if g in train_corr_gene_nid else 0 for g in list(correct_gene_nid.numpy())]
221
    data = {"Rank of Correct Gene": corr_gene_ranks.cpu(), "Number of Times Seen": gene_counts, "Gene ID": correct_gene_nid}
222
    df = pd.DataFrame(data)
223
    fig = px.scatter(df, x = "Number of Times Seen", y = "Rank of Correct Gene", hover_data=list(data.keys()))
224
    return fig, gene_counts
225
226
227
def plot_gene_rank_vs_trainset(corr_gene_ranks, correct_gene_nid, gene_counts): # train_corr_gene_nid has dimension num_gene x num_sets (corr, cand, sparse, target)
228
    trainset_labels = ["-".join([str(int(l)) for l in list((gene_counts[i, :] > 0).numpy())]) for i in range(gene_counts.shape[0])]
229
    gene_ranks_dict = {}
230
    for l, r in zip(trainset_labels, corr_gene_ranks): # (corr, cand, sparse, target)
231
        l_full = []
232
        if l.split("-")[0] == "1": l_full.append("Corr")
233
        if l.split("-")[1] == "1": l_full.append("Cand")
234
        if l.split("-")[2] == "1": l_full.append("Sparse")
235
        if l.split("-")[3] == "1": l_full.append("Target")
236
        if len(l_full) == 0: l_full = "None"
237
        else: l_full = "-".join(l_full)
238
        if l_full not in gene_ranks_dict: gene_ranks_dict[l_full] = []
239
        gene_ranks_dict[l_full].append(int(r))
240
    avg_gene_ranks = {l: np.mean(r) for l, r in gene_ranks_dict.items()}
241
    df = pd.DataFrame({"Train Set": list(avg_gene_ranks.keys()), "Average Rank of Correct Gene": list(avg_gene_ranks.values())})
242
    fig = px.bar(df, x = "Train Set", y = "Average Rank of Correct Gene")
243
    return fig
244
245
246
def plot_gene_rank_vs_fraction_phenotype(corr_gene_ranks, frac_p):
247
    df = pd.DataFrame({"Rank of Correct Gene": corr_gene_ranks, "Fraction of Phenotypes": frac_p})
248
    df = df[df["Fraction of Phenotypes"] > -1]
249
    fig = px.scatter(df, x = "Fraction of Phenotypes", y = "Rank of Correct Gene")
250
    return fig
251
252
253
def plot_gene_rank_vs_hops(corr_gene_ranks, n_hops):
254
    mean_hops = []
255
    min_hops = []
256
    for hops in n_hops:
257
        if type(hops) == list: # gene to phenotype n_hops
258
            mean_hops.append(np.mean(hops))
259
            min_hops.append(np.min(hops))
260
        else: # phenotype to phenotype n_hops
261
            filtered_hops = torch.cat([hops[i][hops[i] > 0] for i in range(hops.shape[0])]).float()
262
            mean_hops.append(torch.mean(filtered_hops).item())
263
            min_hops.append(torch.min(filtered_hops).item())
264
        
265
    df = pd.DataFrame({"Rank of Correct Gene": corr_gene_ranks, "Mean Number of Hops": mean_hops})
266
    fig_mean = px.scatter(df, x = "Mean Number of Hops", y = "Rank of Correct Gene")
267
    df = pd.DataFrame({"Rank of Correct Gene": corr_gene_ranks, "Min Number of Hops": min_hops})
268
    fig_min = px.scatter(df, x = "Min Number of Hops", y = "Rank of Correct Gene")
269
    return fig_mean, fig_min