a b/eval_utils.py
1
from math import inf
2
import os
3
import logging
4
import numpy as np
5
import scanpy as sc
6
import anndata as ad
7
import pandas as pd
8
9
import matplotlib
10
from matplotlib.figure import Figure
11
import matplotlib.pyplot as plt
12
from scipy.sparse.csr import spmatrix
13
from scipy.stats import chi2
14
from typing import Mapping, Sequence, Tuple, Iterable, Union
15
from scipy.sparse import issparse
16
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_samples
17
from sklearn.neighbors import NearestNeighbors
18
19
import psutil
20
import scib
21
22
23
_cpu_count: Union[None, int] = psutil.cpu_count(logical=False)
24
if _cpu_count is None:
25
    _cpu_count: int = psutil.cpu_count(logical=True)
26
_logger = logging.getLogger(__name__)
27
28
29
def evaluate(adata: ad.AnnData,
30
             n_epoch: int,
31
             embedding_key: str = 'delta',
32
             n_neighbors: int = 15,
33
             resolutions: Iterable[float] = [0.01, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64],
34
             clustering_method: str = "leiden",
35
             cell_type_col: str = "cell_types",
36
             batch_col: Union[str, None] = "batch_indices",
37
             color_by: Iterable[str] = None,
38
             return_fig: bool = False,
39
             plot_fname: str = "umap",
40
             plot_ftype: str = "jpg",
41
             plot_dir: Union[str, None] = None,
42
             plot_dpi: int = 300,
43
             min_dist: float = 0.3,
44
             spread: float = 1,
45
             n_jobs: int = 1,
46
             random_state: Union[None, int, np.random.RandomState, np.random.Generator] = 0,
47
             umap_kwargs: dict = dict()
48
             ) -> Mapping[str, Union[float, None, Figure]]:
49
    """Evaluates the clustering and batch correction performance of the given
50
    embeddings, and optionally plots the embeddings.
51
52
    Embeddings will be plotted if return_fig is True or plot_dir is provided.
53
    When tensorboard_dir is provided, will also save the embeddings using a
54
    tensorboard SummaryWriter.
55
56
    NOTE: Set n_jobs to 1 if you encounter pickling error.
57
58
    Args:
59
        adata: the dataset with the embedding to be evaluated.
60
        embedding_key: the key to the embedding. Must be in adata.obsm.
61
        n_neighbors: #neighbors used when computing neithborhood graph and
62
            calculating entropy of batch mixing / kBET.
63
        resolutions: a sequence of resolutions used for clustering.
64
        clustering_method: clustering method used. Should be one of 'leiden' or
65
            'louvain'.
66
        cell_type_col: a key in adata.obs to the cell type column.
67
        batch_col: a key in adata.obs to the batch column.
68
        return_fig: whether to return the Figure object. Useful for visualizing
69
            the plot.
70
        color_by: a list of adata.obs column keys to color the embeddings by.
71
            If None, will look up adata.uns['color_by']. Only used if is
72
            drawing.
73
        plot_fname: file name of the generated plot. Only used if is drawing.
74
        plot_ftype: file type of the generated plot. Only used if is drawing.
75
        plot_dir: directory to save the generated plot. If None, do not save
76
            the plot.
77
        plot_dpi: dpi to save the plot.
78
        writer: an initialized SummaryWriter to save the UMAP plot to. Only
79
            used if is drawing.
80
        min_dist: the min_dist argument in sc.tl.umap. Only used is drawing.
81
        spread: the spread argument in sc.tl.umap. Only used if is drawing.
82
        n_jobs: # jobs to generate. If <= 0, this is set to the number of
83
            physical cores.
84
        random_state: random state for knn calculation.
85
        umap_kwargs: other kwargs to pass to sc.pl.umap.
86
87
    Returns:
88
        A dict storing the ari, nmi, asw, ebm and k_bet of the cell embeddings
89
        with key "ari", "nmi", "asw", "ebm", "k_bet", respectively. If draw is
90
        True and return_fig is True, will also store the plotted figure with
91
        key "fig".
92
    """
93
94
    if cell_type_col and not pd.api.types.is_categorical_dtype(adata.obs[cell_type_col]):
95
        #_logger.warning("scETM.evaluate assumes discrete cell types. Converting cell_type_col to categorical.")
96
        adata.obs[cell_type_col] = adata.obs[cell_type_col].astype(str).astype('category')
97
    if batch_col and not pd.api.types.is_categorical_dtype(adata.obs[batch_col]):
98
        #_logger.warning("scETM.evaluate assumes discrete batches. Converting batch_col to categorical.")
99
        adata.obs[batch_col] = adata.obs[batch_col].astype(str).astype('category')
100
101
    # calculate neighbors
102
    _get_knn_indices(adata, use_rep=embedding_key, n_neighbors=n_neighbors, random_state=random_state, calc_knn=True)
103
104
    # calculate clustering metrics
105
    if cell_type_col in adata.obs and len(resolutions) > 0:
106
        cluster_key, best_ari, best_nmi = clustering(adata, resolutions=resolutions, cell_type_col=cell_type_col, batch_col=batch_col, clustering_method=clustering_method)
107
    else:
108
        cluster_key = best_ari = best_nmi = None
109
110
    if adata.obs[cell_type_col].nunique() > 1:
111
        sw = silhouette_samples(adata.X if embedding_key == 'X' else adata.obsm[embedding_key],
112
                                adata.obs[cell_type_col])
113
        adata.obs['silhouette_width'] = sw
114
        asw = np.mean(sw)
115
        #print(f'{embedding_key}_ASW: {asw:7.4f}')
116
117
        asw_2 = scib.me.silhouette(adata, group_key=cell_type_col, embed=embedding_key)
118
119
120
        if batch_col and cell_type_col:
121
            sw_table = adata.obs.pivot_table(index=cell_type_col, columns=batch_col, values="silhouette_width",
122
                                             aggfunc="mean")
123
            #print(f'SW: {sw_table}')
124
            if plot_dir is not None:
125
                sw_table.to_csv(os.path.join(plot_dir, f'{plot_fname}.csv'))
126
    else:
127
        asw = 0.
128
        asw_2 = 0.
129
130
    # calculate batch correction metrics
131
    need_batch = batch_col and adata.obs[batch_col].nunique() > 1
132
    if need_batch:
133
        ebm = calculate_entropy_batch_mixing(adata,
134
                                             use_rep=embedding_key,
135
                                             batch_col=batch_col,
136
                                             n_neighbors=n_neighbors,
137
                                             calc_knn=False,
138
                                             n_jobs=n_jobs,
139
                                             )
140
        #print(f'{embedding_key}_BE: {ebm:7.4f}')
141
        k_bet = calculate_kbet(adata,
142
                               use_rep=embedding_key,
143
                               batch_col=batch_col,
144
                               n_neighbors=n_neighbors,
145
                               calc_knn=False,
146
                               n_jobs=n_jobs,
147
                               )[2]
148
        #print(f'{embedding_key}_kBET: {k_bet:7.4f}')
149
        batch_asw = scib.me.silhouette_batch(adata, batch_key=batch_col, group_key='cell_type', embed=embedding_key, verbose=False)
150
        batch_graph_score = get_graph_connectivity(adata, use_rep=embedding_key,)
151
    else:
152
        ebm = k_bet = batch_asw = batch_graph_score = None
153
154
    # plot UMAP embeddings
155
    if return_fig or plot_dir is not None:
156
        if color_by is None:
157
            color_by = [batch_col, cell_type_col] if need_batch else [cell_type_col]
158
        color_by = list(color_by)
159
        if 'color_by' in adata.uns:
160
            for col in adata.uns['color_by']:
161
                if col not in color_by:
162
                    color_by.insert(0, col)
163
        if cluster_key is not None:
164
            color_by = [cluster_key] + color_by
165
        fig = draw_embeddings(adata=adata, color_by=color_by,
166
                              min_dist=min_dist, spread=spread,
167
                              ckpt_dir=plot_dir, fname=f'{plot_fname+str(n_epoch)}.{plot_ftype}', return_fig=return_fig,
168
                              dpi=plot_dpi,
169
                              umap_kwargs=umap_kwargs)
170
    else:
171
        fig = None
172
173
    return dict(
174
        ari=best_ari,
175
        nmi=best_nmi,
176
        asw=asw,
177
        asw_2=asw_2,
178
        ebm=ebm,
179
        k_bet=k_bet,
180
        batch_asw=batch_asw,
181
        batch_graph_score=batch_graph_score,
182
        fig=fig
183
    )
184
185
def evaluate_ari(adata: ad.AnnData,
186
             n_epoch: int,
187
             embedding_key: str = 'delta',
188
             n_neighbors: int = 15,
189
             resolutions: Iterable[float] = [0.01, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64],
190
             clustering_method: str = "leiden",
191
             cell_type_col: str = "cell_types",
192
             batch_col: Union[str, None] = "batch_indices",
193
             color_by: Iterable[str] = None,
194
             return_fig: bool = False,
195
             plot_fname: str = "umap",
196
             plot_ftype: str = "jpg",
197
             plot_dir: Union[str, None] = None,
198
             plot_dpi: int = 300,
199
             min_dist: float = 0.3,
200
             spread: float = 1,
201
             n_jobs: int = 1,
202
             random_state: Union[None, int, np.random.RandomState, np.random.Generator] = 0,
203
             umap_kwargs: dict = dict()
204
             ) -> Mapping[str, Union[float, None, Figure]]:
205
    """Evaluates the clustering and batch correction performance of the given
206
    embeddings, and optionally plots the embeddings.
207
208
    Embeddings will be plotted if return_fig is True or plot_dir is provided.
209
    When tensorboard_dir is provided, will also save the embeddings using a
210
    tensorboard SummaryWriter.
211
212
    NOTE: Set n_jobs to 1 if you encounter pickling error.
213
214
    Args:
215
        adata: the dataset with the embedding to be evaluated.
216
        embedding_key: the key to the embedding. Must be in adata.obsm.
217
        n_neighbors: #neighbors used when computing neithborhood graph and
218
            calculating entropy of batch mixing / kBET.
219
        resolutions: a sequence of resolutions used for clustering.
220
        clustering_method: clustering method used. Should be one of 'leiden' or
221
            'louvain'.
222
        cell_type_col: a key in adata.obs to the cell type column.
223
        batch_col: a key in adata.obs to the batch column.
224
        return_fig: whether to return the Figure object. Useful for visualizing
225
            the plot.
226
        color_by: a list of adata.obs column keys to color the embeddings by.
227
            If None, will look up adata.uns['color_by']. Only used if is
228
            drawing.
229
        plot_fname: file name of the generated plot. Only used if is drawing.
230
        plot_ftype: file type of the generated plot. Only used if is drawing.
231
        plot_dir: directory to save the generated plot. If None, do not save
232
            the plot.
233
        plot_dpi: dpi to save the plot.
234
        writer: an initialized SummaryWriter to save the UMAP plot to. Only
235
            used if is drawing.
236
        min_dist: the min_dist argument in sc.tl.umap. Only used is drawing.
237
        spread: the spread argument in sc.tl.umap. Only used if is drawing.
238
        n_jobs: # jobs to generate. If <= 0, this is set to the number of
239
            physical cores.
240
        random_state: random state for knn calculation.
241
        umap_kwargs: other kwargs to pass to sc.pl.umap.
242
243
    Returns:
244
        A dict storing the ari, nmi, asw, ebm and k_bet of the cell embeddings
245
        with key "ari", "nmi", "asw", "ebm", "k_bet", respectively. If draw is
246
        True and return_fig is True, will also store the plotted figure with
247
        key "fig".
248
    """
249
250
    if cell_type_col and not pd.api.types.is_categorical_dtype(adata.obs[cell_type_col]):
251
        #_logger.warning("scETM.evaluate assumes discrete cell types. Converting cell_type_col to categorical.")
252
        adata.obs[cell_type_col] = adata.obs[cell_type_col].astype(str).astype('category')
253
    if batch_col and not pd.api.types.is_categorical_dtype(adata.obs[batch_col]):
254
        #_logger.warning("scETM.evaluate assumes discrete batches. Converting batch_col to categorical.")
255
        adata.obs[batch_col] = adata.obs[batch_col].astype(str).astype('category')
256
257
    # calculate neighbors
258
    _get_knn_indices(adata, use_rep=embedding_key, n_neighbors=n_neighbors, random_state=random_state, calc_knn=True)
259
260
    # calculate clustering metrics
261
    if cell_type_col in adata.obs and len(resolutions) > 0:
262
        cluster_key, best_ari, best_nmi = clustering(adata, resolutions=resolutions, cell_type_col=cell_type_col, batch_col=batch_col, clustering_method=clustering_method)
263
    else:
264
        cluster_key = best_ari = best_nmi = None
265
266
    return best_ari
267
268
def _eff_n_jobs(n_jobs: Union[None, int]) -> int:
269
    """If n_jobs <= 0, set it as the number of physical cores _cpu_count"""
270
    if n_jobs is None:
271
        return 1
272
    return int(n_jobs) if n_jobs > 0 else _cpu_count
273
274
275
def _calculate_kbet_for_one_chunk(knn_indices, attr_values, ideal_dist, n_neighbors):
276
    dof = ideal_dist.size - 1
277
278
    ns = knn_indices.shape[0]
279
    results = np.zeros((ns, 2))
280
    for i in range(ns):
281
        # NOTE: Do not use np.unique. Some of the batches may not be present in
282
        # the neighborhood.
283
        observed_counts = pd.Series(attr_values[knn_indices[i, :]]).value_counts(sort=False).values
284
        expected_counts = ideal_dist * n_neighbors
285
        stat = np.sum((observed_counts - expected_counts) ** 2 / expected_counts)
286
        p_value = 1 - chi2.cdf(stat, dof)
287
        results[i, 0] = stat
288
        results[i, 1] = p_value
289
290
    return results
291
292
293
def _get_knn_indices(adata: ad.AnnData,
294
                     use_rep: str = "delta",
295
                     n_neighbors: int = 25,
296
                     random_state: int = 0,
297
                     calc_knn: bool = True
298
                     ) -> np.ndarray:
299
    if calc_knn:
300
        assert use_rep == 'X' or use_rep in adata.obsm, f'{use_rep} not in adata.obsm and is not "X"'
301
        neighbors = sc.Neighbors(adata)
302
        neighbors.compute_neighbors(n_neighbors=n_neighbors, knn=True, use_rep=use_rep, random_state=random_state,
303
                                    write_knn_indices=True)
304
        adata.obsp['distances'] = neighbors.distances
305
        adata.obsp['connectivities'] = neighbors.connectivities
306
        adata.obsm['knn_indices'] = neighbors.knn_indices
307
        adata.uns['neighbors'] = {
308
            'connectivities_key': 'connectivities',
309
            'distances_key': 'distances',
310
            'knn_indices_key': 'knn_indices',
311
            'params': {
312
                'n_neighbors': n_neighbors,
313
                'use_rep': use_rep,
314
                'metric': 'euclidean',
315
                'method': 'umap'
316
            }
317
        }
318
    else:
319
        assert 'neighbors' in adata.uns, 'No precomputed knn exists.'
320
        assert adata.uns['neighbors']['params'][
321
                   'n_neighbors'] >= n_neighbors, f"pre-computed n_neighbors is {adata.uns['neighbors']['params']['n_neighbors']}, which is smaller than {n_neighbors}"
322
323
    return adata.obsm['knn_indices']
324
325
def get_graph_connectivity(
326
        adata: ad.AnnData,
327
        use_rep: str = "delta",):
328
329
    sc.pp.neighbors(adata, use_rep=use_rep)
330
    score = scib.me.graph_connectivity(adata, label_key='cell_type')
331
    return score
332
333
def calculate_kbet(
334
        adata: ad.AnnData,
335
        use_rep: str = "delta",
336
        batch_col: str = "batch_indices",
337
        n_neighbors: int = 25,
338
        alpha: float = 0.05,
339
        random_state: int = 0,
340
        n_jobs: Union[None, int] = None,
341
        calc_knn: bool = True
342
) -> Tuple[float, float, float]:
343
    """Calculates the kBET metric of the data.
344
345
    kBET measures if cells from different batches mix well in their local
346
    neighborhood.
347
348
    Args:
349
        adata: annotated data matrix.
350
        use_rep: the embedding to be used. Must exist in adata.obsm.
351
        batch_col: a key in adata.obs to the batch column.
352
        n_neighbors: # nearest neighbors.
353
        alpha: acceptance rate threshold. A cell is accepted if its kBET
354
            p-value is greater than or equal to alpha.
355
        random_state: random seed. Used only if method is "hnsw".
356
        n_jobs: # jobs to generate. If <= 0, this is set to the number of
357
            physical cores.
358
        calc_knn: whether to re-calculate the kNN graph or reuse the one stored
359
            in adata.
360
361
    Returns:
362
        stat_mean: mean kBET chi-square statistic over all cells.
363
        pvalue_mean: mean kBET p-value over all cells.
364
        accept_rate: kBET Acceptance rate of the sample.
365
    """
366
367
    _logger.info('Calculating kbet...')
368
    assert batch_col in adata.obs
369
    if adata.obs[batch_col].dtype.name != "category":
370
        _logger.warning(f'Making the column {batch_col} of adata.obs categorical.')
371
        adata.obs[batch_col] = adata.obs[batch_col].astype('category')
372
373
    ideal_dist = (
374
        adata.obs[batch_col].value_counts(normalize=True, sort=False).values
375
    )  # ideal no batch effect distribution
376
    nsample = adata.shape[0]
377
    nbatch = ideal_dist.size
378
379
    attr_values = adata.obs[batch_col].values.copy()
380
    attr_values.categories = range(nbatch)
381
    knn_indices = _get_knn_indices(adata, use_rep, n_neighbors, random_state, calc_knn)
382
383
    # partition into chunks
384
    n_jobs = min(_eff_n_jobs(n_jobs), nsample)
385
    starts = np.zeros(n_jobs + 1, dtype=int)
386
    quotient = nsample // n_jobs
387
    remainder = nsample % n_jobs
388
    for i in range(n_jobs):
389
        starts[i + 1] = starts[i] + quotient + (1 if i < remainder else 0)
390
391
    from joblib import Parallel, delayed, parallel_backend
392
    with parallel_backend("loky", n_jobs=n_jobs):
393
        kBET_arr = np.concatenate(
394
            Parallel()(
395
                delayed(_calculate_kbet_for_one_chunk)(
396
                    knn_indices[starts[i]: starts[i + 1], :], attr_values, ideal_dist, n_neighbors
397
                )
398
                for i in range(n_jobs)
399
            )
400
        )
401
402
    res = kBET_arr.mean(axis=0)
403
    stat_mean = res[0]
404
    pvalue_mean = res[1]
405
    accept_rate = (kBET_arr[:, 1] >= alpha).sum() / nsample
406
407
    return (stat_mean, pvalue_mean, accept_rate)
408
409
410
def _entropy(hist_data):
411
    _, counts = np.unique(hist_data, return_counts=True)
412
    freqs = counts / counts.sum()
413
    return (-freqs * np.log(freqs + 1e-30)).sum()
414
415
416
def _entropy_batch_mixing_for_one_pool(batches, knn_indices, nsample, n_samples_per_pool):
417
    indices = np.random.choice(
418
        np.arange(nsample), size=n_samples_per_pool)
419
    return np.mean(
420
        [
421
            _entropy(batches[knn_indices[indices[i]]])
422
            for i in range(n_samples_per_pool)
423
        ]
424
    )
425
426
427
def calculate_entropy_batch_mixing(
428
        adata: ad.AnnData,
429
        use_rep: str = "delta",
430
        batch_col: str = "batch_indices",
431
        n_neighbors: int = 50,
432
        n_pools: int = 50,
433
        n_samples_per_pool: int = 100,
434
        random_state: int = 0,
435
        n_jobs: Union[None, int] = None,
436
        calc_knn: bool = True
437
) -> float:
438
    """Calculates the entropy of batch mixing of the data.
439
440
    kBET measures if cells from different batches mix well in their local
441
    neighborhood.
442
443
    Args:
444
        adata: annotated data matrix.
445
        use_rep: the embedding to be used. Must exist in adata.obsm.
446
        batch_col: a key in adata.obs to the batch column.
447
        n_neighbors: # nearest neighbors.
448
        n_pools: #pools of cells to calculate entropy of batch mixing.
449
        n_samples_per_pool: #cells per pool to calculate within-pool entropy.
450
        random_state: random seed. Used only if method is "hnsw".
451
        n_jobs: # jobs to generate. If <= 0, this is set to the number of
452
            physical cores.
453
        calc_knn: whether to re-calculate the kNN graph or reuse the one stored
454
            in adata.
455
456
    Returns:
457
        score: the mean entropy of batch mixing, averaged from n_pools samples.
458
    """
459
460
    _logger.info('Calculating batch mixing entropy...')
461
    nsample = adata.n_obs
462
463
    knn_indices = _get_knn_indices(adata, use_rep, n_neighbors, random_state, calc_knn)
464
465
    from joblib import Parallel, delayed, parallel_backend
466
    with parallel_backend("loky", n_jobs=n_jobs, inner_max_num_threads=1):
467
        score = np.mean(
468
            Parallel()(
469
                delayed(_entropy_batch_mixing_for_one_pool)(
470
                    adata.obs[batch_col], knn_indices, nsample, n_samples_per_pool
471
                )
472
                for _ in range(n_pools)
473
            )
474
        )
475
    return score
476
477
478
def clustering(
479
        adata: ad.AnnData,
480
        resolutions: Sequence[float],
481
        clustering_method: str = "leiden",
482
        cell_type_col: str = "cell_types",
483
        batch_col: str = "batch_indices"
484
) -> Tuple[str, float, float]:
485
    """Clusters the data and calculate agreement with cell type and batch
486
    variable.
487
488
    This method cluster the neighborhood graph (requires having run sc.pp.
489
    neighbors first) with "clustering_method" algorithm multiple times with the
490
    given resolutions, and return the best result in terms of ARI with cell
491
    type.
492
    Other metrics such as NMI with cell type, ARi with batch are logged but not
493
    returned. (TODO: also return these metrics)
494
495
    Args:
496
        adata: the dataset to be clustered. adata.obsp shouhld contain the keys
497
            'connectivities' and 'distances'.
498
        resolutions: a list of leiden/louvain resolution parameters. Will
499
            cluster with each resolution in the list and return the best result
500
            (in terms of ARI with cell type).
501
        clustering_method: Either "leiden" or "louvain".
502
        cell_type_col: a key in adata.obs to the cell type column.
503
        batch_col: a key in adata.obs to the batch column.
504
505
    Returns:
506
        best_cluster_key: a key in adata.obs to the best (in terms of ARI with
507
            cell type) cluster assignment column.
508
        best_ari: the best ARI with cell type.
509
        best_nmi: the best NMI with cell type.
510
    """
511
512
    assert len(resolutions) > 0, f'Must specify at least one resolution.'
513
514
    if clustering_method == 'leiden':
515
        clustering_func = sc.tl.leiden
516
    elif clustering_method == 'louvain':
517
        clustering_func = sc.tl.louvain
518
    else:
519
        raise ValueError("Please specify louvain or leiden for the clustering method argument.")
520
    #_logger.info(f'Performing {clustering_method} clustering')
521
    assert cell_type_col in adata.obs, f"{cell_type_col} not in adata.obs"
522
    best_res, best_ari, best_nmi = None, -inf, -inf
523
    for res in resolutions:
524
        col = f'{clustering_method}_{res}'
525
        clustering_func(adata, resolution=res, key_added=col)
526
        ari = adjusted_rand_score(adata.obs[cell_type_col], adata.obs[col])
527
        nmi = normalized_mutual_info_score(adata.obs[cell_type_col], adata.obs[col])
528
        n_unique = adata.obs[col].nunique()
529
        if ari > best_ari:
530
            best_res = res
531
            best_ari = ari
532
        if nmi > best_nmi:
533
            best_nmi = nmi
534
        if batch_col in adata.obs and adata.obs[batch_col].nunique() > 1:
535
            ari_batch = adjusted_rand_score(adata.obs[batch_col], adata.obs[col])
536
            #print(f'Resolution: {res:5.3g}\tARI: {ari:7.4f}\tNMI: {nmi:7.4f}\tbARI: {ari_batch:7.4f}\t# labels: {n_unique}')
537
        else:
538
            #print(f'Resolution: {res:5.3g}\tARI: {ari:7.4f}\tNMI: {nmi:7.4f}\t# labels: {n_unique}')
539
            a=None
540
541
    return f'{clustering_method}_{best_res}', best_ari, best_nmi
542
543
544
def draw_embeddings(adata: ad.AnnData,
545
                    color_by: Union[str, Sequence[str], None] = None,
546
                    min_dist: float = 0.3,
547
                    spread: float = 1,
548
                    ckpt_dir: str = '.',
549
                    fname: str = "umap.pdf",
550
                    return_fig: bool = False,
551
                    dpi: int = 300,
552
                    umap_kwargs: dict = dict()
553
                    ) -> Union[None, Figure]:
554
    """Embeds, plots and optionally saves the neighborhood graph with UMAP.
555
556
    Requires having run sc.pp.neighbors first.
557
558
    Args:
559
        adata: the dataset to draw. adata.obsp shouhld contain the keys
560
            'connectivities' and 'distances'.
561
        color_by: a str or a list of adata.obs keys to color the points in the
562
            scatterplot by. E.g. if both cell_type_col and batch_col is in
563
            color_by, then we would have two plots colored by cell type and
564
            batch variables, respectively.
565
        min_dist: The effective minimum distance between embedded points.
566
            Smaller values will result in a more clustered/clumped embedding
567
            where nearby points on the manifold are drawn closer together,
568
            while larger values will result on a more even dispersal of points.
569
        spread: The effective scale of embedded points. In combination with
570
            `min_dist` this determines how clustered/clumped the embedded
571
            points are.
572
        ckpt_dir: where to save the plot. If None, do not save the plot.
573
        fname: file name of the saved plot. Only used if ckpt_dir is not None.
574
        return_fig: whether to return the Figure object. Useful for visualizing
575
            the plot.
576
        dpi: the dpi of the saved plot. Only used if ckpt_dir is not None.
577
        umap_kwargs: other kwargs to pass to sc.pl.umap.
578
579
    Returns:
580
        If return_fig is True, return the figure containing the plot.
581
    """
582
583
    #_logger.info(f'Plotting UMAP embeddings...')
584
    sc.tl.umap(adata, min_dist=min_dist, spread=spread)
585
    fig = sc.pl.umap(adata, color=color_by, show=False, return_fig=True, **umap_kwargs)
586
    if ckpt_dir is not None:
587
        assert os.path.exists(ckpt_dir), f'ckpt_dir {ckpt_dir} does not exist.'
588
        fig.savefig(
589
            os.path.join(ckpt_dir, fname),
590
            dpi=dpi, bbox_inches='tight'
591
        )
592
    if return_fig:
593
        return fig
594
    fig.clf()
595
    plt.close(fig)
596
597
598
def set_figure_params(
599
        matplotlib_backend: str = 'agg',
600
        dpi: int = 120,
601
        frameon: bool = True,
602
        vector_friendly: bool = True,
603
        fontsize: int = 10,
604
        figsize: Sequence[int] = (10, 10)
605
):
606
    """Set figure parameters.
607
    Args
608
        backend: the backend to switch to.  This can either be one of th
609
            standard backend names, which are case-insensitive:
610
            - interactive backends:
611
                GTK3Agg, GTK3Cairo, MacOSX, nbAgg,
612
                Qt4Agg, Qt4Cairo, Qt5Agg, Qt5Cairo,
613
                TkAgg, TkCairo, WebAgg, WX, WXAgg, WXCairo
614
            - non-interactive backends:
615
                agg, cairo, pdf, pgf, ps, svg, template
616
            or a string of the form: ``module://my.module.name``.
617
        dpi: resolution of rendered figures – this influences the size of
618
            figures in notebooks.
619
        frameon: add frames and axes labels to scatter plots.
620
        vector_friendly: plot scatter plots using `png` backend even when
621
            exporting as `pdf` or `svg`.
622
        fontsize: the fontsize for several `rcParams` entries.
623
        figsize: plt.rcParams['figure.figsize'].
624
    """
625
    matplotlib.use(matplotlib_backend)
626
    sc.set_figure_params(dpi=dpi, figsize=figsize, fontsize=fontsize, frameon=frameon, vector_friendly=vector_friendly)