--- a +++ b/src/multivelo/steady_chrom_func.py @@ -0,0 +1,631 @@ +import os +import sys +import warnings +import numpy as np +from scipy import sparse +from scipy.sparse import csr_matrix, issparse +from scanpy import Neighbors +import pandas as pd +from tqdm.auto import tqdm +from joblib import Parallel, delayed + +current_path = os.path.dirname(__file__) +src_path = os.path.join(current_path, "..") +sys.path.append(src_path) + +from multivelo import mv_logging as logg +from multivelo import settings + + +class ChromatinVelocity: + def __init__(self, c, u, s, + ss, us, + gene=None, + save_plot=False, + plot_dir=None, + fit_args=None, + rna_only=False, + extra_color=None, + r2_adjusted=True, + ): + + self.gene = gene + + # fitting arguments + self.rna_only = rna_only + self.outlier = np.clip(fit_args['outlier'], 80, 100) + self.r2_adjusted = r2_adjusted + + # plot parameters + self.save_plot = save_plot + self.extra_color = extra_color + self.fig_size = fit_args['fig_size'] + self.point_size = fit_args['point_size'] + if plot_dir is None: + self.plot_path = 'plots_steady_state' + else: + self.plot_path = plot_dir + + # input + self.total_n = len(u) + if sparse.issparse(c): + c = c.A + if sparse.issparse(u): + u = u.A + if sparse.issparse(s): + s = s.A + if ss is not None and sparse.issparse(ss): + ss = ss.A + if us is not None and sparse.issparse(us): + us = us.A + self.c_all = np.ravel(np.array(c, dtype=np.float64)) + self.u_all = np.ravel(np.array(u, dtype=np.float64)) + self.s_all = np.ravel(np.array(s, dtype=np.float64)) + if ss is not None: + self.ss_all = np.ravel(np.array(ss, dtype=np.float64)) + if us is not None: + self.us_all = np.ravel(np.array(us, dtype=np.float64)) + + # adjust offset + self.offset_c, self.offset_u, self.offset_s = np.min(self.c_all), \ + np.min(self.u_all), \ + np.min(self.s_all) + self.offset_c = 0 if self.rna_only else self.offset_c + self.c_all -= self.offset_c + self.u_all -= self.offset_u + self.s_all -= self.offset_s + # remove zero counts + self.non_zero = np.ravel(self.c_all > 0) | np.ravel(self.u_all > 0) | \ + np.ravel(self.s_all > 0) + # remove outliers + self.non_outlier = np.ravel(self.c_all <= + np.percentile(self.c_all, self.outlier)) + self.non_outlier &= np.ravel(self.u_all <= + np.percentile(self.u_all, self.outlier)) + self.non_outlier &= np.ravel(self.s_all <= + np.percentile(self.s_all, self.outlier)) + self.c = self.c_all[self.non_zero & self.non_outlier] + self.u = self.u_all[self.non_zero & self.non_outlier] + self.s = self.s_all[self.non_zero & self.non_outlier] + self.ss = (None if ss is None + else self.ss_all[self.non_zero & self.non_outlier]) + self.us = (None if us is None + else self.us_all[self.non_zero & self.non_outlier]) + self.low_quality = len(self.u) < 10 + + logg.update(f'{len(self.u)} cells passed filter and will be used to ' + 'fit regressions.', v=2) + + # 4 rate parameters + self.alpha_c = 0.1 + self.alpha = 0.0 + self.beta = 0.0 + self.gamma = 0.0 + + # other parameters or results + self.loss = np.inf + self.r2 = 0 + self.residual = None + self.residual2 = None + self.steady_state_func = None + + # select cells for regression + w_sub = (self.c >= 0.1 * np.max(self.c)) & \ + (self.u >= 0.1 * np.max(self.u)) & \ + (self.s >= 0.1 * np.max(self.s)) + c_sub = self.c[w_sub] + if not self.rna_only: + w_sub = (self.c >= np.mean(c_sub)+np.std(c_sub)) & \ + (self.u >= 0.1 * np.max(self.u)) & \ + (self.s >= 0.1 * np.max(self.s)) + self.w_sub = w_sub + if np.sum(self.w_sub) < 10: + self.low_quality = True + + def compute_deterministic(self): + if self.rna_only: + # steady-state slope + wu = self.u >= np.percentile(self.u[self.w_sub], 95) + ws = self.s >= np.percentile(self.s[self.w_sub], 95) + ss_u = self.u[wu | ws] + ss_s = self.s[wu | ws] + else: + # chromatin adjusted steady-state slope + u_high = self.u[self.w_sub] + s_high = self.s[self.w_sub] + wu_high = u_high >= np.percentile(u_high, 95) + ws_high = s_high >= np.percentile(s_high, 95) + ss_u = u_high[wu_high | ws_high] + ss_s = s_high[wu_high | ws_high] + gamma = np.dot(ss_u, ss_s) / np.dot(ss_s, ss_s) + self.steady_state_func = lambda x: gamma*x + residual = self.u_all - self.steady_state_func(self.s_all) + self.residual = residual + self.loss = np.dot(self.residual, self.residual) / len(self.u_all) + + if self.r2_adjusted: + gamma = np.dot(self.u, self.s) / np.dot(self.s, self.s) + residual = self.u_all - gamma * self.s_all + + total = self.u_all - np.mean(self.u_all) + self.r2 = 1 - np.dot(residual, residual) / np.dot(total, total) + + def compute_stochastic(self): + self.compute_deterministic() + + var_ss = 2 * self.ss - self.s + cov_us = 2 * self.us + self.u + s_all_ = 2 * self.s_all**2 - (2 * self.ss_all - self.s_all) + u_all_ = (2 * self.us_all + self.u_all) - 2 * self.u_all*self.s_all + gamma2 = np.dot(cov_us, var_ss) / np.dot(var_ss, var_ss) + residual2 = cov_us - gamma2 * var_ss + std_first = np.std(self.residual) + std_second = np.std(residual2) + + # combine first and second moments and recompute gamma + if self.rna_only: + # steady-state slope + wu = self.u >= np.percentile(self.u[self.w_sub], 95) + ws = self.s >= np.percentile(self.s[self.w_sub], 95) + ss_u = self.u * (wu | ws) + ss_s = self.s * (wu | ws) + a = np.hstack((ss_s / std_first, var_ss / std_second)) + b = np.hstack((ss_u / std_first, cov_us / std_second)) + else: + # chromatin adjusted steady-state slope + u_high = self.u[self.w_sub] + s_high = self.s[self.w_sub] + wu_high = u_high >= np.percentile(u_high, 95) + ws_high = s_high >= np.percentile(s_high, 95) + ss_u = u_high * (wu_high | ws_high) + ss_s = s_high * (wu_high | ws_high) + a = np.hstack((ss_s / std_first, var_ss[self.w_sub] / std_second)) + b = np.hstack((ss_u / std_first, cov_us[self.w_sub] / std_second)) + gamma = np.dot(b, a) / np.dot(a, a) + self.steady_state_func = lambda x: gamma*x + self.residual = self.u_all - self.steady_state_func(self.s_all) + self.residual2 = u_all_ - self.steady_state_func(s_all_) + self.loss = np.dot(self.residual, self.residual) / len(self.u_all) + + def get_velocity(self): + return self.residual + + def get_variance_velocity(self): + return self.residual2 + + def get_r2(self): + return self.r2 + + def get_loss(self): + return self.loss + + +def regress_func(c, u, s, ss, us, m, sp, pdir, fa, gene, ro, extra): + + c_90 = np.percentile(c, 90) + u_90 = np.percentile(u, 90) + s_90 = np.percentile(s, 90) + low_quality = ((u_90 == 0 or s_90 == 0) if ro + else (c_90 == 0 or u_90 == 0 or s_90 == 0)) + if low_quality: + logg.update(f'low quality gene {gene}, skipping', v=1) + return np.zeros(len(u)), np.zeros(len(u)), 0, np.inf + + cvc = ChromatinVelocity(c, + u, + s, + ss, + us, + save_plot=sp, + plot_dir=pdir, + fit_args=fa, + gene=gene, + rna_only=ro, + extra_color=extra) + if cvc.low_quality: + return np.zeros(len(u)), np.zeros(len(u)), 0, np.inf + + if m == 'deterministic': + cvc.compute_deterministic() + elif m == 'stochastic': + cvc.compute_stochastic() + velocity = cvc.get_velocity() + r2 = cvc.get_r2() + loss = cvc.get_loss() + variance_velocity = (None if m == 'deterministic' + else cvc.get_variance_velocity()) + return velocity, variance_velocity, r2, loss + + +def velocity_chrom(adata_rna, + adata_atac=None, + gene_list=None, + mode='stochastic', + parallel=True, + n_jobs=None, + save_plot=False, + plot_dir=None, + rna_only=False, + extra_color_key=None, + min_r2=1e-2, + outlier=99.8, + n_pcs=30, + n_neighbors=30, + fig_size=(8, 6), + point_size=7 + ): + + """Multi-omic steady-state model. + + This function incorporates chromatin accessibilities into RNA steady-state + velocity. + + Parameters + ---------- + adata_rna: :class:`~anndata.AnnData` + RNA anndata object. Required fields: `Mu`, `Ms`, and `connectivities`. + adata_atac: :class:`~anndata.AnnData` (default: `None`) + ATAC anndata object. Required fields: `Mc`. + gene_list: `str`, list of `str` (default: highly variable genes) + Genes to use for model fitting. + mode: `str` (default: `'stochastic'`) + Fitting method. + `'stochastic'`: computing steady-state ratio with the first and second + moments. + `'deterministic'`: computing steady-state ratio with the first moments. + parallel: `bool` (default: `True`) + Whether to fit genes in a parallel fashion (recommended). + n_jobs: `int` (default: available threads) + Number of parallel jobs. + save_plot: `bool` (default: `False`) + Whether to save the fitted gene portrait figures as files. This will + take some disk space. + plot_dir: `str` (default: `plots` for multiome and `rna_plots` for + RNA-only) + Directory to save the plots. + rna_only: `bool` (default: `False`) + Whether to only use RNA for fitting (RNA velocity). + extra_color_key: `str` (default: `None`) + Extra color key used for plotting. Common choices are `leiden`, + `celltype`, etc. + The colors for each category must be present in one of anndatas, which + can be pre-computed. + with `scanpy.pl.scatter` function. + min_r2: `float` (default: 1e-2) + Minimum R-squared value for selecting velocity genes. + outlier: `float` (default: 99.8) + The percentile to mark as outlier that will be excluded when fitting + the model. + n_pcs: `int` (default: 30) + Number of principal components to compute distance smoothing neighbors. + This can be different from the one used for expression smoothing. + n_neighbors: `int` (default: 30) + Number of nearest neighbors for distance smoothing. + This can be different from the one used for expression smoothing. + fig_size: `tuple` (default: (8,6)) + Size of each figure when saved. + point_size: `float` (default: 7) + Marker point size for plotting. + + Returns + ------- + fit_r2: `.var` + R-squared of regression fit + fit_loss: `.var` + loss of model fit + velo_s: `.layers` + velocities in spliced space + variance_velo_s: `.layers` + variance velocities based on second moments in spliced space + velo_s_genes: `.var` + velocity genes + velo_s_params: `.var` + fitting arguments used + ATAC: `.layers` + KNN smoothed chromatin accessibilities copied from adata_atac + """ + + fit_args = {} + fit_args['min_r2'] = min_r2 + fit_args['outlier'] = outlier + fit_args['n_pcs'] = n_pcs + fit_args['n_neighbors'] = n_neighbors + fit_args['fig_size'] = list(fig_size) + fit_args['point_size'] = point_size + if mode == 'dynamical': + logg.update('You do not need to run mv.velocity for chromatin ' + 'dynamical model', v=0) + return + elif mode == 'stochastic' or mode == 'deterministic': + fit_args['mode'] = mode + else: + raise ValueError('Unknown mode. Must be either stochastic or ' + 'deterministic') + + all_genes = adata_rna.var_names + if adata_atac is None: + import anndata as ad + rna_only = True + adata_atac = ad.AnnData(X=np.ones(adata_rna.shape), obs=adata_rna.obs, + var=adata_rna.var) + adata_atac.layers['Mc'] = np.ones(adata_rna.shape) + if adata_rna.shape != adata_atac.shape: + raise ValueError('Shape of RNA and ATAC adata objects do not match:' + f'{adata_rna.shape} {adata_atac.shape}') + if not np.all(adata_rna.obs_names == adata_atac.obs_names): + raise ValueError('obs_names of RNA and ATAC adata objects do not ' + 'match, please check if they are consistent') + if not np.all(all_genes == adata_atac.var_names): + raise ValueError('var_names of RNA and ATAC adata objects do not ' + 'match, please check if they are consistent') + if extra_color_key is None: + extra_color = None + elif (isinstance(extra_color_key, str) and extra_color_key in adata_rna.obs + and adata_rna.obs[extra_color_key].dtype.name == 'category'): + ngroups = len(adata_rna.obs[extra_color_key].cat.categories) + extra_color = adata_rna.obs[extra_color_key].cat.rename_categories( + adata_rna.uns[extra_color_key+'_colors'][:ngroups]).to_numpy() + elif (isinstance(extra_color_key, str) + and extra_color_key in adata_atac.obs and + adata_rna.obs[extra_color_key].dtype.name == 'category'): + ngroups = len(adata_atac.obs[extra_color_key].cat.categories) + extra_color = adata_atac.obs[extra_color_key].cat.rename_categories( + adata_atac.uns[extra_color_key+'_colors'][:ngroups]).to_numpy() + else: + raise ValueError('Currently, extra_color_key must be a single string ' + 'of categories and available in adata obs, and its ' + 'colors can be found in adata uns') + if ('connectivities' not in adata_rna.obsp.keys() or + (adata_rna.obsp['connectivities'] > 0).sum(1).min() + > (n_neighbors-1)): + neighbors = Neighbors(adata_rna) + neighbors.compute_neighbors(n_neighbors=n_neighbors, + knn=True, n_pcs=n_pcs) + rna_conn = neighbors.connectivities + else: + rna_conn = adata_rna.obsp['connectivities'].copy() + rna_conn.setdiag(1) + rna_conn = rna_conn.multiply(1.0 / rna_conn.sum(1)).tocsr() + + Mss, Mus = None, None + if mode == 'stochastic': + Mss, Mus = second_order_moments(adata_rna) + + if gene_list is None: + if 'highly_variable' in adata_rna.var: + gene_list = adata_rna.var_names[ + adata_rna.var['highly_variable']].values + else: + gene_list = adata_rna.var_names.values[ + (~np.isnan(np.asarray(adata_rna.layers['Mu'].sum(0)) + .reshape(-1) + if sparse.issparse(adata_rna.layers['Mu']) + else np.sum(adata_rna.layers['Mu'], axis=0))) + & (~np.isnan(np.asarray(adata_rna.layers['Ms'].sum(0)) + .reshape(-1) + if sparse.issparse(adata_rna.layers['Ms']) + else np.sum(adata_rna.layers['Ms'], axis=0))) + & (~np.isnan(np.asarray(adata_atac.layers['Mc'].sum(0)) + .reshape(-1) + if sparse.issparse(adata_atac.layers['Mc']) + else np.sum(adata_atac.layers['Mc'], axis=0)))] + elif isinstance(gene_list, (list, np.ndarray, pd.Index, pd.Series)): + gene_list = np.array([x for x in gene_list if x in all_genes]) + elif isinstance(gene_list, str): + gene_list = np.array([gene_list]) if gene_list in all_genes else [] + else: + raise ValueError('Invalid gene list, must be one of (str, np.ndarray,' + ' pd.Index, pd.Series)') + gn = len(gene_list) + if gn == 0: + raise ValueError('None of the genes specified are in the adata object') + + logg.update(f'{gn} genes will be fitted', v=1) + + velo_s = np.zeros((adata_rna.n_obs, gn)) + variance_velo_s = np.zeros((adata_rna.n_obs, gn)) + r2s = np.zeros(gn) + losses = np.zeros(gn) + + u_mat = (adata_rna[:, gene_list].layers['Mu'].A + if sparse.issparse(adata_rna.layers['Mu']) + else adata_rna[:, gene_list].layers['Mu']) + s_mat = (adata_rna[:, gene_list].layers['Ms'].A + if sparse.issparse(adata_rna.layers['Ms']) + else adata_rna[:, gene_list].layers['Ms']) + c_mat = (adata_atac[:, gene_list].layers['Mc'].A + if sparse.issparse(adata_atac.layers['Mc']) + else adata_atac[:, gene_list].layers['Mc']) + if parallel: + if (n_jobs is None or not isinstance(n_jobs, int) or + n_jobs < 0 or n_jobs > os.cpu_count()): + n_jobs = os.cpu_count() + if n_jobs > gn: + n_jobs = gn + batches = -(-gn // n_jobs) + if n_jobs > 1: + logg.update(f'running {n_jobs} jobs in parallel', v=1) + else: + n_jobs = 1 + batches = gn + if n_jobs == 1: + parallel = False + + pbar = tqdm(total=gn) + for group in range(batches): + gene_indices = range(group * n_jobs, np.min([gn, (group+1) * n_jobs])) + if parallel: + verb = 51 if settings.VERBOSITY >= 2 else 0 + + res = Parallel(n_jobs=n_jobs, backend='loky', verbose=verb)( + delayed(regress_func)( + c_mat[:, i], + u_mat[:, i], + s_mat[:, i], + None if mode == 'deterministic' else Mss[:, i], + None if mode == 'deterministic' else Mus[:, i], + mode, + save_plot, + plot_dir, + fit_args, + gene_list[i], + rna_only, + extra_color) + for i in gene_indices) + + for i, r in zip(gene_indices, res): + velocity, variance_velocity, r2, loss = r + r2s[i] = r2 + losses[i] = loss + velo_s[:, i] = smooth_scale(rna_conn, velocity) + if mode == 'stochastic': + variance_velo_s[:, i] = smooth_scale(rna_conn, + variance_velocity) + + else: + i = group + gene = gene_list[i] + logg.update(f'@@@@@fitting {gene}', v=1) + velocity, variance_velocity, r2, loss = \ + regress_func(c_mat[:, i], + u_mat[:, i], + s_mat[:, i], + None + if mode == 'deterministic' else Mss[:, i], + None if mode == 'deterministic' else Mus[:, i], + mode, + save_plot, + plot_dir, + fit_args, + gene_list[i], + rna_only, + extra_color) + r2s[i] = r2 + losses[i] = loss + velo_s[:, i] = smooth_scale(rna_conn, velocity) + if mode == 'stochastic': + variance_velo_s[:, i] = smooth_scale(rna_conn, + variance_velocity) + pbar.update(len(gene_indices)) + pbar.close() + + filt = losses != np.inf + if np.sum(filt) == 0: + raise ValueError('None of the genes were fitted due to low quality, ' + 'not returning') + adata_copy = adata_rna[:, gene_list[filt]].copy() + adata_copy.layers['ATAC'] = c_mat[:, filt] + adata_copy.var['fit_loss'] = losses[filt] + adata_copy.var['fit_r2'] = r2s[filt] + adata_copy.layers['velo_s'] = velo_s[:, filt] + if mode == 'stochastic': + adata_copy.layers['variance_velo_s'] = variance_velo_s[:, filt] + v_genes = adata_copy.var['fit_r2'] >= min_r2 + adata_copy.var['velo_s_genes'] = v_genes + adata_copy.uns['velo_s_params'] = {'mode': mode, 'fit_offset': False, + 'perc': 95} + adata_copy.uns['velo_s_params'].update(fit_args) + adata_copy.obsp['_RNA_conn'] = rna_conn + return adata_copy + + +def smooth_scale(conn, vector): + max_to = np.max(vector) + min_to = np.min(vector) + v = conn.dot(vector.T).T + max_from = np.max(v) + min_from = np.min(v) + res = ((v - min_from) * (max_to - min_to) / (max_from - min_from)) + min_to + return res + + +############################################################################### +# The following functions are taken directly from scVelo preprocessing +# [Bergen et al., 2020] (https://github.com/theislab/scvelo) +############################################################################### + +def select_connectivities(connectivities, n_neighbors=None): + C = connectivities.copy() + n_counts = (C > 0).sum(1).A1 if issparse(C) else (C > 0).sum(1) + n_neighbors = ( + n_counts.min() if n_neighbors is None else min(n_counts.min(), + n_neighbors) + ) + rows = np.where(n_counts > n_neighbors)[0] + cumsum_neighs = np.insert(n_counts.cumsum(), 0, 0) + dat = C.data + + for row in rows: + n0, n1 = cumsum_neighs[row], cumsum_neighs[row + 1] + rm_idx = n0 + dat[n0:n1].argsort()[::-1][n_neighbors:] + dat[rm_idx] = 0 + C.eliminate_zeros() + return C + + +def get_neighs(adata, mode="distances"): + if hasattr(adata, "obsp") and mode in adata.obsp.keys(): + return adata.obsp[mode] + elif "neighbors" in adata.uns.keys() and mode in adata.uns["neighbors"]: + return adata.uns["neighbors"][mode] + else: + raise ValueError("The selected mode is not valid.") + + +def get_n_neighs(adata): + return (adata.uns.get("neighbors", {}).get("params", {}) + .get("n_neighbors", 0)) + + +def get_connectivities(adata, mode="connectivities", n_neighbors=None, + recurse_neighbors=False): + if "neighbors" in adata.uns.keys(): + C = get_neighs(adata, mode) + if n_neighbors is not None and n_neighbors < get_n_neighs(adata): + if mode == "connectivities": + C = select_connectivities(C, n_neighbors) + else: + C = select_distances(C, n_neighbors) + connectivities = C > 0 + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + connectivities.setdiag(1) + if recurse_neighbors: + connectivities += connectivities.dot(connectivities * 0.5) + connectivities.data = np.clip(connectivities.data, 0, 1) + connectivities = connectivities.multiply(1.0 / + connectivities.sum(1)) + return connectivities.tocsr().astype(np.float32) + else: + return None + + +def second_order_moments(adata, adjusted=False): + """Computes second order moments for stochastic velocity estimation. + Arguments + --------- + adata: `AnnData` + Annotated data matrix. + Returns + ------- + Mss: Second order moments for spliced abundances + Mus: Second order moments for spliced with unspliced abundances + """ + + if "neighbors" not in adata.uns: + raise ValueError( + "You need to run `pp.neighbors` first to compute a neighborhood " + "graph." + ) + + connectivities = get_connectivities(adata) + s, u = csr_matrix(adata.layers["spliced"]), \ + csr_matrix(adata.layers["unspliced"]) + if s.shape[0] == 1: + s, u = s.T, u.T + Mss = csr_matrix.dot(connectivities, s.multiply(s)).astype(np.float32).A + Mus = csr_matrix.dot(connectivities, s.multiply(u)).astype(np.float32).A + if adjusted: + Mss = 2 * Mss - adata.layers["Ms"].reshape(Mss.shape) + Mus = 2 * Mus - adata.layers["Mu"].reshape(Mus.shape) + return Mss, Mus