Diff of /mowgli/pl.py [000000] .. [061d85]

Switch to unified view

a b/mowgli/pl.py
1
import anndata as ad
2
import mudata as md
3
import numpy as np
4
import pandas as pd
5
import scanpy as sc
6
import seaborn as sns
7
from matplotlib import pyplot as plt
8
9
10
def clustermap(mdata: md.MuData, obsm: str = "W_OT", cmap="viridis", **kwds):
11
    """Wrapper around Scanpy's clustermap.
12
13
    Args:
14
        mdata (md.MuData): The input data
15
        obsm (str, optional): The obsm field to consider. Defaults to 'W_OT'.
16
        cmap (str, optional): The colormap. Defaults to 'viridis'.
17
    """
18
19
    # Create an AnnData with the joint embedding.
20
    joint_embedding = ad.AnnData(mdata.obsm[obsm], obs=mdata.obs)
21
22
    # Make the clustermap plot.
23
    sc.pl.clustermap(joint_embedding, cmap=cmap, **kwds)
24
25
26
def factor_violin(
27
    mdata: md.MuData,
28
    groupby: str,
29
    obsm: str = "W_OT",
30
    dim: int = 0,
31
    **kwds,
32
):
33
    """Make a violin plot of cells for a given latent dimension.
34
35
    Args:
36
        mdata (md.MuData): The input data
37
        dim (int, optional): The latent dimension. Defaults to 0.
38
        obsm (str, optional): The embedding. Defaults to 'W_OT'.
39
        groupby (str, optional): Observation groups.
40
    """
41
42
    # Create an AnnData with the joint embedding.
43
    joint_embedding = ad.AnnData(mdata.obsm[obsm], obs=mdata.obs)
44
45
    # Add the obs field that we're interested in.
46
    joint_embedding.obs["Factor " + str(dim)] = joint_embedding.X[:, dim]
47
48
    # Make the violin plot.
49
    sc.pl.violin(joint_embedding, keys="Factor " + str(dim), groupby=groupby, **kwds)
50
51
52
def heatmap(
53
    mdata: md.MuData,
54
    groupby: str,
55
    obsm: str = "W_OT",
56
    cmap: str = "viridis",
57
    sort_var: bool = False,
58
    save: str = None,
59
    **kwds,
60
) -> None:
61
    """Produce a heatmap of an embedding
62
63
    Args:
64
        mdata (md.MuData): Input data
65
        groupby (str): What to group by
66
        obsm (str): The embedding. Defaults to 'W_OT'.
67
        cmap (str, optional): Color map. Defaults to 'viridis'.
68
        sort_var (bool, optional):
69
            Sort dimensions by variance. Defaults to False.
70
    """
71
72
    # Create an AnnData with the joint embedding.
73
    joint_embedding = ad.AnnData(mdata.obsm[obsm], obs=mdata.obs)
74
75
    # Try to compute a dendrogram.
76
    try:
77
        sc.pp.pca(joint_embedding)
78
        sc.tl.dendrogram(joint_embedding, groupby=groupby, use_rep="X_pca")
79
    except Exception:
80
        print("Dendrogram not computed.")
81
        pass
82
83
    # Get the dimension names to show.
84
    if sort_var:
85
        idx = joint_embedding.X.std(0).argsort()[::-1]
86
        var_names = joint_embedding.var_names[idx]
87
    else:
88
        var_names = joint_embedding.var_names
89
90
    # PLot the heatmap.
91
    return sc.pl.heatmap(
92
        joint_embedding, var_names, groupby=groupby, cmap=cmap, save=save, **kwds
93
    )
94
95
96
def enrich(enr: pd.DataFrame, query_name: str, n_terms: int = 10):
97
    """Display a list of enriched terms.
98
99
    Args:
100
        enr (pd.DataFrame): The enrichment object returned by mowgli.tl.enrich()
101
        query_name (str): The name of the query, e.g. "dimension 0".
102
    """
103
104
    # Subset the enrichment object to the query of interest.
105
    sub_enr = enr[enr["query"] == query_name].head(n_terms)
106
    sub_enr["minlogp"] = -np.log10(sub_enr["p_value"])
107
108
    fig, ax = plt.subplots()
109
110
    # Display the enriched terms.
111
    ax.hlines(
112
        y=sub_enr["name"],
113
        xmin=0,
114
        xmax=sub_enr["minlogp"],
115
        color="lightgray",
116
        zorder=1,
117
        alpha=0.8,
118
    )
119
    sns.scatterplot(
120
        data=sub_enr,
121
        x="minlogp",
122
        y="name",
123
        hue="source",
124
        s=100,
125
        alpha=0.8,
126
        ax=ax,
127
        zorder=3,
128
    )
129
130
    ax.set_xlabel("$-log_{10}(p)$")
131
    ax.set_ylabel("Enriched terms")
132
133
    plt.show()
134
135
136
def top_features(
137
    mdata: md.MuData,
138
    mod: str = "rna",
139
    uns: str = "H_OT",
140
    dim: int = 0,
141
    n_top: int = 10,
142
    ax: plt.axes = None,
143
    palette: str = "Blues_r",
144
):
145
    """Display the top features for a given dimension.
146
147
    Args:
148
        mdata (md.MuData): The input mdata object
149
        mod (str, optional): The modality to consider. Defaults to 'rna'.
150
        uns (str, optional): The uns field to consider. Defaults to 'H_OT'.
151
        dim (int, optional): The latent dimension. Defaults to 0.
152
        n_top (int, optional): The number of top features to display. Defaults to 10.
153
        ax (plt.axes, optional): The axes to use. Defaults to None.
154
        palette (str, optional): The color palette to use. Defaults to 'Blues_r'.
155
156
    Returns:
157
        plt.axes: The axes used.
158
    """
159
160
    # Get the variable names.
161
    var_names = mdata[mod].var_names[mdata[mod].var.highly_variable]
162
163
    # Get the top features.
164
    idx_top_features = np.argsort(mdata[mod].uns[uns][:, dim])[::-1][:n_top]
165
    df = pd.DataFrame(
166
        {
167
            "features": var_names[idx_top_features],
168
            "weights": mdata[mod].uns[uns][idx_top_features, dim],
169
        }
170
    )
171
172
    # Display the top features.
173
    if ax is None:
174
        ax = sns.barplot(data=df, x="weights", y="features", palette=palette)
175
    else:
176
        sns.barplot(data=df, x="weights", y="features", palette=palette, ax=ax)
177
178
    return ax
179
180
181
def umap(
182
    mdata: md.MuData,
183
    dim: int | list = 0,
184
    rescale: bool = False,
185
    obsm: str = "W_OT",
186
    neighbours_key=None,
187
    **kwds,
188
):
189
    """Wrapper around Scanpy's sc.pl.umap. Computes UMAP for a given latent dimension and plots it.
190
    Args:
191
        mdata (md.MuData): The input data
192
        dim (int | list, optional): The latent dimension. Defaults to 0.
193
        rescale (bool, optional): If True, Rescale the color palette across all plots to the maximum value in the weight matrix. Defaults to False.
194
        obsm (str, optional): The embedding. Defaults to 'W_OT'.
195
        neighbours_key (str, optional): The key for the neighbours in `mdata.uns` to use to compute neighbors. Defaults to None.
196
    """
197
198
    adata_tmp = ad.AnnData(mdata.obsm[obsm], obs=pd.DataFrame(index=mdata.obs.index))
199
200
    if isinstance(dim, int):
201
        mowgli_cat = f"mowgli:{dim}"
202
203
    elif isinstance(dim, list):
204
        # clean dim of doubles and sort them
205
        dim = sorted(list(set(dim)))
206
        mowgli_cat = [f"mowgli:{x}" for x in dim]
207
208
    else:
209
        raise ValueError("dim must be an integer or a list of integers")
210
211
    adata_tmp.obs[mowgli_cat] = adata_tmp.X[:, dim]
212
213
    # check if neighbors exists
214
    if neighbours_key is None:
215
        print("Computing neighbors with scanpy default parameters")
216
        neighbours_key = "mowgli_neighbors"  # set the default neighbors key
217
        # compute neiughborts using all dimension in the mowgli matrix
218
        sc.pp.neighbors(adata_tmp, use_rep="X", key_added=neighbours_key)
219
220
    else:
221
        if neighbours_key not in mdata.uns.keys():
222
            raise ValueError(f"neighbours key {neighbours_key} not found in mdata.uns")
223
224
        adata_tmp.uns[neighbours_key] = mdata.uns[neighbours_key]
225
226
    # compute umap
227
    print("Computing UMAP")
228
    sc.tl.umap(adata_tmp, neighbors_key=neighbours_key)
229
230
    # plot umap
231
    if rescale:
232
        vmax = adata_tmp.X.max()
233
        sc.pl.umap(adata_tmp, color=mowgli_cat, size=18.5, alpha=0.4, vmax=vmax, **kwds)
234
    else:
235
        sc.pl.umap(adata_tmp, color=mowgli_cat, size=18.5, alpha=0.4, **kwds)