|
a |
|
b/mowgli/pl.py |
|
|
1 |
import anndata as ad |
|
|
2 |
import mudata as md |
|
|
3 |
import numpy as np |
|
|
4 |
import pandas as pd |
|
|
5 |
import scanpy as sc |
|
|
6 |
import seaborn as sns |
|
|
7 |
from matplotlib import pyplot as plt |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
def clustermap(mdata: md.MuData, obsm: str = "W_OT", cmap="viridis", **kwds): |
|
|
11 |
"""Wrapper around Scanpy's clustermap. |
|
|
12 |
|
|
|
13 |
Args: |
|
|
14 |
mdata (md.MuData): The input data |
|
|
15 |
obsm (str, optional): The obsm field to consider. Defaults to 'W_OT'. |
|
|
16 |
cmap (str, optional): The colormap. Defaults to 'viridis'. |
|
|
17 |
""" |
|
|
18 |
|
|
|
19 |
# Create an AnnData with the joint embedding. |
|
|
20 |
joint_embedding = ad.AnnData(mdata.obsm[obsm], obs=mdata.obs) |
|
|
21 |
|
|
|
22 |
# Make the clustermap plot. |
|
|
23 |
sc.pl.clustermap(joint_embedding, cmap=cmap, **kwds) |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
def factor_violin( |
|
|
27 |
mdata: md.MuData, |
|
|
28 |
groupby: str, |
|
|
29 |
obsm: str = "W_OT", |
|
|
30 |
dim: int = 0, |
|
|
31 |
**kwds, |
|
|
32 |
): |
|
|
33 |
"""Make a violin plot of cells for a given latent dimension. |
|
|
34 |
|
|
|
35 |
Args: |
|
|
36 |
mdata (md.MuData): The input data |
|
|
37 |
dim (int, optional): The latent dimension. Defaults to 0. |
|
|
38 |
obsm (str, optional): The embedding. Defaults to 'W_OT'. |
|
|
39 |
groupby (str, optional): Observation groups. |
|
|
40 |
""" |
|
|
41 |
|
|
|
42 |
# Create an AnnData with the joint embedding. |
|
|
43 |
joint_embedding = ad.AnnData(mdata.obsm[obsm], obs=mdata.obs) |
|
|
44 |
|
|
|
45 |
# Add the obs field that we're interested in. |
|
|
46 |
joint_embedding.obs["Factor " + str(dim)] = joint_embedding.X[:, dim] |
|
|
47 |
|
|
|
48 |
# Make the violin plot. |
|
|
49 |
sc.pl.violin(joint_embedding, keys="Factor " + str(dim), groupby=groupby, **kwds) |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
def heatmap( |
|
|
53 |
mdata: md.MuData, |
|
|
54 |
groupby: str, |
|
|
55 |
obsm: str = "W_OT", |
|
|
56 |
cmap: str = "viridis", |
|
|
57 |
sort_var: bool = False, |
|
|
58 |
save: str = None, |
|
|
59 |
**kwds, |
|
|
60 |
) -> None: |
|
|
61 |
"""Produce a heatmap of an embedding |
|
|
62 |
|
|
|
63 |
Args: |
|
|
64 |
mdata (md.MuData): Input data |
|
|
65 |
groupby (str): What to group by |
|
|
66 |
obsm (str): The embedding. Defaults to 'W_OT'. |
|
|
67 |
cmap (str, optional): Color map. Defaults to 'viridis'. |
|
|
68 |
sort_var (bool, optional): |
|
|
69 |
Sort dimensions by variance. Defaults to False. |
|
|
70 |
""" |
|
|
71 |
|
|
|
72 |
# Create an AnnData with the joint embedding. |
|
|
73 |
joint_embedding = ad.AnnData(mdata.obsm[obsm], obs=mdata.obs) |
|
|
74 |
|
|
|
75 |
# Try to compute a dendrogram. |
|
|
76 |
try: |
|
|
77 |
sc.pp.pca(joint_embedding) |
|
|
78 |
sc.tl.dendrogram(joint_embedding, groupby=groupby, use_rep="X_pca") |
|
|
79 |
except Exception: |
|
|
80 |
print("Dendrogram not computed.") |
|
|
81 |
pass |
|
|
82 |
|
|
|
83 |
# Get the dimension names to show. |
|
|
84 |
if sort_var: |
|
|
85 |
idx = joint_embedding.X.std(0).argsort()[::-1] |
|
|
86 |
var_names = joint_embedding.var_names[idx] |
|
|
87 |
else: |
|
|
88 |
var_names = joint_embedding.var_names |
|
|
89 |
|
|
|
90 |
# PLot the heatmap. |
|
|
91 |
return sc.pl.heatmap( |
|
|
92 |
joint_embedding, var_names, groupby=groupby, cmap=cmap, save=save, **kwds |
|
|
93 |
) |
|
|
94 |
|
|
|
95 |
|
|
|
96 |
def enrich(enr: pd.DataFrame, query_name: str, n_terms: int = 10): |
|
|
97 |
"""Display a list of enriched terms. |
|
|
98 |
|
|
|
99 |
Args: |
|
|
100 |
enr (pd.DataFrame): The enrichment object returned by mowgli.tl.enrich() |
|
|
101 |
query_name (str): The name of the query, e.g. "dimension 0". |
|
|
102 |
""" |
|
|
103 |
|
|
|
104 |
# Subset the enrichment object to the query of interest. |
|
|
105 |
sub_enr = enr[enr["query"] == query_name].head(n_terms) |
|
|
106 |
sub_enr["minlogp"] = -np.log10(sub_enr["p_value"]) |
|
|
107 |
|
|
|
108 |
fig, ax = plt.subplots() |
|
|
109 |
|
|
|
110 |
# Display the enriched terms. |
|
|
111 |
ax.hlines( |
|
|
112 |
y=sub_enr["name"], |
|
|
113 |
xmin=0, |
|
|
114 |
xmax=sub_enr["minlogp"], |
|
|
115 |
color="lightgray", |
|
|
116 |
zorder=1, |
|
|
117 |
alpha=0.8, |
|
|
118 |
) |
|
|
119 |
sns.scatterplot( |
|
|
120 |
data=sub_enr, |
|
|
121 |
x="minlogp", |
|
|
122 |
y="name", |
|
|
123 |
hue="source", |
|
|
124 |
s=100, |
|
|
125 |
alpha=0.8, |
|
|
126 |
ax=ax, |
|
|
127 |
zorder=3, |
|
|
128 |
) |
|
|
129 |
|
|
|
130 |
ax.set_xlabel("$-log_{10}(p)$") |
|
|
131 |
ax.set_ylabel("Enriched terms") |
|
|
132 |
|
|
|
133 |
plt.show() |
|
|
134 |
|
|
|
135 |
|
|
|
136 |
def top_features( |
|
|
137 |
mdata: md.MuData, |
|
|
138 |
mod: str = "rna", |
|
|
139 |
uns: str = "H_OT", |
|
|
140 |
dim: int = 0, |
|
|
141 |
n_top: int = 10, |
|
|
142 |
ax: plt.axes = None, |
|
|
143 |
palette: str = "Blues_r", |
|
|
144 |
): |
|
|
145 |
"""Display the top features for a given dimension. |
|
|
146 |
|
|
|
147 |
Args: |
|
|
148 |
mdata (md.MuData): The input mdata object |
|
|
149 |
mod (str, optional): The modality to consider. Defaults to 'rna'. |
|
|
150 |
uns (str, optional): The uns field to consider. Defaults to 'H_OT'. |
|
|
151 |
dim (int, optional): The latent dimension. Defaults to 0. |
|
|
152 |
n_top (int, optional): The number of top features to display. Defaults to 10. |
|
|
153 |
ax (plt.axes, optional): The axes to use. Defaults to None. |
|
|
154 |
palette (str, optional): The color palette to use. Defaults to 'Blues_r'. |
|
|
155 |
|
|
|
156 |
Returns: |
|
|
157 |
plt.axes: The axes used. |
|
|
158 |
""" |
|
|
159 |
|
|
|
160 |
# Get the variable names. |
|
|
161 |
var_names = mdata[mod].var_names[mdata[mod].var.highly_variable] |
|
|
162 |
|
|
|
163 |
# Get the top features. |
|
|
164 |
idx_top_features = np.argsort(mdata[mod].uns[uns][:, dim])[::-1][:n_top] |
|
|
165 |
df = pd.DataFrame( |
|
|
166 |
{ |
|
|
167 |
"features": var_names[idx_top_features], |
|
|
168 |
"weights": mdata[mod].uns[uns][idx_top_features, dim], |
|
|
169 |
} |
|
|
170 |
) |
|
|
171 |
|
|
|
172 |
# Display the top features. |
|
|
173 |
if ax is None: |
|
|
174 |
ax = sns.barplot(data=df, x="weights", y="features", palette=palette) |
|
|
175 |
else: |
|
|
176 |
sns.barplot(data=df, x="weights", y="features", palette=palette, ax=ax) |
|
|
177 |
|
|
|
178 |
return ax |
|
|
179 |
|
|
|
180 |
|
|
|
181 |
def umap( |
|
|
182 |
mdata: md.MuData, |
|
|
183 |
dim: int | list = 0, |
|
|
184 |
rescale: bool = False, |
|
|
185 |
obsm: str = "W_OT", |
|
|
186 |
neighbours_key=None, |
|
|
187 |
**kwds, |
|
|
188 |
): |
|
|
189 |
"""Wrapper around Scanpy's sc.pl.umap. Computes UMAP for a given latent dimension and plots it. |
|
|
190 |
Args: |
|
|
191 |
mdata (md.MuData): The input data |
|
|
192 |
dim (int | list, optional): The latent dimension. Defaults to 0. |
|
|
193 |
rescale (bool, optional): If True, Rescale the color palette across all plots to the maximum value in the weight matrix. Defaults to False. |
|
|
194 |
obsm (str, optional): The embedding. Defaults to 'W_OT'. |
|
|
195 |
neighbours_key (str, optional): The key for the neighbours in `mdata.uns` to use to compute neighbors. Defaults to None. |
|
|
196 |
""" |
|
|
197 |
|
|
|
198 |
adata_tmp = ad.AnnData(mdata.obsm[obsm], obs=pd.DataFrame(index=mdata.obs.index)) |
|
|
199 |
|
|
|
200 |
if isinstance(dim, int): |
|
|
201 |
mowgli_cat = f"mowgli:{dim}" |
|
|
202 |
|
|
|
203 |
elif isinstance(dim, list): |
|
|
204 |
# clean dim of doubles and sort them |
|
|
205 |
dim = sorted(list(set(dim))) |
|
|
206 |
mowgli_cat = [f"mowgli:{x}" for x in dim] |
|
|
207 |
|
|
|
208 |
else: |
|
|
209 |
raise ValueError("dim must be an integer or a list of integers") |
|
|
210 |
|
|
|
211 |
adata_tmp.obs[mowgli_cat] = adata_tmp.X[:, dim] |
|
|
212 |
|
|
|
213 |
# check if neighbors exists |
|
|
214 |
if neighbours_key is None: |
|
|
215 |
print("Computing neighbors with scanpy default parameters") |
|
|
216 |
neighbours_key = "mowgli_neighbors" # set the default neighbors key |
|
|
217 |
# compute neiughborts using all dimension in the mowgli matrix |
|
|
218 |
sc.pp.neighbors(adata_tmp, use_rep="X", key_added=neighbours_key) |
|
|
219 |
|
|
|
220 |
else: |
|
|
221 |
if neighbours_key not in mdata.uns.keys(): |
|
|
222 |
raise ValueError(f"neighbours key {neighbours_key} not found in mdata.uns") |
|
|
223 |
|
|
|
224 |
adata_tmp.uns[neighbours_key] = mdata.uns[neighbours_key] |
|
|
225 |
|
|
|
226 |
# compute umap |
|
|
227 |
print("Computing UMAP") |
|
|
228 |
sc.tl.umap(adata_tmp, neighbors_key=neighbours_key) |
|
|
229 |
|
|
|
230 |
# plot umap |
|
|
231 |
if rescale: |
|
|
232 |
vmax = adata_tmp.X.max() |
|
|
233 |
sc.pl.umap(adata_tmp, color=mowgli_cat, size=18.5, alpha=0.4, vmax=vmax, **kwds) |
|
|
234 |
else: |
|
|
235 |
sc.pl.umap(adata_tmp, color=mowgli_cat, size=18.5, alpha=0.4, **kwds) |