Switch to side-by-side view

--- a
+++ b/src/multivelo/steady_chrom_func.py
@@ -0,0 +1,631 @@
+import os
+import sys
+import warnings
+import numpy as np
+from scipy import sparse
+from scipy.sparse import csr_matrix, issparse
+from scanpy import Neighbors
+import pandas as pd
+from tqdm.auto import tqdm
+from joblib import Parallel, delayed
+
+current_path = os.path.dirname(__file__)
+src_path = os.path.join(current_path, "..")
+sys.path.append(src_path)
+
+from multivelo import mv_logging as logg
+from multivelo import settings
+
+
+class ChromatinVelocity:
+    def __init__(self, c, u, s,
+                 ss, us,
+                 gene=None,
+                 save_plot=False,
+                 plot_dir=None,
+                 fit_args=None,
+                 rna_only=False,
+                 extra_color=None,
+                 r2_adjusted=True,
+                 ):
+
+        self.gene = gene
+
+        # fitting arguments
+        self.rna_only = rna_only
+        self.outlier = np.clip(fit_args['outlier'], 80, 100)
+        self.r2_adjusted = r2_adjusted
+
+        # plot parameters
+        self.save_plot = save_plot
+        self.extra_color = extra_color
+        self.fig_size = fit_args['fig_size']
+        self.point_size = fit_args['point_size']
+        if plot_dir is None:
+            self.plot_path = 'plots_steady_state'
+        else:
+            self.plot_path = plot_dir
+
+        # input
+        self.total_n = len(u)
+        if sparse.issparse(c):
+            c = c.A
+        if sparse.issparse(u):
+            u = u.A
+        if sparse.issparse(s):
+            s = s.A
+        if ss is not None and sparse.issparse(ss):
+            ss = ss.A
+        if us is not None and sparse.issparse(us):
+            us = us.A
+        self.c_all = np.ravel(np.array(c, dtype=np.float64))
+        self.u_all = np.ravel(np.array(u, dtype=np.float64))
+        self.s_all = np.ravel(np.array(s, dtype=np.float64))
+        if ss is not None:
+            self.ss_all = np.ravel(np.array(ss, dtype=np.float64))
+        if us is not None:
+            self.us_all = np.ravel(np.array(us, dtype=np.float64))
+
+        # adjust offset
+        self.offset_c, self.offset_u, self.offset_s = np.min(self.c_all), \
+            np.min(self.u_all), \
+            np.min(self.s_all)
+        self.offset_c = 0 if self.rna_only else self.offset_c
+        self.c_all -= self.offset_c
+        self.u_all -= self.offset_u
+        self.s_all -= self.offset_s
+        # remove zero counts
+        self.non_zero = np.ravel(self.c_all > 0) | np.ravel(self.u_all > 0) | \
+            np.ravel(self.s_all > 0)
+        # remove outliers
+        self.non_outlier = np.ravel(self.c_all <=
+                                    np.percentile(self.c_all, self.outlier))
+        self.non_outlier &= np.ravel(self.u_all <=
+                                     np.percentile(self.u_all, self.outlier))
+        self.non_outlier &= np.ravel(self.s_all <=
+                                     np.percentile(self.s_all, self.outlier))
+        self.c = self.c_all[self.non_zero & self.non_outlier]
+        self.u = self.u_all[self.non_zero & self.non_outlier]
+        self.s = self.s_all[self.non_zero & self.non_outlier]
+        self.ss = (None if ss is None
+                   else self.ss_all[self.non_zero & self.non_outlier])
+        self.us = (None if us is None
+                   else self.us_all[self.non_zero & self.non_outlier])
+        self.low_quality = len(self.u) < 10
+
+        logg.update(f'{len(self.u)} cells passed filter and will be used to '
+                    'fit regressions.', v=2)
+
+        # 4 rate parameters
+        self.alpha_c = 0.1
+        self.alpha = 0.0
+        self.beta = 0.0
+        self.gamma = 0.0
+
+        # other parameters or results
+        self.loss = np.inf
+        self.r2 = 0
+        self.residual = None
+        self.residual2 = None
+        self.steady_state_func = None
+
+        # select cells for regression
+        w_sub = (self.c >= 0.1 * np.max(self.c)) & \
+                (self.u >= 0.1 * np.max(self.u)) & \
+                (self.s >= 0.1 * np.max(self.s))
+        c_sub = self.c[w_sub]
+        if not self.rna_only:
+            w_sub = (self.c >= np.mean(c_sub)+np.std(c_sub)) & \
+                    (self.u >= 0.1 * np.max(self.u)) & \
+                    (self.s >= 0.1 * np.max(self.s))
+        self.w_sub = w_sub
+        if np.sum(self.w_sub) < 10:
+            self.low_quality = True
+
+    def compute_deterministic(self):
+        if self.rna_only:
+            # steady-state slope
+            wu = self.u >= np.percentile(self.u[self.w_sub], 95)
+            ws = self.s >= np.percentile(self.s[self.w_sub], 95)
+            ss_u = self.u[wu | ws]
+            ss_s = self.s[wu | ws]
+        else:
+            # chromatin adjusted steady-state slope
+            u_high = self.u[self.w_sub]
+            s_high = self.s[self.w_sub]
+            wu_high = u_high >= np.percentile(u_high, 95)
+            ws_high = s_high >= np.percentile(s_high, 95)
+            ss_u = u_high[wu_high | ws_high]
+            ss_s = s_high[wu_high | ws_high]
+        gamma = np.dot(ss_u, ss_s) / np.dot(ss_s, ss_s)
+        self.steady_state_func = lambda x: gamma*x
+        residual = self.u_all - self.steady_state_func(self.s_all)
+        self.residual = residual
+        self.loss = np.dot(self.residual, self.residual) / len(self.u_all)
+
+        if self.r2_adjusted:
+            gamma = np.dot(self.u, self.s) / np.dot(self.s, self.s)
+            residual = self.u_all - gamma * self.s_all
+
+        total = self.u_all - np.mean(self.u_all)
+        self.r2 = 1 - np.dot(residual, residual) / np.dot(total, total)
+
+    def compute_stochastic(self):
+        self.compute_deterministic()
+
+        var_ss = 2 * self.ss - self.s
+        cov_us = 2 * self.us + self.u
+        s_all_ = 2 * self.s_all**2 - (2 * self.ss_all - self.s_all)
+        u_all_ = (2 * self.us_all + self.u_all) - 2 * self.u_all*self.s_all
+        gamma2 = np.dot(cov_us, var_ss) / np.dot(var_ss, var_ss)
+        residual2 = cov_us - gamma2 * var_ss
+        std_first = np.std(self.residual)
+        std_second = np.std(residual2)
+
+        # combine first and second moments and recompute gamma
+        if self.rna_only:
+            # steady-state slope
+            wu = self.u >= np.percentile(self.u[self.w_sub], 95)
+            ws = self.s >= np.percentile(self.s[self.w_sub], 95)
+            ss_u = self.u * (wu | ws)
+            ss_s = self.s * (wu | ws)
+            a = np.hstack((ss_s / std_first, var_ss / std_second))
+            b = np.hstack((ss_u / std_first, cov_us / std_second))
+        else:
+            # chromatin adjusted steady-state slope
+            u_high = self.u[self.w_sub]
+            s_high = self.s[self.w_sub]
+            wu_high = u_high >= np.percentile(u_high, 95)
+            ws_high = s_high >= np.percentile(s_high, 95)
+            ss_u = u_high * (wu_high | ws_high)
+            ss_s = s_high * (wu_high | ws_high)
+            a = np.hstack((ss_s / std_first, var_ss[self.w_sub] / std_second))
+            b = np.hstack((ss_u / std_first, cov_us[self.w_sub] / std_second))
+        gamma = np.dot(b, a) / np.dot(a, a)
+        self.steady_state_func = lambda x: gamma*x
+        self.residual = self.u_all - self.steady_state_func(self.s_all)
+        self.residual2 = u_all_ - self.steady_state_func(s_all_)
+        self.loss = np.dot(self.residual, self.residual) / len(self.u_all)
+
+    def get_velocity(self):
+        return self.residual
+
+    def get_variance_velocity(self):
+        return self.residual2
+
+    def get_r2(self):
+        return self.r2
+
+    def get_loss(self):
+        return self.loss
+
+
+def regress_func(c, u, s, ss, us, m, sp, pdir, fa, gene, ro, extra):
+
+    c_90 = np.percentile(c, 90)
+    u_90 = np.percentile(u, 90)
+    s_90 = np.percentile(s, 90)
+    low_quality = ((u_90 == 0 or s_90 == 0) if ro
+                   else (c_90 == 0 or u_90 == 0 or s_90 == 0))
+    if low_quality:
+        logg.update(f'low quality gene {gene}, skipping', v=1)
+        return np.zeros(len(u)), np.zeros(len(u)), 0, np.inf
+
+    cvc = ChromatinVelocity(c,
+                            u,
+                            s,
+                            ss,
+                            us,
+                            save_plot=sp,
+                            plot_dir=pdir,
+                            fit_args=fa,
+                            gene=gene,
+                            rna_only=ro,
+                            extra_color=extra)
+    if cvc.low_quality:
+        return np.zeros(len(u)), np.zeros(len(u)), 0, np.inf
+
+    if m == 'deterministic':
+        cvc.compute_deterministic()
+    elif m == 'stochastic':
+        cvc.compute_stochastic()
+    velocity = cvc.get_velocity()
+    r2 = cvc.get_r2()
+    loss = cvc.get_loss()
+    variance_velocity = (None if m == 'deterministic'
+                         else cvc.get_variance_velocity())
+    return velocity, variance_velocity, r2, loss
+
+
+def velocity_chrom(adata_rna,
+                   adata_atac=None,
+                   gene_list=None,
+                   mode='stochastic',
+                   parallel=True,
+                   n_jobs=None,
+                   save_plot=False,
+                   plot_dir=None,
+                   rna_only=False,
+                   extra_color_key=None,
+                   min_r2=1e-2,
+                   outlier=99.8,
+                   n_pcs=30,
+                   n_neighbors=30,
+                   fig_size=(8, 6),
+                   point_size=7
+                   ):
+
+    """Multi-omic steady-state model.
+
+    This function incorporates chromatin accessibilities into RNA steady-state
+    velocity.
+
+    Parameters
+    ----------
+    adata_rna: :class:`~anndata.AnnData`
+        RNA anndata object. Required fields: `Mu`, `Ms`, and `connectivities`.
+    adata_atac: :class:`~anndata.AnnData` (default: `None`)
+        ATAC anndata object. Required fields: `Mc`.
+    gene_list: `str`,  list of `str` (default: highly variable genes)
+        Genes to use for model fitting.
+    mode: `str` (default: `'stochastic'`)
+        Fitting method.
+        `'stochastic'`: computing steady-state ratio with the first and second
+        moments.
+        `'deterministic'`: computing steady-state ratio with the first moments.
+    parallel: `bool` (default: `True`)
+        Whether to fit genes in a parallel fashion (recommended).
+    n_jobs: `int` (default: available threads)
+        Number of parallel jobs.
+    save_plot: `bool` (default: `False`)
+        Whether to save the fitted gene portrait figures as files. This will
+        take some disk space.
+    plot_dir: `str` (default: `plots` for multiome and `rna_plots` for
+    RNA-only)
+        Directory to save the plots.
+    rna_only: `bool` (default: `False`)
+        Whether to only use RNA for fitting (RNA velocity).
+    extra_color_key: `str` (default: `None`)
+        Extra color key used for plotting. Common choices are `leiden`,
+        `celltype`, etc.
+        The colors for each category must be present in one of anndatas, which
+        can be pre-computed.
+        with `scanpy.pl.scatter` function.
+    min_r2: `float` (default: 1e-2)
+        Minimum R-squared value for selecting velocity genes.
+    outlier: `float` (default: 99.8)
+        The percentile to mark as outlier that will be excluded when fitting
+        the model.
+    n_pcs: `int` (default: 30)
+        Number of principal components to compute distance smoothing neighbors.
+        This can be different from the one used for expression smoothing.
+    n_neighbors: `int` (default: 30)
+        Number of nearest neighbors for distance smoothing.
+        This can be different from the one used for expression smoothing.
+    fig_size: `tuple` (default: (8,6))
+        Size of each figure when saved.
+    point_size: `float` (default: 7)
+        Marker point size for plotting.
+
+    Returns
+    -------
+    fit_r2: `.var`
+        R-squared of regression fit
+    fit_loss: `.var`
+        loss of model fit
+    velo_s: `.layers`
+        velocities in spliced space
+    variance_velo_s: `.layers`
+        variance velocities based on second moments in spliced space
+    velo_s_genes: `.var`
+        velocity genes
+    velo_s_params: `.var`
+        fitting arguments used
+    ATAC: `.layers`
+        KNN smoothed chromatin accessibilities copied from adata_atac
+    """
+
+    fit_args = {}
+    fit_args['min_r2'] = min_r2
+    fit_args['outlier'] = outlier
+    fit_args['n_pcs'] = n_pcs
+    fit_args['n_neighbors'] = n_neighbors
+    fit_args['fig_size'] = list(fig_size)
+    fit_args['point_size'] = point_size
+    if mode == 'dynamical':
+        logg.update('You do not need to run mv.velocity for chromatin '
+                    'dynamical model', v=0)
+        return
+    elif mode == 'stochastic' or mode == 'deterministic':
+        fit_args['mode'] = mode
+    else:
+        raise ValueError('Unknown mode. Must be either stochastic or '
+                         'deterministic')
+
+    all_genes = adata_rna.var_names
+    if adata_atac is None:
+        import anndata as ad
+        rna_only = True
+        adata_atac = ad.AnnData(X=np.ones(adata_rna.shape), obs=adata_rna.obs,
+                                var=adata_rna.var)
+        adata_atac.layers['Mc'] = np.ones(adata_rna.shape)
+    if adata_rna.shape != adata_atac.shape:
+        raise ValueError('Shape of RNA and ATAC adata objects do not match:'
+                         f'{adata_rna.shape} {adata_atac.shape}')
+    if not np.all(adata_rna.obs_names == adata_atac.obs_names):
+        raise ValueError('obs_names of RNA and ATAC adata objects do not '
+                         'match, please check if they are consistent')
+    if not np.all(all_genes == adata_atac.var_names):
+        raise ValueError('var_names of RNA and ATAC adata objects do not '
+                         'match, please check if they are consistent')
+    if extra_color_key is None:
+        extra_color = None
+    elif (isinstance(extra_color_key, str) and extra_color_key in adata_rna.obs
+          and adata_rna.obs[extra_color_key].dtype.name == 'category'):
+        ngroups = len(adata_rna.obs[extra_color_key].cat.categories)
+        extra_color = adata_rna.obs[extra_color_key].cat.rename_categories(
+            adata_rna.uns[extra_color_key+'_colors'][:ngroups]).to_numpy()
+    elif (isinstance(extra_color_key, str)
+          and extra_color_key in adata_atac.obs and
+          adata_rna.obs[extra_color_key].dtype.name == 'category'):
+        ngroups = len(adata_atac.obs[extra_color_key].cat.categories)
+        extra_color = adata_atac.obs[extra_color_key].cat.rename_categories(
+            adata_atac.uns[extra_color_key+'_colors'][:ngroups]).to_numpy()
+    else:
+        raise ValueError('Currently, extra_color_key must be a single string '
+                         'of categories and available in adata obs, and its '
+                         'colors can be found in adata uns')
+    if ('connectivities' not in adata_rna.obsp.keys() or
+            (adata_rna.obsp['connectivities'] > 0).sum(1).min()
+            > (n_neighbors-1)):
+        neighbors = Neighbors(adata_rna)
+        neighbors.compute_neighbors(n_neighbors=n_neighbors,
+                                    knn=True, n_pcs=n_pcs)
+        rna_conn = neighbors.connectivities
+    else:
+        rna_conn = adata_rna.obsp['connectivities'].copy()
+    rna_conn.setdiag(1)
+    rna_conn = rna_conn.multiply(1.0 / rna_conn.sum(1)).tocsr()
+
+    Mss, Mus = None, None
+    if mode == 'stochastic':
+        Mss, Mus = second_order_moments(adata_rna)
+
+    if gene_list is None:
+        if 'highly_variable' in adata_rna.var:
+            gene_list = adata_rna.var_names[
+                adata_rna.var['highly_variable']].values
+        else:
+            gene_list = adata_rna.var_names.values[
+                (~np.isnan(np.asarray(adata_rna.layers['Mu'].sum(0))
+                           .reshape(-1)
+                           if sparse.issparse(adata_rna.layers['Mu'])
+                           else np.sum(adata_rna.layers['Mu'], axis=0)))
+                & (~np.isnan(np.asarray(adata_rna.layers['Ms'].sum(0))
+                             .reshape(-1)
+                             if sparse.issparse(adata_rna.layers['Ms'])
+                             else np.sum(adata_rna.layers['Ms'], axis=0)))
+                & (~np.isnan(np.asarray(adata_atac.layers['Mc'].sum(0))
+                             .reshape(-1)
+                             if sparse.issparse(adata_atac.layers['Mc'])
+                             else np.sum(adata_atac.layers['Mc'], axis=0)))]
+    elif isinstance(gene_list, (list, np.ndarray, pd.Index, pd.Series)):
+        gene_list = np.array([x for x in gene_list if x in all_genes])
+    elif isinstance(gene_list, str):
+        gene_list = np.array([gene_list]) if gene_list in all_genes else []
+    else:
+        raise ValueError('Invalid gene list, must be one of (str, np.ndarray,'
+                         ' pd.Index, pd.Series)')
+    gn = len(gene_list)
+    if gn == 0:
+        raise ValueError('None of the genes specified are in the adata object')
+
+    logg.update(f'{gn} genes will be fitted', v=1)
+
+    velo_s = np.zeros((adata_rna.n_obs, gn))
+    variance_velo_s = np.zeros((adata_rna.n_obs, gn))
+    r2s = np.zeros(gn)
+    losses = np.zeros(gn)
+
+    u_mat = (adata_rna[:, gene_list].layers['Mu'].A
+             if sparse.issparse(adata_rna.layers['Mu'])
+             else adata_rna[:, gene_list].layers['Mu'])
+    s_mat = (adata_rna[:, gene_list].layers['Ms'].A
+             if sparse.issparse(adata_rna.layers['Ms'])
+             else adata_rna[:, gene_list].layers['Ms'])
+    c_mat = (adata_atac[:, gene_list].layers['Mc'].A
+             if sparse.issparse(adata_atac.layers['Mc'])
+             else adata_atac[:, gene_list].layers['Mc'])
+    if parallel:
+        if (n_jobs is None or not isinstance(n_jobs, int) or
+                n_jobs < 0 or n_jobs > os.cpu_count()):
+            n_jobs = os.cpu_count()
+        if n_jobs > gn:
+            n_jobs = gn
+        batches = -(-gn // n_jobs)
+        if n_jobs > 1:
+            logg.update(f'running {n_jobs} jobs in parallel', v=1)
+    else:
+        n_jobs = 1
+        batches = gn
+    if n_jobs == 1:
+        parallel = False
+
+    pbar = tqdm(total=gn)
+    for group in range(batches):
+        gene_indices = range(group * n_jobs, np.min([gn, (group+1) * n_jobs]))
+        if parallel:
+            verb = 51 if settings.VERBOSITY >= 2 else 0
+
+            res = Parallel(n_jobs=n_jobs, backend='loky', verbose=verb)(
+                delayed(regress_func)(
+                    c_mat[:, i],
+                    u_mat[:, i],
+                    s_mat[:, i],
+                    None if mode == 'deterministic' else Mss[:, i],
+                    None if mode == 'deterministic' else Mus[:, i],
+                    mode,
+                    save_plot,
+                    plot_dir,
+                    fit_args,
+                    gene_list[i],
+                    rna_only,
+                    extra_color)
+                for i in gene_indices)
+
+            for i, r in zip(gene_indices, res):
+                velocity, variance_velocity, r2, loss = r
+                r2s[i] = r2
+                losses[i] = loss
+                velo_s[:, i] = smooth_scale(rna_conn, velocity)
+                if mode == 'stochastic':
+                    variance_velo_s[:, i] = smooth_scale(rna_conn,
+                                                         variance_velocity)
+
+        else:
+            i = group
+            gene = gene_list[i]
+            logg.update(f'@@@@@fitting {gene}', v=1)
+            velocity, variance_velocity, r2, loss = \
+                regress_func(c_mat[:, i],
+                             u_mat[:, i],
+                             s_mat[:, i],
+                             None
+                             if mode == 'deterministic' else Mss[:, i],
+                             None if mode == 'deterministic' else Mus[:, i],
+                             mode,
+                             save_plot,
+                             plot_dir,
+                             fit_args,
+                             gene_list[i],
+                             rna_only,
+                             extra_color)
+            r2s[i] = r2
+            losses[i] = loss
+            velo_s[:, i] = smooth_scale(rna_conn, velocity)
+            if mode == 'stochastic':
+                variance_velo_s[:, i] = smooth_scale(rna_conn,
+                                                     variance_velocity)
+        pbar.update(len(gene_indices))
+    pbar.close()
+
+    filt = losses != np.inf
+    if np.sum(filt) == 0:
+        raise ValueError('None of the genes were fitted due to low quality, '
+                         'not returning')
+    adata_copy = adata_rna[:, gene_list[filt]].copy()
+    adata_copy.layers['ATAC'] = c_mat[:, filt]
+    adata_copy.var['fit_loss'] = losses[filt]
+    adata_copy.var['fit_r2'] = r2s[filt]
+    adata_copy.layers['velo_s'] = velo_s[:, filt]
+    if mode == 'stochastic':
+        adata_copy.layers['variance_velo_s'] = variance_velo_s[:, filt]
+    v_genes = adata_copy.var['fit_r2'] >= min_r2
+    adata_copy.var['velo_s_genes'] = v_genes
+    adata_copy.uns['velo_s_params'] = {'mode': mode, 'fit_offset': False,
+                                       'perc': 95}
+    adata_copy.uns['velo_s_params'].update(fit_args)
+    adata_copy.obsp['_RNA_conn'] = rna_conn
+    return adata_copy
+
+
+def smooth_scale(conn, vector):
+    max_to = np.max(vector)
+    min_to = np.min(vector)
+    v = conn.dot(vector.T).T
+    max_from = np.max(v)
+    min_from = np.min(v)
+    res = ((v - min_from) * (max_to - min_to) / (max_from - min_from)) + min_to
+    return res
+
+
+###############################################################################
+# The following functions are taken directly from scVelo preprocessing
+# [Bergen et al., 2020] (https://github.com/theislab/scvelo)
+###############################################################################
+
+def select_connectivities(connectivities, n_neighbors=None):
+    C = connectivities.copy()
+    n_counts = (C > 0).sum(1).A1 if issparse(C) else (C > 0).sum(1)
+    n_neighbors = (
+        n_counts.min() if n_neighbors is None else min(n_counts.min(),
+                                                       n_neighbors)
+    )
+    rows = np.where(n_counts > n_neighbors)[0]
+    cumsum_neighs = np.insert(n_counts.cumsum(), 0, 0)
+    dat = C.data
+
+    for row in rows:
+        n0, n1 = cumsum_neighs[row], cumsum_neighs[row + 1]
+        rm_idx = n0 + dat[n0:n1].argsort()[::-1][n_neighbors:]
+        dat[rm_idx] = 0
+    C.eliminate_zeros()
+    return C
+
+
+def get_neighs(adata, mode="distances"):
+    if hasattr(adata, "obsp") and mode in adata.obsp.keys():
+        return adata.obsp[mode]
+    elif "neighbors" in adata.uns.keys() and mode in adata.uns["neighbors"]:
+        return adata.uns["neighbors"][mode]
+    else:
+        raise ValueError("The selected mode is not valid.")
+
+
+def get_n_neighs(adata):
+    return (adata.uns.get("neighbors", {}).get("params", {})
+            .get("n_neighbors", 0))
+
+
+def get_connectivities(adata, mode="connectivities", n_neighbors=None,
+                       recurse_neighbors=False):
+    if "neighbors" in adata.uns.keys():
+        C = get_neighs(adata, mode)
+        if n_neighbors is not None and n_neighbors < get_n_neighs(adata):
+            if mode == "connectivities":
+                C = select_connectivities(C, n_neighbors)
+            else:
+                C = select_distances(C, n_neighbors)
+        connectivities = C > 0
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore")
+            connectivities.setdiag(1)
+            if recurse_neighbors:
+                connectivities += connectivities.dot(connectivities * 0.5)
+                connectivities.data = np.clip(connectivities.data, 0, 1)
+            connectivities = connectivities.multiply(1.0 /
+                                                     connectivities.sum(1))
+        return connectivities.tocsr().astype(np.float32)
+    else:
+        return None
+
+
+def second_order_moments(adata, adjusted=False):
+    """Computes second order moments for stochastic velocity estimation.
+    Arguments
+    ---------
+    adata: `AnnData`
+        Annotated data matrix.
+    Returns
+    -------
+    Mss: Second order moments for spliced abundances
+    Mus: Second order moments for spliced with unspliced abundances
+    """
+
+    if "neighbors" not in adata.uns:
+        raise ValueError(
+            "You need to run `pp.neighbors` first to compute a neighborhood "
+            "graph."
+        )
+
+    connectivities = get_connectivities(adata)
+    s, u = csr_matrix(adata.layers["spliced"]), \
+        csr_matrix(adata.layers["unspliced"])
+    if s.shape[0] == 1:
+        s, u = s.T, u.T
+    Mss = csr_matrix.dot(connectivities, s.multiply(s)).astype(np.float32).A
+    Mus = csr_matrix.dot(connectivities, s.multiply(u)).astype(np.float32).A
+    if adjusted:
+        Mss = 2 * Mss - adata.layers["Ms"].reshape(Mss.shape)
+        Mus = 2 * Mus - adata.layers["Mu"].reshape(Mus.shape)
+    return Mss, Mus