Switch to side-by-side view

--- a
+++ b/tests/plotting/conftest.py
@@ -0,0 +1,136 @@
+from abc import ABC, ABCMeta
+from functools import wraps
+from pathlib import Path
+from typing import Callable, Optional
+
+import pytest
+
+import numpy as np
+import pandas as pd
+
+import matplotlib.pyplot as plt
+from matplotlib.testing.compare import compare_images
+
+from anndata import AnnData
+
+from moscot import _constants
+from moscot.plotting._utils import set_plotting_vars
+
+HERE: Path = Path(__file__).parent
+
+EXPECTED = HERE / "expected_figures"
+ACTUAL = HERE / "actual_figures"
+TOL = 60
+DPI = 40
+
+
+@pytest.fixture
+def adata_pl_cell_transition(gt_temporal_adata: AnnData) -> AnnData:
+    plot_vars = {
+        "transition_matrix": gt_temporal_adata.uns["cell_transition_10_105_forward"],
+        "source_groups": "cell_type",
+        "target_groups": "cell_type",
+        "source": 0,
+        "target": 1,
+    }
+    set_plotting_vars(gt_temporal_adata, _constants.CELL_TRANSITION, key=_constants.CELL_TRANSITION, value=plot_vars)
+
+    return gt_temporal_adata
+
+
+@pytest.fixture
+def adata_pl_push(adata_time: AnnData) -> AnnData:
+    rng = np.random.RandomState(0)
+    plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1}
+    adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"]
+    adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category")
+    set_plotting_vars(adata_time, _constants.PUSH, key=_constants.PUSH, value=plot_vars)
+    push_initial_dist = np.zeros(
+        shape=(len(adata_time[adata_time.obs["time"] == 0]),)
+    )  # we need this for a cat. distr. in plots
+    push_initial_dist[0:10] = 0.1
+    nan2 = np.empty(len(adata_time[adata_time.obs["time"] == 2]))
+    nan2[:] = np.nan
+    adata_time.obs[_constants.PUSH] = np.hstack(
+        (push_initial_dist, np.abs(rng.randn(len(adata_time[adata_time.obs["time"] == 1]))), nan2)
+    )
+    return adata_time
+
+
+@pytest.fixture
+def adata_pl_pull(adata_time: AnnData) -> AnnData:
+    rng = np.random.RandomState(0)
+    plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1}
+    adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"]
+    adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category")
+    set_plotting_vars(adata_time, _constants.PULL, key=_constants.PULL, value=plot_vars)
+    pull_initial_dist = np.zeros(
+        shape=(len(adata_time[adata_time.obs["time"] == 1]),)
+    )  # we need this for a cat. distr. in plots
+    pull_initial_dist[0:10] = 0.1
+    rand0 = np.abs(rng.randn(len(adata_time[adata_time.obs["time"] == 0])))
+    nan2 = np.empty(len(adata_time[adata_time.obs["time"] == 2]))
+    nan2[:] = np.nan
+    adata_time.obs[_constants.PULL] = np.hstack((rand0, pull_initial_dist, nan2))
+    return adata_time
+
+
+@pytest.fixture
+def adata_pl_sankey(adata_time: AnnData) -> AnnData:
+    rng = np.random.RandomState(0)
+    celltypes = ["A", "B", "C", "D", "E"]
+    adata_time.obs["celltype"] = rng.choice(celltypes, size=len(adata_time))
+    adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category")
+    data1 = np.abs(rng.randn(5, 5))
+    data2 = np.abs(rng.randn(5, 5))
+    tm1 = pd.DataFrame(data=data1, index=celltypes, columns=celltypes)
+    tm2 = pd.DataFrame(data=data2, index=celltypes, columns=celltypes)
+    plot_vars = {"transition_matrices": [tm1, tm2], "captions": ["0", "1"], "key": "celltype"}
+    set_plotting_vars(adata_time, _constants.SANKEY, key=_constants.SANKEY, value=plot_vars)
+
+    return adata_time
+
+
+def _decorate(fn: Callable, clsname: str, name: Optional[str] = None) -> Callable:
+    @wraps(fn)
+    def save_and_compare(self, *args, **kwargs):
+        fn(self, *args, **kwargs)
+        self.compare(fig_name)
+
+    if not callable(fn):
+        raise TypeError(f"Expected a `callable` for class `{clsname}`, found `{type(fn).__name__}`.")
+
+    name = fn.__name__ if name is None else name
+
+    if not name.startswith("test_plot_") or not clsname.startswith("Test"):
+        return fn
+
+    fig_name = f"{clsname[4:]}_{name[10:]}"
+
+    return save_and_compare
+
+
+class PlotTesterMeta(ABCMeta):
+    def __new__(cls, clsname, superclasses, attributedict):
+        for key, value in attributedict.items():
+            if callable(value):
+                attributedict[key] = _decorate(value, clsname, name=key)
+        return super().__new__(cls, clsname, superclasses, attributedict)
+
+
+# ideally, we would you metaclass=PlotTesterMeta and all plotting tests just subclass this
+# but for some reason, pytest erases the metaclass info
+class PlotTester(ABC):  # noqa: B024
+    @classmethod
+    def compare(cls, basename: str, tolerance: Optional[float] = None):
+        ACTUAL.mkdir(parents=True, exist_ok=True)
+        out_path = ACTUAL / f"{basename}.png"
+
+        plt.savefig(out_path, dpi=DPI)
+        plt.close()
+
+        tolerance = TOL if tolerance is None else tolerance
+
+        res = compare_images(str(EXPECTED / f"{basename}.png"), str(out_path), tolerance)
+
+        assert res is None, res