--- a
+++ b/tests/problems/generic/test_mixins.py
@@ -0,0 +1,419 @@
+from typing import List, Literal, Optional, Tuple
+
+import pytest
+
+import numpy as np
+import pandas as pd
+from scipy.sparse.linalg import LinearOperator
+
+from anndata import AnnData
+
+from tests._utils import ATOL, RTOL, CompoundProblemWithMixin, MockSolverOutput
+
+
+class TestBaseAnalysisMixin:
+    @pytest.mark.parametrize("n_samples", [10, 42])
+    @pytest.mark.parametrize("account_for_unbalancedness", [True, False])
+    @pytest.mark.parametrize("interpolation_parameter", [None, 0.1, 5])
+    def test_sample_from_tmap_pipeline(
+        self,
+        gt_temporal_adata: AnnData,
+        n_samples: int,
+        account_for_unbalancedness: bool,
+        interpolation_parameter: Optional[float],
+    ):
+        source_dim = len(gt_temporal_adata[gt_temporal_adata.obs["day"] == 10])
+        target_dim = len(gt_temporal_adata[gt_temporal_adata.obs["day"] == 10.5])
+        problem = CompoundProblemWithMixin(gt_temporal_adata)
+        problem = problem.prepare(key="day", subset=[(10, 10.5)], policy="sequential", xy_callback="local-pca")
+        problem[10, 10.5]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"])
+
+        if interpolation_parameter is not None and not 0 <= interpolation_parameter <= 1:
+            with pytest.raises(ValueError, match=r"^Expected interpolation"):
+                problem._sample_from_tmap(
+                    10,
+                    10.5,
+                    n_samples,
+                    source_dim=source_dim,
+                    target_dim=target_dim,
+                    account_for_unbalancedness=account_for_unbalancedness,
+                    interpolation_parameter=interpolation_parameter,
+                )
+        elif interpolation_parameter is None and account_for_unbalancedness:
+            with pytest.raises(ValueError, match=r"^When accounting for unbalancedness"):
+                problem._sample_from_tmap(
+                    10,
+                    10.5,
+                    n_samples,
+                    source_dim=source_dim,
+                    target_dim=target_dim,
+                    account_for_unbalancedness=account_for_unbalancedness,
+                    interpolation_parameter=interpolation_parameter,
+                )
+        else:
+            result = problem._sample_from_tmap(
+                10,
+                10.5,
+                n_samples,
+                source_dim=source_dim,
+                target_dim=target_dim,
+                account_for_unbalancedness=account_for_unbalancedness,
+                interpolation_parameter=interpolation_parameter,
+            )
+            assert isinstance(result, tuple)
+            assert isinstance(result[0], np.ndarray)
+            assert isinstance(result[1], list)
+            assert isinstance(result[1][0], np.ndarray)
+            assert len(np.concatenate(result[1])) == n_samples
+
+    @pytest.mark.parametrize("forward", [True, False])
+    @pytest.mark.parametrize("scale_by_marginals", [True, False])
+    def test_interpolate_transport(self, gt_temporal_adata: AnnData, forward: bool, scale_by_marginals: bool):
+        problem = CompoundProblemWithMixin(gt_temporal_adata)
+        problem = problem.prepare(
+            key="day", subset=[(10, 10.5), (10.5, 11), (10, 11)], policy="explicit", xy_callback="local-pca"
+        )
+        problem[(10.0, 10.5)]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"])
+        problem[(10.5, 11.0)]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_105_11"])
+        problem[(10.0, 11.0)]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_11"])
+        tmap = problem._interpolate_transport([(10, 11)], scale_by_marginals=True, explicit_steps=[(10.0, 11.0)])
+
+        assert isinstance(tmap, LinearOperator)
+        # TODO(@MUCDK) add regression test after discussing with @giovp what this function should be
+        # doing / it is more generic
+
+    def test_cell_transition_aggregation_cell_forward(self, gt_temporal_adata: AnnData):
+        # the method used in this test does the same but has to instantiate the transport matrix
+        config = gt_temporal_adata.uns
+        config["key"]
+        key_1 = config["key_1"]
+        key_2 = config["key_2"]
+        config["key_3"]
+        problem = CompoundProblemWithMixin(gt_temporal_adata)
+        problem = problem.prepare(key="day", subset=[(10, 10.5)], policy="explicit", xy_callback="local-pca")
+        assert set(problem.problems.keys()) == {(key_1, key_2)}
+        problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"])
+
+        ctr = problem._cell_transition(
+            key="day",
+            source=10,
+            target=10.5,
+            source_groups=None,
+            target_groups="cell_type",
+            forward=True,
+            aggregation_mode="cell",
+        )
+
+        adata_early = gt_temporal_adata[gt_temporal_adata.obs["day"] == 10]
+        adata_late = gt_temporal_adata[gt_temporal_adata.obs["day"] == 10.5]
+
+        transition_matrix_indexed = pd.DataFrame(
+            index=adata_early.obs.index, columns=adata_late.obs.index, data=gt_temporal_adata.uns["tmap_10_105"]
+        )
+        unique_cell_types_late = adata_late.obs["cell_type"].cat.categories
+        df_res = pd.DataFrame(index=adata_early.obs.index)
+        for ct in unique_cell_types_late:
+            cols_cell_type = adata_late[adata_late.obs["cell_type"] == ct].obs.index
+            df_res[ct] = transition_matrix_indexed[cols_cell_type].sum(axis=1)
+
+        df_res = df_res.div(df_res.sum(axis=1), axis=0)
+
+        ctr_ordered = ctr.sort_index().sort_index(axis=1)
+        df_res_ordered = df_res.sort_index().sort_index(axis=1)
+        np.testing.assert_allclose(
+            ctr_ordered.values.astype(float), df_res_ordered.values.astype(float), rtol=RTOL, atol=ATOL
+        )
+
+    def test_cell_transition_aggregation_cell_backward(self, gt_temporal_adata: AnnData):
+        # the method used in this test does the same but has to instantiate the transport matrix
+        config = gt_temporal_adata.uns
+        config["key"]
+        key_1 = config["key_1"]
+        key_2 = config["key_2"]
+        problem = CompoundProblemWithMixin(gt_temporal_adata)
+        problem = problem.prepare(key="day", subset=[(10, 10.5)], policy="explicit", xy_callback="local-pca")
+        assert set(problem.problems.keys()) == {(key_1, key_2)}
+        problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"])
+
+        ctr = problem._cell_transition(
+            key="day",
+            source=10,
+            target=10.5,
+            source_groups="cell_type",
+            target_groups=None,
+            forward=False,
+            aggregation_mode="cell",
+        )
+
+        adata_early = gt_temporal_adata[gt_temporal_adata.obs["day"] == 10]
+        adata_late = gt_temporal_adata[gt_temporal_adata.obs["day"] == 10.5]
+
+        transition_matrix_indexed = pd.DataFrame(
+            index=adata_early.obs.index, columns=adata_late.obs.index, data=gt_temporal_adata.uns["tmap_10_105"]
+        )
+        unique_cell_types_early = adata_early.obs["cell_type"].cat.categories
+        df_res = pd.DataFrame(columns=adata_late.obs.index)
+        for ct in unique_cell_types_early:
+            rows_cell_type = adata_early[adata_early.obs["cell_type"] == ct].obs.index
+            df_res.loc[ct] = transition_matrix_indexed.loc[rows_cell_type].sum(axis=0)
+
+        df_res = df_res.div(df_res.sum(axis=0), axis=1)
+
+        ctr_ordered = ctr.sort_index().sort_index(axis=1)
+        df_res_ordered = df_res.sort_index().sort_index(axis=1)
+        np.testing.assert_allclose(
+            ctr_ordered.values.astype(float), df_res_ordered.values.astype(float), rtol=RTOL, atol=ATOL
+        )
+
+    @pytest.mark.parametrize("corr_method", ["pearson", "spearman"])
+    @pytest.mark.parametrize("significance_method", ["fisher", "perm_test"])
+    def test_compute_feature_correlation(
+        self,
+        adata_time: AnnData,
+        corr_method: Literal["pearson", "spearman"],
+        significance_method: Literal["fisher", "perm_test"],
+    ):
+        key_added = "test"
+        rng = np.random.RandomState(42)
+        adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
+        n0 = adata_time[adata_time.obs["time"] == 0].n_obs
+        n1 = adata_time[adata_time.obs["time"] == 1].n_obs
+        tmap = rng.uniform(1e-6, 1, size=(n0, n1))
+        tmap /= tmap.sum().sum()
+        problem = CompoundProblemWithMixin(adata_time)
+        problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential")
+        problem[0, 1]._solution = MockSolverOutput(tmap)
+
+        adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze()))
+
+        res = problem.compute_feature_correlation(
+            obs_key=key_added, corr_method=corr_method, significance_method=significance_method
+        )
+
+        assert isinstance(res, pd.DataFrame)
+        assert res.isnull().values.sum() == 0
+
+        assert np.all(res[f"{key_added}_corr"] >= -1.0)
+        assert np.all(res[f"{key_added}_corr"] <= 1.0)
+
+        assert np.all(res[f"{key_added}_qval"] >= 0)
+        assert np.all(res[f"{key_added}_qval"] <= 1.0)
+
+    @pytest.mark.parametrize("corr_method", ["pearson", "spearman"])
+    @pytest.mark.parametrize("features", [10, None])
+    @pytest.mark.parametrize("method", ["fisher", "perm_test"])
+    def test_compute_feature_correlation_subset(
+        self,
+        adata_time: AnnData,
+        features: Optional[int],
+        corr_method: Literal["pearson", "spearman"],
+        method: Literal["fisher", "perm_test"],
+    ):
+        key_added = "test"
+        rng = np.random.RandomState(42)
+        adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
+        n0 = adata_time[adata_time.obs["time"] == 0].n_obs
+        n1 = adata_time[adata_time.obs["time"] == 1].n_obs
+        tmap = rng.uniform(1e-6, 1, size=(n0, n1))
+        tmap /= tmap.sum().sum()
+        problem = CompoundProblemWithMixin(adata_time)
+        problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential")
+        problem[0, 1]._solution = MockSolverOutput(tmap)
+
+        adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze()))
+
+        if isinstance(features, int):
+            features = list(adata_time.var_names)[:features]
+            features_validation = features
+        else:
+            features_validation = list(adata_time.var_names)
+        res = problem.compute_feature_correlation(
+            obs_key=key_added,
+            annotation={"celltype": ["A"]},
+            corr_method=corr_method,
+            significance_method=method,
+            features=features,
+        )
+        assert isinstance(res, pd.DataFrame)
+        assert res.isnull().values.sum() == 0
+        assert set(res.index) == set(features_validation)
+
+    @pytest.mark.parametrize(
+        "features",
+        [
+            ("human", ["KLF12", "ZNF143"]),
+            ("mouse", ["Zic5"]),
+            ("drosophila", ["Cf2", "Dlip3", "Dref"]),
+            ("error", [None]),
+        ],
+    )
+    def test_compute_feature_correlation_transcription_factors(
+        self,
+        adata_time: AnnData,
+        features: Tuple[str, List[str]],
+    ):
+        key_added = "test"
+        rng = np.random.RandomState(42)
+        adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
+        n0 = adata_time[adata_time.obs["time"] == 0].n_obs
+        n1 = adata_time[adata_time.obs["time"] == 1].n_obs
+        tmap = rng.uniform(1e-6, 1, size=(n0, n1))
+        tmap /= tmap.sum().sum()
+        problem = CompoundProblemWithMixin(adata_time)
+        problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential")
+        problem[0, 1]._solution = MockSolverOutput(tmap)
+
+        adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze()))
+
+        if features[0] == "error":
+            with np.testing.assert_raises(NotImplementedError):
+                res = problem.compute_feature_correlation(
+                    obs_key=key_added, annotation={"celltype": ["A"]}, features=features[0]
+                )
+        else:
+            res = problem.compute_feature_correlation(
+                obs_key=key_added, annotation={"celltype": ["A"]}, features=features[0]
+            )
+            assert res.isnull().values.sum() == 0
+            assert isinstance(res, pd.DataFrame)
+            assert set(res.index) == set(features[1])
+
+    @pytest.mark.parametrize("forward", [True, False])
+    @pytest.mark.parametrize("key_added", [None, "test"])
+    @pytest.mark.parametrize("batch_size", [None, 2])
+    @pytest.mark.parametrize("c", [0.0, 0.1])
+    def test_compute_entropy_pipeline(
+        self, adata_time: AnnData, forward: bool, key_added: Optional[str], batch_size: int, c: float
+    ):
+        rng = np.random.RandomState(42)
+        adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
+        n0 = adata_time[adata_time.obs["time"] == 0].n_obs
+        n1 = adata_time[adata_time.obs["time"] == 1].n_obs
+
+        tmap = rng.uniform(1e-6, 1, size=(n0, n1))
+        tmap /= tmap.sum().sum()
+        problem = CompoundProblemWithMixin(adata_time)
+        problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential")
+        problem[0, 1]._solution = MockSolverOutput(tmap)
+
+        out = problem.compute_entropy(
+            source=0, target=1, forward=forward, key_added=key_added, batch_size=batch_size, c=c
+        )
+        if key_added is None:
+            assert isinstance(out, pd.DataFrame)
+            assert len(out) == n0
+        else:
+            assert out is None
+            assert key_added in adata_time.obs
+            assert np.sum(adata_time[adata_time.obs["time"] == int(1 - forward)].obs[key_added].isna()) == 0
+            assert (
+                np.sum(adata_time[adata_time.obs["time"] == int(forward)].obs[key_added].isna()) == n1
+                if forward
+                else n0
+            )
+
+    @pytest.mark.parametrize("forward", [True, False])
+    @pytest.mark.parametrize("batch_size", [None, 2, 15])
+    def test_compute_entropy_regression(self, adata_time: AnnData, forward: bool, batch_size: Optional[int]):
+        from scipy import stats
+
+        rng = np.random.RandomState(42)
+        adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
+        n0 = adata_time[adata_time.obs["time"] == 0].n_obs
+        n1 = adata_time[adata_time.obs["time"] == 1].n_obs
+
+        tmap = rng.uniform(1e-6, 1, size=(n0, n1))
+        tmap /= tmap.sum().sum()
+        problem = CompoundProblemWithMixin(adata_time)
+        problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential")
+        problem[0, 1]._solution = MockSolverOutput(tmap)
+
+        moscot_out = problem.compute_entropy(source=0, target=1, forward=forward, batch_size=batch_size, key_added=None)
+        gt_out = stats.entropy(tmap + 1e-10, axis=1 if forward else 0)
+        gt_out = np.expand_dims(gt_out, axis=1) if forward else np.expand_dims(gt_out, axis=0).T
+
+        np.testing.assert_allclose(
+            np.array(moscot_out, dtype=float), np.array(gt_out, dtype=float), rtol=RTOL, atol=ATOL
+        )
+
+    def test_seed_reproducible(self, adata_time: AnnData):
+        key_added = "test"
+        rng = np.random.RandomState(42)
+        adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
+        n0 = adata_time[adata_time.obs["time"] == 0].n_obs
+        n1 = adata_time[adata_time.obs["time"] == 1].n_obs
+        tmap = rng.uniform(1e-6, 1, size=(n0, n1))
+        tmap /= tmap.sum().sum()
+        problem = CompoundProblemWithMixin(adata_time)
+        problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential")
+        problem[0, 1]._solution = MockSolverOutput(tmap)
+
+        adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze()))
+
+        res_a = problem.compute_feature_correlation(
+            obs_key=key_added, n_perms=10, n_jobs=1, seed=0, significance_method="perm_test"
+        )
+        res_b = problem.compute_feature_correlation(
+            obs_key=key_added, n_perms=10, n_jobs=1, seed=0, significance_method="perm_test"
+        )
+        res_c = problem.compute_feature_correlation(
+            obs_key=key_added, n_perms=10, n_jobs=1, seed=2, significance_method="perm_test"
+        )
+
+        assert res_a is not res_b
+        np.testing.assert_array_equal(res_a.index, res_b.index)
+        np.testing.assert_array_equal(res_a.columns, res_b.columns)
+        np.testing.assert_allclose(res_a.values, res_b.values)
+
+        assert not np.allclose(res_a.values, res_c.values)
+
+    def test_seed_reproducible_parallelized(self, adata_time: AnnData):
+        key_added = "test"
+        rng = np.random.RandomState(42)
+        adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
+        n0 = adata_time[adata_time.obs["time"] == 0].n_obs
+        n1 = adata_time[adata_time.obs["time"] == 1].n_obs
+        tmap = rng.uniform(1e-6, 1, size=(n0, n1))
+        tmap /= tmap.sum().sum()
+        problem = CompoundProblemWithMixin(adata_time)
+        problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential")
+        problem[0, 1]._solution = MockSolverOutput(tmap)
+
+        adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze()))
+
+        res_a = problem.compute_feature_correlation(
+            obs_key=key_added, n_perms=10, n_jobs=2, backend="threading", seed=0, method="perm_test"
+        )
+        res_b = problem.compute_feature_correlation(
+            obs_key=key_added, n_perms=10, n_jobs=2, backend="threading", seed=0, method="perm_test"
+        )
+
+        assert res_a is not res_b
+        np.testing.assert_array_equal(res_a.index, res_b.index)
+        np.testing.assert_array_equal(res_a.columns, res_b.columns)
+        np.testing.assert_allclose(res_a.values, res_b.values)
+
+    @pytest.mark.parametrize("corr_method", ["pearson", "spearman"])
+    def test_confidence_level(self, adata_time: AnnData, corr_method: Literal["pearson", "spearman"]):
+        key_added = "test"
+        rng = np.random.RandomState(42)
+        adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy()
+        n0 = adata_time[adata_time.obs["time"] == 0].n_obs
+        n1 = adata_time[adata_time.obs["time"] == 1].n_obs
+        tmap = rng.uniform(1e-6, 1, size=(n0, n1))
+        tmap /= tmap.sum().sum()
+        problem = CompoundProblemWithMixin(adata_time)
+        problem = problem.prepare(key="time", xy_callback="local-pca", policy="sequential")
+        problem[0, 1]._solution = MockSolverOutput(tmap)
+
+        adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze()))
+
+        res_narrow = problem.compute_feature_correlation(
+            obs_key=key_added, corr_method=corr_method, confidence_level=0.95
+        )
+        res_wide = problem.compute_feature_correlation(
+            obs_key=key_added, corr_method=corr_method, confidence_level=0.99
+        )
+
+        assert np.all(res_narrow[f"{key_added}_ci_low"] >= res_wide[f"{key_added}_ci_low"])
+        assert np.all(res_narrow[f"{key_added}_ci_high"] <= res_wide[f"{key_added}_ci_high"])