Diff of /tests/conftest.py [000000] .. [6ff4a8]

Switch to unified view

a b/tests/conftest.py
1
from math import cos, sin
2
from typing import Literal, Optional, Tuple, Union
3
4
import pytest
5
6
import jax.numpy as jnp
7
import numpy as np
8
import pandas as pd
9
from jax import config
10
from scipy.sparse import csr_matrix
11
12
import matplotlib.pyplot as plt
13
14
import anndata as ad
15
import scanpy as sc
16
from anndata import AnnData
17
18
from tests._utils import Geom_t, _make_adata, _make_grid
19
20
ANGLES = (0, 30, 60)
21
22
23
# TODO(michalk8): consider passing this via env
24
config.update("jax_enable_x64", True)
25
26
27
_gt_temporal_adata = sc.read("tests/data/moscot_temporal_tests.h5ad")
28
29
30
def pytest_sessionstart() -> None:
31
    sc.pl.set_rcParams_defaults()
32
    sc.set_figure_params(dpi=40, color_map="viridis")
33
34
35
@pytest.fixture(autouse=True)
36
def _close_figure():
37
    # prevent `RuntimeWarning: More than 20 figures have been opened.`
38
    yield
39
    plt.close()
40
41
42
@pytest.fixture
43
def x() -> Geom_t:
44
    rng = np.random.RandomState(0)
45
    n = 20  # number of points in the first distribution
46
    sig = 1  # std of first distribution
47
48
    phi = np.arange(n)[:, None]
49
    xs = phi + sig * rng.randn(n, 1)
50
51
    return jnp.asarray(xs)
52
53
54
@pytest.fixture
55
def y() -> Geom_t:
56
    rng = np.random.RandomState(1)
57
    n2 = 30  # number of points in the second distribution
58
    sig = 1  # std of first distribution
59
60
    phi2 = np.arange(n2)[:, None]
61
    xt = phi2 + sig * rng.randn(n2, 1)
62
63
    return jnp.asarray(xt)
64
65
66
@pytest.fixture
67
def xy() -> Tuple[Geom_t, Geom_t]:
68
    rng = np.random.RandomState(2)
69
    n = 20  # number of points in the first distribution
70
    n2 = 30  # number of points in the second distribution
71
    sig = 1  # std of first distribution
72
    sig2 = 0.1  # std of second distribution
73
74
    phi = np.arange(n)[:, None]
75
    phi + sig * rng.randn(n, 1)
76
    ys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * rng.randn(n, 1)
77
78
    phi2 = np.arange(n2)[:, None]
79
    phi2 + sig * rng.randn(n2, 1)
80
    yt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * rng.randn(n2, 1)
81
    yt = yt[::-1, :]
82
83
    return jnp.asarray(ys), jnp.asarray(yt)
84
85
86
@pytest.fixture
87
def ab() -> Tuple[np.ndarray, np.ndarray]:
88
    rng = np.random.RandomState(42)
89
    return rng.normal(size=(20, 2)), rng.normal(size=(30, 4))
90
91
92
@pytest.fixture
93
def x_cost(x: Geom_t) -> jnp.ndarray:
94
    return ((x[:, None, :] - x[None, ...]) ** 2).sum(-1)
95
96
97
@pytest.fixture
98
def y_cost(y: Geom_t) -> jnp.ndarray:
99
    return ((y[:, None, :] - y[None, ...]) ** 2).sum(-1)
100
101
102
@pytest.fixture
103
def xy_cost(xy: Geom_t) -> jnp.ndarray:
104
    x, y = xy
105
    return ((x[:, None, :] - y[None, ...]) ** 2).sum(-1)
106
107
108
@pytest.fixture
109
def adata_x(x: Geom_t) -> AnnData:
110
    rng = np.random.RandomState(43)
111
    pc = rng.normal(size=(len(x), 4))
112
    return AnnData(X=np.asarray(x, dtype=float), obsm={"X_pca": pc})
113
114
115
@pytest.fixture
116
def adata_y(y: Geom_t) -> AnnData:
117
    rng = np.random.RandomState(44)
118
    pc = rng.normal(size=(len(y), 4))
119
    return AnnData(X=np.asarray(y, dtype=float), obsm={"X_pca": pc})
120
121
122
def creat_prob(n: int, *, uniform: bool = False, seed: Optional[int] = None) -> Geom_t:
123
    rng = np.random.RandomState(seed)
124
    a = np.ones((n,)) if uniform else np.abs(rng.normal(size=(n,)))
125
    a /= np.sum(a)
126
    return jnp.asarray(a)
127
128
129
@pytest.fixture
130
def adata_time() -> AnnData:
131
    rng = np.random.RandomState(42)
132
133
    adatas = [
134
        AnnData(
135
            X=csr_matrix(rng.normal(size=(96, 60))),
136
            obs={
137
                "left_marginals_balanced": creat_prob(96, seed=42),
138
                "right_marginals_balanced": creat_prob(96, seed=42),
139
            },
140
        )
141
        for _ in range(3)
142
    ]
143
    adata = ad.concat(adatas, label="time", index_unique="-")
144
    adata.obs["time"] = pd.to_numeric(adata.obs["time"]).astype("category")
145
    adata.obs["batch"] = rng.choice((0, 1, 2), len(adata))
146
    adata.obs["left_marginals_unbalanced"] = np.ones(len(adata))
147
    adata.obs["right_marginals_unbalanced"] = np.ones(len(adata))
148
    adata.obs["celltype"] = rng.choice(["A", "B", "C"], size=len(adata))
149
    # genes from mouse/human proliferation/apoptosis
150
    genes = ["ANLN", "ANP32E", "ATAD2", "Mcm4", "Smc4", "Gtse1", "ADD1", "AIFM3", "ANKH", "Ercc5", "Serpinb5", "Inhbb"]
151
    # genes which are transcription factors, 3 from drosophila, 2 from human, 1 from mouse
152
    genes += ["Cf2", "Dlip3", "Dref", "KLF12", "ZNF143", "Zic5"]
153
    adata.var.index = ["gene_" + el if i > len(genes) - 1 else genes[i] for i, el in enumerate(adata.var.index)]
154
    adata.obsm["X_umap"] = rng.randn(len(adata), 2)
155
    sc.pp.pca(adata)
156
    return adata
157
158
159
@pytest.fixture
160
def gt_temporal_adata() -> AnnData:
161
    adata = _gt_temporal_adata.copy()
162
    # TODO(michalk8): remove both lines once data has been regenerated
163
    adata.obs["day"] = pd.to_numeric(adata.obs["day"]).astype("category")
164
    adata.obs_names_make_unique()
165
    return adata
166
167
168
@pytest.fixture
169
def adata_space_rotate() -> AnnData:
170
    rng = np.random.RandomState(31)
171
    grid = _make_grid(10)
172
    adatas = _make_adata(grid, n=len(ANGLES), seed=32)
173
    for adata, angle in zip(adatas, ANGLES):
174
        theta = np.deg2rad(angle)
175
        rot = np.array([[cos(theta), -sin(theta)], [sin(theta), cos(theta)]])
176
        adata.obsm["spatial"] = np.dot(adata.obsm["spatial"], rot)
177
178
    adata = ad.concat(adatas, label="batch", index_unique="-")
179
    adata.obs["celltype"] = rng.choice(["A", "B", "C"], size=len(adata))
180
    adata.uns["spatial"] = {}
181
    sc.pp.pca(adata)
182
    return adata
183
184
185
@pytest.fixture
186
def adata_mapping() -> AnnData:
187
    grid = _make_grid(10)
188
    adataref, adata1, adata2 = _make_adata(grid, n=3, seed=17, cat_key="covariate", num_categories=3)
189
    sc.pp.pca(adataref, n_comps=30)
190
    return ad.concat([adataref, adata1, adata2], label="batch", join="outer", index_unique="-")
191
192
193
@pytest.fixture
194
def adata_translation() -> AnnData:
195
    rng = np.random.RandomState(31)
196
    adatas = [AnnData(X=csr_matrix(rng.normal(size=(100, 60)))) for _ in range(3)]
197
    adata = ad.concat(adatas, label="batch", index_unique="-")
198
    adata.obs["celltype"] = rng.choice(["A", "B", "C"], size=len(adata))
199
    adata.obs["celltype"] = adata.obs["celltype"].astype("category")
200
    adata.layers["counts"] = adata.X.toarray()
201
    sc.pp.pca(adata)
202
    return adata
203
204
205
@pytest.fixture
206
def adata_translation_split(adata_translation) -> Tuple[AnnData, AnnData]:
207
    rng = np.random.RandomState(15)
208
    adata_src = adata_translation[adata_translation.obs.batch != "0"].copy()
209
    adata_tgt = adata_translation[adata_translation.obs.batch == "0"].copy()
210
    adata_src.obsm["emb_src"] = rng.normal(size=(adata_src.shape[0], 5))
211
    adata_tgt.obsm["emb_tgt"] = rng.normal(size=(adata_tgt.shape[0], 15))
212
    return adata_src, adata_tgt
213
214
215
@pytest.fixture
216
def adata_anno(
217
    problem_kind: Literal["temporal", "cross_modality", "alignment", "mapping"],
218
) -> Union[AnnData, Tuple[AnnData, AnnData]]:
219
    rng = np.random.RandomState(31)
220
    adata_src = AnnData(X=csr_matrix(rng.normal(size=(10, 60))))
221
    rng_src = rng.choice(["A", "B", "C"], size=5).tolist()
222
    adata_src.obs["celltype1"] = ["C", "C", "A", "B", "B"] + rng_src
223
    adata_src.obs["celltype1"] = adata_src.obs["celltype1"].astype("category")
224
    adata_src.uns["expected_max1"] = ["C", "C", "A", "B", "B"] + rng_src + rng_src
225
    adata_src.uns["expected_sum1"] = ["C", "C", "B", "B", "B"] + rng_src + rng_src
226
227
    adata_tgt = AnnData(X=csr_matrix(rng.normal(size=(15, 60))))
228
    rng_tgt = rng.choice(["A", "B", "C"], size=5).tolist()
229
    adata_tgt.obs["celltype2"] = ["C", "C", "A", "B", "B"] + rng_tgt + rng_tgt
230
    adata_tgt.obs["celltype2"] = adata_tgt.obs["celltype2"].astype("category")
231
    adata_tgt.uns["expected_max2"] = ["C", "C", "A", "B", "B"] + rng_tgt
232
    adata_tgt.uns["expected_sum2"] = ["C", "C", "B", "B", "B"] + rng_tgt
233
234
    if problem_kind == "cross_modality":
235
        adata_src.obs["batch"] = "0"
236
        adata_tgt.obs["batch"] = "1"
237
        adata_src.obsm["emb_src"] = rng.normal(size=(adata_src.shape[0], 5))
238
        adata_tgt.obsm["emb_tgt"] = rng.normal(size=(adata_tgt.shape[0], 15))
239
        sc.pp.pca(adata_src)
240
        sc.pp.pca(adata_tgt)
241
        return adata_src, adata_tgt
242
    if problem_kind == "mapping":
243
        adata_src.obs["batch"] = "0"
244
        adata_tgt.obs["batch"] = "1"
245
        sc.pp.pca(adata_src)
246
        sc.pp.pca(adata_tgt)
247
        adata_tgt.obsm["spatial"] = rng.normal(size=(adata_tgt.n_obs, 2))
248
        return adata_src, adata_tgt
249
    if problem_kind == "alignment":
250
        adata_src.obsm["spatial"] = rng.normal(size=(adata_src.n_obs, 2))
251
        adata_tgt.obsm["spatial"] = rng.normal(size=(adata_tgt.n_obs, 2))
252
    key = "day" if problem_kind == "temporal" else "batch"
253
    adatas = [adata_src, adata_tgt]
254
    adata = ad.concat(adatas, join="outer", label=key, index_unique="-", uns_merge="unique")
255
    adata.obs[key] = (pd.to_numeric(adata.obs[key]) if key == "day" else adata.obs[key]).astype("category")
256
    adata.layers["counts"] = adata.X.toarray()
257
    sc.pp.pca(adata)
258
    return adata
259
260
261
@pytest.fixture
262
def gt_tm_annotation() -> np.ndarray:
263
    tm = np.zeros((10, 15))
264
    for i in range(10):
265
        tm[i][i] = 1
266
    for i in range(10, 15):
267
        tm[i - 5][i] = 1
268
    for j in range(2, 5):
269
        for i in range(2, 5):
270
            tm[i][j] = 0.3 if i != j else 0.4
271
    return tm