--- a +++ b/tests/plotting/conftest.py @@ -0,0 +1,136 @@ +from abc import ABC, ABCMeta +from functools import wraps +from pathlib import Path +from typing import Callable, Optional + +import pytest + +import numpy as np +import pandas as pd + +import matplotlib.pyplot as plt +from matplotlib.testing.compare import compare_images + +from anndata import AnnData + +from moscot import _constants +from moscot.plotting._utils import set_plotting_vars + +HERE: Path = Path(__file__).parent + +EXPECTED = HERE / "expected_figures" +ACTUAL = HERE / "actual_figures" +TOL = 60 +DPI = 40 + + +@pytest.fixture +def adata_pl_cell_transition(gt_temporal_adata: AnnData) -> AnnData: + plot_vars = { + "transition_matrix": gt_temporal_adata.uns["cell_transition_10_105_forward"], + "source_groups": "cell_type", + "target_groups": "cell_type", + "source": 0, + "target": 1, + } + set_plotting_vars(gt_temporal_adata, _constants.CELL_TRANSITION, key=_constants.CELL_TRANSITION, value=plot_vars) + + return gt_temporal_adata + + +@pytest.fixture +def adata_pl_push(adata_time: AnnData) -> AnnData: + rng = np.random.RandomState(0) + plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} + adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"] + adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category") + set_plotting_vars(adata_time, _constants.PUSH, key=_constants.PUSH, value=plot_vars) + push_initial_dist = np.zeros( + shape=(len(adata_time[adata_time.obs["time"] == 0]),) + ) # we need this for a cat. distr. in plots + push_initial_dist[0:10] = 0.1 + nan2 = np.empty(len(adata_time[adata_time.obs["time"] == 2])) + nan2[:] = np.nan + adata_time.obs[_constants.PUSH] = np.hstack( + (push_initial_dist, np.abs(rng.randn(len(adata_time[adata_time.obs["time"] == 1]))), nan2) + ) + return adata_time + + +@pytest.fixture +def adata_pl_pull(adata_time: AnnData) -> AnnData: + rng = np.random.RandomState(0) + plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} + adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"] + adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category") + set_plotting_vars(adata_time, _constants.PULL, key=_constants.PULL, value=plot_vars) + pull_initial_dist = np.zeros( + shape=(len(adata_time[adata_time.obs["time"] == 1]),) + ) # we need this for a cat. distr. in plots + pull_initial_dist[0:10] = 0.1 + rand0 = np.abs(rng.randn(len(adata_time[adata_time.obs["time"] == 0]))) + nan2 = np.empty(len(adata_time[adata_time.obs["time"] == 2])) + nan2[:] = np.nan + adata_time.obs[_constants.PULL] = np.hstack((rand0, pull_initial_dist, nan2)) + return adata_time + + +@pytest.fixture +def adata_pl_sankey(adata_time: AnnData) -> AnnData: + rng = np.random.RandomState(0) + celltypes = ["A", "B", "C", "D", "E"] + adata_time.obs["celltype"] = rng.choice(celltypes, size=len(adata_time)) + adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category") + data1 = np.abs(rng.randn(5, 5)) + data2 = np.abs(rng.randn(5, 5)) + tm1 = pd.DataFrame(data=data1, index=celltypes, columns=celltypes) + tm2 = pd.DataFrame(data=data2, index=celltypes, columns=celltypes) + plot_vars = {"transition_matrices": [tm1, tm2], "captions": ["0", "1"], "key": "celltype"} + set_plotting_vars(adata_time, _constants.SANKEY, key=_constants.SANKEY, value=plot_vars) + + return adata_time + + +def _decorate(fn: Callable, clsname: str, name: Optional[str] = None) -> Callable: + @wraps(fn) + def save_and_compare(self, *args, **kwargs): + fn(self, *args, **kwargs) + self.compare(fig_name) + + if not callable(fn): + raise TypeError(f"Expected a `callable` for class `{clsname}`, found `{type(fn).__name__}`.") + + name = fn.__name__ if name is None else name + + if not name.startswith("test_plot_") or not clsname.startswith("Test"): + return fn + + fig_name = f"{clsname[4:]}_{name[10:]}" + + return save_and_compare + + +class PlotTesterMeta(ABCMeta): + def __new__(cls, clsname, superclasses, attributedict): + for key, value in attributedict.items(): + if callable(value): + attributedict[key] = _decorate(value, clsname, name=key) + return super().__new__(cls, clsname, superclasses, attributedict) + + +# ideally, we would you metaclass=PlotTesterMeta and all plotting tests just subclass this +# but for some reason, pytest erases the metaclass info +class PlotTester(ABC): # noqa: B024 + @classmethod + def compare(cls, basename: str, tolerance: Optional[float] = None): + ACTUAL.mkdir(parents=True, exist_ok=True) + out_path = ACTUAL / f"{basename}.png" + + plt.savefig(out_path, dpi=DPI) + plt.close() + + tolerance = TOL if tolerance is None else tolerance + + res = compare_images(str(EXPECTED / f"{basename}.png"), str(out_path), tolerance) + + assert res is None, res