a b/tests/plotting/conftest.py
1
from abc import ABC, ABCMeta
2
from functools import wraps
3
from pathlib import Path
4
from typing import Callable, Optional
5
6
import pytest
7
8
import numpy as np
9
import pandas as pd
10
11
import matplotlib.pyplot as plt
12
from matplotlib.testing.compare import compare_images
13
14
from anndata import AnnData
15
16
from moscot import _constants
17
from moscot.plotting._utils import set_plotting_vars
18
19
HERE: Path = Path(__file__).parent
20
21
EXPECTED = HERE / "expected_figures"
22
ACTUAL = HERE / "actual_figures"
23
TOL = 60
24
DPI = 40
25
26
27
@pytest.fixture
28
def adata_pl_cell_transition(gt_temporal_adata: AnnData) -> AnnData:
29
    plot_vars = {
30
        "transition_matrix": gt_temporal_adata.uns["cell_transition_10_105_forward"],
31
        "source_groups": "cell_type",
32
        "target_groups": "cell_type",
33
        "source": 0,
34
        "target": 1,
35
    }
36
    set_plotting_vars(gt_temporal_adata, _constants.CELL_TRANSITION, key=_constants.CELL_TRANSITION, value=plot_vars)
37
38
    return gt_temporal_adata
39
40
41
@pytest.fixture
42
def adata_pl_push(adata_time: AnnData) -> AnnData:
43
    rng = np.random.RandomState(0)
44
    plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1}
45
    adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"]
46
    adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category")
47
    set_plotting_vars(adata_time, _constants.PUSH, key=_constants.PUSH, value=plot_vars)
48
    push_initial_dist = np.zeros(
49
        shape=(len(adata_time[adata_time.obs["time"] == 0]),)
50
    )  # we need this for a cat. distr. in plots
51
    push_initial_dist[0:10] = 0.1
52
    nan2 = np.empty(len(adata_time[adata_time.obs["time"] == 2]))
53
    nan2[:] = np.nan
54
    adata_time.obs[_constants.PUSH] = np.hstack(
55
        (push_initial_dist, np.abs(rng.randn(len(adata_time[adata_time.obs["time"] == 1]))), nan2)
56
    )
57
    return adata_time
58
59
60
@pytest.fixture
61
def adata_pl_pull(adata_time: AnnData) -> AnnData:
62
    rng = np.random.RandomState(0)
63
    plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1}
64
    adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"]
65
    adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category")
66
    set_plotting_vars(adata_time, _constants.PULL, key=_constants.PULL, value=plot_vars)
67
    pull_initial_dist = np.zeros(
68
        shape=(len(adata_time[adata_time.obs["time"] == 1]),)
69
    )  # we need this for a cat. distr. in plots
70
    pull_initial_dist[0:10] = 0.1
71
    rand0 = np.abs(rng.randn(len(adata_time[adata_time.obs["time"] == 0])))
72
    nan2 = np.empty(len(adata_time[adata_time.obs["time"] == 2]))
73
    nan2[:] = np.nan
74
    adata_time.obs[_constants.PULL] = np.hstack((rand0, pull_initial_dist, nan2))
75
    return adata_time
76
77
78
@pytest.fixture
79
def adata_pl_sankey(adata_time: AnnData) -> AnnData:
80
    rng = np.random.RandomState(0)
81
    celltypes = ["A", "B", "C", "D", "E"]
82
    adata_time.obs["celltype"] = rng.choice(celltypes, size=len(adata_time))
83
    adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category")
84
    data1 = np.abs(rng.randn(5, 5))
85
    data2 = np.abs(rng.randn(5, 5))
86
    tm1 = pd.DataFrame(data=data1, index=celltypes, columns=celltypes)
87
    tm2 = pd.DataFrame(data=data2, index=celltypes, columns=celltypes)
88
    plot_vars = {"transition_matrices": [tm1, tm2], "captions": ["0", "1"], "key": "celltype"}
89
    set_plotting_vars(adata_time, _constants.SANKEY, key=_constants.SANKEY, value=plot_vars)
90
91
    return adata_time
92
93
94
def _decorate(fn: Callable, clsname: str, name: Optional[str] = None) -> Callable:
95
    @wraps(fn)
96
    def save_and_compare(self, *args, **kwargs):
97
        fn(self, *args, **kwargs)
98
        self.compare(fig_name)
99
100
    if not callable(fn):
101
        raise TypeError(f"Expected a `callable` for class `{clsname}`, found `{type(fn).__name__}`.")
102
103
    name = fn.__name__ if name is None else name
104
105
    if not name.startswith("test_plot_") or not clsname.startswith("Test"):
106
        return fn
107
108
    fig_name = f"{clsname[4:]}_{name[10:]}"
109
110
    return save_and_compare
111
112
113
class PlotTesterMeta(ABCMeta):
114
    def __new__(cls, clsname, superclasses, attributedict):
115
        for key, value in attributedict.items():
116
            if callable(value):
117
                attributedict[key] = _decorate(value, clsname, name=key)
118
        return super().__new__(cls, clsname, superclasses, attributedict)
119
120
121
# ideally, we would you metaclass=PlotTesterMeta and all plotting tests just subclass this
122
# but for some reason, pytest erases the metaclass info
123
class PlotTester(ABC):  # noqa: B024
124
    @classmethod
125
    def compare(cls, basename: str, tolerance: Optional[float] = None):
126
        ACTUAL.mkdir(parents=True, exist_ok=True)
127
        out_path = ACTUAL / f"{basename}.png"
128
129
        plt.savefig(out_path, dpi=DPI)
130
        plt.close()
131
132
        tolerance = TOL if tolerance is None else tolerance
133
134
        res = compare_images(str(EXPECTED / f"{basename}.png"), str(out_path), tolerance)
135
136
        assert res is None, res