--- a
+++ b/shepherd/utils/train_utils.py
@@ -0,0 +1,269 @@
+import torch, torch.nn as nn, torch.nn.functional as F, numpy as np
+from torch.nn.parameter import Parameter
+from allennlp.modules.attention.attention import Attention
+from allennlp.nn import Activation
+
+import umap
+import pandas as pd
+
+# Matplotlib
+from matplotlib import pyplot as plt
+from matplotlib.backends.backend_pdf import PdfPages
+import plotly.express as px
+
+####################################
+# Evaluation utils
+
+def mean_reciprocal_rank(correct_gene_ranks):
+    return torch.mean(1/correct_gene_ranks)
+
+def average_rank(correct_gene_ranks):
+    return torch.mean(correct_gene_ranks)
+
+def top_k_acc(correct_gene_ranks, k):
+    return torch.sum(correct_gene_ranks <= k) / len(correct_gene_ranks)
+
+
+###########################################
+
+# below functions from AllenNLP
+
+def masked_mean(
+    vector: torch.Tensor, mask: torch.BoolTensor, dim: int, keepdim: bool = False
+    ) -> torch.Tensor:
+    """
+    To calculate mean along certain dimensions on masked values
+    # Parameters
+    vector : `torch.Tensor`
+        The vector to calculate mean.
+    mask : `torch.BoolTensor`
+        The mask of the vector. It must be broadcastable with vector.
+    dim : `int`
+        The dimension to calculate mean
+    keepdim : `bool`
+        Whether to keep dimension
+    # Returns
+    `torch.Tensor`
+        A `torch.Tensor` of including the mean values.
+    """
+    replaced_vector = vector.masked_fill(~mask, 0.0)
+
+    value_sum = torch.sum(replaced_vector, dim=dim, keepdim=keepdim)
+    value_count = torch.sum(mask, dim=dim, keepdim=keepdim)
+    return value_sum / value_count.float().clamp(min=tiny_value_of_dtype(torch.float))
+
+
+def tiny_value_of_dtype(dtype: torch.dtype):
+    """
+    Returns a moderately tiny value for a given PyTorch data type that is used to avoid numerical
+    issues such as division by zero.
+    This is different from `info_value_of_dtype(dtype).tiny` because it causes some NaN bugs.
+    Only supports floating point dtypes.
+    """
+    if not dtype.is_floating_point:
+        raise TypeError("Only supports floating point dtypes.")
+    if dtype == torch.float or dtype == torch.double:
+        return 1e-13
+    elif dtype == torch.half:
+        return 1e-4
+    else:
+        raise TypeError("Does not support dtype " + str(dtype))
+
+
+def masked_softmax(
+    vector: torch.Tensor,
+    mask: torch.BoolTensor,
+    dim: int = -1,
+    memory_efficient: bool = False,
+    ) -> torch.Tensor:
+    """
+    `torch.nn.functional.softmax(vector)` does not work if some elements of `vector` should be
+    masked.  This performs a softmax on just the non-masked portions of `vector`.  Passing
+    `None` in for the mask is also acceptable; you'll just get a regular softmax.
+    `vector` can have an arbitrary number of dimensions; the only requirement is that `mask` is
+    broadcastable to `vector's` shape.  If `mask` has fewer dimensions than `vector`, we will
+    unsqueeze on dimension 1 until they match.  If you need a different unsqueezing of your mask,
+    do it yourself before passing the mask into this function.
+    If `memory_efficient` is set to true, we will simply use a very large negative number for those
+    masked positions so that the probabilities of those positions would be approximately 0.
+    This is not accurate in math, but works for most cases and consumes less memory.
+    In the case that the input vector is completely masked and `memory_efficient` is false, this function
+    returns an array of `0.0`. This behavior may cause `NaN` if this is used as the last layer of
+    a model that uses categorical cross-entropy loss. Instead, if `memory_efficient` is true, this function
+    will treat every element as equal, and do softmax over equal numbers.
+    """
+    if mask is None:
+        result = torch.nn.functional.softmax(vector, dim=dim)
+    else:
+        while mask.dim() < vector.dim():
+            mask = mask.unsqueeze(1)
+        if not memory_efficient:
+            # To limit numerical errors from large vector elements outside the mask, we zero these out.
+            result = torch.nn.functional.softmax(vector * mask, dim=dim)
+            result = result * mask
+            result = result / (
+                result.sum(dim=dim, keepdim=True) + tiny_value_of_dtype(result.dtype)
+            )
+        else:
+            masked_vector = vector.masked_fill(~mask, min_value_of_dtype(vector.dtype))
+            result = torch.nn.functional.softmax(masked_vector, dim=dim)
+    return result
+
+
+def weighted_sum(matrix: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
+    """
+    Takes a matrix of vectors and a set of weights over the rows in the matrix (which we call an
+    "attention" vector), and returns a weighted sum of the rows in the matrix.  This is the typical
+    computation performed after an attention mechanism.
+    Note that while we call this a "matrix" of vectors and an attention "vector", we also handle
+    higher-order tensors.  We always sum over the second-to-last dimension of the "matrix", and we
+    assume that all dimensions in the "matrix" prior to the last dimension are matched in the
+    "vector".  Non-matched dimensions in the "vector" must be `directly after the batch dimension`.
+    For example, say I have a "matrix" with dimensions `(batch_size, num_queries, num_words,
+    embedding_dim)`.  The attention "vector" then must have at least those dimensions, and could
+    have more. Both:
+        - `(batch_size, num_queries, num_words)` (distribution over words for each query)
+        - `(batch_size, num_documents, num_queries, num_words)` (distribution over words in a
+          query for each document)
+    are valid input "vectors", producing tensors of shape:
+    `(batch_size, num_queries, embedding_dim)` and
+    `(batch_size, num_documents, num_queries, embedding_dim)` respectively.
+    """
+    # We'll special-case a few settings here, where there are efficient (but poorly-named)
+    # operations in pytorch that already do the computation we need.
+    if attention.dim() == 2 and matrix.dim() == 3:
+        return attention.unsqueeze(1).bmm(matrix).squeeze(1)
+    if attention.dim() == 3 and matrix.dim() == 3:
+        return attention.bmm(matrix)
+    if matrix.dim() - 1 < attention.dim():
+        expanded_size = list(matrix.size())
+        for i in range(attention.dim() - matrix.dim() + 1):
+            matrix = matrix.unsqueeze(1)
+            expanded_size.insert(i + 1, attention.size(i + 1))
+        matrix = matrix.expand(*expanded_size)
+    intermediate = attention.unsqueeze(-1).expand_as(matrix) * matrix
+    return intermediate.sum(dim=-2)
+
+
+
+###########################################
+# wandb plots
+
+def mrr_vs_percent_overlap(correct_gene_rank, percent_overlap_train):
+    df = pd.DataFrame({"Percent of Phenotypes Found in Single Train Patient": percent_overlap_train.squeeze(), "Rank of Correct Gene": correct_gene_rank})
+    fig = px.scatter(df, x = "Percent of Phenotypes Found in Single Train Patient", y = "Rank of Correct Gene")
+    return fig
+
+def plot_softmax(softmax):
+    softmax = [s.detach().item() for s in softmax]
+    df = pd.DataFrame({'softmax':softmax})
+    fig = px.histogram(df, x="softmax")
+    return fig
+
+def fit_umap(embed, labels={}, n_neighbors=15, min_dist=0.1, n_components=2, metric='euclidean', random_state=3):
+    embed = embed.detach().cpu()
+    mapping = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, metric=metric, random_state=random_state).fit(embed)
+    embedding = mapping.transform(embed)
+
+    data = {"x": embedding[:, 0], "y": embedding[:, 1]}
+    if len(labels) > 0: data.update(labels)
+    df = pd.DataFrame(data)
+    if len(labels) > 0: fig = px.scatter(df, x = "x", y = "y", color = "Node Type", hover_data=list(labels.keys()))
+    else: fig = px.scatter(df, x = "x", y = "y")
+    return fig
+
+def plot_degree_vs_attention(attn_weights, phenotype_names, single_patient=False):
+    if single_patient:
+        phenotype_names = phenotype_names[0]
+        attn_weights = attn_weights[0]
+        data = [(w.item(), p_name[1]) for w, p_name in zip(attn_weights, phenotype_names)]
+    else:
+        data = [(attn_weights[i], phenotype_names[i]) for i in range(len(phenotype_names))]
+        data = [(w.item(), p_name[1]) for attn_w, phen_name in data for w, p_name in zip(attn_w, phen_name)] 
+    attn_weights, degrees = zip(*data)
+    df = pd.DataFrame({"Node Degree": degrees, "Attention Weight": attn_weights})
+    fig = px.scatter(df, x = "Node Degree", y = "Attention Weight")
+    return fig
+
+def plot_nhops_to_gene_vs_attention(attn_weights, phenotype_names, nhops_g_p, single_patient=False):
+    if single_patient:
+        attn_weights = attn_weights[0]
+        nhops_g_p = nhops_g_p[0]
+        phenotype_names = phenotype_names[0]
+        data = [(w.item(), hops) for w, hops in zip(attn_weights, nhops_g_p)]
+    else:
+        data = [(attn_weights[i], nhops_g_p[i]) for i in range(len(phenotype_names))]
+        data = [(w.item(), hop) for attn_w, nhops in data for w, hop in zip(attn_w, nhops)] 
+    attn_weights, n_hops_g_p = zip(*data)
+    df = pd.DataFrame({"Number of Hops to Gene in KG": n_hops_g_p, "Attention Weight": attn_weights})
+    fig = px.scatter(df, x = "Number of Hops to Gene in KG", y = "Attention Weight")
+    return fig
+
+def plot_gene_rank_vs_x_intrain(corr_gene_ranks, in_train):
+    if sum(in_train == 1) == 0: 
+        values_in_train = -1
+        err_in_train = 0
+    else: 
+        values_in_train = torch.mean(corr_gene_ranks[in_train == 1].float())
+        err_in_train = torch.std(corr_gene_ranks[in_train == 1].float())
+    if sum(in_train == 0) == 0: 
+        values_not_in_train = -1
+        err_not_in_train = 0
+    else: 
+        values_not_in_train = torch.mean(corr_gene_ranks[in_train == 0].float())
+        err_not_in_train = torch.std(corr_gene_ranks[in_train == 0].float())
+    df = pd.DataFrame({"Average Rank of Correct Gene": [values_in_train, values_not_in_train], "In Train or Not": ["True", "False"], "Error Bars": [err_in_train, err_not_in_train]})
+    fig = px.bar(df, x = "In Train or Not", y = "Average Rank of Correct Gene", error_y = "Error Bars")
+    return fig
+
+def plot_gene_rank_vs_numtrain(corr_gene_ranks, correct_gene_nid, train_corr_gene_nid):
+    gene_counts = [train_corr_gene_nid[g] if g in train_corr_gene_nid else 0 for g in list(correct_gene_nid.numpy())]
+    data = {"Rank of Correct Gene": corr_gene_ranks.cpu(), "Number of Times Seen": gene_counts, "Gene ID": correct_gene_nid}
+    df = pd.DataFrame(data)
+    fig = px.scatter(df, x = "Number of Times Seen", y = "Rank of Correct Gene", hover_data=list(data.keys()))
+    return fig, gene_counts
+
+
+def plot_gene_rank_vs_trainset(corr_gene_ranks, correct_gene_nid, gene_counts): # train_corr_gene_nid has dimension num_gene x num_sets (corr, cand, sparse, target)
+    trainset_labels = ["-".join([str(int(l)) for l in list((gene_counts[i, :] > 0).numpy())]) for i in range(gene_counts.shape[0])]
+    gene_ranks_dict = {}
+    for l, r in zip(trainset_labels, corr_gene_ranks): # (corr, cand, sparse, target)
+        l_full = []
+        if l.split("-")[0] == "1": l_full.append("Corr")
+        if l.split("-")[1] == "1": l_full.append("Cand")
+        if l.split("-")[2] == "1": l_full.append("Sparse")
+        if l.split("-")[3] == "1": l_full.append("Target")
+        if len(l_full) == 0: l_full = "None"
+        else: l_full = "-".join(l_full)
+        if l_full not in gene_ranks_dict: gene_ranks_dict[l_full] = []
+        gene_ranks_dict[l_full].append(int(r))
+    avg_gene_ranks = {l: np.mean(r) for l, r in gene_ranks_dict.items()}
+    df = pd.DataFrame({"Train Set": list(avg_gene_ranks.keys()), "Average Rank of Correct Gene": list(avg_gene_ranks.values())})
+    fig = px.bar(df, x = "Train Set", y = "Average Rank of Correct Gene")
+    return fig
+
+
+def plot_gene_rank_vs_fraction_phenotype(corr_gene_ranks, frac_p):
+    df = pd.DataFrame({"Rank of Correct Gene": corr_gene_ranks, "Fraction of Phenotypes": frac_p})
+    df = df[df["Fraction of Phenotypes"] > -1]
+    fig = px.scatter(df, x = "Fraction of Phenotypes", y = "Rank of Correct Gene")
+    return fig
+
+
+def plot_gene_rank_vs_hops(corr_gene_ranks, n_hops):
+    mean_hops = []
+    min_hops = []
+    for hops in n_hops:
+        if type(hops) == list: # gene to phenotype n_hops
+            mean_hops.append(np.mean(hops))
+            min_hops.append(np.min(hops))
+        else: # phenotype to phenotype n_hops
+            filtered_hops = torch.cat([hops[i][hops[i] > 0] for i in range(hops.shape[0])]).float()
+            mean_hops.append(torch.mean(filtered_hops).item())
+            min_hops.append(torch.min(filtered_hops).item())
+        
+    df = pd.DataFrame({"Rank of Correct Gene": corr_gene_ranks, "Mean Number of Hops": mean_hops})
+    fig_mean = px.scatter(df, x = "Mean Number of Hops", y = "Rank of Correct Gene")
+    df = pd.DataFrame({"Rank of Correct Gene": corr_gene_ranks, "Min Number of Hops": min_hops})
+    fig_min = px.scatter(df, x = "Min Number of Hops", y = "Rank of Correct Gene")
+    return fig_mean, fig_min