|
a |
|
b/shepherd/utils/train_utils.py |
|
|
1 |
import torch, torch.nn as nn, torch.nn.functional as F, numpy as np |
|
|
2 |
from torch.nn.parameter import Parameter |
|
|
3 |
from allennlp.modules.attention.attention import Attention |
|
|
4 |
from allennlp.nn import Activation |
|
|
5 |
|
|
|
6 |
import umap |
|
|
7 |
import pandas as pd |
|
|
8 |
|
|
|
9 |
# Matplotlib |
|
|
10 |
from matplotlib import pyplot as plt |
|
|
11 |
from matplotlib.backends.backend_pdf import PdfPages |
|
|
12 |
import plotly.express as px |
|
|
13 |
|
|
|
14 |
#################################### |
|
|
15 |
# Evaluation utils |
|
|
16 |
|
|
|
17 |
def mean_reciprocal_rank(correct_gene_ranks): |
|
|
18 |
return torch.mean(1/correct_gene_ranks) |
|
|
19 |
|
|
|
20 |
def average_rank(correct_gene_ranks): |
|
|
21 |
return torch.mean(correct_gene_ranks) |
|
|
22 |
|
|
|
23 |
def top_k_acc(correct_gene_ranks, k): |
|
|
24 |
return torch.sum(correct_gene_ranks <= k) / len(correct_gene_ranks) |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
########################################### |
|
|
28 |
|
|
|
29 |
# below functions from AllenNLP |
|
|
30 |
|
|
|
31 |
def masked_mean( |
|
|
32 |
vector: torch.Tensor, mask: torch.BoolTensor, dim: int, keepdim: bool = False |
|
|
33 |
) -> torch.Tensor: |
|
|
34 |
""" |
|
|
35 |
To calculate mean along certain dimensions on masked values |
|
|
36 |
# Parameters |
|
|
37 |
vector : `torch.Tensor` |
|
|
38 |
The vector to calculate mean. |
|
|
39 |
mask : `torch.BoolTensor` |
|
|
40 |
The mask of the vector. It must be broadcastable with vector. |
|
|
41 |
dim : `int` |
|
|
42 |
The dimension to calculate mean |
|
|
43 |
keepdim : `bool` |
|
|
44 |
Whether to keep dimension |
|
|
45 |
# Returns |
|
|
46 |
`torch.Tensor` |
|
|
47 |
A `torch.Tensor` of including the mean values. |
|
|
48 |
""" |
|
|
49 |
replaced_vector = vector.masked_fill(~mask, 0.0) |
|
|
50 |
|
|
|
51 |
value_sum = torch.sum(replaced_vector, dim=dim, keepdim=keepdim) |
|
|
52 |
value_count = torch.sum(mask, dim=dim, keepdim=keepdim) |
|
|
53 |
return value_sum / value_count.float().clamp(min=tiny_value_of_dtype(torch.float)) |
|
|
54 |
|
|
|
55 |
|
|
|
56 |
def tiny_value_of_dtype(dtype: torch.dtype): |
|
|
57 |
""" |
|
|
58 |
Returns a moderately tiny value for a given PyTorch data type that is used to avoid numerical |
|
|
59 |
issues such as division by zero. |
|
|
60 |
This is different from `info_value_of_dtype(dtype).tiny` because it causes some NaN bugs. |
|
|
61 |
Only supports floating point dtypes. |
|
|
62 |
""" |
|
|
63 |
if not dtype.is_floating_point: |
|
|
64 |
raise TypeError("Only supports floating point dtypes.") |
|
|
65 |
if dtype == torch.float or dtype == torch.double: |
|
|
66 |
return 1e-13 |
|
|
67 |
elif dtype == torch.half: |
|
|
68 |
return 1e-4 |
|
|
69 |
else: |
|
|
70 |
raise TypeError("Does not support dtype " + str(dtype)) |
|
|
71 |
|
|
|
72 |
|
|
|
73 |
def masked_softmax( |
|
|
74 |
vector: torch.Tensor, |
|
|
75 |
mask: torch.BoolTensor, |
|
|
76 |
dim: int = -1, |
|
|
77 |
memory_efficient: bool = False, |
|
|
78 |
) -> torch.Tensor: |
|
|
79 |
""" |
|
|
80 |
`torch.nn.functional.softmax(vector)` does not work if some elements of `vector` should be |
|
|
81 |
masked. This performs a softmax on just the non-masked portions of `vector`. Passing |
|
|
82 |
`None` in for the mask is also acceptable; you'll just get a regular softmax. |
|
|
83 |
`vector` can have an arbitrary number of dimensions; the only requirement is that `mask` is |
|
|
84 |
broadcastable to `vector's` shape. If `mask` has fewer dimensions than `vector`, we will |
|
|
85 |
unsqueeze on dimension 1 until they match. If you need a different unsqueezing of your mask, |
|
|
86 |
do it yourself before passing the mask into this function. |
|
|
87 |
If `memory_efficient` is set to true, we will simply use a very large negative number for those |
|
|
88 |
masked positions so that the probabilities of those positions would be approximately 0. |
|
|
89 |
This is not accurate in math, but works for most cases and consumes less memory. |
|
|
90 |
In the case that the input vector is completely masked and `memory_efficient` is false, this function |
|
|
91 |
returns an array of `0.0`. This behavior may cause `NaN` if this is used as the last layer of |
|
|
92 |
a model that uses categorical cross-entropy loss. Instead, if `memory_efficient` is true, this function |
|
|
93 |
will treat every element as equal, and do softmax over equal numbers. |
|
|
94 |
""" |
|
|
95 |
if mask is None: |
|
|
96 |
result = torch.nn.functional.softmax(vector, dim=dim) |
|
|
97 |
else: |
|
|
98 |
while mask.dim() < vector.dim(): |
|
|
99 |
mask = mask.unsqueeze(1) |
|
|
100 |
if not memory_efficient: |
|
|
101 |
# To limit numerical errors from large vector elements outside the mask, we zero these out. |
|
|
102 |
result = torch.nn.functional.softmax(vector * mask, dim=dim) |
|
|
103 |
result = result * mask |
|
|
104 |
result = result / ( |
|
|
105 |
result.sum(dim=dim, keepdim=True) + tiny_value_of_dtype(result.dtype) |
|
|
106 |
) |
|
|
107 |
else: |
|
|
108 |
masked_vector = vector.masked_fill(~mask, min_value_of_dtype(vector.dtype)) |
|
|
109 |
result = torch.nn.functional.softmax(masked_vector, dim=dim) |
|
|
110 |
return result |
|
|
111 |
|
|
|
112 |
|
|
|
113 |
def weighted_sum(matrix: torch.Tensor, attention: torch.Tensor) -> torch.Tensor: |
|
|
114 |
""" |
|
|
115 |
Takes a matrix of vectors and a set of weights over the rows in the matrix (which we call an |
|
|
116 |
"attention" vector), and returns a weighted sum of the rows in the matrix. This is the typical |
|
|
117 |
computation performed after an attention mechanism. |
|
|
118 |
Note that while we call this a "matrix" of vectors and an attention "vector", we also handle |
|
|
119 |
higher-order tensors. We always sum over the second-to-last dimension of the "matrix", and we |
|
|
120 |
assume that all dimensions in the "matrix" prior to the last dimension are matched in the |
|
|
121 |
"vector". Non-matched dimensions in the "vector" must be `directly after the batch dimension`. |
|
|
122 |
For example, say I have a "matrix" with dimensions `(batch_size, num_queries, num_words, |
|
|
123 |
embedding_dim)`. The attention "vector" then must have at least those dimensions, and could |
|
|
124 |
have more. Both: |
|
|
125 |
- `(batch_size, num_queries, num_words)` (distribution over words for each query) |
|
|
126 |
- `(batch_size, num_documents, num_queries, num_words)` (distribution over words in a |
|
|
127 |
query for each document) |
|
|
128 |
are valid input "vectors", producing tensors of shape: |
|
|
129 |
`(batch_size, num_queries, embedding_dim)` and |
|
|
130 |
`(batch_size, num_documents, num_queries, embedding_dim)` respectively. |
|
|
131 |
""" |
|
|
132 |
# We'll special-case a few settings here, where there are efficient (but poorly-named) |
|
|
133 |
# operations in pytorch that already do the computation we need. |
|
|
134 |
if attention.dim() == 2 and matrix.dim() == 3: |
|
|
135 |
return attention.unsqueeze(1).bmm(matrix).squeeze(1) |
|
|
136 |
if attention.dim() == 3 and matrix.dim() == 3: |
|
|
137 |
return attention.bmm(matrix) |
|
|
138 |
if matrix.dim() - 1 < attention.dim(): |
|
|
139 |
expanded_size = list(matrix.size()) |
|
|
140 |
for i in range(attention.dim() - matrix.dim() + 1): |
|
|
141 |
matrix = matrix.unsqueeze(1) |
|
|
142 |
expanded_size.insert(i + 1, attention.size(i + 1)) |
|
|
143 |
matrix = matrix.expand(*expanded_size) |
|
|
144 |
intermediate = attention.unsqueeze(-1).expand_as(matrix) * matrix |
|
|
145 |
return intermediate.sum(dim=-2) |
|
|
146 |
|
|
|
147 |
|
|
|
148 |
|
|
|
149 |
########################################### |
|
|
150 |
# wandb plots |
|
|
151 |
|
|
|
152 |
def mrr_vs_percent_overlap(correct_gene_rank, percent_overlap_train): |
|
|
153 |
df = pd.DataFrame({"Percent of Phenotypes Found in Single Train Patient": percent_overlap_train.squeeze(), "Rank of Correct Gene": correct_gene_rank}) |
|
|
154 |
fig = px.scatter(df, x = "Percent of Phenotypes Found in Single Train Patient", y = "Rank of Correct Gene") |
|
|
155 |
return fig |
|
|
156 |
|
|
|
157 |
def plot_softmax(softmax): |
|
|
158 |
softmax = [s.detach().item() for s in softmax] |
|
|
159 |
df = pd.DataFrame({'softmax':softmax}) |
|
|
160 |
fig = px.histogram(df, x="softmax") |
|
|
161 |
return fig |
|
|
162 |
|
|
|
163 |
def fit_umap(embed, labels={}, n_neighbors=15, min_dist=0.1, n_components=2, metric='euclidean', random_state=3): |
|
|
164 |
embed = embed.detach().cpu() |
|
|
165 |
mapping = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, metric=metric, random_state=random_state).fit(embed) |
|
|
166 |
embedding = mapping.transform(embed) |
|
|
167 |
|
|
|
168 |
data = {"x": embedding[:, 0], "y": embedding[:, 1]} |
|
|
169 |
if len(labels) > 0: data.update(labels) |
|
|
170 |
df = pd.DataFrame(data) |
|
|
171 |
if len(labels) > 0: fig = px.scatter(df, x = "x", y = "y", color = "Node Type", hover_data=list(labels.keys())) |
|
|
172 |
else: fig = px.scatter(df, x = "x", y = "y") |
|
|
173 |
return fig |
|
|
174 |
|
|
|
175 |
def plot_degree_vs_attention(attn_weights, phenotype_names, single_patient=False): |
|
|
176 |
if single_patient: |
|
|
177 |
phenotype_names = phenotype_names[0] |
|
|
178 |
attn_weights = attn_weights[0] |
|
|
179 |
data = [(w.item(), p_name[1]) for w, p_name in zip(attn_weights, phenotype_names)] |
|
|
180 |
else: |
|
|
181 |
data = [(attn_weights[i], phenotype_names[i]) for i in range(len(phenotype_names))] |
|
|
182 |
data = [(w.item(), p_name[1]) for attn_w, phen_name in data for w, p_name in zip(attn_w, phen_name)] |
|
|
183 |
attn_weights, degrees = zip(*data) |
|
|
184 |
df = pd.DataFrame({"Node Degree": degrees, "Attention Weight": attn_weights}) |
|
|
185 |
fig = px.scatter(df, x = "Node Degree", y = "Attention Weight") |
|
|
186 |
return fig |
|
|
187 |
|
|
|
188 |
def plot_nhops_to_gene_vs_attention(attn_weights, phenotype_names, nhops_g_p, single_patient=False): |
|
|
189 |
if single_patient: |
|
|
190 |
attn_weights = attn_weights[0] |
|
|
191 |
nhops_g_p = nhops_g_p[0] |
|
|
192 |
phenotype_names = phenotype_names[0] |
|
|
193 |
data = [(w.item(), hops) for w, hops in zip(attn_weights, nhops_g_p)] |
|
|
194 |
else: |
|
|
195 |
data = [(attn_weights[i], nhops_g_p[i]) for i in range(len(phenotype_names))] |
|
|
196 |
data = [(w.item(), hop) for attn_w, nhops in data for w, hop in zip(attn_w, nhops)] |
|
|
197 |
attn_weights, n_hops_g_p = zip(*data) |
|
|
198 |
df = pd.DataFrame({"Number of Hops to Gene in KG": n_hops_g_p, "Attention Weight": attn_weights}) |
|
|
199 |
fig = px.scatter(df, x = "Number of Hops to Gene in KG", y = "Attention Weight") |
|
|
200 |
return fig |
|
|
201 |
|
|
|
202 |
def plot_gene_rank_vs_x_intrain(corr_gene_ranks, in_train): |
|
|
203 |
if sum(in_train == 1) == 0: |
|
|
204 |
values_in_train = -1 |
|
|
205 |
err_in_train = 0 |
|
|
206 |
else: |
|
|
207 |
values_in_train = torch.mean(corr_gene_ranks[in_train == 1].float()) |
|
|
208 |
err_in_train = torch.std(corr_gene_ranks[in_train == 1].float()) |
|
|
209 |
if sum(in_train == 0) == 0: |
|
|
210 |
values_not_in_train = -1 |
|
|
211 |
err_not_in_train = 0 |
|
|
212 |
else: |
|
|
213 |
values_not_in_train = torch.mean(corr_gene_ranks[in_train == 0].float()) |
|
|
214 |
err_not_in_train = torch.std(corr_gene_ranks[in_train == 0].float()) |
|
|
215 |
df = pd.DataFrame({"Average Rank of Correct Gene": [values_in_train, values_not_in_train], "In Train or Not": ["True", "False"], "Error Bars": [err_in_train, err_not_in_train]}) |
|
|
216 |
fig = px.bar(df, x = "In Train or Not", y = "Average Rank of Correct Gene", error_y = "Error Bars") |
|
|
217 |
return fig |
|
|
218 |
|
|
|
219 |
def plot_gene_rank_vs_numtrain(corr_gene_ranks, correct_gene_nid, train_corr_gene_nid): |
|
|
220 |
gene_counts = [train_corr_gene_nid[g] if g in train_corr_gene_nid else 0 for g in list(correct_gene_nid.numpy())] |
|
|
221 |
data = {"Rank of Correct Gene": corr_gene_ranks.cpu(), "Number of Times Seen": gene_counts, "Gene ID": correct_gene_nid} |
|
|
222 |
df = pd.DataFrame(data) |
|
|
223 |
fig = px.scatter(df, x = "Number of Times Seen", y = "Rank of Correct Gene", hover_data=list(data.keys())) |
|
|
224 |
return fig, gene_counts |
|
|
225 |
|
|
|
226 |
|
|
|
227 |
def plot_gene_rank_vs_trainset(corr_gene_ranks, correct_gene_nid, gene_counts): # train_corr_gene_nid has dimension num_gene x num_sets (corr, cand, sparse, target) |
|
|
228 |
trainset_labels = ["-".join([str(int(l)) for l in list((gene_counts[i, :] > 0).numpy())]) for i in range(gene_counts.shape[0])] |
|
|
229 |
gene_ranks_dict = {} |
|
|
230 |
for l, r in zip(trainset_labels, corr_gene_ranks): # (corr, cand, sparse, target) |
|
|
231 |
l_full = [] |
|
|
232 |
if l.split("-")[0] == "1": l_full.append("Corr") |
|
|
233 |
if l.split("-")[1] == "1": l_full.append("Cand") |
|
|
234 |
if l.split("-")[2] == "1": l_full.append("Sparse") |
|
|
235 |
if l.split("-")[3] == "1": l_full.append("Target") |
|
|
236 |
if len(l_full) == 0: l_full = "None" |
|
|
237 |
else: l_full = "-".join(l_full) |
|
|
238 |
if l_full not in gene_ranks_dict: gene_ranks_dict[l_full] = [] |
|
|
239 |
gene_ranks_dict[l_full].append(int(r)) |
|
|
240 |
avg_gene_ranks = {l: np.mean(r) for l, r in gene_ranks_dict.items()} |
|
|
241 |
df = pd.DataFrame({"Train Set": list(avg_gene_ranks.keys()), "Average Rank of Correct Gene": list(avg_gene_ranks.values())}) |
|
|
242 |
fig = px.bar(df, x = "Train Set", y = "Average Rank of Correct Gene") |
|
|
243 |
return fig |
|
|
244 |
|
|
|
245 |
|
|
|
246 |
def plot_gene_rank_vs_fraction_phenotype(corr_gene_ranks, frac_p): |
|
|
247 |
df = pd.DataFrame({"Rank of Correct Gene": corr_gene_ranks, "Fraction of Phenotypes": frac_p}) |
|
|
248 |
df = df[df["Fraction of Phenotypes"] > -1] |
|
|
249 |
fig = px.scatter(df, x = "Fraction of Phenotypes", y = "Rank of Correct Gene") |
|
|
250 |
return fig |
|
|
251 |
|
|
|
252 |
|
|
|
253 |
def plot_gene_rank_vs_hops(corr_gene_ranks, n_hops): |
|
|
254 |
mean_hops = [] |
|
|
255 |
min_hops = [] |
|
|
256 |
for hops in n_hops: |
|
|
257 |
if type(hops) == list: # gene to phenotype n_hops |
|
|
258 |
mean_hops.append(np.mean(hops)) |
|
|
259 |
min_hops.append(np.min(hops)) |
|
|
260 |
else: # phenotype to phenotype n_hops |
|
|
261 |
filtered_hops = torch.cat([hops[i][hops[i] > 0] for i in range(hops.shape[0])]).float() |
|
|
262 |
mean_hops.append(torch.mean(filtered_hops).item()) |
|
|
263 |
min_hops.append(torch.min(filtered_hops).item()) |
|
|
264 |
|
|
|
265 |
df = pd.DataFrame({"Rank of Correct Gene": corr_gene_ranks, "Mean Number of Hops": mean_hops}) |
|
|
266 |
fig_mean = px.scatter(df, x = "Mean Number of Hops", y = "Rank of Correct Gene") |
|
|
267 |
df = pd.DataFrame({"Rank of Correct Gene": corr_gene_ranks, "Min Number of Hops": min_hops}) |
|
|
268 |
fig_min = px.scatter(df, x = "Min Number of Hops", y = "Rank of Correct Gene") |
|
|
269 |
return fig_mean, fig_min |