--- a +++ b/mowgli/pl.py @@ -0,0 +1,235 @@ +import anndata as ad +import mudata as md +import numpy as np +import pandas as pd +import scanpy as sc +import seaborn as sns +from matplotlib import pyplot as plt + + +def clustermap(mdata: md.MuData, obsm: str = "W_OT", cmap="viridis", **kwds): + """Wrapper around Scanpy's clustermap. + + Args: + mdata (md.MuData): The input data + obsm (str, optional): The obsm field to consider. Defaults to 'W_OT'. + cmap (str, optional): The colormap. Defaults to 'viridis'. + """ + + # Create an AnnData with the joint embedding. + joint_embedding = ad.AnnData(mdata.obsm[obsm], obs=mdata.obs) + + # Make the clustermap plot. + sc.pl.clustermap(joint_embedding, cmap=cmap, **kwds) + + +def factor_violin( + mdata: md.MuData, + groupby: str, + obsm: str = "W_OT", + dim: int = 0, + **kwds, +): + """Make a violin plot of cells for a given latent dimension. + + Args: + mdata (md.MuData): The input data + dim (int, optional): The latent dimension. Defaults to 0. + obsm (str, optional): The embedding. Defaults to 'W_OT'. + groupby (str, optional): Observation groups. + """ + + # Create an AnnData with the joint embedding. + joint_embedding = ad.AnnData(mdata.obsm[obsm], obs=mdata.obs) + + # Add the obs field that we're interested in. + joint_embedding.obs["Factor " + str(dim)] = joint_embedding.X[:, dim] + + # Make the violin plot. + sc.pl.violin(joint_embedding, keys="Factor " + str(dim), groupby=groupby, **kwds) + + +def heatmap( + mdata: md.MuData, + groupby: str, + obsm: str = "W_OT", + cmap: str = "viridis", + sort_var: bool = False, + save: str = None, + **kwds, +) -> None: + """Produce a heatmap of an embedding + + Args: + mdata (md.MuData): Input data + groupby (str): What to group by + obsm (str): The embedding. Defaults to 'W_OT'. + cmap (str, optional): Color map. Defaults to 'viridis'. + sort_var (bool, optional): + Sort dimensions by variance. Defaults to False. + """ + + # Create an AnnData with the joint embedding. + joint_embedding = ad.AnnData(mdata.obsm[obsm], obs=mdata.obs) + + # Try to compute a dendrogram. + try: + sc.pp.pca(joint_embedding) + sc.tl.dendrogram(joint_embedding, groupby=groupby, use_rep="X_pca") + except Exception: + print("Dendrogram not computed.") + pass + + # Get the dimension names to show. + if sort_var: + idx = joint_embedding.X.std(0).argsort()[::-1] + var_names = joint_embedding.var_names[idx] + else: + var_names = joint_embedding.var_names + + # PLot the heatmap. + return sc.pl.heatmap( + joint_embedding, var_names, groupby=groupby, cmap=cmap, save=save, **kwds + ) + + +def enrich(enr: pd.DataFrame, query_name: str, n_terms: int = 10): + """Display a list of enriched terms. + + Args: + enr (pd.DataFrame): The enrichment object returned by mowgli.tl.enrich() + query_name (str): The name of the query, e.g. "dimension 0". + """ + + # Subset the enrichment object to the query of interest. + sub_enr = enr[enr["query"] == query_name].head(n_terms) + sub_enr["minlogp"] = -np.log10(sub_enr["p_value"]) + + fig, ax = plt.subplots() + + # Display the enriched terms. + ax.hlines( + y=sub_enr["name"], + xmin=0, + xmax=sub_enr["minlogp"], + color="lightgray", + zorder=1, + alpha=0.8, + ) + sns.scatterplot( + data=sub_enr, + x="minlogp", + y="name", + hue="source", + s=100, + alpha=0.8, + ax=ax, + zorder=3, + ) + + ax.set_xlabel("$-log_{10}(p)$") + ax.set_ylabel("Enriched terms") + + plt.show() + + +def top_features( + mdata: md.MuData, + mod: str = "rna", + uns: str = "H_OT", + dim: int = 0, + n_top: int = 10, + ax: plt.axes = None, + palette: str = "Blues_r", +): + """Display the top features for a given dimension. + + Args: + mdata (md.MuData): The input mdata object + mod (str, optional): The modality to consider. Defaults to 'rna'. + uns (str, optional): The uns field to consider. Defaults to 'H_OT'. + dim (int, optional): The latent dimension. Defaults to 0. + n_top (int, optional): The number of top features to display. Defaults to 10. + ax (plt.axes, optional): The axes to use. Defaults to None. + palette (str, optional): The color palette to use. Defaults to 'Blues_r'. + + Returns: + plt.axes: The axes used. + """ + + # Get the variable names. + var_names = mdata[mod].var_names[mdata[mod].var.highly_variable] + + # Get the top features. + idx_top_features = np.argsort(mdata[mod].uns[uns][:, dim])[::-1][:n_top] + df = pd.DataFrame( + { + "features": var_names[idx_top_features], + "weights": mdata[mod].uns[uns][idx_top_features, dim], + } + ) + + # Display the top features. + if ax is None: + ax = sns.barplot(data=df, x="weights", y="features", palette=palette) + else: + sns.barplot(data=df, x="weights", y="features", palette=palette, ax=ax) + + return ax + + +def umap( + mdata: md.MuData, + dim: int | list = 0, + rescale: bool = False, + obsm: str = "W_OT", + neighbours_key=None, + **kwds, +): + """Wrapper around Scanpy's sc.pl.umap. Computes UMAP for a given latent dimension and plots it. + Args: + mdata (md.MuData): The input data + dim (int | list, optional): The latent dimension. Defaults to 0. + rescale (bool, optional): If True, Rescale the color palette across all plots to the maximum value in the weight matrix. Defaults to False. + obsm (str, optional): The embedding. Defaults to 'W_OT'. + neighbours_key (str, optional): The key for the neighbours in `mdata.uns` to use to compute neighbors. Defaults to None. + """ + + adata_tmp = ad.AnnData(mdata.obsm[obsm], obs=pd.DataFrame(index=mdata.obs.index)) + + if isinstance(dim, int): + mowgli_cat = f"mowgli:{dim}" + + elif isinstance(dim, list): + # clean dim of doubles and sort them + dim = sorted(list(set(dim))) + mowgli_cat = [f"mowgli:{x}" for x in dim] + + else: + raise ValueError("dim must be an integer or a list of integers") + + adata_tmp.obs[mowgli_cat] = adata_tmp.X[:, dim] + + # check if neighbors exists + if neighbours_key is None: + print("Computing neighbors with scanpy default parameters") + neighbours_key = "mowgli_neighbors" # set the default neighbors key + # compute neiughborts using all dimension in the mowgli matrix + sc.pp.neighbors(adata_tmp, use_rep="X", key_added=neighbours_key) + + else: + if neighbours_key not in mdata.uns.keys(): + raise ValueError(f"neighbours key {neighbours_key} not found in mdata.uns") + + adata_tmp.uns[neighbours_key] = mdata.uns[neighbours_key] + + # compute umap + print("Computing UMAP") + sc.tl.umap(adata_tmp, neighbors_key=neighbours_key) + + # plot umap + if rescale: + vmax = adata_tmp.X.max() + sc.pl.umap(adata_tmp, color=mowgli_cat, size=18.5, alpha=0.4, vmax=vmax, **kwds) + else: + sc.pl.umap(adata_tmp, color=mowgli_cat, size=18.5, alpha=0.4, **kwds)