from typing import List, Literal, Optional, Tuple
import pytest
import numpy as np
import pandas as pd
from scipy.sparse.linalg import LinearOperator
from anndata import AnnData
from tests._utils import ATOL, RTOL, CompoundProblemWithMixin, MockSolverOutput
class TestBaseAnalysisMixin:
@pytest.mark.parametrize("n_samples", [10, 42])
@pytest.mark.parametrize("account_for_unbalancedness", [True, False])
@pytest.mark.parametrize("interpolation_parameter", [None, 0.1, 5])
def test_sample_from_tmap_pipeline(
self,
gt_temporal_adata: AnnData,
n_samples: int,
account_for_unbalancedness: bool,
interpolation_parameter: Optional[float],
):
source_dim = len(gt_temporal_adata[gt_temporal_adata.obs["day"] == 10])
target_dim = len(gt_temporal_adata[gt_temporal_adata.obs["day"] == 10.5])
problem = CompoundProblemWithMixin(gt_temporal_adata)
problem = problem.prepare(key="day", subset=[(10, 10.5)], policy="sequential", xy_callback="local-pca")
problem[10, 10.5]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"])
if interpolation_parameter is not None and not 0 <= interpolation_parameter <= 1:
with pytest.raises(ValueError, match=r"^Expected interpolation"):
problem._sample_from_tmap(
10,
10.5,
n_samples,
source_dim=source_dim,
target_dim=target_dim,
account_for_unbalancedness=account_for_unbalancedness,
interpolation_parameter=interpolation_parameter,
)
elif interpolation_parameter is None and account_for_unbalancedness:
with pytest.raises(ValueError, match=r"^When accounting for unbalancedness"):
problem._sample_from_tmap(
10,
10.5,
n_samples,
source_dim=source_dim,
target_dim=target_dim,
account_for_unbalancedness=account_for_unbalancedness,
interpolation_parameter=interpolation_parameter,
)
else:
result = problem._sample_from_tmap(
10,
10.5,
n_samples,
source_dim=source_dim,
target_dim=target_dim,
account_for_unbalancedness=account_for_unbalancedness,
interpolation_parameter=interpolation_parameter,
)
assert isinstance(result, tuple)
assert isinstance(result[0], np.ndarray)
assert isinstance(result[1], list)
assert isinstance(result[1][0], np.ndarray)
assert len(np.concatenate(result[1])) == n_samples
@pytest.mark.parametrize("forward", [True, False])
@pytest.mark.parametrize("scale_by_marginals", [True, False])
def test_interpolate_transport(self, gt_temporal_adata: AnnData, forward: bool, scale_by_marginals: bool):
problem = CompoundProblemWithMixin(gt_temporal_adata)
problem = problem.prepare(
key="day", subset=[(10, 10.5), (10.5, 11), (10, 11)], policy="explicit", xy_callback="local-pca"
)
problem[(10.0, 10.5)]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"])
problem[(10.5, 11.0)]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_105_11"])
problem[(10.0, 11.0)]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_11"])
tmap = problem._interpolate_transport([(10, 11)], scale_by_marginals=True, explicit_steps=[(10.0, 11.0)])
assert isinstance(tmap, LinearOperator)
# TODO(@MUCDK) add regression test after discussing with @giovp what this function should be
# doing / it is more generic
def test_cell_transition_aggregation_cell_forward(self, gt_temporal_adata: AnnData):
# the method used in this test does the same but has to instantiate the transport matrix
config = gt_temporal_adata.uns
config["key"]
key_1 = config["key_1"]
key_2 = config["key_2"]
config["key_3"]
problem = CompoundProblemWithMixin(gt_temporal_adata)
problem = problem.prepare(key="day", subset=[(10, 10.5)], policy="explicit", xy_callback="local-pca")
assert set(problem.problems.keys()) == {(key_1, key_2)}
problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"])
ctr = problem._cell_transition(
key="day",
source=10,
target=10.5,
source_groups=None,
target_groups="cell_type",
forward=True,
aggregation_mode="cell",
)
adata_early = gt_temporal_adata[gt_temporal_adata.obs["day"] == 10]
adata_late = gt_temporal_adata[gt_temporal_adata.obs["day"] == 10.5]
transition_matrix_indexed = pd.DataFrame(
index=adata_early.obs.index, columns=adata_late.obs.index, data=gt_temporal_adata.uns["tmap_10_105"]
)
unique_cell_types_late = adata_late.obs["cell_type"].cat.categories
df_res = pd.DataFrame(index=adata_early.obs.index)
for ct in unique_cell_types_late:
cols_cell_type = adata_late[adata_late.obs["cell_type"] == ct].obs.index
df_res[ct] = transition_matrix_indexed[cols_cell_type].sum(axis=1)
df_res = df_res.div(df_res.sum(axis=1), axis=0)
ctr_ordered = ctr.sort_index().sort_index(axis=1)
df_res_ordered = df_res.sort_index().sort_index(axis=1)
np.testing.assert_allclose(
ctr_ordered.values.astype(float), df_res_ordered.values.astype(float), rtol=RTOL, atol=ATOL
)
def test_cell_transition_aggregation_cell_backward(self, gt_temporal_adata: AnnData):
# the method used in this test does the same but has to instantiate the transport matrix
config = gt_temporal_adata.uns
config["key"]
key_1 = config["key_1"]
key_2 = config["key_2"]
problem = CompoundProblemWithMixin(gt_temporal_adata)
problem = problem.prepare(key="day", subset=[(10, 10.5)], policy="explicit", xy_callback="local-pca")
assert set(problem.problems.keys()) == {(key_1, key_2)}
problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"])
ctr = problem._cell_transition(
key="day",
source=10,
target=10.5,
source_groups="cell_type",
target_groups=None,
forward=False,
aggregation_mode="cell",
)
adata_early = gt_temporal_adata[gt_temporal_adata.obs["day"] == 10]
adata_late = gt_temporal_adata[gt_temporal_adata.obs["day"] == 10.5]
transition_matrix_indexed = pd.DataFrame(
index=adata_early.obs.index, columns=adata_late.obs.index, data=gt_temporal_adata.uns["tmap_10_105"]
)
unique_cell_types_early = adata_early.obs["cell_type"].cat.categories
df_res = pd.DataFrame(columns=adata_late.obs.index)
for ct in unique_cell_types_early:
rows_cell_type = adata_early[adata_early.obs["cell_type"] == ct].obs.index
df_res.loc[ct] = transition_matrix_indexed.loc[rows_cell_type].sum(axis=0)
df_res = df_res.div(df_res.sum(axis=0), axis=1)
ctr_ordered = ctr.sort_index().sort_index(axis=1)
df_res_ordered = df_res.sort_index().sort_index(axis=1)
np.testing.assert_allclose(
ctr_ordered.values.astype(float), df_res_ordered.values.astype(float), rtol=RTOL, atol=ATOL
)
@pytest.mark.parametrize("corr_method", ["pearson", "spearman"])
@pytest.mark.parametrize("significance_method", ["fisher", "perm_test"])
def test_compute_feature_correlation(
self,
adata_time: AnnData,
corr_method: Literal["pearson", "spearman"],
significance_method: Literal["fisher", "perm_test"],
):
key_added = "test"
rng = np.random.RandomState(42)
adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
n0 = adata_time[adata_time.obs["time"] == 0].n_obs
n1 = adata_time[adata_time.obs["time"] == 1].n_obs
tmap = rng.uniform(1e-6, 1, size=(n0, n1))
tmap /= tmap.sum().sum()
problem = CompoundProblemWithMixin(adata_time)
problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential")
problem[0, 1]._solution = MockSolverOutput(tmap)
adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze()))
res = problem.compute_feature_correlation(
obs_key=key_added, corr_method=corr_method, significance_method=significance_method
)
assert isinstance(res, pd.DataFrame)
assert res.isnull().values.sum() == 0
assert np.all(res[f"{key_added}_corr"] >= -1.0)
assert np.all(res[f"{key_added}_corr"] <= 1.0)
assert np.all(res[f"{key_added}_qval"] >= 0)
assert np.all(res[f"{key_added}_qval"] <= 1.0)
@pytest.mark.parametrize("corr_method", ["pearson", "spearman"])
@pytest.mark.parametrize("features", [10, None])
@pytest.mark.parametrize("method", ["fisher", "perm_test"])
def test_compute_feature_correlation_subset(
self,
adata_time: AnnData,
features: Optional[int],
corr_method: Literal["pearson", "spearman"],
method: Literal["fisher", "perm_test"],
):
key_added = "test"
rng = np.random.RandomState(42)
adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
n0 = adata_time[adata_time.obs["time"] == 0].n_obs
n1 = adata_time[adata_time.obs["time"] == 1].n_obs
tmap = rng.uniform(1e-6, 1, size=(n0, n1))
tmap /= tmap.sum().sum()
problem = CompoundProblemWithMixin(adata_time)
problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential")
problem[0, 1]._solution = MockSolverOutput(tmap)
adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze()))
if isinstance(features, int):
features = list(adata_time.var_names)[:features]
features_validation = features
else:
features_validation = list(adata_time.var_names)
res = problem.compute_feature_correlation(
obs_key=key_added,
annotation={"celltype": ["A"]},
corr_method=corr_method,
significance_method=method,
features=features,
)
assert isinstance(res, pd.DataFrame)
assert res.isnull().values.sum() == 0
assert set(res.index) == set(features_validation)
@pytest.mark.parametrize(
"features",
[
("human", ["KLF12", "ZNF143"]),
("mouse", ["Zic5"]),
("drosophila", ["Cf2", "Dlip3", "Dref"]),
("error", [None]),
],
)
def test_compute_feature_correlation_transcription_factors(
self,
adata_time: AnnData,
features: Tuple[str, List[str]],
):
key_added = "test"
rng = np.random.RandomState(42)
adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
n0 = adata_time[adata_time.obs["time"] == 0].n_obs
n1 = adata_time[adata_time.obs["time"] == 1].n_obs
tmap = rng.uniform(1e-6, 1, size=(n0, n1))
tmap /= tmap.sum().sum()
problem = CompoundProblemWithMixin(adata_time)
problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential")
problem[0, 1]._solution = MockSolverOutput(tmap)
adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze()))
if features[0] == "error":
with np.testing.assert_raises(NotImplementedError):
res = problem.compute_feature_correlation(
obs_key=key_added, annotation={"celltype": ["A"]}, features=features[0]
)
else:
res = problem.compute_feature_correlation(
obs_key=key_added, annotation={"celltype": ["A"]}, features=features[0]
)
assert res.isnull().values.sum() == 0
assert isinstance(res, pd.DataFrame)
assert set(res.index) == set(features[1])
@pytest.mark.parametrize("forward", [True, False])
@pytest.mark.parametrize("key_added", [None, "test"])
@pytest.mark.parametrize("batch_size", [None, 2])
@pytest.mark.parametrize("c", [0.0, 0.1])
def test_compute_entropy_pipeline(
self, adata_time: AnnData, forward: bool, key_added: Optional[str], batch_size: int, c: float
):
rng = np.random.RandomState(42)
adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
n0 = adata_time[adata_time.obs["time"] == 0].n_obs
n1 = adata_time[adata_time.obs["time"] == 1].n_obs
tmap = rng.uniform(1e-6, 1, size=(n0, n1))
tmap /= tmap.sum().sum()
problem = CompoundProblemWithMixin(adata_time)
problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential")
problem[0, 1]._solution = MockSolverOutput(tmap)
out = problem.compute_entropy(
source=0, target=1, forward=forward, key_added=key_added, batch_size=batch_size, c=c
)
if key_added is None:
assert isinstance(out, pd.DataFrame)
assert len(out) == n0
else:
assert out is None
assert key_added in adata_time.obs
assert np.sum(adata_time[adata_time.obs["time"] == int(1 - forward)].obs[key_added].isna()) == 0
assert (
np.sum(adata_time[adata_time.obs["time"] == int(forward)].obs[key_added].isna()) == n1
if forward
else n0
)
@pytest.mark.parametrize("forward", [True, False])
@pytest.mark.parametrize("batch_size", [None, 2, 15])
def test_compute_entropy_regression(self, adata_time: AnnData, forward: bool, batch_size: Optional[int]):
from scipy import stats
rng = np.random.RandomState(42)
adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
n0 = adata_time[adata_time.obs["time"] == 0].n_obs
n1 = adata_time[adata_time.obs["time"] == 1].n_obs
tmap = rng.uniform(1e-6, 1, size=(n0, n1))
tmap /= tmap.sum().sum()
problem = CompoundProblemWithMixin(adata_time)
problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential")
problem[0, 1]._solution = MockSolverOutput(tmap)
moscot_out = problem.compute_entropy(source=0, target=1, forward=forward, batch_size=batch_size, key_added=None)
gt_out = stats.entropy(tmap + 1e-10, axis=1 if forward else 0)
gt_out = np.expand_dims(gt_out, axis=1) if forward else np.expand_dims(gt_out, axis=0).T
np.testing.assert_allclose(
np.array(moscot_out, dtype=float), np.array(gt_out, dtype=float), rtol=RTOL, atol=ATOL
)
def test_seed_reproducible(self, adata_time: AnnData):
key_added = "test"
rng = np.random.RandomState(42)
adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
n0 = adata_time[adata_time.obs["time"] == 0].n_obs
n1 = adata_time[adata_time.obs["time"] == 1].n_obs
tmap = rng.uniform(1e-6, 1, size=(n0, n1))
tmap /= tmap.sum().sum()
problem = CompoundProblemWithMixin(adata_time)
problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential")
problem[0, 1]._solution = MockSolverOutput(tmap)
adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze()))
res_a = problem.compute_feature_correlation(
obs_key=key_added, n_perms=10, n_jobs=1, seed=0, significance_method="perm_test"
)
res_b = problem.compute_feature_correlation(
obs_key=key_added, n_perms=10, n_jobs=1, seed=0, significance_method="perm_test"
)
res_c = problem.compute_feature_correlation(
obs_key=key_added, n_perms=10, n_jobs=1, seed=2, significance_method="perm_test"
)
assert res_a is not res_b
np.testing.assert_array_equal(res_a.index, res_b.index)
np.testing.assert_array_equal(res_a.columns, res_b.columns)
np.testing.assert_allclose(res_a.values, res_b.values)
assert not np.allclose(res_a.values, res_c.values)
def test_seed_reproducible_parallelized(self, adata_time: AnnData):
key_added = "test"
rng = np.random.RandomState(42)
adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
n0 = adata_time[adata_time.obs["time"] == 0].n_obs
n1 = adata_time[adata_time.obs["time"] == 1].n_obs
tmap = rng.uniform(1e-6, 1, size=(n0, n1))
tmap /= tmap.sum().sum()
problem = CompoundProblemWithMixin(adata_time)
problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential")
problem[0, 1]._solution = MockSolverOutput(tmap)
adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze()))
res_a = problem.compute_feature_correlation(
obs_key=key_added, n_perms=10, n_jobs=2, backend="threading", seed=0, method="perm_test"
)
res_b = problem.compute_feature_correlation(
obs_key=key_added, n_perms=10, n_jobs=2, backend="threading", seed=0, method="perm_test"
)
assert res_a is not res_b
np.testing.assert_array_equal(res_a.index, res_b.index)
np.testing.assert_array_equal(res_a.columns, res_b.columns)
np.testing.assert_allclose(res_a.values, res_b.values)
@pytest.mark.parametrize("corr_method", ["pearson", "spearman"])
def test_confidence_level(self, adata_time: AnnData, corr_method: Literal["pearson", "spearman"]):
key_added = "test"
rng = np.random.RandomState(42)
adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
n0 = adata_time[adata_time.obs["time"] == 0].n_obs
n1 = adata_time[adata_time.obs["time"] == 1].n_obs
tmap = rng.uniform(1e-6, 1, size=(n0, n1))
tmap /= tmap.sum().sum()
problem = CompoundProblemWithMixin(adata_time)
problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential")
problem[0, 1]._solution = MockSolverOutput(tmap)
adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze()))
res_narrow = problem.compute_feature_correlation(
obs_key=key_added, corr_method=corr_method, confidence_level=0.95
)
res_wide = problem.compute_feature_correlation(
obs_key=key_added, corr_method=corr_method, confidence_level=0.99
)
assert np.all(res_narrow[f"{key_added}_ci_low"] >= res_wide[f"{key_added}_ci_low"])
assert np.all(res_narrow[f"{key_added}_ci_high"] <= res_wide[f"{key_added}_ci_high"])