[28d9d2]: / mowgli / pl.py

Download this file

236 lines (189 with data), 7.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import anndata as ad
import mudata as md
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
from matplotlib import pyplot as plt
def clustermap(mdata: md.MuData, obsm: str = "W_OT", cmap="viridis", **kwds):
"""Wrapper around Scanpy's clustermap.
Args:
mdata (md.MuData): The input data
obsm (str, optional): The obsm field to consider. Defaults to 'W_OT'.
cmap (str, optional): The colormap. Defaults to 'viridis'.
"""
# Create an AnnData with the joint embedding.
joint_embedding = ad.AnnData(mdata.obsm[obsm], obs=mdata.obs)
# Make the clustermap plot.
sc.pl.clustermap(joint_embedding, cmap=cmap, **kwds)
def factor_violin(
mdata: md.MuData,
groupby: str,
obsm: str = "W_OT",
dim: int = 0,
**kwds,
):
"""Make a violin plot of cells for a given latent dimension.
Args:
mdata (md.MuData): The input data
dim (int, optional): The latent dimension. Defaults to 0.
obsm (str, optional): The embedding. Defaults to 'W_OT'.
groupby (str, optional): Observation groups.
"""
# Create an AnnData with the joint embedding.
joint_embedding = ad.AnnData(mdata.obsm[obsm], obs=mdata.obs)
# Add the obs field that we're interested in.
joint_embedding.obs["Factor " + str(dim)] = joint_embedding.X[:, dim]
# Make the violin plot.
sc.pl.violin(joint_embedding, keys="Factor " + str(dim), groupby=groupby, **kwds)
def heatmap(
mdata: md.MuData,
groupby: str,
obsm: str = "W_OT",
cmap: str = "viridis",
sort_var: bool = False,
save: str = None,
**kwds,
) -> None:
"""Produce a heatmap of an embedding
Args:
mdata (md.MuData): Input data
groupby (str): What to group by
obsm (str): The embedding. Defaults to 'W_OT'.
cmap (str, optional): Color map. Defaults to 'viridis'.
sort_var (bool, optional):
Sort dimensions by variance. Defaults to False.
"""
# Create an AnnData with the joint embedding.
joint_embedding = ad.AnnData(mdata.obsm[obsm], obs=mdata.obs)
# Try to compute a dendrogram.
try:
sc.pp.pca(joint_embedding)
sc.tl.dendrogram(joint_embedding, groupby=groupby, use_rep="X_pca")
except Exception:
print("Dendrogram not computed.")
pass
# Get the dimension names to show.
if sort_var:
idx = joint_embedding.X.std(0).argsort()[::-1]
var_names = joint_embedding.var_names[idx]
else:
var_names = joint_embedding.var_names
# PLot the heatmap.
return sc.pl.heatmap(
joint_embedding, var_names, groupby=groupby, cmap=cmap, save=save, **kwds
)
def enrich(enr: pd.DataFrame, query_name: str, n_terms: int = 10):
"""Display a list of enriched terms.
Args:
enr (pd.DataFrame): The enrichment object returned by mowgli.tl.enrich()
query_name (str): The name of the query, e.g. "dimension 0".
"""
# Subset the enrichment object to the query of interest.
sub_enr = enr[enr["query"] == query_name].head(n_terms)
sub_enr["minlogp"] = -np.log10(sub_enr["p_value"])
fig, ax = plt.subplots()
# Display the enriched terms.
ax.hlines(
y=sub_enr["name"],
xmin=0,
xmax=sub_enr["minlogp"],
color="lightgray",
zorder=1,
alpha=0.8,
)
sns.scatterplot(
data=sub_enr,
x="minlogp",
y="name",
hue="source",
s=100,
alpha=0.8,
ax=ax,
zorder=3,
)
ax.set_xlabel("$-log_{10}(p)$")
ax.set_ylabel("Enriched terms")
plt.show()
def top_features(
mdata: md.MuData,
mod: str = "rna",
uns: str = "H_OT",
dim: int = 0,
n_top: int = 10,
ax: plt.axes = None,
palette: str = "Blues_r",
):
"""Display the top features for a given dimension.
Args:
mdata (md.MuData): The input mdata object
mod (str, optional): The modality to consider. Defaults to 'rna'.
uns (str, optional): The uns field to consider. Defaults to 'H_OT'.
dim (int, optional): The latent dimension. Defaults to 0.
n_top (int, optional): The number of top features to display. Defaults to 10.
ax (plt.axes, optional): The axes to use. Defaults to None.
palette (str, optional): The color palette to use. Defaults to 'Blues_r'.
Returns:
plt.axes: The axes used.
"""
# Get the variable names.
var_names = mdata[mod].var_names[mdata[mod].var.highly_variable]
# Get the top features.
idx_top_features = np.argsort(mdata[mod].uns[uns][:, dim])[::-1][:n_top]
df = pd.DataFrame(
{
"features": var_names[idx_top_features],
"weights": mdata[mod].uns[uns][idx_top_features, dim],
}
)
# Display the top features.
if ax is None:
ax = sns.barplot(data=df, x="weights", y="features", palette=palette)
else:
sns.barplot(data=df, x="weights", y="features", palette=palette, ax=ax)
return ax
def umap(
mdata: md.MuData,
dim: int | list = 0,
rescale: bool = False,
obsm: str = "W_OT",
neighbours_key=None,
**kwds,
):
"""Wrapper around Scanpy's sc.pl.umap. Computes UMAP for a given latent dimension and plots it.
Args:
mdata (md.MuData): The input data
dim (int | list, optional): The latent dimension. Defaults to 0.
rescale (bool, optional): If True, Rescale the color palette across all plots to the maximum value in the weight matrix. Defaults to False.
obsm (str, optional): The embedding. Defaults to 'W_OT'.
neighbours_key (str, optional): The key for the neighbours in `mdata.uns` to use to compute neighbors. Defaults to None.
"""
adata_tmp = ad.AnnData(mdata.obsm[obsm], obs=pd.DataFrame(index=mdata.obs.index))
if isinstance(dim, int):
mowgli_cat = f"mowgli:{dim}"
elif isinstance(dim, list):
# clean dim of doubles and sort them
dim = sorted(list(set(dim)))
mowgli_cat = [f"mowgli:{x}" for x in dim]
else:
raise ValueError("dim must be an integer or a list of integers")
adata_tmp.obs[mowgli_cat] = adata_tmp.X[:, dim]
# check if neighbors exists
if neighbours_key is None:
print("Computing neighbors with scanpy default parameters")
neighbours_key = "mowgli_neighbors" # set the default neighbors key
# compute neiughborts using all dimension in the mowgli matrix
sc.pp.neighbors(adata_tmp, use_rep="X", key_added=neighbours_key)
else:
if neighbours_key not in mdata.uns.keys():
raise ValueError(f"neighbours key {neighbours_key} not found in mdata.uns")
adata_tmp.uns[neighbours_key] = mdata.uns[neighbours_key]
# compute umap
print("Computing UMAP")
sc.tl.umap(adata_tmp, neighbors_key=neighbours_key)
# plot umap
if rescale:
vmax = adata_tmp.X.max()
sc.pl.umap(adata_tmp, color=mowgli_cat, size=18.5, alpha=0.4, vmax=vmax, **kwds)
else:
sc.pl.umap(adata_tmp, color=mowgli_cat, size=18.5, alpha=0.4, **kwds)