Switch to side-by-side view

--- a
+++ b/tests/problems/space/test_mixins.py
@@ -0,0 +1,242 @@
+import pickle
+from math import acos
+from pathlib import Path
+from typing import Dict, List, Optional
+
+import pytest
+
+import numpy as np
+import pandas as pd
+
+from anndata import AnnData
+
+from moscot.problems.space import AlignmentProblem, MappingProblem
+from tests._utils import MockSolverOutput, _adata_spatial_split
+from tests.conftest import ANGLES
+
+# TODO(giovp): refactor as fixture
+SOLUTIONS_PATH_ALIGNMENT = Path(__file__).parent.parent.parent / "data/alignment_solutions.pkl"  # base is moscot
+SOLUTIONS_PATH_MAPPING = Path(__file__).parent.parent.parent / "data/mapping_solutions.pkl"
+
+
+class TestSpatialAlignmentAnalysisMixin:
+    def test_analysis(self, adata_space_rotate: AnnData):
+        import scanpy as sc
+
+        adata_ref = adata_space_rotate.copy()
+        sc.pp.subsample(adata_ref, fraction=0.9)
+        problem = AlignmentProblem(adata=adata_ref).prepare(batch_key="batch").solve(epsilon=1e-1)
+        categories = adata_space_rotate.obs.batch.cat.categories
+
+        for ref in categories:
+            problem.align(reference=ref, mode="affine", key_added="spatial_affine")
+            problem.align(reference=ref, mode="warp", key_added="spatial_warp")
+            tgts = set(categories) - set(ref)
+            for c in zip(tgts):
+                assert (
+                    adata_ref[adata_ref.obs.batch == c].obsm["spatial_warp"].shape
+                    == adata_ref[adata_ref.obs.batch == c].obsm["spatial_affine"].shape
+                )
+            angles = sorted(
+                round(np.rad2deg(acos(arr[0, 0])), 3)
+                for arr in adata_ref.uns["spatial_affine"]["alignment_metadata"].values()
+                if isinstance(arr, np.ndarray)
+            )
+            assert np.sum(angles) <= np.sum(ANGLES) + 2
+
+            problem.align(reference=ref, mode="affine", spatial_key="spatial")
+            for c in zip(tgts):
+                assert (
+                    adata_ref[adata_ref.obs.batch == c].obsm["spatial_affine"].shape
+                    == adata_ref[adata_ref.obs.batch == c].obsm["spatial"].shape
+                )
+
+    def test_regression_testing(self, adata_space_rotate: AnnData):
+        ap = AlignmentProblem(adata=adata_space_rotate).prepare(batch_key="batch").solve(alpha=0.5, epsilon=1)
+        # TODO(giovp): unnecessary assert
+        assert SOLUTIONS_PATH_ALIGNMENT.exists()
+        with open(SOLUTIONS_PATH_ALIGNMENT, "rb") as fname:
+            sol = pickle.load(fname)
+
+        assert sol.keys() == ap.solutions.keys()
+        for k in sol:
+            np.testing.assert_almost_equal(sol[k].cost, ap.solutions[k].cost, decimal=1)
+            np.testing.assert_almost_equal(
+                np.array(sol[k].transport_matrix), np.array(ap.solutions[k].transport_matrix), decimal=3
+            )
+
+    @pytest.mark.fast
+    @pytest.mark.parametrize("forward", [True, False])
+    @pytest.mark.parametrize("normalize", [True, False])
+    def test_cell_transition_pipeline(self, adata_space_rotate: AnnData, forward: bool, normalize: bool):
+        rng = np.random.RandomState(0)
+        adata_space_rotate.obs["celltype"] = rng.choice(["a", "b", "c"], len(adata_space_rotate))
+        adata_space_rotate.obs["celltype"] = adata_space_rotate.obs["celltype"].astype("category")
+        # TODO(@MUCDK) use MockSolverOutput if no regression test
+        ap = AlignmentProblem(adata=adata_space_rotate)
+        ap = ap.prepare(batch_key="batch")
+        mock_tmap = np.abs(
+            rng.randn(
+                len(adata_space_rotate[adata_space_rotate.obs["batch"] == "1"]),
+                len(adata_space_rotate[adata_space_rotate.obs["batch"] == "2"]),
+            )
+        )
+        ap[("1", "2")]._solution = MockSolverOutput(mock_tmap / mock_tmap.sum().sum())
+        result = ap.cell_transition(
+            source="1",
+            target="2",
+            source_groups="celltype",
+            target_groups="celltype",
+            forward=forward,
+            normalize=normalize,
+        )
+        assert isinstance(result, pd.DataFrame)
+        assert result.shape == (3, 3)
+
+    @pytest.mark.fast
+    @pytest.mark.parametrize("forward", [True, False])
+    @pytest.mark.parametrize("mapping_mode", ["max", "sum"])
+    @pytest.mark.parametrize("batch_size", [3, 7, None])
+    @pytest.mark.parametrize("problem_kind", ["alignment"])
+    def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, batch_size, gt_tm_annotation):
+        ap = AlignmentProblem(adata=adata_anno)
+        ap = ap.prepare(batch_key="batch", joint_attr={"attr": "X"})
+        problem_keys = ("0", "1")
+        assert set(ap.problems.keys()) == {problem_keys}
+        ap[problem_keys].set_solution(MockSolverOutput(gt_tm_annotation))
+        annotation_label = "celltype1" if forward else "celltype2"
+        result = ap.annotation_mapping(
+            mapping_mode=mapping_mode,
+            annotation_label=annotation_label,
+            source="0",
+            target="1",
+            forward=forward,
+            batch_size=batch_size,
+        )
+        if forward:
+            expected_result = (
+                adata_anno.uns["expected_max1"] if mapping_mode == "max" else adata_anno.uns["expected_sum1"]
+            )
+        else:
+            expected_result = (
+                adata_anno.uns["expected_max2"] if mapping_mode == "max" else adata_anno.uns["expected_sum2"]
+            )
+        assert (result[annotation_label] == expected_result).all()
+
+
+class TestSpatialMappingAnalysisMixin:
+    @pytest.mark.parametrize("sc_attr", [{"attr": "X"}, {"attr": "obsm", "key": "X_pca"}])
+    @pytest.mark.parametrize("var_names", ["0", [str(i) for i in range(20)]])
+    @pytest.mark.parametrize("groupby", [None, "covariate"])
+    @pytest.mark.parametrize("batch_size", [None, 7, 10, 100])
+    def test_analysis(
+        self,
+        adata_mapping: AnnData,
+        sc_attr: Dict[str, str],
+        var_names: Optional[List[Optional[str]]],
+        groupby: Optional[str],
+        batch_size: Optional[int],
+    ):
+        adataref, adatasp = _adata_spatial_split(adata_mapping)
+        mp = MappingProblem(adataref, adatasp).prepare(batch_key="batch", sc_attr=sc_attr).solve()
+
+        corr = mp.correlate(var_names, groupby=groupby, batch_size=batch_size)
+        imp = mp.impute(batch_size=batch_size)
+
+        if groupby:
+            for key in adata_mapping.obs[groupby].cat.categories:
+                pd.testing.assert_series_equal(*[corr[problem][key] for problem in corr])
+        else:
+            pd.testing.assert_series_equal(*list(corr.values()))
+        assert imp.shape == adatasp.shape
+
+    def test_correspondence(
+        self,
+        adata_mapping: AnnData,
+    ):
+        adataref, adatasp = _adata_spatial_split(adata_mapping)
+        df = (
+            MappingProblem(adataref, adatasp)
+            .prepare(batch_key="batch", sc_attr={"attr": "X"})
+            .spatial_correspondence(interval=[3, 4])
+        )
+        assert "batch" in df.columns
+        np.testing.assert_array_equal(df["batch"].cat.categories, adatasp.obs["batch"].cat.categories)
+        df2 = (
+            MappingProblem(adataref, adatasp)
+            .prepare(batch_key="batch", sc_attr={"attr": "X"})
+            .spatial_correspondence(attr={"attr": "obsm", "key": "spatial"}, interval=[3, 4])
+        )
+        np.testing.assert_array_equal(df.index_interval.cat.categories, df2.index_interval.cat.categories)
+        df3 = MappingProblem(adataref, adatasp).prepare(sc_attr={"attr": "X"}).spatial_correspondence(interval=[2, 3])
+        np.testing.assert_array_equal(df3.value_interval.unique(), (2, 3))
+
+    def test_regression_testing(self, adata_mapping: AnnData):
+        adataref, adatasp = _adata_spatial_split(adata_mapping)
+        mp = MappingProblem(adataref, adatasp)
+        mp = mp.prepare(batch_key="batch", sc_attr={"attr": "X"})
+        mp = mp.solve()
+        assert SOLUTIONS_PATH_MAPPING.exists()
+        with open(SOLUTIONS_PATH_MAPPING, "rb") as fname:
+            sol = pickle.load(fname)
+
+        assert sol.keys() == mp.solutions.keys()
+        for k in sol:
+            np.testing.assert_almost_equal(sol[k].cost, mp.solutions[k].cost, decimal=1)
+            np.testing.assert_almost_equal(
+                np.array(sol[k].transport_matrix), np.array(mp.solutions[k].transport_matrix), decimal=3
+            )
+
+    @pytest.mark.fast
+    @pytest.mark.parametrize("forward", [True, False])
+    @pytest.mark.parametrize("normalize", [True, False])
+    def test_cell_transition_pipeline(self, adata_mapping: AnnData, forward: bool, normalize: bool):
+        rng = np.random.RandomState(0)
+        adataref, adatasp = _adata_spatial_split(adata_mapping)
+        adatasp.obs["celltype"] = rng.choice(["a", "b", "c"], len(adatasp))
+        adatasp.obs["celltype"] = adatasp.obs["celltype"].astype("category")
+        adataref.obs["celltype"] = rng.choice(["d", "e", "f", "g"], len(adataref))
+        adataref.obs["celltype"] = adataref.obs["celltype"].astype("category")
+        # TODO(@MUCDK) use MockSolverOutput if no regression test
+        mp = MappingProblem(adataref, adatasp)
+        mp = mp.prepare(batch_key="batch", sc_attr={"attr": "obsm", "key": "X_pca"})
+        # mp = mp.solve()
+        mock_tmap = np.abs(rng.randn(len(adatasp[adatasp.obs["batch"] == "1"]), len(adataref)))
+        mp[("1", "ref")]._solution = MockSolverOutput(mock_tmap / np.sum(mock_tmap))
+
+        result = mp.cell_transition(
+            source="1",
+            source_groups="celltype",
+            target_groups="celltype",
+            forward=forward,
+            normalize=normalize,
+        )
+
+        assert isinstance(result, pd.DataFrame)
+        assert result.shape == (3, 4)
+
+    @pytest.mark.fast
+    @pytest.mark.parametrize("forward", [True, False])
+    @pytest.mark.parametrize("mapping_mode", ["max", "sum"])
+    @pytest.mark.parametrize("batch_size", [3, 7, None])
+    @pytest.mark.parametrize("problem_kind", ["mapping"])
+    def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, batch_size, gt_tm_annotation):
+        adataref, adatasp = adata_anno
+        mp = MappingProblem(adataref, adatasp)
+        mp = mp.prepare(sc_attr={"attr": "obsm", "key": "X_pca"}, joint_attr={"attr": "X"})
+        problem_keys = ("src", "tgt")
+        assert set(mp.problems.keys()) == {problem_keys}
+        mp[problem_keys].set_solution(MockSolverOutput(gt_tm_annotation.T))
+        annotation_label = "celltype1" if not forward else "celltype2"
+        result = mp.annotation_mapping(
+            mapping_mode=mapping_mode,
+            annotation_label=annotation_label,
+            source="src",
+            forward=forward,
+            batch_size=batch_size,
+        )
+        if not forward:
+            expected_result = adataref.uns["expected_max1"] if mapping_mode == "max" else adataref.uns["expected_sum1"]
+        else:
+            expected_result = adatasp.uns["expected_max2"] if mapping_mode == "max" else adatasp.uns["expected_sum2"]
+        assert (result[annotation_label] == expected_result).all()