Switch to unified view

a b/src/multivelo/auxiliary.py
1
import numpy as np
2
import matplotlib.pyplot as plt
3
from scipy.sparse import coo_matrix, csr_matrix, diags
4
from umap.umap_ import fuzzy_simplicial_set
5
from anndata import AnnData
6
import scanpy as sc
7
import scvelo as scv
8
import pandas as pd
9
from tqdm.auto import tqdm
10
import scipy
11
import os
12
import sys
13
from joblib import Parallel, delayed
14
from tqdm.auto import tqdm
15
16
current_path = os.path.dirname(__file__)
17
src_path = os.path.join(current_path, "..")
18
sys.path.append(src_path)
19
20
from multivelo import mv_logging as logg
21
from multivelo import settings
22
23
current_path = os.path.dirname(__file__)
24
25
sys.path.append(current_path)
26
27
from pyWNN import pyWNN
28
29
def do_count(fastqs, input_loc, output_loc, whitelist_path=None, tech="10XV3", strand=None, threads=8, memory=4):
30
31
    """Get spliced and unspliced counts from fastq data.
32
33
    Makes use of the kallisto-bustools count function:
34
    https://www.kallistobus.tools/kb_usage/kb_count/
35
    kb-python.readthedocs.io/en/latest/autoapi/kb_python/count/index.html
36
37
    Parameters
38
    ----------
39
    fastqs: `List[str]`
40
        The file locations of the fastqs to process.
41
    input_loc: `str`
42
        The folder location of the reference files.
43
        The folder should contain an index file with the name "index.idx", a 
44
        transcripts-to-gene file with the name "t2g.txt", a cDNA 
45
        transcripts-to-capture file with the name "cdna_t2c.txt", and an
46
        intron transcripts-to-captured file with the name intron_t2c.txt.
47
    output_loc: `str`
48
        The desired folder location of the output of the function.
49
    whitelist_path: `str` (default: `None`)
50
        Path to a barcode whitelist to use to replace the selected technology's
51
        whitelist.
52
    tech: `str` (default: `10XV3`)
53
        The technology used to collect the single-cell data. 
54
    strand: `str` (default: `None`)
55
        The strandedness desired to process the data.
56
    threads: `int`
57
        The number of threads to use for parallel processing.
58
    memory: `int`
59
        Maximum memory (in GB) to use while processing.
60
61
    Returns
62
    -------
63
    adata_count: :class:`~anndata.AnnData`
64
        An AnnData object containing all the spliced and unspliced counts,
65
        as well as associated gene names.
66
67
    """
68
69
    # convert the number of threads and the amount of allocated memory
70
    # into correctly-formatted strings for running kb count
71
    thread_string = str(threads)
72
    memory_string = str(memory) + "G"
73
74
    # locations of important files
75
    index_loc = input_loc + "/index.idx"
76
    t2g_loc = input_loc + "/t2g.txt"
77
78
    cdna_t2c = input_loc + "/cdna_t2c.txt"
79
    intron_t2c = input_loc + "/intron_t2c.txt"
80
81
    # keep the original argv values in case the user specifies it
82
    orig_argv = sys.argv
83
84
    # assemble the input array to use for kb count
85
    input_array = ["count",
86
                "count",
87
                "-i", index_loc, "-g", t2g_loc,
88
                "-x", tech,
89
                "-o", output_loc,
90
                "-t", thread_string, "-m", memory_string,
91
                "--workflow", "lamanno",
92
                "-c1", cdna_t2c,
93
                "-c2", intron_t2c,
94
                "--h5ad"]
95
96
    # specify the stranded-ness of the run
97
    if strand is not None:
98
        input_array.append("--strand")
99
        input_array.append(strand)
100
101
    # specify the whitelist path of the run
102
    if whitelist_path is not None:
103
        input_array.append("-w")
104
        input_array.append(whitelist_path)
105
106
    # add the fastq's we're doing the run on
107
    for fastq in fastqs:
108
        input_array.append(fastq)
109
110
    # use our assembled input array as the parameters for kb count
111
    sys.argv = input_array
112
113
    # run kb count
114
    kbm.main()
115
116
    # set argv back to its original value
117
    sys.argv = orig_argv
118
119
    # get the anndata object for 
120
    path = output_loc + "/counts_unfiltered"
121
    adata_count = sc.read(path + "/adata.h5ad")
122
123
    return adata_count
124
125
126
def prepare_gene_mat(var_dict, peaks, gene_mat, adata_atac_X_copy, i):
127
128
    for peak in peaks:
129
        if peak in var_dict:
130
            peak_index = var_dict[peak]
131
132
            gene_mat[:, i] += adata_atac_X_copy[:, peak_index]
133
134
135
def aggregate_peaks_10x(adata_atac, peak_annot_file, linkage_file,
136
                        peak_dist=10000, min_corr=0.5, gene_body=False,
137
                        return_dict=False, parallel=False, n_jobs=1):
138
139
    """Peak to gene aggregation.
140
141
    This function aggregates promoter and enhancer peaks to genes based on the
142
    10X linkage file.
143
144
    Parameters
145
    ----------
146
    adata_atac: :class:`~anndata.AnnData`
147
        ATAC anndata object which stores raw peak counts.
148
    peak_annot_file: `str`
149
        Peak annotation file from 10X CellRanger ARC.
150
    linkage_file: `str`
151
        Peak-gene linkage file from 10X CellRanger ARC. This file stores highly
152
        correlated peak-peak and peak-gene pair information.
153
    peak_dist: `int` (default: 10000)
154
        Maximum distance for peaks to be included for a gene.
155
    min_corr: `float` (default: 0.5)
156
        Minimum correlation for a peak to be considered as enhancer.
157
    gene_body: `bool` (default: `False`)
158
        Whether to add gene body peaks to the associated promoters.
159
    return_dict: `bool` (default: `False`)
160
        Whether to return promoter and enhancer dictionaries.
161
162
    Returns
163
    -------
164
    A new ATAC anndata object which stores gene aggreagted peak counts.
165
    Additionally, if `return_dict==True`:
166
        A dictionary which stores genes and promoter peaks.
167
        And a dictionary which stores genes and enhancer peaks.
168
    """
169
    promoter_dict = {}
170
    distal_dict = {}
171
    gene_body_dict = {}
172
    corr_dict = {}
173
174
    # read annotations
175
    with open(peak_annot_file) as f:
176
        header = next(f)
177
        tmp = header.split('\t')
178
        if len(tmp) == 4:
179
            cellranger_version = 1
180
        elif len(tmp) == 6:
181
            cellranger_version = 2
182
        else:
183
            raise ValueError('Peak annotation file should contain 4 columns '
184
                             '(CellRanger ARC 1.0.0) or 6 columns (CellRanger '
185
                             'ARC 2.0.0)')
186
187
        logg.update(f'CellRanger ARC identified as {cellranger_version}.0.0',
188
                    v=1)
189
190
        if cellranger_version == 1:
191
            for line in f:
192
                tmp = line.rstrip().split('\t')
193
                tmp1 = tmp[0].split('_')
194
                peak = f'{tmp1[0]}:{tmp1[1]}-{tmp1[2]}'
195
                if tmp[1] != '':
196
                    genes = tmp[1].split(';')
197
                    dists = tmp[2].split(';')
198
                    types = tmp[3].split(';')
199
                    for i, gene in enumerate(genes):
200
                        dist = dists[i]
201
                        annot = types[i]
202
                        if annot == 'promoter':
203
                            if gene not in promoter_dict:
204
                                promoter_dict[gene] = [peak]
205
                            else:
206
                                promoter_dict[gene].append(peak)
207
                        elif annot == 'distal':
208
                            if dist == '0':
209
                                if gene not in gene_body_dict:
210
                                    gene_body_dict[gene] = [peak]
211
                                else:
212
                                    gene_body_dict[gene].append(peak)
213
                            else:
214
                                if gene not in distal_dict:
215
                                    distal_dict[gene] = [peak]
216
                                else:
217
                                    distal_dict[gene].append(peak)
218
        else:
219
            for line in f:
220
                tmp = line.rstrip().split('\t')
221
                peak = f'{tmp[0]}:{tmp[1]}-{tmp[2]}'
222
                gene = tmp[3]
223
                dist = tmp[4]
224
                annot = tmp[5]
225
                if annot == 'promoter':
226
                    if gene not in promoter_dict:
227
                        promoter_dict[gene] = [peak]
228
                    else:
229
                        promoter_dict[gene].append(peak)
230
                elif annot == 'distal':
231
                    if dist == '0':
232
                        if gene not in gene_body_dict:
233
                            gene_body_dict[gene] = [peak]
234
                        else:
235
                            gene_body_dict[gene].append(peak)
236
                    else:
237
                        if gene not in distal_dict:
238
                            distal_dict[gene] = [peak]
239
                        else:
240
                            distal_dict[gene].append(peak)
241
242
    # read linkages
243
    with open(linkage_file) as f:
244
        for line in f:
245
            tmp = line.rstrip().split('\t')
246
            if tmp[12] == "peak-peak":
247
                peak1 = f'{tmp[0]}:{tmp[1]}-{tmp[2]}'
248
                peak2 = f'{tmp[3]}:{tmp[4]}-{tmp[5]}'
249
                tmp2 = tmp[6].split('><')[0][1:].split(';')
250
                tmp3 = tmp[6].split('><')[1][:-1].split(';')
251
                corr = float(tmp[7])
252
                for t2 in tmp2:
253
                    gene1 = t2.split('_')
254
                    for t3 in tmp3:
255
                        gene2 = t3.split('_')
256
                        # one of the peaks is in promoter, peaks belong to the
257
                        # same gene or are close in distance
258
                        if (((gene1[1] == "promoter") !=
259
                            (gene2[1] == "promoter")) and
260
                            ((gene1[0] == gene2[0]) or
261
                             (float(tmp[11]) < peak_dist))):
262
263
                            if gene1[1] == "promoter":
264
                                gene = gene1[0]
265
                            else:
266
                                gene = gene2[0]
267
                            if gene in corr_dict:
268
                                # peak 1 is in promoter, peak 2 is not in gene
269
                                # body -> peak 2 is added to gene 1
270
                                if (peak2 not in corr_dict[gene] and
271
                                    gene1[1] == "promoter" and
272
                                    (gene2[0] not in gene_body_dict or
273
                                     peak2 not in gene_body_dict[gene2[0]])):
274
275
                                    corr_dict[gene][0].append(peak2)
276
                                    corr_dict[gene][1].append(corr)
277
                                # peak 2 is in promoter, peak 1 is not in gene
278
                                # body -> peak 1 is added to gene 2
279
                                if (peak1 not in corr_dict[gene] and
280
                                    gene2[1] == "promoter" and
281
                                    (gene1[0] not in gene_body_dict or
282
                                     peak1 not in gene_body_dict[gene1[0]])):
283
284
                                    corr_dict[gene][0].append(peak1)
285
                                    corr_dict[gene][1].append(corr)
286
                            else:
287
                                # peak 1 is in promoter, peak 2 is not in gene
288
                                # body -> peak 2 is added to gene 1
289
                                if (gene1[1] == "promoter" and
290
                                    (gene2[0] not in
291
                                     gene_body_dict
292
                                     or peak2 not in
293
                                     gene_body_dict[gene2[0]])):
294
295
                                    corr_dict[gene] = [[peak2], [corr]]
296
                                # peak 2 is in promoter, peak 1 is not in gene
297
                                # body -> peak 1 is added to gene 2
298
                                if (gene2[1] == "promoter" and
299
                                    (gene1[0] not in
300
                                     gene_body_dict
301
                                     or peak1 not in
302
                                     gene_body_dict[gene1[0]])):
303
304
                                    corr_dict[gene] = [[peak1], [corr]]
305
            elif tmp[12] == "peak-gene":
306
                peak1 = f'{tmp[0]}:{tmp[1]}-{tmp[2]}'
307
                tmp2 = tmp[6].split('><')[0][1:].split(';')
308
                gene2 = tmp[6].split('><')[1][:-1]
309
                corr = float(tmp[7])
310
                for t2 in tmp2:
311
                    gene1 = t2.split('_')
312
                    # peak 1 belongs to gene 2 or are close in distance
313
                    # -> peak 1 is added to gene 2
314
                    if ((gene1[0] == gene2) or (float(tmp[11]) < peak_dist)):
315
                        gene = gene1[0]
316
                        if gene in corr_dict:
317
                            if (peak1 not in corr_dict[gene] and
318
                                gene1[1] != "promoter" and
319
                                (gene1[0] not in gene_body_dict or
320
                                 peak1 not in gene_body_dict[gene1[0]])):
321
322
                                corr_dict[gene][0].append(peak1)
323
                                corr_dict[gene][1].append(corr)
324
                        else:
325
                            if (gene1[1] != "promoter" and
326
                                (gene1[0] not in gene_body_dict or
327
                                 peak1 not in gene_body_dict[gene1[0]])):
328
                                corr_dict[gene] = [[peak1], [corr]]
329
            elif tmp[12] == "gene-peak":
330
                peak2 = f'{tmp[3]}:{tmp[4]}-{tmp[5]}'
331
                gene1 = tmp[6].split('><')[0][1:]
332
                tmp3 = tmp[6].split('><')[1][:-1].split(';')
333
                corr = float(tmp[7])
334
                for t3 in tmp3:
335
                    gene2 = t3.split('_')
336
                    # peak 2 belongs to gene 1 or are close in distance
337
                    # -> peak 2 is added to gene 1
338
                    if ((gene1 == gene2[0]) or (float(tmp[11]) < peak_dist)):
339
                        gene = gene1
340
                        if gene in corr_dict:
341
                            if (peak2 not in corr_dict[gene] and
342
                                gene2[1] != "promoter" and
343
                                (gene2[0] not in gene_body_dict or
344
                                 peak2 not in gene_body_dict[gene2[0]])):
345
346
                                corr_dict[gene][0].append(peak2)
347
                                corr_dict[gene][1].append(corr)
348
                        else:
349
                            if (gene2[1] != "promoter" and
350
                                (gene2[0] not in gene_body_dict or
351
                                 peak2 not in gene_body_dict[gene2[0]])):
352
353
                                corr_dict[gene] = [[peak2], [corr]]
354
355
    gene_dict = promoter_dict
356
    enhancer_dict = {}
357
    promoter_genes = list(promoter_dict.keys())
358
    logg.update(f'Found {len(promoter_genes)} genes with promoter peaks', 1)
359
    for gene in promoter_genes:
360
        if gene_body:  # add gene-body peaks
361
            if gene in gene_body_dict:
362
                for peak in gene_body_dict[gene]:
363
                    if peak not in gene_dict[gene]:
364
                        gene_dict[gene].append(peak)
365
        enhancer_dict[gene] = []
366
        if gene in corr_dict:  # add enhancer peaks
367
            for j, peak in enumerate(corr_dict[gene][0]):
368
                corr = corr_dict[gene][1][j]
369
                if corr > min_corr:
370
                    if peak not in gene_dict[gene]:
371
                        gene_dict[gene].append(peak)
372
                        enhancer_dict[gene].append(peak)
373
374
    # aggregate to genes
375
    adata_atac_X_copy = adata_atac.X.A
376
    gene_mat = np.zeros((adata_atac.shape[0], len(promoter_genes)))
377
    var_names = adata_atac.var_names.to_numpy()
378
    var_dict = {}
379
380
    for i, name in enumerate(var_names):
381
        var_dict.update({name: i})
382
383
    # if we only want to run one job at a time, then no parallelization
384
    # is necessary
385
    if n_jobs == 1:
386
        parallel = False
387
388
    if parallel:
389
        # if we want to run in parallel, modify the gene_mat variable with
390
        # multiple cores, calling prepare_gene_mat with joblib.Parallel()
391
        Parallel(n_jobs=n_jobs,
392
                 require='sharedmem')(
393
                 delayed(prepare_gene_mat)(var_dict,
394
                                           gene_dict[promoter_genes[i]],
395
                                           gene_mat,
396
                                           adata_atac_X_copy,
397
                                           i)for i in tqdm(range(
398
                                               len(promoter_genes))))
399
400
    else:
401
        # if we aren't running in parallel, just call prepare_gene_mat
402
        # from a for loop
403
        for i, gene in tqdm(enumerate(promoter_genes),
404
                            total=len(promoter_genes)):
405
            prepare_gene_mat(var_dict,
406
                             gene_dict[promoter_genes[i]],
407
                             gene_mat,
408
                             adata_atac_X_copy,
409
                             i)
410
411
    gene_mat[gene_mat < 0] = 0
412
    gene_mat = AnnData(X=csr_matrix(gene_mat))
413
    gene_mat.obs_names = pd.Index(list(adata_atac.obs_names))
414
    gene_mat.var_names = pd.Index(promoter_genes)
415
    gene_mat = gene_mat[:, gene_mat.X.sum(0) > 0]
416
    if return_dict:
417
        return gene_mat, promoter_dict, enhancer_dict
418
    else:
419
        return gene_mat
420
421
422
def tfidf_norm(adata_atac, scale_factor=1e4, copy=False):
423
    """TF-IDF normalization.
424
425
    This function normalizes counts in an AnnData object with TF-IDF.
426
427
    Parameters
428
    ----------
429
    adata_atac: :class:`~anndata.AnnData`
430
        ATAC anndata object.
431
    scale_factor: `float` (default: 1e4)
432
        Value to be multiplied after normalization.
433
    copy: `bool` (default: `False`)
434
        Whether to return a copy or modify `.X` directly.
435
436
    Returns
437
    -------
438
    If `copy==True`, a new ATAC anndata object which stores normalized counts
439
    in `.X`.
440
    """
441
    npeaks = adata_atac.X.sum(1)
442
    npeaks_inv = csr_matrix(1.0/npeaks)
443
    tf = adata_atac.X.multiply(npeaks_inv)
444
    idf = diags(np.ravel(adata_atac.X.shape[0] / adata_atac.X.sum(0))).log1p()
445
    if copy:
446
        adata_atac_copy = adata_atac.copy()
447
        adata_atac_copy.X = tf.dot(idf) * scale_factor
448
        return adata_atac_copy
449
    else:
450
        adata_atac.X = tf.dot(idf) * scale_factor
451
452
453
def gen_wnn(adata_rna, adata_adt, dims, nn, random_state=0):
454
    """Computes inputs for KNN smoothing.
455
456
    This function calculates the nn_idx and nn_dist matrices needed
457
    to run knn_smooth_chrom().
458
459
    Parameters
460
    ----------
461
    adata_rna: :class:`~anndata.AnnData`
462
        RNA anndata object.
463
    adata_atac: :class:`~anndata.AnnData`
464
        ATAC anndata object.
465
    dims: `List[int]`
466
        Dimensions of data for RNA (index=0) and ATAC (index=1)
467
    nn: `int` (default: `None`)
468
        Top N neighbors to extract for each cell in the connectivities matrix.
469
470
    Returns
471
    -------
472
    nn_idx: `np.darray` (default: `None`)
473
        KNN index matrix of size (cells, k).
474
    nn_dist: `np.darray` (default: `None`)
475
        KNN distance matrix of size (cells, k).
476
    """
477
478
    # make a copy of the original adata objects so as to keep them unchanged
479
    rna_copy = adata_rna.copy()
480
    adt_copy = adata_adt.copy()
481
482
    sc.tl.pca(rna_copy,
483
              n_comps=dims[0],
484
              random_state=np.random.RandomState(seed=42),
485
              use_highly_variable=True)  # run PCA on RNA
486
487
    lsi = scipy.sparse.linalg.svds(adt_copy.X, k=dims[1])  # run SVD on ADT
488
489
    # get the lsi result
490
    adt_copy.obsm['X_lsi'] = lsi[0]
491
492
    # add the PCA from adt to rna
493
    rna_copy.obsm['X_rna_pca'] = rna_copy.obsm.pop('X_pca')
494
    rna_copy.obsm['X_adt_lsi'] = adt_copy.obsm['X_lsi']
495
496
    # run WNN
497
    WNNobj = pyWNN(rna_copy,
498
                      reps=['X_rna_pca', 'X_adt_lsi'],
499
                      npcs=dims,
500
                      n_neighbors=nn,
501
                      seed=42)
502
503
    adata_seurat = WNNobj.compute_wnn(rna_copy)
504
505
    # get the matrix storing the distances between each cell and its neighbors
506
    cx = scipy.sparse.coo_matrix(adata_seurat.obsp["WNN_distance"])
507
508
    # the number of cells
509
    cells = adata_seurat.obsp['WNN_distance'].shape[0]
510
511
    # define the shape of our final results
512
    # and make the arrays that will hold the results
513
    new_shape = (cells, nn)
514
    nn_dist = np.zeros(shape=new_shape)
515
    nn_idx = np.zeros(shape=new_shape)
516
517
    # new_col defines what column we store data in
518
    # our result arrays
519
    new_col = 0
520
521
    # loop through the distance matrices
522
    for i, j, v in zip(cx.row, cx.col, cx.data):
523
524
        # store the distances between neighbor cells
525
        nn_dist[i][new_col % nn] = v
526
527
        # for each cell's row, store the row numbers of its neighbor cells
528
        # (1-indexing instead of 0- is a holdover from R multimodalneighbors())
529
        nn_idx[i][new_col % nn] = int(j) + 1
530
531
        new_col += 1
532
533
    return nn_idx, nn_dist
534
535
536
def knn_smooth_chrom(adata_atac, nn_idx=None, nn_dist=None, conn=None,
537
                     n_neighbors=None):
538
    """KNN smoothing.
539
540
    This function smooth (impute) the count matrix with k nearest neighbors.
541
    The inputs can be either KNN index and distance matrices or a pre-computed
542
    connectivities matrix (for example in adata_rna object).
543
544
    Parameters
545
    ----------
546
    adata_atac: :class:`~anndata.AnnData`
547
        ATAC anndata object.
548
    nn_idx: `np.darray` (default: `None`)
549
        KNN index matrix of size (cells, k).
550
    nn_dist: `np.darray` (default: `None`)
551
        KNN distance matrix of size (cells, k).
552
    conn: `csr_matrix` (default: `None`)
553
        Pre-computed connectivities matrix.
554
    n_neighbors: `int` (default: `None`)
555
        Top N neighbors to extract for each cell in the connectivities matrix.
556
557
    Returns
558
    -------
559
    `.layers['Mc']` stores imputed values.
560
    """
561
    if nn_idx is not None and nn_dist is not None:
562
        if nn_idx.shape[0] != adata_atac.shape[0]:
563
            raise ValueError('Number of rows of KNN indices does not equal to '
564
                             'number of observations.')
565
        if nn_dist.shape[0] != adata_atac.shape[0]:
566
            raise ValueError('Number of rows of KNN distances does not equal '
567
                             'to number of observations.')
568
        X = coo_matrix(([], ([], [])), shape=(nn_idx.shape[0], 1))
569
        conn, sigma, rho, dists = fuzzy_simplicial_set(X, nn_idx.shape[1],
570
                                                       None, None,
571
                                                       knn_indices=nn_idx-1,
572
                                                       knn_dists=nn_dist,
573
                                                       return_dists=True)
574
    elif conn is not None:
575
        pass
576
    else:
577
        raise ValueError('Please input nearest neighbor indices and distances,'
578
                         ' or a connectivities matrix of size n x n, with '
579
                         'columns being neighbors.'
580
                         ' For example, RNA connectivities can usually be '
581
                         'found in adata.obsp.')
582
583
    conn = conn.tocsr().copy()
584
    n_counts = (conn > 0).sum(1).A1
585
    if n_neighbors is not None and n_neighbors < n_counts.min():
586
        conn = top_n_sparse(conn, n_neighbors)
587
    conn.setdiag(1)
588
    conn_norm = conn.multiply(1.0 / conn.sum(1)).tocsr()
589
    adata_atac.layers['Mc'] = csr_matrix.dot(conn_norm, adata_atac.X)
590
    adata_atac.obsp['connectivities'] = conn
591
592
593
def calculate_qc_metrics(adata, **kwargs):
594
    """Basic QC metrics.
595
596
    This function calculate basic QC metrics with
597
    `scanpy.pp.calculate_qc_metrics`.
598
    Additionally, total counts and the ratio of unspliced and spliced matrices,
599
    as well as the cell cycle scores (with `scvelo.tl.score_genes_cell_cycle`)
600
    will be computed.
601
602
    Parameters
603
    ----------
604
    adata: :class:`~anndata.AnnData`
605
        RNA anndata object. Required fields: `unspliced` and `spliced`.
606
    Additional parameters passed to `scanpy.pp.calculate_qc_metrics`.
607
608
    Returns
609
    -------
610
    Outputs of `scanpy.pp.calculate_qc_metrics` and
611
    `scvelo.tl.score_genes_cell_cycle`. total_unspliced, total_spliced: `.var`
612
        total counts of unspliced and spliced matrices.
613
    unspliced_ratio: `.var`
614
        ratio of unspliced counts vs (unspliced + spliced counts).
615
    cell_cycle_score: `.var`
616
        cell cycle score difference between G2M_score and S_score.
617
    """
618
    sc.pp.calculate_qc_metrics(adata, **kwargs)
619
    if 'spliced' not in adata.layers:
620
        raise ValueError('Spliced matrix not found in adata.layers')
621
    if 'unspliced' not in adata.layers:
622
        raise ValueError('Unspliced matrix not found in adata.layers')
623
624
    logg.update(adata.layers['spliced'].shape, v=1)
625
626
    total_s = np.nansum(adata.layers['spliced'].toarray(), axis=1)
627
    total_u = np.nansum(adata.layers['unspliced'].toarray(), axis=1)
628
629
    logg.update(total_u.shape, v=1)
630
631
    adata.obs['total_unspliced'] = total_u
632
    adata.obs['total_spliced'] = total_s
633
    adata.obs['unspliced_ratio'] = total_u / (total_s + total_u)
634
    scv.tl.score_genes_cell_cycle(adata)
635
    adata.obs['cell_cycle_score'] = (adata.obs['G2M_score']
636
                                     - adata.obs['S_score'])
637
638
639
def ellipse_fit(adata,
640
                genes,
641
                color_by='quantile',
642
                n_cols=8,
643
                title=None,
644
                figsize=None,
645
                axis_on=False,
646
                pointsize=2,
647
                linewidth=2
648
                ):
649
    """Fit ellipses to unspliced and spliced phase portraits.
650
651
    This function plots the ellipse fits on the unspliced-spliced phase
652
    portraits.
653
654
    Parameters
655
    ----------
656
    adata: :class:`~anndata.AnnData`
657
        RNA anndata object. Required fields: `Mu` and `Ms`.
658
    genes: `str`,  list of `str`
659
        List of genes to plot.
660
    color_by: `str` (default: `quantile`)
661
        Color by the four quantiles based on ellipse fit if `quantile`. Other
662
        common values are leiden, louvain, celltype, etc.
663
        If not `quantile`, the color field must be present in `.uns`, which
664
        can be pre-computed with `scanpy.pl.scatter`.
665
        For `quantile`, red, orange, green, and blue represent quantile left,
666
        top, right, and bottom, respectively.
667
        If `quantile_scores`, `multivelo.compute_quantile_scores` function
668
        must have been run.
669
    n_cols: `int` (default: 8)
670
        Number of columns to plot on each row.
671
    figsize: `tuple` (default: `None`)
672
        Total figure size.
673
    title: `tuple` (default: `None`)
674
        Title of the figure. Default is `Ellipse Fit`.
675
    axis_on: `bool` (default: `False`)
676
        Whether to show axis labels.
677
    pointsize: `float` (default: 2)
678
        Point size for scatter plots.
679
    linewidth: `float` (default: 2)
680
        Line width for ellipse.
681
    """
682
    by_quantile = color_by == 'quantile'
683
    by_quantile_score = color_by == 'quantile_scores'
684
    if not by_quantile and not by_quantile_score:
685
        types = adata.obs[color_by].cat.categories
686
        colors = adata.uns[f'{color_by}_colors']
687
    gn = len(genes)
688
    if gn < n_cols:
689
        n_cols = gn
690
    fig, axs = plt.subplots(-(-gn // n_cols), n_cols, squeeze=False,
691
                            figsize=(2 * n_cols, 2.4 * (-(-gn // n_cols)))
692
                            if figsize is None else figsize)
693
    count = 0
694
    for gene in genes:
695
        u = np.array(adata[:, gene].layers['Mu'])
696
        s = np.array(adata[:, gene].layers['Ms'])
697
        row = count // n_cols
698
        col = count % n_cols
699
        non_zero = (u > 0) & (s > 0)
700
        if np.sum(non_zero) < 10:
701
            count += 1
702
            fig.delaxes(axs[row, col])
703
            continue
704
705
        mean_u, mean_s = np.mean(u[non_zero]), np.mean(s[non_zero])
706
        std_u, std_s = np.std(u[non_zero]), np.std(s[non_zero])
707
        u_ = (u - mean_u)/std_u
708
        s_ = (s - mean_s)/std_s
709
        X = np.reshape(s_[non_zero], (-1, 1))
710
        Y = np.reshape(u_[non_zero], (-1, 1))
711
712
        # Ax^2 + Bxy + Cy^2 + Dx + Ey + 1 = 0
713
        A = np.hstack([X**2, X * Y, Y**2, X, Y])
714
        b = -np.ones_like(X)
715
        x, res, _, _ = np.linalg.lstsq(A, b)
716
        x = x.squeeze()
717
        A, B, C, D, E = x
718
        good_fit = B**2 - 4*A*C < 0
719
        theta = np.arctan(B/(A - C))/2 \
720
            if x[0] > x[2] \
721
            else np.pi/2 + np.arctan(B/(A - C))/2
722
        good_fit = good_fit & (theta < np.pi/2) & (theta > 0)
723
        if not good_fit:
724
            count += 1
725
            fig.delaxes(axs[row, col])
726
            continue
727
        x_coord = np.linspace((-mean_s)/std_s, (np.max(s)-mean_s)/std_s, 500)
728
        y_coord = np.linspace((-mean_u)/std_u, (np.max(u)-mean_u)/std_u, 500)
729
        X_coord, Y_coord = np.meshgrid(x_coord, y_coord)
730
        Z_coord = (A * X_coord**2 + B * X_coord * Y_coord + C * Y_coord**2 +
731
                   D * X_coord + E * Y_coord + 1)
732
733
        M0 = np.array([
734
             A, B/2, D/2,
735
             B/2, C, E/2,
736
             D/2, E/2, 1,
737
        ]).reshape(3, 3)
738
        M = np.array([
739
            A, B/2,
740
            B/2, C,
741
        ]).reshape(2, 2)
742
        l1, l2 = np.sort(np.linalg.eigvals(M))
743
        xc = (B*E - 2*C*D)/(4*A*C - B**2)
744
        yc = (B*D - 2*A*E)/(4*A*C - B**2)
745
        slope_major = np.tan(theta)
746
        theta2 = np.pi/2 + theta
747
        slope_minor = np.tan(theta2)
748
        a = np.sqrt(-np.linalg.det(M0)/np.linalg.det(M)/l2)
749
        b = np.sqrt(-np.linalg.det(M0)/np.linalg.det(M)/l1)
750
        xtop = xc + a*np.cos(theta)
751
        ytop = yc + a*np.sin(theta)
752
        xbot = xc - a*np.cos(theta)
753
        ybot = yc - a*np.sin(theta)
754
        xtop2 = xc + b*np.cos(theta2)
755
        ytop2 = yc + b*np.sin(theta2)
756
        xbot2 = xc - b*np.cos(theta2)
757
        ybot2 = yc - b*np.sin(theta2)
758
        mse = res[0] / np.sum(non_zero)
759
        major = lambda x, y: (y - yc) - (slope_major * (x - xc))
760
        minor = lambda x, y: (y - yc) - (slope_minor * (x - xc))
761
        quant1 = (major(s_, u_) > 0) & (minor(s_, u_) < 0)
762
        quant2 = (major(s_, u_) > 0) & (minor(s_, u_) > 0)
763
        quant3 = (major(s_, u_) < 0) & (minor(s_, u_) > 0)
764
        quant4 = (major(s_, u_) < 0) & (minor(s_, u_) < 0)
765
        if (np.sum(quant1 | quant4) < 10) or (np.sum(quant2 | quant3) < 10):
766
            count += 1
767
            continue
768
769
        if by_quantile:
770
            axs[row, col].scatter(s_[quant1], u_[quant1], s=pointsize,
771
                                  c='tab:red', alpha=0.6)
772
            axs[row, col].scatter(s_[quant2], u_[quant2], s=pointsize,
773
                                  c='tab:orange', alpha=0.6)
774
            axs[row, col].scatter(s_[quant3], u_[quant3], s=pointsize,
775
                                  c='tab:green', alpha=0.6)
776
            axs[row, col].scatter(s_[quant4], u_[quant4], s=pointsize,
777
                                  c='tab:blue', alpha=0.6)
778
        elif by_quantile_score:
779
            if 'quantile_scores' not in adata.layers:
780
                raise ValueError('Please run multivelo.compute_quantile_scores'
781
                                 ' first to compute quantile scores.')
782
            axs[row, col].scatter(s_, u_, s=pointsize,
783
                                  c=adata[:, gene].layers['quantile_scores'],
784
                                  cmap='RdBu_r', alpha=0.7)
785
        else:
786
            for i in range(len(types)):
787
                filt = adata.obs[color_by] == types[i]
788
                axs[row, col].scatter(s_[filt], u_[filt], s=pointsize,
789
                                      c=colors[i], alpha=0.7)
790
        axs[row, col].contour(X_coord, Y_coord, Z_coord, levels=[0],
791
                              colors=('r'), linewidths=linewidth, alpha=0.7)
792
        axs[row, col].scatter([xc], [yc], c='black', s=5, zorder=2)
793
        axs[row, col].scatter([0], [0], c='black', s=5, zorder=2)
794
        axs[row, col].plot([xtop, xbot], [ytop, ybot], color='b',
795
                           linestyle='dashed', linewidth=linewidth, alpha=0.7)
796
        axs[row, col].plot([xtop2, xbot2], [ytop2, ybot2], color='g',
797
                           linestyle='dashed', linewidth=linewidth, alpha=0.7)
798
799
        axs[row, col].set_title(f'{gene} {mse:.3g}')
800
        axs[row, col].set_xlabel('s')
801
        axs[row, col].set_ylabel('u')
802
        common_range = [(np.min([(-mean_s)/std_s, (-mean_u)/std_u])
803
                        - (0.05*np.max(s)/std_s)),
804
                        (np.max([(np.max(s)-mean_s)/std_s,
805
                                 (np.max(u)-mean_u)/std_u])
806
                        + (0.05*np.max(s)/std_s))]
807
        axs[row, col].set_xlim(common_range)
808
        axs[row, col].set_ylim(common_range)
809
        if not axis_on:
810
            axs[row, col].xaxis.set_ticks_position('none')
811
            axs[row, col].yaxis.set_ticks_position('none')
812
            axs[row, col].get_xaxis().set_visible(False)
813
            axs[row, col].get_yaxis().set_visible(False)
814
            axs[row, col].xaxis.set_ticks_position('none')
815
            axs[row, col].yaxis.set_ticks_position('none')
816
            axs[row, col].set_frame_on(False)
817
        count += 1
818
819
    for i in range(col+1, n_cols):
820
        fig.delaxes(axs[row, i])
821
    if title is not None:
822
        fig.suptitle(title, fontsize=15)
823
    else:
824
        fig.suptitle('Ellipse Fit', fontsize=15)
825
    fig.tight_layout(rect=[0, 0.1, 1, 0.98])
826
827
828
def compute_quantile_scores(adata,
829
                            n_pcs=30,
830
                            n_neighbors=30
831
                            ):
832
    """Fit ellipses to unspliced and spliced phase portraits and compute
833
        quantile scores.
834
835
    This function fit ellipses to unspliced-spliced phase portraits. The cells
836
    are split into four groups (quantiles) based on the axes of the ellipse.
837
    Then the function assigns each quantile a score: -3 for left, -1 for top, 1
838
    for right, and 3 for bottom. These gene-specific values are smoothed with a
839
    connectivities matrix. This is similar to the RNA velocity gene time
840
    assignment.
841
842
    In addition, a 2-bit tuple is assigned to each of the four quantiles, (0,0)
843
    for left, (1,0) for top, (1,1) for right, and (0,1) for bottom. This is to
844
    mimic the distance relationship between quantiles.
845
846
    Parameters
847
    ----------
848
    adata: :class:`~anndata.AnnData`
849
        RNA anndata object. Required fields: `Mu` and `Ms`.
850
    n_pcs: `int` (default: 30)
851
        Number of principal components to compute connectivities.
852
    n_neighbors: `int` (default: 30)
853
        Number of nearest neighbors to compute connectivities.
854
855
    Returns
856
    -------
857
    quantile_scores: `.layers`
858
        gene-specific quantile scores
859
    quantile_scores_1st_bit, quantile_scores_2nd_bit: `.layers`
860
        2-bit assignment for gene quantiles
861
    quantile_score_sum: `.obs`
862
        aggreagted quantile scores
863
    quantile_genes: `.var`
864
        genes with good quantilty ellipse fits
865
    """
866
    neighbors = Neighbors(adata)
867
    neighbors.compute_neighbors(n_neighbors=n_neighbors, knn=True, n_pcs=n_pcs)
868
    conn = neighbors.connectivities
869
    conn.setdiag(1)
870
    conn_norm = conn.multiply(1.0 / conn.sum(1)).tocsr()
871
872
    quantile_scores = np.zeros(adata.shape)
873
    quantile_scores_2bit = np.zeros((adata.shape[0], adata.shape[1], 2))
874
    quantile_gene = np.full(adata.n_vars, False)
875
    quality_gene_idx = []
876
    for idx, gene in enumerate(adata.var_names):
877
        u = np.array(adata[:, gene].layers['Mu'])
878
        s = np.array(adata[:, gene].layers['Ms'])
879
        non_zero = (u > 0) & (s > 0)
880
        if np.sum(non_zero) < 10:
881
            continue
882
883
        mean_u, mean_s = np.mean(u[non_zero]), np.mean(s[non_zero])
884
        std_u, std_s = np.std(u[non_zero]), np.std(s[non_zero])
885
        u_ = (u - mean_u)/std_u
886
        s_ = (s - mean_s)/std_s
887
        X = np.reshape(s_[non_zero], (-1, 1))
888
        Y = np.reshape(u_[non_zero], (-1, 1))
889
890
        # Ax^2 + Bxy + Cy^2 + Dx + Ey + 1 = 0
891
        A = np.hstack([X**2, X * Y, Y**2, X, Y])
892
        b = -np.ones_like(X)
893
        x, res, _, _ = np.linalg.lstsq(A, b)
894
        x = x.squeeze()
895
        A, B, C, D, E = x
896
        good_fit = B**2 - 4*A*C < 0
897
        theta = np.arctan(B/(A - C))/2 \
898
            if x[0] > x[2] \
899
            else np.pi/2 + np.arctan(B/(A - C))/2
900
        good_fit = good_fit & (theta < np.pi/2) & (theta > 0)
901
        if not good_fit:
902
            continue
903
904
        x_coord = np.linspace((-mean_s)/std_s, (np.max(s)-mean_s)/std_s, 500)
905
        y_coord = np.linspace((-mean_u)/std_u, (np.max(u)-mean_u)/std_u, 500)
906
        X_coord, Y_coord = np.meshgrid(x_coord, y_coord)
907
        M = np.array([
908
            A, B/2,
909
            B/2, C,
910
        ]).reshape(2, 2)
911
        l1, l2 = np.sort(np.linalg.eigvals(M))
912
        xc = (B*E - 2*C*D)/(4*A*C - B**2)
913
        yc = (B*D - 2*A*E)/(4*A*C - B**2)
914
        slope_major = np.tan(theta)
915
        theta2 = np.pi/2 + theta
916
        slope_minor = np.tan(theta2)
917
        major = lambda x, y: (y - yc) - (slope_major * (x - xc))
918
        minor = lambda x, y: (y - yc) - (slope_minor * (x - xc))
919
920
        quant1 = (major(s_, u_) > 0) & (minor(s_, u_) < 0)
921
        quant2 = (major(s_, u_) > 0) & (minor(s_, u_) > 0)
922
        quant3 = (major(s_, u_) < 0) & (minor(s_, u_) > 0)
923
        quant4 = (major(s_, u_) < 0) & (minor(s_, u_) < 0)
924
        if (np.sum(quant1 | quant4) < 10) or (np.sum(quant2 | quant3) < 10):
925
            continue
926
927
        quantile_scores[:, idx:idx+1] = ((-3.) * quant1 + (-1.) * quant2 + 1.
928
                                         * quant3 + 3. * quant4)
929
        quantile_scores_2bit[:, idx:idx+1, 0] = 1. * (quant1 | quant2)
930
        quantile_scores_2bit[:, idx:idx+1, 1] = 1. * (quant2 | quant3)
931
        quality_gene_idx.append(idx)
932
933
    quantile_scores = csr_matrix.dot(conn_norm, quantile_scores)
934
    quantile_scores_2bit[:, :, 0] = csr_matrix.dot(conn_norm,
935
                                                   quantile_scores_2bit[:,
936
                                                                        :, 0])
937
    quantile_scores_2bit[:, :, 1] = csr_matrix.dot(conn_norm,
938
                                                   quantile_scores_2bit[:,
939
                                                                        :, 1])
940
    adata.layers['quantile_scores'] = quantile_scores
941
    adata.layers['quantile_scores_1st_bit'] = quantile_scores_2bit[:, :, 0]
942
    adata.layers['quantile_scores_2nd_bit'] = quantile_scores_2bit[:, :, 1]
943
    quantile_gene[quality_gene_idx] = True
944
945
    if settings.VERBOSITY >= 1:
946
        perc_good = np.sum(quantile_gene) / adata.n_vars * 100
947
948
    logg.update(f'{np.sum(quantile_gene)}/{adata.n_vars} - {perc_good:.3g}%'
949
                'genes have good ellipse fits', v=1)
950
951
    adata.obs['quantile_score_sum'] = \
952
        np.sum(adata[:, quantile_gene].layers['quantile_scores'], axis=1)
953
    adata.var['quantile_genes'] = quantile_gene
954
955
956
def cluster_by_quantile(adata,
957
                        plot=False,
958
                        n_clusters=None,
959
                        affinity='euclidean',
960
                        linkage='ward'
961
                        ):
962
    """Cluster genes based on 2-bit quantile scores.
963
964
    This function cluster similar genes based on their 2-bit quantile score
965
    assignments from ellipse fit.
966
    Hierarchical cluster is done with `sklean.cluster.AgglomerativeClustering`.
967
968
    Parameters
969
    ----------
970
    adata: :class:`~anndata.AnnData`
971
        RNA anndata object. Required fields: `Mu` and `Ms`.
972
    plot: `bool` (default: `False`)
973
        Plot the hierarchical clusters.
974
    n_clusters: `int` (default: None)
975
        The number of clusters to keep.
976
    affinity: `str` (default: `euclidean`)
977
        Metric used to compute linkage. Passed to
978
        `sklean.cluster.AgglomerativeClustering`.
979
    linkage: `str` (default: `ward`)
980
        Linkage criterion to use. Passed to
981
        `sklean.cluster.AgglomerativeClustering`.
982
983
    Returns
984
    -------
985
    quantile_cluster: `.var`
986
        cluster assignments of genes based on quantiles
987
    """
988
    from sklearn.cluster import AgglomerativeClustering
989
    if 'quantile_scores_1st_bit' not in adata.layers.keys():
990
        raise ValueError("Quantile scores not found. Please run "
991
                         "compute_quantile_scores function first.")
992
    quantile_gene = adata.var['quantile_genes']
993
    if plot or n_clusters is None:
994
        cluster = AgglomerativeClustering(distance_threshold=0,
995
                                          n_clusters=None,
996
                                          affinity=affinity,
997
                                          linkage=linkage)
998
        cluster = cluster.fit(np.vstack((adata[:, quantile_gene]
999
                                         .layers['quantile_scores_1st_bit'],
1000
                                         adata[:, quantile_gene]
1001
                                         .layers['quantile_scores_2nd_bit']))
1002
                                .transpose())
1003
1004
        # https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_dendrogram.html
1005
        def plot_dendrogram(model, **kwargs):
1006
            from scipy.cluster.hierarchy import dendrogram
1007
            counts = np.zeros(model.children_.shape[0])
1008
            n_samples = len(model.labels_)
1009
            for i, merge in enumerate(model.children_):
1010
                current_count = 0
1011
                for child_idx in merge:
1012
                    if child_idx < n_samples:
1013
                        current_count += 1
1014
                    else:
1015
                        current_count += counts[child_idx - n_samples]
1016
                counts[i] = current_count
1017
            linkage_matrix = np.column_stack([model.children_,
1018
                                              model.distances_,
1019
                                              counts]).astype(float)
1020
            dendrogram(linkage_matrix, **kwargs)
1021
1022
        plot_dendrogram(cluster, truncate_mode='level', p=5, no_labels=True)
1023
1024
    if n_clusters is not None:
1025
        n_clusters = int(n_clusters)
1026
        cluster = AgglomerativeClustering(n_clusters=n_clusters,
1027
                                          affinity=affinity,
1028
                                          linkage=linkage)
1029
        cluster = cluster.fit_predict(np.vstack((adata[:, quantile_gene].layers
1030
                                                 ['quantile_scores_1st_bit'],
1031
                                                 adata[:, quantile_gene].layers
1032
                                                 ['quantile_scores_2nd_bit']))
1033
                                        .transpose())
1034
        quantile_cluster = np.full(adata.n_vars, -1)
1035
        quantile_cluster[quantile_gene] = cluster
1036
        adata.var['quantile_cluster'] = quantile_cluster