--- a +++ b/src/multivelo/auxiliary.py @@ -0,0 +1,1036 @@ +import numpy as np +import matplotlib.pyplot as plt +from scipy.sparse import coo_matrix, csr_matrix, diags +from umap.umap_ import fuzzy_simplicial_set +from anndata import AnnData +import scanpy as sc +import scvelo as scv +import pandas as pd +from tqdm.auto import tqdm +import scipy +import os +import sys +from joblib import Parallel, delayed +from tqdm.auto import tqdm + +current_path = os.path.dirname(__file__) +src_path = os.path.join(current_path, "..") +sys.path.append(src_path) + +from multivelo import mv_logging as logg +from multivelo import settings + +current_path = os.path.dirname(__file__) + +sys.path.append(current_path) + +from pyWNN import pyWNN + +def do_count(fastqs, input_loc, output_loc, whitelist_path=None, tech="10XV3", strand=None, threads=8, memory=4): + + """Get spliced and unspliced counts from fastq data. + + Makes use of the kallisto-bustools count function: + https://www.kallistobus.tools/kb_usage/kb_count/ + kb-python.readthedocs.io/en/latest/autoapi/kb_python/count/index.html + + Parameters + ---------- + fastqs: `List[str]` + The file locations of the fastqs to process. + input_loc: `str` + The folder location of the reference files. + The folder should contain an index file with the name "index.idx", a + transcripts-to-gene file with the name "t2g.txt", a cDNA + transcripts-to-capture file with the name "cdna_t2c.txt", and an + intron transcripts-to-captured file with the name intron_t2c.txt. + output_loc: `str` + The desired folder location of the output of the function. + whitelist_path: `str` (default: `None`) + Path to a barcode whitelist to use to replace the selected technology's + whitelist. + tech: `str` (default: `10XV3`) + The technology used to collect the single-cell data. + strand: `str` (default: `None`) + The strandedness desired to process the data. + threads: `int` + The number of threads to use for parallel processing. + memory: `int` + Maximum memory (in GB) to use while processing. + + Returns + ------- + adata_count: :class:`~anndata.AnnData` + An AnnData object containing all the spliced and unspliced counts, + as well as associated gene names. + + """ + + # convert the number of threads and the amount of allocated memory + # into correctly-formatted strings for running kb count + thread_string = str(threads) + memory_string = str(memory) + "G" + + # locations of important files + index_loc = input_loc + "/index.idx" + t2g_loc = input_loc + "/t2g.txt" + + cdna_t2c = input_loc + "/cdna_t2c.txt" + intron_t2c = input_loc + "/intron_t2c.txt" + + # keep the original argv values in case the user specifies it + orig_argv = sys.argv + + # assemble the input array to use for kb count + input_array = ["count", + "count", + "-i", index_loc, "-g", t2g_loc, + "-x", tech, + "-o", output_loc, + "-t", thread_string, "-m", memory_string, + "--workflow", "lamanno", + "-c1", cdna_t2c, + "-c2", intron_t2c, + "--h5ad"] + + # specify the stranded-ness of the run + if strand is not None: + input_array.append("--strand") + input_array.append(strand) + + # specify the whitelist path of the run + if whitelist_path is not None: + input_array.append("-w") + input_array.append(whitelist_path) + + # add the fastq's we're doing the run on + for fastq in fastqs: + input_array.append(fastq) + + # use our assembled input array as the parameters for kb count + sys.argv = input_array + + # run kb count + kbm.main() + + # set argv back to its original value + sys.argv = orig_argv + + # get the anndata object for + path = output_loc + "/counts_unfiltered" + adata_count = sc.read(path + "/adata.h5ad") + + return adata_count + + +def prepare_gene_mat(var_dict, peaks, gene_mat, adata_atac_X_copy, i): + + for peak in peaks: + if peak in var_dict: + peak_index = var_dict[peak] + + gene_mat[:, i] += adata_atac_X_copy[:, peak_index] + + +def aggregate_peaks_10x(adata_atac, peak_annot_file, linkage_file, + peak_dist=10000, min_corr=0.5, gene_body=False, + return_dict=False, parallel=False, n_jobs=1): + + """Peak to gene aggregation. + + This function aggregates promoter and enhancer peaks to genes based on the + 10X linkage file. + + Parameters + ---------- + adata_atac: :class:`~anndata.AnnData` + ATAC anndata object which stores raw peak counts. + peak_annot_file: `str` + Peak annotation file from 10X CellRanger ARC. + linkage_file: `str` + Peak-gene linkage file from 10X CellRanger ARC. This file stores highly + correlated peak-peak and peak-gene pair information. + peak_dist: `int` (default: 10000) + Maximum distance for peaks to be included for a gene. + min_corr: `float` (default: 0.5) + Minimum correlation for a peak to be considered as enhancer. + gene_body: `bool` (default: `False`) + Whether to add gene body peaks to the associated promoters. + return_dict: `bool` (default: `False`) + Whether to return promoter and enhancer dictionaries. + + Returns + ------- + A new ATAC anndata object which stores gene aggreagted peak counts. + Additionally, if `return_dict==True`: + A dictionary which stores genes and promoter peaks. + And a dictionary which stores genes and enhancer peaks. + """ + promoter_dict = {} + distal_dict = {} + gene_body_dict = {} + corr_dict = {} + + # read annotations + with open(peak_annot_file) as f: + header = next(f) + tmp = header.split('\t') + if len(tmp) == 4: + cellranger_version = 1 + elif len(tmp) == 6: + cellranger_version = 2 + else: + raise ValueError('Peak annotation file should contain 4 columns ' + '(CellRanger ARC 1.0.0) or 6 columns (CellRanger ' + 'ARC 2.0.0)') + + logg.update(f'CellRanger ARC identified as {cellranger_version}.0.0', + v=1) + + if cellranger_version == 1: + for line in f: + tmp = line.rstrip().split('\t') + tmp1 = tmp[0].split('_') + peak = f'{tmp1[0]}:{tmp1[1]}-{tmp1[2]}' + if tmp[1] != '': + genes = tmp[1].split(';') + dists = tmp[2].split(';') + types = tmp[3].split(';') + for i, gene in enumerate(genes): + dist = dists[i] + annot = types[i] + if annot == 'promoter': + if gene not in promoter_dict: + promoter_dict[gene] = [peak] + else: + promoter_dict[gene].append(peak) + elif annot == 'distal': + if dist == '0': + if gene not in gene_body_dict: + gene_body_dict[gene] = [peak] + else: + gene_body_dict[gene].append(peak) + else: + if gene not in distal_dict: + distal_dict[gene] = [peak] + else: + distal_dict[gene].append(peak) + else: + for line in f: + tmp = line.rstrip().split('\t') + peak = f'{tmp[0]}:{tmp[1]}-{tmp[2]}' + gene = tmp[3] + dist = tmp[4] + annot = tmp[5] + if annot == 'promoter': + if gene not in promoter_dict: + promoter_dict[gene] = [peak] + else: + promoter_dict[gene].append(peak) + elif annot == 'distal': + if dist == '0': + if gene not in gene_body_dict: + gene_body_dict[gene] = [peak] + else: + gene_body_dict[gene].append(peak) + else: + if gene not in distal_dict: + distal_dict[gene] = [peak] + else: + distal_dict[gene].append(peak) + + # read linkages + with open(linkage_file) as f: + for line in f: + tmp = line.rstrip().split('\t') + if tmp[12] == "peak-peak": + peak1 = f'{tmp[0]}:{tmp[1]}-{tmp[2]}' + peak2 = f'{tmp[3]}:{tmp[4]}-{tmp[5]}' + tmp2 = tmp[6].split('><')[0][1:].split(';') + tmp3 = tmp[6].split('><')[1][:-1].split(';') + corr = float(tmp[7]) + for t2 in tmp2: + gene1 = t2.split('_') + for t3 in tmp3: + gene2 = t3.split('_') + # one of the peaks is in promoter, peaks belong to the + # same gene or are close in distance + if (((gene1[1] == "promoter") != + (gene2[1] == "promoter")) and + ((gene1[0] == gene2[0]) or + (float(tmp[11]) < peak_dist))): + + if gene1[1] == "promoter": + gene = gene1[0] + else: + gene = gene2[0] + if gene in corr_dict: + # peak 1 is in promoter, peak 2 is not in gene + # body -> peak 2 is added to gene 1 + if (peak2 not in corr_dict[gene] and + gene1[1] == "promoter" and + (gene2[0] not in gene_body_dict or + peak2 not in gene_body_dict[gene2[0]])): + + corr_dict[gene][0].append(peak2) + corr_dict[gene][1].append(corr) + # peak 2 is in promoter, peak 1 is not in gene + # body -> peak 1 is added to gene 2 + if (peak1 not in corr_dict[gene] and + gene2[1] == "promoter" and + (gene1[0] not in gene_body_dict or + peak1 not in gene_body_dict[gene1[0]])): + + corr_dict[gene][0].append(peak1) + corr_dict[gene][1].append(corr) + else: + # peak 1 is in promoter, peak 2 is not in gene + # body -> peak 2 is added to gene 1 + if (gene1[1] == "promoter" and + (gene2[0] not in + gene_body_dict + or peak2 not in + gene_body_dict[gene2[0]])): + + corr_dict[gene] = [[peak2], [corr]] + # peak 2 is in promoter, peak 1 is not in gene + # body -> peak 1 is added to gene 2 + if (gene2[1] == "promoter" and + (gene1[0] not in + gene_body_dict + or peak1 not in + gene_body_dict[gene1[0]])): + + corr_dict[gene] = [[peak1], [corr]] + elif tmp[12] == "peak-gene": + peak1 = f'{tmp[0]}:{tmp[1]}-{tmp[2]}' + tmp2 = tmp[6].split('><')[0][1:].split(';') + gene2 = tmp[6].split('><')[1][:-1] + corr = float(tmp[7]) + for t2 in tmp2: + gene1 = t2.split('_') + # peak 1 belongs to gene 2 or are close in distance + # -> peak 1 is added to gene 2 + if ((gene1[0] == gene2) or (float(tmp[11]) < peak_dist)): + gene = gene1[0] + if gene in corr_dict: + if (peak1 not in corr_dict[gene] and + gene1[1] != "promoter" and + (gene1[0] not in gene_body_dict or + peak1 not in gene_body_dict[gene1[0]])): + + corr_dict[gene][0].append(peak1) + corr_dict[gene][1].append(corr) + else: + if (gene1[1] != "promoter" and + (gene1[0] not in gene_body_dict or + peak1 not in gene_body_dict[gene1[0]])): + corr_dict[gene] = [[peak1], [corr]] + elif tmp[12] == "gene-peak": + peak2 = f'{tmp[3]}:{tmp[4]}-{tmp[5]}' + gene1 = tmp[6].split('><')[0][1:] + tmp3 = tmp[6].split('><')[1][:-1].split(';') + corr = float(tmp[7]) + for t3 in tmp3: + gene2 = t3.split('_') + # peak 2 belongs to gene 1 or are close in distance + # -> peak 2 is added to gene 1 + if ((gene1 == gene2[0]) or (float(tmp[11]) < peak_dist)): + gene = gene1 + if gene in corr_dict: + if (peak2 not in corr_dict[gene] and + gene2[1] != "promoter" and + (gene2[0] not in gene_body_dict or + peak2 not in gene_body_dict[gene2[0]])): + + corr_dict[gene][0].append(peak2) + corr_dict[gene][1].append(corr) + else: + if (gene2[1] != "promoter" and + (gene2[0] not in gene_body_dict or + peak2 not in gene_body_dict[gene2[0]])): + + corr_dict[gene] = [[peak2], [corr]] + + gene_dict = promoter_dict + enhancer_dict = {} + promoter_genes = list(promoter_dict.keys()) + logg.update(f'Found {len(promoter_genes)} genes with promoter peaks', 1) + for gene in promoter_genes: + if gene_body: # add gene-body peaks + if gene in gene_body_dict: + for peak in gene_body_dict[gene]: + if peak not in gene_dict[gene]: + gene_dict[gene].append(peak) + enhancer_dict[gene] = [] + if gene in corr_dict: # add enhancer peaks + for j, peak in enumerate(corr_dict[gene][0]): + corr = corr_dict[gene][1][j] + if corr > min_corr: + if peak not in gene_dict[gene]: + gene_dict[gene].append(peak) + enhancer_dict[gene].append(peak) + + # aggregate to genes + adata_atac_X_copy = adata_atac.X.A + gene_mat = np.zeros((adata_atac.shape[0], len(promoter_genes))) + var_names = adata_atac.var_names.to_numpy() + var_dict = {} + + for i, name in enumerate(var_names): + var_dict.update({name: i}) + + # if we only want to run one job at a time, then no parallelization + # is necessary + if n_jobs == 1: + parallel = False + + if parallel: + # if we want to run in parallel, modify the gene_mat variable with + # multiple cores, calling prepare_gene_mat with joblib.Parallel() + Parallel(n_jobs=n_jobs, + require='sharedmem')( + delayed(prepare_gene_mat)(var_dict, + gene_dict[promoter_genes[i]], + gene_mat, + adata_atac_X_copy, + i)for i in tqdm(range( + len(promoter_genes)))) + + else: + # if we aren't running in parallel, just call prepare_gene_mat + # from a for loop + for i, gene in tqdm(enumerate(promoter_genes), + total=len(promoter_genes)): + prepare_gene_mat(var_dict, + gene_dict[promoter_genes[i]], + gene_mat, + adata_atac_X_copy, + i) + + gene_mat[gene_mat < 0] = 0 + gene_mat = AnnData(X=csr_matrix(gene_mat)) + gene_mat.obs_names = pd.Index(list(adata_atac.obs_names)) + gene_mat.var_names = pd.Index(promoter_genes) + gene_mat = gene_mat[:, gene_mat.X.sum(0) > 0] + if return_dict: + return gene_mat, promoter_dict, enhancer_dict + else: + return gene_mat + + +def tfidf_norm(adata_atac, scale_factor=1e4, copy=False): + """TF-IDF normalization. + + This function normalizes counts in an AnnData object with TF-IDF. + + Parameters + ---------- + adata_atac: :class:`~anndata.AnnData` + ATAC anndata object. + scale_factor: `float` (default: 1e4) + Value to be multiplied after normalization. + copy: `bool` (default: `False`) + Whether to return a copy or modify `.X` directly. + + Returns + ------- + If `copy==True`, a new ATAC anndata object which stores normalized counts + in `.X`. + """ + npeaks = adata_atac.X.sum(1) + npeaks_inv = csr_matrix(1.0/npeaks) + tf = adata_atac.X.multiply(npeaks_inv) + idf = diags(np.ravel(adata_atac.X.shape[0] / adata_atac.X.sum(0))).log1p() + if copy: + adata_atac_copy = adata_atac.copy() + adata_atac_copy.X = tf.dot(idf) * scale_factor + return adata_atac_copy + else: + adata_atac.X = tf.dot(idf) * scale_factor + + +def gen_wnn(adata_rna, adata_adt, dims, nn, random_state=0): + """Computes inputs for KNN smoothing. + + This function calculates the nn_idx and nn_dist matrices needed + to run knn_smooth_chrom(). + + Parameters + ---------- + adata_rna: :class:`~anndata.AnnData` + RNA anndata object. + adata_atac: :class:`~anndata.AnnData` + ATAC anndata object. + dims: `List[int]` + Dimensions of data for RNA (index=0) and ATAC (index=1) + nn: `int` (default: `None`) + Top N neighbors to extract for each cell in the connectivities matrix. + + Returns + ------- + nn_idx: `np.darray` (default: `None`) + KNN index matrix of size (cells, k). + nn_dist: `np.darray` (default: `None`) + KNN distance matrix of size (cells, k). + """ + + # make a copy of the original adata objects so as to keep them unchanged + rna_copy = adata_rna.copy() + adt_copy = adata_adt.copy() + + sc.tl.pca(rna_copy, + n_comps=dims[0], + random_state=np.random.RandomState(seed=42), + use_highly_variable=True) # run PCA on RNA + + lsi = scipy.sparse.linalg.svds(adt_copy.X, k=dims[1]) # run SVD on ADT + + # get the lsi result + adt_copy.obsm['X_lsi'] = lsi[0] + + # add the PCA from adt to rna + rna_copy.obsm['X_rna_pca'] = rna_copy.obsm.pop('X_pca') + rna_copy.obsm['X_adt_lsi'] = adt_copy.obsm['X_lsi'] + + # run WNN + WNNobj = pyWNN(rna_copy, + reps=['X_rna_pca', 'X_adt_lsi'], + npcs=dims, + n_neighbors=nn, + seed=42) + + adata_seurat = WNNobj.compute_wnn(rna_copy) + + # get the matrix storing the distances between each cell and its neighbors + cx = scipy.sparse.coo_matrix(adata_seurat.obsp["WNN_distance"]) + + # the number of cells + cells = adata_seurat.obsp['WNN_distance'].shape[0] + + # define the shape of our final results + # and make the arrays that will hold the results + new_shape = (cells, nn) + nn_dist = np.zeros(shape=new_shape) + nn_idx = np.zeros(shape=new_shape) + + # new_col defines what column we store data in + # our result arrays + new_col = 0 + + # loop through the distance matrices + for i, j, v in zip(cx.row, cx.col, cx.data): + + # store the distances between neighbor cells + nn_dist[i][new_col % nn] = v + + # for each cell's row, store the row numbers of its neighbor cells + # (1-indexing instead of 0- is a holdover from R multimodalneighbors()) + nn_idx[i][new_col % nn] = int(j) + 1 + + new_col += 1 + + return nn_idx, nn_dist + + +def knn_smooth_chrom(adata_atac, nn_idx=None, nn_dist=None, conn=None, + n_neighbors=None): + """KNN smoothing. + + This function smooth (impute) the count matrix with k nearest neighbors. + The inputs can be either KNN index and distance matrices or a pre-computed + connectivities matrix (for example in adata_rna object). + + Parameters + ---------- + adata_atac: :class:`~anndata.AnnData` + ATAC anndata object. + nn_idx: `np.darray` (default: `None`) + KNN index matrix of size (cells, k). + nn_dist: `np.darray` (default: `None`) + KNN distance matrix of size (cells, k). + conn: `csr_matrix` (default: `None`) + Pre-computed connectivities matrix. + n_neighbors: `int` (default: `None`) + Top N neighbors to extract for each cell in the connectivities matrix. + + Returns + ------- + `.layers['Mc']` stores imputed values. + """ + if nn_idx is not None and nn_dist is not None: + if nn_idx.shape[0] != adata_atac.shape[0]: + raise ValueError('Number of rows of KNN indices does not equal to ' + 'number of observations.') + if nn_dist.shape[0] != adata_atac.shape[0]: + raise ValueError('Number of rows of KNN distances does not equal ' + 'to number of observations.') + X = coo_matrix(([], ([], [])), shape=(nn_idx.shape[0], 1)) + conn, sigma, rho, dists = fuzzy_simplicial_set(X, nn_idx.shape[1], + None, None, + knn_indices=nn_idx-1, + knn_dists=nn_dist, + return_dists=True) + elif conn is not None: + pass + else: + raise ValueError('Please input nearest neighbor indices and distances,' + ' or a connectivities matrix of size n x n, with ' + 'columns being neighbors.' + ' For example, RNA connectivities can usually be ' + 'found in adata.obsp.') + + conn = conn.tocsr().copy() + n_counts = (conn > 0).sum(1).A1 + if n_neighbors is not None and n_neighbors < n_counts.min(): + conn = top_n_sparse(conn, n_neighbors) + conn.setdiag(1) + conn_norm = conn.multiply(1.0 / conn.sum(1)).tocsr() + adata_atac.layers['Mc'] = csr_matrix.dot(conn_norm, adata_atac.X) + adata_atac.obsp['connectivities'] = conn + + +def calculate_qc_metrics(adata, **kwargs): + """Basic QC metrics. + + This function calculate basic QC metrics with + `scanpy.pp.calculate_qc_metrics`. + Additionally, total counts and the ratio of unspliced and spliced matrices, + as well as the cell cycle scores (with `scvelo.tl.score_genes_cell_cycle`) + will be computed. + + Parameters + ---------- + adata: :class:`~anndata.AnnData` + RNA anndata object. Required fields: `unspliced` and `spliced`. + Additional parameters passed to `scanpy.pp.calculate_qc_metrics`. + + Returns + ------- + Outputs of `scanpy.pp.calculate_qc_metrics` and + `scvelo.tl.score_genes_cell_cycle`. total_unspliced, total_spliced: `.var` + total counts of unspliced and spliced matrices. + unspliced_ratio: `.var` + ratio of unspliced counts vs (unspliced + spliced counts). + cell_cycle_score: `.var` + cell cycle score difference between G2M_score and S_score. + """ + sc.pp.calculate_qc_metrics(adata, **kwargs) + if 'spliced' not in adata.layers: + raise ValueError('Spliced matrix not found in adata.layers') + if 'unspliced' not in adata.layers: + raise ValueError('Unspliced matrix not found in adata.layers') + + logg.update(adata.layers['spliced'].shape, v=1) + + total_s = np.nansum(adata.layers['spliced'].toarray(), axis=1) + total_u = np.nansum(adata.layers['unspliced'].toarray(), axis=1) + + logg.update(total_u.shape, v=1) + + adata.obs['total_unspliced'] = total_u + adata.obs['total_spliced'] = total_s + adata.obs['unspliced_ratio'] = total_u / (total_s + total_u) + scv.tl.score_genes_cell_cycle(adata) + adata.obs['cell_cycle_score'] = (adata.obs['G2M_score'] + - adata.obs['S_score']) + + +def ellipse_fit(adata, + genes, + color_by='quantile', + n_cols=8, + title=None, + figsize=None, + axis_on=False, + pointsize=2, + linewidth=2 + ): + """Fit ellipses to unspliced and spliced phase portraits. + + This function plots the ellipse fits on the unspliced-spliced phase + portraits. + + Parameters + ---------- + adata: :class:`~anndata.AnnData` + RNA anndata object. Required fields: `Mu` and `Ms`. + genes: `str`, list of `str` + List of genes to plot. + color_by: `str` (default: `quantile`) + Color by the four quantiles based on ellipse fit if `quantile`. Other + common values are leiden, louvain, celltype, etc. + If not `quantile`, the color field must be present in `.uns`, which + can be pre-computed with `scanpy.pl.scatter`. + For `quantile`, red, orange, green, and blue represent quantile left, + top, right, and bottom, respectively. + If `quantile_scores`, `multivelo.compute_quantile_scores` function + must have been run. + n_cols: `int` (default: 8) + Number of columns to plot on each row. + figsize: `tuple` (default: `None`) + Total figure size. + title: `tuple` (default: `None`) + Title of the figure. Default is `Ellipse Fit`. + axis_on: `bool` (default: `False`) + Whether to show axis labels. + pointsize: `float` (default: 2) + Point size for scatter plots. + linewidth: `float` (default: 2) + Line width for ellipse. + """ + by_quantile = color_by == 'quantile' + by_quantile_score = color_by == 'quantile_scores' + if not by_quantile and not by_quantile_score: + types = adata.obs[color_by].cat.categories + colors = adata.uns[f'{color_by}_colors'] + gn = len(genes) + if gn < n_cols: + n_cols = gn + fig, axs = plt.subplots(-(-gn // n_cols), n_cols, squeeze=False, + figsize=(2 * n_cols, 2.4 * (-(-gn // n_cols))) + if figsize is None else figsize) + count = 0 + for gene in genes: + u = np.array(adata[:, gene].layers['Mu']) + s = np.array(adata[:, gene].layers['Ms']) + row = count // n_cols + col = count % n_cols + non_zero = (u > 0) & (s > 0) + if np.sum(non_zero) < 10: + count += 1 + fig.delaxes(axs[row, col]) + continue + + mean_u, mean_s = np.mean(u[non_zero]), np.mean(s[non_zero]) + std_u, std_s = np.std(u[non_zero]), np.std(s[non_zero]) + u_ = (u - mean_u)/std_u + s_ = (s - mean_s)/std_s + X = np.reshape(s_[non_zero], (-1, 1)) + Y = np.reshape(u_[non_zero], (-1, 1)) + + # Ax^2 + Bxy + Cy^2 + Dx + Ey + 1 = 0 + A = np.hstack([X**2, X * Y, Y**2, X, Y]) + b = -np.ones_like(X) + x, res, _, _ = np.linalg.lstsq(A, b) + x = x.squeeze() + A, B, C, D, E = x + good_fit = B**2 - 4*A*C < 0 + theta = np.arctan(B/(A - C))/2 \ + if x[0] > x[2] \ + else np.pi/2 + np.arctan(B/(A - C))/2 + good_fit = good_fit & (theta < np.pi/2) & (theta > 0) + if not good_fit: + count += 1 + fig.delaxes(axs[row, col]) + continue + x_coord = np.linspace((-mean_s)/std_s, (np.max(s)-mean_s)/std_s, 500) + y_coord = np.linspace((-mean_u)/std_u, (np.max(u)-mean_u)/std_u, 500) + X_coord, Y_coord = np.meshgrid(x_coord, y_coord) + Z_coord = (A * X_coord**2 + B * X_coord * Y_coord + C * Y_coord**2 + + D * X_coord + E * Y_coord + 1) + + M0 = np.array([ + A, B/2, D/2, + B/2, C, E/2, + D/2, E/2, 1, + ]).reshape(3, 3) + M = np.array([ + A, B/2, + B/2, C, + ]).reshape(2, 2) + l1, l2 = np.sort(np.linalg.eigvals(M)) + xc = (B*E - 2*C*D)/(4*A*C - B**2) + yc = (B*D - 2*A*E)/(4*A*C - B**2) + slope_major = np.tan(theta) + theta2 = np.pi/2 + theta + slope_minor = np.tan(theta2) + a = np.sqrt(-np.linalg.det(M0)/np.linalg.det(M)/l2) + b = np.sqrt(-np.linalg.det(M0)/np.linalg.det(M)/l1) + xtop = xc + a*np.cos(theta) + ytop = yc + a*np.sin(theta) + xbot = xc - a*np.cos(theta) + ybot = yc - a*np.sin(theta) + xtop2 = xc + b*np.cos(theta2) + ytop2 = yc + b*np.sin(theta2) + xbot2 = xc - b*np.cos(theta2) + ybot2 = yc - b*np.sin(theta2) + mse = res[0] / np.sum(non_zero) + major = lambda x, y: (y - yc) - (slope_major * (x - xc)) + minor = lambda x, y: (y - yc) - (slope_minor * (x - xc)) + quant1 = (major(s_, u_) > 0) & (minor(s_, u_) < 0) + quant2 = (major(s_, u_) > 0) & (minor(s_, u_) > 0) + quant3 = (major(s_, u_) < 0) & (minor(s_, u_) > 0) + quant4 = (major(s_, u_) < 0) & (minor(s_, u_) < 0) + if (np.sum(quant1 | quant4) < 10) or (np.sum(quant2 | quant3) < 10): + count += 1 + continue + + if by_quantile: + axs[row, col].scatter(s_[quant1], u_[quant1], s=pointsize, + c='tab:red', alpha=0.6) + axs[row, col].scatter(s_[quant2], u_[quant2], s=pointsize, + c='tab:orange', alpha=0.6) + axs[row, col].scatter(s_[quant3], u_[quant3], s=pointsize, + c='tab:green', alpha=0.6) + axs[row, col].scatter(s_[quant4], u_[quant4], s=pointsize, + c='tab:blue', alpha=0.6) + elif by_quantile_score: + if 'quantile_scores' not in adata.layers: + raise ValueError('Please run multivelo.compute_quantile_scores' + ' first to compute quantile scores.') + axs[row, col].scatter(s_, u_, s=pointsize, + c=adata[:, gene].layers['quantile_scores'], + cmap='RdBu_r', alpha=0.7) + else: + for i in range(len(types)): + filt = adata.obs[color_by] == types[i] + axs[row, col].scatter(s_[filt], u_[filt], s=pointsize, + c=colors[i], alpha=0.7) + axs[row, col].contour(X_coord, Y_coord, Z_coord, levels=[0], + colors=('r'), linewidths=linewidth, alpha=0.7) + axs[row, col].scatter([xc], [yc], c='black', s=5, zorder=2) + axs[row, col].scatter([0], [0], c='black', s=5, zorder=2) + axs[row, col].plot([xtop, xbot], [ytop, ybot], color='b', + linestyle='dashed', linewidth=linewidth, alpha=0.7) + axs[row, col].plot([xtop2, xbot2], [ytop2, ybot2], color='g', + linestyle='dashed', linewidth=linewidth, alpha=0.7) + + axs[row, col].set_title(f'{gene} {mse:.3g}') + axs[row, col].set_xlabel('s') + axs[row, col].set_ylabel('u') + common_range = [(np.min([(-mean_s)/std_s, (-mean_u)/std_u]) + - (0.05*np.max(s)/std_s)), + (np.max([(np.max(s)-mean_s)/std_s, + (np.max(u)-mean_u)/std_u]) + + (0.05*np.max(s)/std_s))] + axs[row, col].set_xlim(common_range) + axs[row, col].set_ylim(common_range) + if not axis_on: + axs[row, col].xaxis.set_ticks_position('none') + axs[row, col].yaxis.set_ticks_position('none') + axs[row, col].get_xaxis().set_visible(False) + axs[row, col].get_yaxis().set_visible(False) + axs[row, col].xaxis.set_ticks_position('none') + axs[row, col].yaxis.set_ticks_position('none') + axs[row, col].set_frame_on(False) + count += 1 + + for i in range(col+1, n_cols): + fig.delaxes(axs[row, i]) + if title is not None: + fig.suptitle(title, fontsize=15) + else: + fig.suptitle('Ellipse Fit', fontsize=15) + fig.tight_layout(rect=[0, 0.1, 1, 0.98]) + + +def compute_quantile_scores(adata, + n_pcs=30, + n_neighbors=30 + ): + """Fit ellipses to unspliced and spliced phase portraits and compute + quantile scores. + + This function fit ellipses to unspliced-spliced phase portraits. The cells + are split into four groups (quantiles) based on the axes of the ellipse. + Then the function assigns each quantile a score: -3 for left, -1 for top, 1 + for right, and 3 for bottom. These gene-specific values are smoothed with a + connectivities matrix. This is similar to the RNA velocity gene time + assignment. + + In addition, a 2-bit tuple is assigned to each of the four quantiles, (0,0) + for left, (1,0) for top, (1,1) for right, and (0,1) for bottom. This is to + mimic the distance relationship between quantiles. + + Parameters + ---------- + adata: :class:`~anndata.AnnData` + RNA anndata object. Required fields: `Mu` and `Ms`. + n_pcs: `int` (default: 30) + Number of principal components to compute connectivities. + n_neighbors: `int` (default: 30) + Number of nearest neighbors to compute connectivities. + + Returns + ------- + quantile_scores: `.layers` + gene-specific quantile scores + quantile_scores_1st_bit, quantile_scores_2nd_bit: `.layers` + 2-bit assignment for gene quantiles + quantile_score_sum: `.obs` + aggreagted quantile scores + quantile_genes: `.var` + genes with good quantilty ellipse fits + """ + neighbors = Neighbors(adata) + neighbors.compute_neighbors(n_neighbors=n_neighbors, knn=True, n_pcs=n_pcs) + conn = neighbors.connectivities + conn.setdiag(1) + conn_norm = conn.multiply(1.0 / conn.sum(1)).tocsr() + + quantile_scores = np.zeros(adata.shape) + quantile_scores_2bit = np.zeros((adata.shape[0], adata.shape[1], 2)) + quantile_gene = np.full(adata.n_vars, False) + quality_gene_idx = [] + for idx, gene in enumerate(adata.var_names): + u = np.array(adata[:, gene].layers['Mu']) + s = np.array(adata[:, gene].layers['Ms']) + non_zero = (u > 0) & (s > 0) + if np.sum(non_zero) < 10: + continue + + mean_u, mean_s = np.mean(u[non_zero]), np.mean(s[non_zero]) + std_u, std_s = np.std(u[non_zero]), np.std(s[non_zero]) + u_ = (u - mean_u)/std_u + s_ = (s - mean_s)/std_s + X = np.reshape(s_[non_zero], (-1, 1)) + Y = np.reshape(u_[non_zero], (-1, 1)) + + # Ax^2 + Bxy + Cy^2 + Dx + Ey + 1 = 0 + A = np.hstack([X**2, X * Y, Y**2, X, Y]) + b = -np.ones_like(X) + x, res, _, _ = np.linalg.lstsq(A, b) + x = x.squeeze() + A, B, C, D, E = x + good_fit = B**2 - 4*A*C < 0 + theta = np.arctan(B/(A - C))/2 \ + if x[0] > x[2] \ + else np.pi/2 + np.arctan(B/(A - C))/2 + good_fit = good_fit & (theta < np.pi/2) & (theta > 0) + if not good_fit: + continue + + x_coord = np.linspace((-mean_s)/std_s, (np.max(s)-mean_s)/std_s, 500) + y_coord = np.linspace((-mean_u)/std_u, (np.max(u)-mean_u)/std_u, 500) + X_coord, Y_coord = np.meshgrid(x_coord, y_coord) + M = np.array([ + A, B/2, + B/2, C, + ]).reshape(2, 2) + l1, l2 = np.sort(np.linalg.eigvals(M)) + xc = (B*E - 2*C*D)/(4*A*C - B**2) + yc = (B*D - 2*A*E)/(4*A*C - B**2) + slope_major = np.tan(theta) + theta2 = np.pi/2 + theta + slope_minor = np.tan(theta2) + major = lambda x, y: (y - yc) - (slope_major * (x - xc)) + minor = lambda x, y: (y - yc) - (slope_minor * (x - xc)) + + quant1 = (major(s_, u_) > 0) & (minor(s_, u_) < 0) + quant2 = (major(s_, u_) > 0) & (minor(s_, u_) > 0) + quant3 = (major(s_, u_) < 0) & (minor(s_, u_) > 0) + quant4 = (major(s_, u_) < 0) & (minor(s_, u_) < 0) + if (np.sum(quant1 | quant4) < 10) or (np.sum(quant2 | quant3) < 10): + continue + + quantile_scores[:, idx:idx+1] = ((-3.) * quant1 + (-1.) * quant2 + 1. + * quant3 + 3. * quant4) + quantile_scores_2bit[:, idx:idx+1, 0] = 1. * (quant1 | quant2) + quantile_scores_2bit[:, idx:idx+1, 1] = 1. * (quant2 | quant3) + quality_gene_idx.append(idx) + + quantile_scores = csr_matrix.dot(conn_norm, quantile_scores) + quantile_scores_2bit[:, :, 0] = csr_matrix.dot(conn_norm, + quantile_scores_2bit[:, + :, 0]) + quantile_scores_2bit[:, :, 1] = csr_matrix.dot(conn_norm, + quantile_scores_2bit[:, + :, 1]) + adata.layers['quantile_scores'] = quantile_scores + adata.layers['quantile_scores_1st_bit'] = quantile_scores_2bit[:, :, 0] + adata.layers['quantile_scores_2nd_bit'] = quantile_scores_2bit[:, :, 1] + quantile_gene[quality_gene_idx] = True + + if settings.VERBOSITY >= 1: + perc_good = np.sum(quantile_gene) / adata.n_vars * 100 + + logg.update(f'{np.sum(quantile_gene)}/{adata.n_vars} - {perc_good:.3g}%' + 'genes have good ellipse fits', v=1) + + adata.obs['quantile_score_sum'] = \ + np.sum(adata[:, quantile_gene].layers['quantile_scores'], axis=1) + adata.var['quantile_genes'] = quantile_gene + + +def cluster_by_quantile(adata, + plot=False, + n_clusters=None, + affinity='euclidean', + linkage='ward' + ): + """Cluster genes based on 2-bit quantile scores. + + This function cluster similar genes based on their 2-bit quantile score + assignments from ellipse fit. + Hierarchical cluster is done with `sklean.cluster.AgglomerativeClustering`. + + Parameters + ---------- + adata: :class:`~anndata.AnnData` + RNA anndata object. Required fields: `Mu` and `Ms`. + plot: `bool` (default: `False`) + Plot the hierarchical clusters. + n_clusters: `int` (default: None) + The number of clusters to keep. + affinity: `str` (default: `euclidean`) + Metric used to compute linkage. Passed to + `sklean.cluster.AgglomerativeClustering`. + linkage: `str` (default: `ward`) + Linkage criterion to use. Passed to + `sklean.cluster.AgglomerativeClustering`. + + Returns + ------- + quantile_cluster: `.var` + cluster assignments of genes based on quantiles + """ + from sklearn.cluster import AgglomerativeClustering + if 'quantile_scores_1st_bit' not in adata.layers.keys(): + raise ValueError("Quantile scores not found. Please run " + "compute_quantile_scores function first.") + quantile_gene = adata.var['quantile_genes'] + if plot or n_clusters is None: + cluster = AgglomerativeClustering(distance_threshold=0, + n_clusters=None, + affinity=affinity, + linkage=linkage) + cluster = cluster.fit(np.vstack((adata[:, quantile_gene] + .layers['quantile_scores_1st_bit'], + adata[:, quantile_gene] + .layers['quantile_scores_2nd_bit'])) + .transpose()) + + # https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_dendrogram.html + def plot_dendrogram(model, **kwargs): + from scipy.cluster.hierarchy import dendrogram + counts = np.zeros(model.children_.shape[0]) + n_samples = len(model.labels_) + for i, merge in enumerate(model.children_): + current_count = 0 + for child_idx in merge: + if child_idx < n_samples: + current_count += 1 + else: + current_count += counts[child_idx - n_samples] + counts[i] = current_count + linkage_matrix = np.column_stack([model.children_, + model.distances_, + counts]).astype(float) + dendrogram(linkage_matrix, **kwargs) + + plot_dendrogram(cluster, truncate_mode='level', p=5, no_labels=True) + + if n_clusters is not None: + n_clusters = int(n_clusters) + cluster = AgglomerativeClustering(n_clusters=n_clusters, + affinity=affinity, + linkage=linkage) + cluster = cluster.fit_predict(np.vstack((adata[:, quantile_gene].layers + ['quantile_scores_1st_bit'], + adata[:, quantile_gene].layers + ['quantile_scores_2nd_bit'])) + .transpose()) + quantile_cluster = np.full(adata.n_vars, -1) + quantile_cluster[quantile_gene] = cluster + adata.var['quantile_cluster'] = quantile_cluster