|
a |
|
b/eval_utils.py |
|
|
1 |
from math import inf |
|
|
2 |
import os |
|
|
3 |
import logging |
|
|
4 |
import numpy as np |
|
|
5 |
import scanpy as sc |
|
|
6 |
import anndata as ad |
|
|
7 |
import pandas as pd |
|
|
8 |
|
|
|
9 |
import matplotlib |
|
|
10 |
from matplotlib.figure import Figure |
|
|
11 |
import matplotlib.pyplot as plt |
|
|
12 |
from scipy.sparse.csr import spmatrix |
|
|
13 |
from scipy.stats import chi2 |
|
|
14 |
from typing import Mapping, Sequence, Tuple, Iterable, Union |
|
|
15 |
from scipy.sparse import issparse |
|
|
16 |
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_samples |
|
|
17 |
from sklearn.neighbors import NearestNeighbors |
|
|
18 |
|
|
|
19 |
import psutil |
|
|
20 |
import scib |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
_cpu_count: Union[None, int] = psutil.cpu_count(logical=False) |
|
|
24 |
if _cpu_count is None: |
|
|
25 |
_cpu_count: int = psutil.cpu_count(logical=True) |
|
|
26 |
_logger = logging.getLogger(__name__) |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
def evaluate(adata: ad.AnnData, |
|
|
30 |
n_epoch: int, |
|
|
31 |
embedding_key: str = 'delta', |
|
|
32 |
n_neighbors: int = 15, |
|
|
33 |
resolutions: Iterable[float] = [0.01, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64], |
|
|
34 |
clustering_method: str = "leiden", |
|
|
35 |
cell_type_col: str = "cell_types", |
|
|
36 |
batch_col: Union[str, None] = "batch_indices", |
|
|
37 |
color_by: Iterable[str] = None, |
|
|
38 |
return_fig: bool = False, |
|
|
39 |
plot_fname: str = "umap", |
|
|
40 |
plot_ftype: str = "jpg", |
|
|
41 |
plot_dir: Union[str, None] = None, |
|
|
42 |
plot_dpi: int = 300, |
|
|
43 |
min_dist: float = 0.3, |
|
|
44 |
spread: float = 1, |
|
|
45 |
n_jobs: int = 1, |
|
|
46 |
random_state: Union[None, int, np.random.RandomState, np.random.Generator] = 0, |
|
|
47 |
umap_kwargs: dict = dict() |
|
|
48 |
) -> Mapping[str, Union[float, None, Figure]]: |
|
|
49 |
"""Evaluates the clustering and batch correction performance of the given |
|
|
50 |
embeddings, and optionally plots the embeddings. |
|
|
51 |
|
|
|
52 |
Embeddings will be plotted if return_fig is True or plot_dir is provided. |
|
|
53 |
When tensorboard_dir is provided, will also save the embeddings using a |
|
|
54 |
tensorboard SummaryWriter. |
|
|
55 |
|
|
|
56 |
NOTE: Set n_jobs to 1 if you encounter pickling error. |
|
|
57 |
|
|
|
58 |
Args: |
|
|
59 |
adata: the dataset with the embedding to be evaluated. |
|
|
60 |
embedding_key: the key to the embedding. Must be in adata.obsm. |
|
|
61 |
n_neighbors: #neighbors used when computing neithborhood graph and |
|
|
62 |
calculating entropy of batch mixing / kBET. |
|
|
63 |
resolutions: a sequence of resolutions used for clustering. |
|
|
64 |
clustering_method: clustering method used. Should be one of 'leiden' or |
|
|
65 |
'louvain'. |
|
|
66 |
cell_type_col: a key in adata.obs to the cell type column. |
|
|
67 |
batch_col: a key in adata.obs to the batch column. |
|
|
68 |
return_fig: whether to return the Figure object. Useful for visualizing |
|
|
69 |
the plot. |
|
|
70 |
color_by: a list of adata.obs column keys to color the embeddings by. |
|
|
71 |
If None, will look up adata.uns['color_by']. Only used if is |
|
|
72 |
drawing. |
|
|
73 |
plot_fname: file name of the generated plot. Only used if is drawing. |
|
|
74 |
plot_ftype: file type of the generated plot. Only used if is drawing. |
|
|
75 |
plot_dir: directory to save the generated plot. If None, do not save |
|
|
76 |
the plot. |
|
|
77 |
plot_dpi: dpi to save the plot. |
|
|
78 |
writer: an initialized SummaryWriter to save the UMAP plot to. Only |
|
|
79 |
used if is drawing. |
|
|
80 |
min_dist: the min_dist argument in sc.tl.umap. Only used is drawing. |
|
|
81 |
spread: the spread argument in sc.tl.umap. Only used if is drawing. |
|
|
82 |
n_jobs: # jobs to generate. If <= 0, this is set to the number of |
|
|
83 |
physical cores. |
|
|
84 |
random_state: random state for knn calculation. |
|
|
85 |
umap_kwargs: other kwargs to pass to sc.pl.umap. |
|
|
86 |
|
|
|
87 |
Returns: |
|
|
88 |
A dict storing the ari, nmi, asw, ebm and k_bet of the cell embeddings |
|
|
89 |
with key "ari", "nmi", "asw", "ebm", "k_bet", respectively. If draw is |
|
|
90 |
True and return_fig is True, will also store the plotted figure with |
|
|
91 |
key "fig". |
|
|
92 |
""" |
|
|
93 |
|
|
|
94 |
if cell_type_col and not pd.api.types.is_categorical_dtype(adata.obs[cell_type_col]): |
|
|
95 |
#_logger.warning("scETM.evaluate assumes discrete cell types. Converting cell_type_col to categorical.") |
|
|
96 |
adata.obs[cell_type_col] = adata.obs[cell_type_col].astype(str).astype('category') |
|
|
97 |
if batch_col and not pd.api.types.is_categorical_dtype(adata.obs[batch_col]): |
|
|
98 |
#_logger.warning("scETM.evaluate assumes discrete batches. Converting batch_col to categorical.") |
|
|
99 |
adata.obs[batch_col] = adata.obs[batch_col].astype(str).astype('category') |
|
|
100 |
|
|
|
101 |
# calculate neighbors |
|
|
102 |
_get_knn_indices(adata, use_rep=embedding_key, n_neighbors=n_neighbors, random_state=random_state, calc_knn=True) |
|
|
103 |
|
|
|
104 |
# calculate clustering metrics |
|
|
105 |
if cell_type_col in adata.obs and len(resolutions) > 0: |
|
|
106 |
cluster_key, best_ari, best_nmi = clustering(adata, resolutions=resolutions, cell_type_col=cell_type_col, batch_col=batch_col, clustering_method=clustering_method) |
|
|
107 |
else: |
|
|
108 |
cluster_key = best_ari = best_nmi = None |
|
|
109 |
|
|
|
110 |
if adata.obs[cell_type_col].nunique() > 1: |
|
|
111 |
sw = silhouette_samples(adata.X if embedding_key == 'X' else adata.obsm[embedding_key], |
|
|
112 |
adata.obs[cell_type_col]) |
|
|
113 |
adata.obs['silhouette_width'] = sw |
|
|
114 |
asw = np.mean(sw) |
|
|
115 |
#print(f'{embedding_key}_ASW: {asw:7.4f}') |
|
|
116 |
|
|
|
117 |
asw_2 = scib.me.silhouette(adata, group_key=cell_type_col, embed=embedding_key) |
|
|
118 |
|
|
|
119 |
|
|
|
120 |
if batch_col and cell_type_col: |
|
|
121 |
sw_table = adata.obs.pivot_table(index=cell_type_col, columns=batch_col, values="silhouette_width", |
|
|
122 |
aggfunc="mean") |
|
|
123 |
#print(f'SW: {sw_table}') |
|
|
124 |
if plot_dir is not None: |
|
|
125 |
sw_table.to_csv(os.path.join(plot_dir, f'{plot_fname}.csv')) |
|
|
126 |
else: |
|
|
127 |
asw = 0. |
|
|
128 |
asw_2 = 0. |
|
|
129 |
|
|
|
130 |
# calculate batch correction metrics |
|
|
131 |
need_batch = batch_col and adata.obs[batch_col].nunique() > 1 |
|
|
132 |
if need_batch: |
|
|
133 |
ebm = calculate_entropy_batch_mixing(adata, |
|
|
134 |
use_rep=embedding_key, |
|
|
135 |
batch_col=batch_col, |
|
|
136 |
n_neighbors=n_neighbors, |
|
|
137 |
calc_knn=False, |
|
|
138 |
n_jobs=n_jobs, |
|
|
139 |
) |
|
|
140 |
#print(f'{embedding_key}_BE: {ebm:7.4f}') |
|
|
141 |
k_bet = calculate_kbet(adata, |
|
|
142 |
use_rep=embedding_key, |
|
|
143 |
batch_col=batch_col, |
|
|
144 |
n_neighbors=n_neighbors, |
|
|
145 |
calc_knn=False, |
|
|
146 |
n_jobs=n_jobs, |
|
|
147 |
)[2] |
|
|
148 |
#print(f'{embedding_key}_kBET: {k_bet:7.4f}') |
|
|
149 |
batch_asw = scib.me.silhouette_batch(adata, batch_key=batch_col, group_key='cell_type', embed=embedding_key, verbose=False) |
|
|
150 |
batch_graph_score = get_graph_connectivity(adata, use_rep=embedding_key,) |
|
|
151 |
else: |
|
|
152 |
ebm = k_bet = batch_asw = batch_graph_score = None |
|
|
153 |
|
|
|
154 |
# plot UMAP embeddings |
|
|
155 |
if return_fig or plot_dir is not None: |
|
|
156 |
if color_by is None: |
|
|
157 |
color_by = [batch_col, cell_type_col] if need_batch else [cell_type_col] |
|
|
158 |
color_by = list(color_by) |
|
|
159 |
if 'color_by' in adata.uns: |
|
|
160 |
for col in adata.uns['color_by']: |
|
|
161 |
if col not in color_by: |
|
|
162 |
color_by.insert(0, col) |
|
|
163 |
if cluster_key is not None: |
|
|
164 |
color_by = [cluster_key] + color_by |
|
|
165 |
fig = draw_embeddings(adata=adata, color_by=color_by, |
|
|
166 |
min_dist=min_dist, spread=spread, |
|
|
167 |
ckpt_dir=plot_dir, fname=f'{plot_fname+str(n_epoch)}.{plot_ftype}', return_fig=return_fig, |
|
|
168 |
dpi=plot_dpi, |
|
|
169 |
umap_kwargs=umap_kwargs) |
|
|
170 |
else: |
|
|
171 |
fig = None |
|
|
172 |
|
|
|
173 |
return dict( |
|
|
174 |
ari=best_ari, |
|
|
175 |
nmi=best_nmi, |
|
|
176 |
asw=asw, |
|
|
177 |
asw_2=asw_2, |
|
|
178 |
ebm=ebm, |
|
|
179 |
k_bet=k_bet, |
|
|
180 |
batch_asw=batch_asw, |
|
|
181 |
batch_graph_score=batch_graph_score, |
|
|
182 |
fig=fig |
|
|
183 |
) |
|
|
184 |
|
|
|
185 |
def evaluate_ari(adata: ad.AnnData, |
|
|
186 |
n_epoch: int, |
|
|
187 |
embedding_key: str = 'delta', |
|
|
188 |
n_neighbors: int = 15, |
|
|
189 |
resolutions: Iterable[float] = [0.01, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64], |
|
|
190 |
clustering_method: str = "leiden", |
|
|
191 |
cell_type_col: str = "cell_types", |
|
|
192 |
batch_col: Union[str, None] = "batch_indices", |
|
|
193 |
color_by: Iterable[str] = None, |
|
|
194 |
return_fig: bool = False, |
|
|
195 |
plot_fname: str = "umap", |
|
|
196 |
plot_ftype: str = "jpg", |
|
|
197 |
plot_dir: Union[str, None] = None, |
|
|
198 |
plot_dpi: int = 300, |
|
|
199 |
min_dist: float = 0.3, |
|
|
200 |
spread: float = 1, |
|
|
201 |
n_jobs: int = 1, |
|
|
202 |
random_state: Union[None, int, np.random.RandomState, np.random.Generator] = 0, |
|
|
203 |
umap_kwargs: dict = dict() |
|
|
204 |
) -> Mapping[str, Union[float, None, Figure]]: |
|
|
205 |
"""Evaluates the clustering and batch correction performance of the given |
|
|
206 |
embeddings, and optionally plots the embeddings. |
|
|
207 |
|
|
|
208 |
Embeddings will be plotted if return_fig is True or plot_dir is provided. |
|
|
209 |
When tensorboard_dir is provided, will also save the embeddings using a |
|
|
210 |
tensorboard SummaryWriter. |
|
|
211 |
|
|
|
212 |
NOTE: Set n_jobs to 1 if you encounter pickling error. |
|
|
213 |
|
|
|
214 |
Args: |
|
|
215 |
adata: the dataset with the embedding to be evaluated. |
|
|
216 |
embedding_key: the key to the embedding. Must be in adata.obsm. |
|
|
217 |
n_neighbors: #neighbors used when computing neithborhood graph and |
|
|
218 |
calculating entropy of batch mixing / kBET. |
|
|
219 |
resolutions: a sequence of resolutions used for clustering. |
|
|
220 |
clustering_method: clustering method used. Should be one of 'leiden' or |
|
|
221 |
'louvain'. |
|
|
222 |
cell_type_col: a key in adata.obs to the cell type column. |
|
|
223 |
batch_col: a key in adata.obs to the batch column. |
|
|
224 |
return_fig: whether to return the Figure object. Useful for visualizing |
|
|
225 |
the plot. |
|
|
226 |
color_by: a list of adata.obs column keys to color the embeddings by. |
|
|
227 |
If None, will look up adata.uns['color_by']. Only used if is |
|
|
228 |
drawing. |
|
|
229 |
plot_fname: file name of the generated plot. Only used if is drawing. |
|
|
230 |
plot_ftype: file type of the generated plot. Only used if is drawing. |
|
|
231 |
plot_dir: directory to save the generated plot. If None, do not save |
|
|
232 |
the plot. |
|
|
233 |
plot_dpi: dpi to save the plot. |
|
|
234 |
writer: an initialized SummaryWriter to save the UMAP plot to. Only |
|
|
235 |
used if is drawing. |
|
|
236 |
min_dist: the min_dist argument in sc.tl.umap. Only used is drawing. |
|
|
237 |
spread: the spread argument in sc.tl.umap. Only used if is drawing. |
|
|
238 |
n_jobs: # jobs to generate. If <= 0, this is set to the number of |
|
|
239 |
physical cores. |
|
|
240 |
random_state: random state for knn calculation. |
|
|
241 |
umap_kwargs: other kwargs to pass to sc.pl.umap. |
|
|
242 |
|
|
|
243 |
Returns: |
|
|
244 |
A dict storing the ari, nmi, asw, ebm and k_bet of the cell embeddings |
|
|
245 |
with key "ari", "nmi", "asw", "ebm", "k_bet", respectively. If draw is |
|
|
246 |
True and return_fig is True, will also store the plotted figure with |
|
|
247 |
key "fig". |
|
|
248 |
""" |
|
|
249 |
|
|
|
250 |
if cell_type_col and not pd.api.types.is_categorical_dtype(adata.obs[cell_type_col]): |
|
|
251 |
#_logger.warning("scETM.evaluate assumes discrete cell types. Converting cell_type_col to categorical.") |
|
|
252 |
adata.obs[cell_type_col] = adata.obs[cell_type_col].astype(str).astype('category') |
|
|
253 |
if batch_col and not pd.api.types.is_categorical_dtype(adata.obs[batch_col]): |
|
|
254 |
#_logger.warning("scETM.evaluate assumes discrete batches. Converting batch_col to categorical.") |
|
|
255 |
adata.obs[batch_col] = adata.obs[batch_col].astype(str).astype('category') |
|
|
256 |
|
|
|
257 |
# calculate neighbors |
|
|
258 |
_get_knn_indices(adata, use_rep=embedding_key, n_neighbors=n_neighbors, random_state=random_state, calc_knn=True) |
|
|
259 |
|
|
|
260 |
# calculate clustering metrics |
|
|
261 |
if cell_type_col in adata.obs and len(resolutions) > 0: |
|
|
262 |
cluster_key, best_ari, best_nmi = clustering(adata, resolutions=resolutions, cell_type_col=cell_type_col, batch_col=batch_col, clustering_method=clustering_method) |
|
|
263 |
else: |
|
|
264 |
cluster_key = best_ari = best_nmi = None |
|
|
265 |
|
|
|
266 |
return best_ari |
|
|
267 |
|
|
|
268 |
def _eff_n_jobs(n_jobs: Union[None, int]) -> int: |
|
|
269 |
"""If n_jobs <= 0, set it as the number of physical cores _cpu_count""" |
|
|
270 |
if n_jobs is None: |
|
|
271 |
return 1 |
|
|
272 |
return int(n_jobs) if n_jobs > 0 else _cpu_count |
|
|
273 |
|
|
|
274 |
|
|
|
275 |
def _calculate_kbet_for_one_chunk(knn_indices, attr_values, ideal_dist, n_neighbors): |
|
|
276 |
dof = ideal_dist.size - 1 |
|
|
277 |
|
|
|
278 |
ns = knn_indices.shape[0] |
|
|
279 |
results = np.zeros((ns, 2)) |
|
|
280 |
for i in range(ns): |
|
|
281 |
# NOTE: Do not use np.unique. Some of the batches may not be present in |
|
|
282 |
# the neighborhood. |
|
|
283 |
observed_counts = pd.Series(attr_values[knn_indices[i, :]]).value_counts(sort=False).values |
|
|
284 |
expected_counts = ideal_dist * n_neighbors |
|
|
285 |
stat = np.sum((observed_counts - expected_counts) ** 2 / expected_counts) |
|
|
286 |
p_value = 1 - chi2.cdf(stat, dof) |
|
|
287 |
results[i, 0] = stat |
|
|
288 |
results[i, 1] = p_value |
|
|
289 |
|
|
|
290 |
return results |
|
|
291 |
|
|
|
292 |
|
|
|
293 |
def _get_knn_indices(adata: ad.AnnData, |
|
|
294 |
use_rep: str = "delta", |
|
|
295 |
n_neighbors: int = 25, |
|
|
296 |
random_state: int = 0, |
|
|
297 |
calc_knn: bool = True |
|
|
298 |
) -> np.ndarray: |
|
|
299 |
if calc_knn: |
|
|
300 |
assert use_rep == 'X' or use_rep in adata.obsm, f'{use_rep} not in adata.obsm and is not "X"' |
|
|
301 |
neighbors = sc.Neighbors(adata) |
|
|
302 |
neighbors.compute_neighbors(n_neighbors=n_neighbors, knn=True, use_rep=use_rep, random_state=random_state, |
|
|
303 |
write_knn_indices=True) |
|
|
304 |
adata.obsp['distances'] = neighbors.distances |
|
|
305 |
adata.obsp['connectivities'] = neighbors.connectivities |
|
|
306 |
adata.obsm['knn_indices'] = neighbors.knn_indices |
|
|
307 |
adata.uns['neighbors'] = { |
|
|
308 |
'connectivities_key': 'connectivities', |
|
|
309 |
'distances_key': 'distances', |
|
|
310 |
'knn_indices_key': 'knn_indices', |
|
|
311 |
'params': { |
|
|
312 |
'n_neighbors': n_neighbors, |
|
|
313 |
'use_rep': use_rep, |
|
|
314 |
'metric': 'euclidean', |
|
|
315 |
'method': 'umap' |
|
|
316 |
} |
|
|
317 |
} |
|
|
318 |
else: |
|
|
319 |
assert 'neighbors' in adata.uns, 'No precomputed knn exists.' |
|
|
320 |
assert adata.uns['neighbors']['params'][ |
|
|
321 |
'n_neighbors'] >= n_neighbors, f"pre-computed n_neighbors is {adata.uns['neighbors']['params']['n_neighbors']}, which is smaller than {n_neighbors}" |
|
|
322 |
|
|
|
323 |
return adata.obsm['knn_indices'] |
|
|
324 |
|
|
|
325 |
def get_graph_connectivity( |
|
|
326 |
adata: ad.AnnData, |
|
|
327 |
use_rep: str = "delta",): |
|
|
328 |
|
|
|
329 |
sc.pp.neighbors(adata, use_rep=use_rep) |
|
|
330 |
score = scib.me.graph_connectivity(adata, label_key='cell_type') |
|
|
331 |
return score |
|
|
332 |
|
|
|
333 |
def calculate_kbet( |
|
|
334 |
adata: ad.AnnData, |
|
|
335 |
use_rep: str = "delta", |
|
|
336 |
batch_col: str = "batch_indices", |
|
|
337 |
n_neighbors: int = 25, |
|
|
338 |
alpha: float = 0.05, |
|
|
339 |
random_state: int = 0, |
|
|
340 |
n_jobs: Union[None, int] = None, |
|
|
341 |
calc_knn: bool = True |
|
|
342 |
) -> Tuple[float, float, float]: |
|
|
343 |
"""Calculates the kBET metric of the data. |
|
|
344 |
|
|
|
345 |
kBET measures if cells from different batches mix well in their local |
|
|
346 |
neighborhood. |
|
|
347 |
|
|
|
348 |
Args: |
|
|
349 |
adata: annotated data matrix. |
|
|
350 |
use_rep: the embedding to be used. Must exist in adata.obsm. |
|
|
351 |
batch_col: a key in adata.obs to the batch column. |
|
|
352 |
n_neighbors: # nearest neighbors. |
|
|
353 |
alpha: acceptance rate threshold. A cell is accepted if its kBET |
|
|
354 |
p-value is greater than or equal to alpha. |
|
|
355 |
random_state: random seed. Used only if method is "hnsw". |
|
|
356 |
n_jobs: # jobs to generate. If <= 0, this is set to the number of |
|
|
357 |
physical cores. |
|
|
358 |
calc_knn: whether to re-calculate the kNN graph or reuse the one stored |
|
|
359 |
in adata. |
|
|
360 |
|
|
|
361 |
Returns: |
|
|
362 |
stat_mean: mean kBET chi-square statistic over all cells. |
|
|
363 |
pvalue_mean: mean kBET p-value over all cells. |
|
|
364 |
accept_rate: kBET Acceptance rate of the sample. |
|
|
365 |
""" |
|
|
366 |
|
|
|
367 |
_logger.info('Calculating kbet...') |
|
|
368 |
assert batch_col in adata.obs |
|
|
369 |
if adata.obs[batch_col].dtype.name != "category": |
|
|
370 |
_logger.warning(f'Making the column {batch_col} of adata.obs categorical.') |
|
|
371 |
adata.obs[batch_col] = adata.obs[batch_col].astype('category') |
|
|
372 |
|
|
|
373 |
ideal_dist = ( |
|
|
374 |
adata.obs[batch_col].value_counts(normalize=True, sort=False).values |
|
|
375 |
) # ideal no batch effect distribution |
|
|
376 |
nsample = adata.shape[0] |
|
|
377 |
nbatch = ideal_dist.size |
|
|
378 |
|
|
|
379 |
attr_values = adata.obs[batch_col].values.copy() |
|
|
380 |
attr_values.categories = range(nbatch) |
|
|
381 |
knn_indices = _get_knn_indices(adata, use_rep, n_neighbors, random_state, calc_knn) |
|
|
382 |
|
|
|
383 |
# partition into chunks |
|
|
384 |
n_jobs = min(_eff_n_jobs(n_jobs), nsample) |
|
|
385 |
starts = np.zeros(n_jobs + 1, dtype=int) |
|
|
386 |
quotient = nsample // n_jobs |
|
|
387 |
remainder = nsample % n_jobs |
|
|
388 |
for i in range(n_jobs): |
|
|
389 |
starts[i + 1] = starts[i] + quotient + (1 if i < remainder else 0) |
|
|
390 |
|
|
|
391 |
from joblib import Parallel, delayed, parallel_backend |
|
|
392 |
with parallel_backend("loky", n_jobs=n_jobs): |
|
|
393 |
kBET_arr = np.concatenate( |
|
|
394 |
Parallel()( |
|
|
395 |
delayed(_calculate_kbet_for_one_chunk)( |
|
|
396 |
knn_indices[starts[i]: starts[i + 1], :], attr_values, ideal_dist, n_neighbors |
|
|
397 |
) |
|
|
398 |
for i in range(n_jobs) |
|
|
399 |
) |
|
|
400 |
) |
|
|
401 |
|
|
|
402 |
res = kBET_arr.mean(axis=0) |
|
|
403 |
stat_mean = res[0] |
|
|
404 |
pvalue_mean = res[1] |
|
|
405 |
accept_rate = (kBET_arr[:, 1] >= alpha).sum() / nsample |
|
|
406 |
|
|
|
407 |
return (stat_mean, pvalue_mean, accept_rate) |
|
|
408 |
|
|
|
409 |
|
|
|
410 |
def _entropy(hist_data): |
|
|
411 |
_, counts = np.unique(hist_data, return_counts=True) |
|
|
412 |
freqs = counts / counts.sum() |
|
|
413 |
return (-freqs * np.log(freqs + 1e-30)).sum() |
|
|
414 |
|
|
|
415 |
|
|
|
416 |
def _entropy_batch_mixing_for_one_pool(batches, knn_indices, nsample, n_samples_per_pool): |
|
|
417 |
indices = np.random.choice( |
|
|
418 |
np.arange(nsample), size=n_samples_per_pool) |
|
|
419 |
return np.mean( |
|
|
420 |
[ |
|
|
421 |
_entropy(batches[knn_indices[indices[i]]]) |
|
|
422 |
for i in range(n_samples_per_pool) |
|
|
423 |
] |
|
|
424 |
) |
|
|
425 |
|
|
|
426 |
|
|
|
427 |
def calculate_entropy_batch_mixing( |
|
|
428 |
adata: ad.AnnData, |
|
|
429 |
use_rep: str = "delta", |
|
|
430 |
batch_col: str = "batch_indices", |
|
|
431 |
n_neighbors: int = 50, |
|
|
432 |
n_pools: int = 50, |
|
|
433 |
n_samples_per_pool: int = 100, |
|
|
434 |
random_state: int = 0, |
|
|
435 |
n_jobs: Union[None, int] = None, |
|
|
436 |
calc_knn: bool = True |
|
|
437 |
) -> float: |
|
|
438 |
"""Calculates the entropy of batch mixing of the data. |
|
|
439 |
|
|
|
440 |
kBET measures if cells from different batches mix well in their local |
|
|
441 |
neighborhood. |
|
|
442 |
|
|
|
443 |
Args: |
|
|
444 |
adata: annotated data matrix. |
|
|
445 |
use_rep: the embedding to be used. Must exist in adata.obsm. |
|
|
446 |
batch_col: a key in adata.obs to the batch column. |
|
|
447 |
n_neighbors: # nearest neighbors. |
|
|
448 |
n_pools: #pools of cells to calculate entropy of batch mixing. |
|
|
449 |
n_samples_per_pool: #cells per pool to calculate within-pool entropy. |
|
|
450 |
random_state: random seed. Used only if method is "hnsw". |
|
|
451 |
n_jobs: # jobs to generate. If <= 0, this is set to the number of |
|
|
452 |
physical cores. |
|
|
453 |
calc_knn: whether to re-calculate the kNN graph or reuse the one stored |
|
|
454 |
in adata. |
|
|
455 |
|
|
|
456 |
Returns: |
|
|
457 |
score: the mean entropy of batch mixing, averaged from n_pools samples. |
|
|
458 |
""" |
|
|
459 |
|
|
|
460 |
_logger.info('Calculating batch mixing entropy...') |
|
|
461 |
nsample = adata.n_obs |
|
|
462 |
|
|
|
463 |
knn_indices = _get_knn_indices(adata, use_rep, n_neighbors, random_state, calc_knn) |
|
|
464 |
|
|
|
465 |
from joblib import Parallel, delayed, parallel_backend |
|
|
466 |
with parallel_backend("loky", n_jobs=n_jobs, inner_max_num_threads=1): |
|
|
467 |
score = np.mean( |
|
|
468 |
Parallel()( |
|
|
469 |
delayed(_entropy_batch_mixing_for_one_pool)( |
|
|
470 |
adata.obs[batch_col], knn_indices, nsample, n_samples_per_pool |
|
|
471 |
) |
|
|
472 |
for _ in range(n_pools) |
|
|
473 |
) |
|
|
474 |
) |
|
|
475 |
return score |
|
|
476 |
|
|
|
477 |
|
|
|
478 |
def clustering( |
|
|
479 |
adata: ad.AnnData, |
|
|
480 |
resolutions: Sequence[float], |
|
|
481 |
clustering_method: str = "leiden", |
|
|
482 |
cell_type_col: str = "cell_types", |
|
|
483 |
batch_col: str = "batch_indices" |
|
|
484 |
) -> Tuple[str, float, float]: |
|
|
485 |
"""Clusters the data and calculate agreement with cell type and batch |
|
|
486 |
variable. |
|
|
487 |
|
|
|
488 |
This method cluster the neighborhood graph (requires having run sc.pp. |
|
|
489 |
neighbors first) with "clustering_method" algorithm multiple times with the |
|
|
490 |
given resolutions, and return the best result in terms of ARI with cell |
|
|
491 |
type. |
|
|
492 |
Other metrics such as NMI with cell type, ARi with batch are logged but not |
|
|
493 |
returned. (TODO: also return these metrics) |
|
|
494 |
|
|
|
495 |
Args: |
|
|
496 |
adata: the dataset to be clustered. adata.obsp shouhld contain the keys |
|
|
497 |
'connectivities' and 'distances'. |
|
|
498 |
resolutions: a list of leiden/louvain resolution parameters. Will |
|
|
499 |
cluster with each resolution in the list and return the best result |
|
|
500 |
(in terms of ARI with cell type). |
|
|
501 |
clustering_method: Either "leiden" or "louvain". |
|
|
502 |
cell_type_col: a key in adata.obs to the cell type column. |
|
|
503 |
batch_col: a key in adata.obs to the batch column. |
|
|
504 |
|
|
|
505 |
Returns: |
|
|
506 |
best_cluster_key: a key in adata.obs to the best (in terms of ARI with |
|
|
507 |
cell type) cluster assignment column. |
|
|
508 |
best_ari: the best ARI with cell type. |
|
|
509 |
best_nmi: the best NMI with cell type. |
|
|
510 |
""" |
|
|
511 |
|
|
|
512 |
assert len(resolutions) > 0, f'Must specify at least one resolution.' |
|
|
513 |
|
|
|
514 |
if clustering_method == 'leiden': |
|
|
515 |
clustering_func = sc.tl.leiden |
|
|
516 |
elif clustering_method == 'louvain': |
|
|
517 |
clustering_func = sc.tl.louvain |
|
|
518 |
else: |
|
|
519 |
raise ValueError("Please specify louvain or leiden for the clustering method argument.") |
|
|
520 |
#_logger.info(f'Performing {clustering_method} clustering') |
|
|
521 |
assert cell_type_col in adata.obs, f"{cell_type_col} not in adata.obs" |
|
|
522 |
best_res, best_ari, best_nmi = None, -inf, -inf |
|
|
523 |
for res in resolutions: |
|
|
524 |
col = f'{clustering_method}_{res}' |
|
|
525 |
clustering_func(adata, resolution=res, key_added=col) |
|
|
526 |
ari = adjusted_rand_score(adata.obs[cell_type_col], adata.obs[col]) |
|
|
527 |
nmi = normalized_mutual_info_score(adata.obs[cell_type_col], adata.obs[col]) |
|
|
528 |
n_unique = adata.obs[col].nunique() |
|
|
529 |
if ari > best_ari: |
|
|
530 |
best_res = res |
|
|
531 |
best_ari = ari |
|
|
532 |
if nmi > best_nmi: |
|
|
533 |
best_nmi = nmi |
|
|
534 |
if batch_col in adata.obs and adata.obs[batch_col].nunique() > 1: |
|
|
535 |
ari_batch = adjusted_rand_score(adata.obs[batch_col], adata.obs[col]) |
|
|
536 |
#print(f'Resolution: {res:5.3g}\tARI: {ari:7.4f}\tNMI: {nmi:7.4f}\tbARI: {ari_batch:7.4f}\t# labels: {n_unique}') |
|
|
537 |
else: |
|
|
538 |
#print(f'Resolution: {res:5.3g}\tARI: {ari:7.4f}\tNMI: {nmi:7.4f}\t# labels: {n_unique}') |
|
|
539 |
a=None |
|
|
540 |
|
|
|
541 |
return f'{clustering_method}_{best_res}', best_ari, best_nmi |
|
|
542 |
|
|
|
543 |
|
|
|
544 |
def draw_embeddings(adata: ad.AnnData, |
|
|
545 |
color_by: Union[str, Sequence[str], None] = None, |
|
|
546 |
min_dist: float = 0.3, |
|
|
547 |
spread: float = 1, |
|
|
548 |
ckpt_dir: str = '.', |
|
|
549 |
fname: str = "umap.pdf", |
|
|
550 |
return_fig: bool = False, |
|
|
551 |
dpi: int = 300, |
|
|
552 |
umap_kwargs: dict = dict() |
|
|
553 |
) -> Union[None, Figure]: |
|
|
554 |
"""Embeds, plots and optionally saves the neighborhood graph with UMAP. |
|
|
555 |
|
|
|
556 |
Requires having run sc.pp.neighbors first. |
|
|
557 |
|
|
|
558 |
Args: |
|
|
559 |
adata: the dataset to draw. adata.obsp shouhld contain the keys |
|
|
560 |
'connectivities' and 'distances'. |
|
|
561 |
color_by: a str or a list of adata.obs keys to color the points in the |
|
|
562 |
scatterplot by. E.g. if both cell_type_col and batch_col is in |
|
|
563 |
color_by, then we would have two plots colored by cell type and |
|
|
564 |
batch variables, respectively. |
|
|
565 |
min_dist: The effective minimum distance between embedded points. |
|
|
566 |
Smaller values will result in a more clustered/clumped embedding |
|
|
567 |
where nearby points on the manifold are drawn closer together, |
|
|
568 |
while larger values will result on a more even dispersal of points. |
|
|
569 |
spread: The effective scale of embedded points. In combination with |
|
|
570 |
`min_dist` this determines how clustered/clumped the embedded |
|
|
571 |
points are. |
|
|
572 |
ckpt_dir: where to save the plot. If None, do not save the plot. |
|
|
573 |
fname: file name of the saved plot. Only used if ckpt_dir is not None. |
|
|
574 |
return_fig: whether to return the Figure object. Useful for visualizing |
|
|
575 |
the plot. |
|
|
576 |
dpi: the dpi of the saved plot. Only used if ckpt_dir is not None. |
|
|
577 |
umap_kwargs: other kwargs to pass to sc.pl.umap. |
|
|
578 |
|
|
|
579 |
Returns: |
|
|
580 |
If return_fig is True, return the figure containing the plot. |
|
|
581 |
""" |
|
|
582 |
|
|
|
583 |
#_logger.info(f'Plotting UMAP embeddings...') |
|
|
584 |
sc.tl.umap(adata, min_dist=min_dist, spread=spread) |
|
|
585 |
fig = sc.pl.umap(adata, color=color_by, show=False, return_fig=True, **umap_kwargs) |
|
|
586 |
if ckpt_dir is not None: |
|
|
587 |
assert os.path.exists(ckpt_dir), f'ckpt_dir {ckpt_dir} does not exist.' |
|
|
588 |
fig.savefig( |
|
|
589 |
os.path.join(ckpt_dir, fname), |
|
|
590 |
dpi=dpi, bbox_inches='tight' |
|
|
591 |
) |
|
|
592 |
if return_fig: |
|
|
593 |
return fig |
|
|
594 |
fig.clf() |
|
|
595 |
plt.close(fig) |
|
|
596 |
|
|
|
597 |
|
|
|
598 |
def set_figure_params( |
|
|
599 |
matplotlib_backend: str = 'agg', |
|
|
600 |
dpi: int = 120, |
|
|
601 |
frameon: bool = True, |
|
|
602 |
vector_friendly: bool = True, |
|
|
603 |
fontsize: int = 10, |
|
|
604 |
figsize: Sequence[int] = (10, 10) |
|
|
605 |
): |
|
|
606 |
"""Set figure parameters. |
|
|
607 |
Args |
|
|
608 |
backend: the backend to switch to. This can either be one of th |
|
|
609 |
standard backend names, which are case-insensitive: |
|
|
610 |
- interactive backends: |
|
|
611 |
GTK3Agg, GTK3Cairo, MacOSX, nbAgg, |
|
|
612 |
Qt4Agg, Qt4Cairo, Qt5Agg, Qt5Cairo, |
|
|
613 |
TkAgg, TkCairo, WebAgg, WX, WXAgg, WXCairo |
|
|
614 |
- non-interactive backends: |
|
|
615 |
agg, cairo, pdf, pgf, ps, svg, template |
|
|
616 |
or a string of the form: ``module://my.module.name``. |
|
|
617 |
dpi: resolution of rendered figures – this influences the size of |
|
|
618 |
figures in notebooks. |
|
|
619 |
frameon: add frames and axes labels to scatter plots. |
|
|
620 |
vector_friendly: plot scatter plots using `png` backend even when |
|
|
621 |
exporting as `pdf` or `svg`. |
|
|
622 |
fontsize: the fontsize for several `rcParams` entries. |
|
|
623 |
figsize: plt.rcParams['figure.figsize']. |
|
|
624 |
""" |
|
|
625 |
matplotlib.use(matplotlib_backend) |
|
|
626 |
sc.set_figure_params(dpi=dpi, figsize=figsize, fontsize=fontsize, frameon=frameon, vector_friendly=vector_friendly) |