|
a |
|
b/tests/plotting/conftest.py |
|
|
1 |
from abc import ABC, ABCMeta |
|
|
2 |
from functools import wraps |
|
|
3 |
from pathlib import Path |
|
|
4 |
from typing import Callable, Optional |
|
|
5 |
|
|
|
6 |
import pytest |
|
|
7 |
|
|
|
8 |
import numpy as np |
|
|
9 |
import pandas as pd |
|
|
10 |
|
|
|
11 |
import matplotlib.pyplot as plt |
|
|
12 |
from matplotlib.testing.compare import compare_images |
|
|
13 |
|
|
|
14 |
from anndata import AnnData |
|
|
15 |
|
|
|
16 |
from moscot import _constants |
|
|
17 |
from moscot.plotting._utils import set_plotting_vars |
|
|
18 |
|
|
|
19 |
HERE: Path = Path(__file__).parent |
|
|
20 |
|
|
|
21 |
EXPECTED = HERE / "expected_figures" |
|
|
22 |
ACTUAL = HERE / "actual_figures" |
|
|
23 |
TOL = 60 |
|
|
24 |
DPI = 40 |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
@pytest.fixture |
|
|
28 |
def adata_pl_cell_transition(gt_temporal_adata: AnnData) -> AnnData: |
|
|
29 |
plot_vars = { |
|
|
30 |
"transition_matrix": gt_temporal_adata.uns["cell_transition_10_105_forward"], |
|
|
31 |
"source_groups": "cell_type", |
|
|
32 |
"target_groups": "cell_type", |
|
|
33 |
"source": 0, |
|
|
34 |
"target": 1, |
|
|
35 |
} |
|
|
36 |
set_plotting_vars(gt_temporal_adata, _constants.CELL_TRANSITION, key=_constants.CELL_TRANSITION, value=plot_vars) |
|
|
37 |
|
|
|
38 |
return gt_temporal_adata |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
@pytest.fixture |
|
|
42 |
def adata_pl_push(adata_time: AnnData) -> AnnData: |
|
|
43 |
rng = np.random.RandomState(0) |
|
|
44 |
plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} |
|
|
45 |
adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"] |
|
|
46 |
adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category") |
|
|
47 |
set_plotting_vars(adata_time, _constants.PUSH, key=_constants.PUSH, value=plot_vars) |
|
|
48 |
push_initial_dist = np.zeros( |
|
|
49 |
shape=(len(adata_time[adata_time.obs["time"] == 0]),) |
|
|
50 |
) # we need this for a cat. distr. in plots |
|
|
51 |
push_initial_dist[0:10] = 0.1 |
|
|
52 |
nan2 = np.empty(len(adata_time[adata_time.obs["time"] == 2])) |
|
|
53 |
nan2[:] = np.nan |
|
|
54 |
adata_time.obs[_constants.PUSH] = np.hstack( |
|
|
55 |
(push_initial_dist, np.abs(rng.randn(len(adata_time[adata_time.obs["time"] == 1]))), nan2) |
|
|
56 |
) |
|
|
57 |
return adata_time |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
@pytest.fixture |
|
|
61 |
def adata_pl_pull(adata_time: AnnData) -> AnnData: |
|
|
62 |
rng = np.random.RandomState(0) |
|
|
63 |
plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} |
|
|
64 |
adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"] |
|
|
65 |
adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category") |
|
|
66 |
set_plotting_vars(adata_time, _constants.PULL, key=_constants.PULL, value=plot_vars) |
|
|
67 |
pull_initial_dist = np.zeros( |
|
|
68 |
shape=(len(adata_time[adata_time.obs["time"] == 1]),) |
|
|
69 |
) # we need this for a cat. distr. in plots |
|
|
70 |
pull_initial_dist[0:10] = 0.1 |
|
|
71 |
rand0 = np.abs(rng.randn(len(adata_time[adata_time.obs["time"] == 0]))) |
|
|
72 |
nan2 = np.empty(len(adata_time[adata_time.obs["time"] == 2])) |
|
|
73 |
nan2[:] = np.nan |
|
|
74 |
adata_time.obs[_constants.PULL] = np.hstack((rand0, pull_initial_dist, nan2)) |
|
|
75 |
return adata_time |
|
|
76 |
|
|
|
77 |
|
|
|
78 |
@pytest.fixture |
|
|
79 |
def adata_pl_sankey(adata_time: AnnData) -> AnnData: |
|
|
80 |
rng = np.random.RandomState(0) |
|
|
81 |
celltypes = ["A", "B", "C", "D", "E"] |
|
|
82 |
adata_time.obs["celltype"] = rng.choice(celltypes, size=len(adata_time)) |
|
|
83 |
adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category") |
|
|
84 |
data1 = np.abs(rng.randn(5, 5)) |
|
|
85 |
data2 = np.abs(rng.randn(5, 5)) |
|
|
86 |
tm1 = pd.DataFrame(data=data1, index=celltypes, columns=celltypes) |
|
|
87 |
tm2 = pd.DataFrame(data=data2, index=celltypes, columns=celltypes) |
|
|
88 |
plot_vars = {"transition_matrices": [tm1, tm2], "captions": ["0", "1"], "key": "celltype"} |
|
|
89 |
set_plotting_vars(adata_time, _constants.SANKEY, key=_constants.SANKEY, value=plot_vars) |
|
|
90 |
|
|
|
91 |
return adata_time |
|
|
92 |
|
|
|
93 |
|
|
|
94 |
def _decorate(fn: Callable, clsname: str, name: Optional[str] = None) -> Callable: |
|
|
95 |
@wraps(fn) |
|
|
96 |
def save_and_compare(self, *args, **kwargs): |
|
|
97 |
fn(self, *args, **kwargs) |
|
|
98 |
self.compare(fig_name) |
|
|
99 |
|
|
|
100 |
if not callable(fn): |
|
|
101 |
raise TypeError(f"Expected a `callable` for class `{clsname}`, found `{type(fn).__name__}`.") |
|
|
102 |
|
|
|
103 |
name = fn.__name__ if name is None else name |
|
|
104 |
|
|
|
105 |
if not name.startswith("test_plot_") or not clsname.startswith("Test"): |
|
|
106 |
return fn |
|
|
107 |
|
|
|
108 |
fig_name = f"{clsname[4:]}_{name[10:]}" |
|
|
109 |
|
|
|
110 |
return save_and_compare |
|
|
111 |
|
|
|
112 |
|
|
|
113 |
class PlotTesterMeta(ABCMeta): |
|
|
114 |
def __new__(cls, clsname, superclasses, attributedict): |
|
|
115 |
for key, value in attributedict.items(): |
|
|
116 |
if callable(value): |
|
|
117 |
attributedict[key] = _decorate(value, clsname, name=key) |
|
|
118 |
return super().__new__(cls, clsname, superclasses, attributedict) |
|
|
119 |
|
|
|
120 |
|
|
|
121 |
# ideally, we would you metaclass=PlotTesterMeta and all plotting tests just subclass this |
|
|
122 |
# but for some reason, pytest erases the metaclass info |
|
|
123 |
class PlotTester(ABC): # noqa: B024 |
|
|
124 |
@classmethod |
|
|
125 |
def compare(cls, basename: str, tolerance: Optional[float] = None): |
|
|
126 |
ACTUAL.mkdir(parents=True, exist_ok=True) |
|
|
127 |
out_path = ACTUAL / f"{basename}.png" |
|
|
128 |
|
|
|
129 |
plt.savefig(out_path, dpi=DPI) |
|
|
130 |
plt.close() |
|
|
131 |
|
|
|
132 |
tolerance = TOL if tolerance is None else tolerance |
|
|
133 |
|
|
|
134 |
res = compare_images(str(EXPECTED / f"{basename}.png"), str(out_path), tolerance) |
|
|
135 |
|
|
|
136 |
assert res is None, res |