|
a |
|
b/tests/conftest.py |
|
|
1 |
from math import cos, sin |
|
|
2 |
from typing import Literal, Optional, Tuple, Union |
|
|
3 |
|
|
|
4 |
import pytest |
|
|
5 |
|
|
|
6 |
import jax.numpy as jnp |
|
|
7 |
import numpy as np |
|
|
8 |
import pandas as pd |
|
|
9 |
from jax import config |
|
|
10 |
from scipy.sparse import csr_matrix |
|
|
11 |
|
|
|
12 |
import matplotlib.pyplot as plt |
|
|
13 |
|
|
|
14 |
import anndata as ad |
|
|
15 |
import scanpy as sc |
|
|
16 |
from anndata import AnnData |
|
|
17 |
|
|
|
18 |
from tests._utils import Geom_t, _make_adata, _make_grid |
|
|
19 |
|
|
|
20 |
ANGLES = (0, 30, 60) |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
# TODO(michalk8): consider passing this via env |
|
|
24 |
config.update("jax_enable_x64", True) |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
_gt_temporal_adata = sc.read("tests/data/moscot_temporal_tests.h5ad") |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
def pytest_sessionstart() -> None: |
|
|
31 |
sc.pl.set_rcParams_defaults() |
|
|
32 |
sc.set_figure_params(dpi=40, color_map="viridis") |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
@pytest.fixture(autouse=True) |
|
|
36 |
def _close_figure(): |
|
|
37 |
# prevent `RuntimeWarning: More than 20 figures have been opened.` |
|
|
38 |
yield |
|
|
39 |
plt.close() |
|
|
40 |
|
|
|
41 |
|
|
|
42 |
@pytest.fixture |
|
|
43 |
def x() -> Geom_t: |
|
|
44 |
rng = np.random.RandomState(0) |
|
|
45 |
n = 20 # number of points in the first distribution |
|
|
46 |
sig = 1 # std of first distribution |
|
|
47 |
|
|
|
48 |
phi = np.arange(n)[:, None] |
|
|
49 |
xs = phi + sig * rng.randn(n, 1) |
|
|
50 |
|
|
|
51 |
return jnp.asarray(xs) |
|
|
52 |
|
|
|
53 |
|
|
|
54 |
@pytest.fixture |
|
|
55 |
def y() -> Geom_t: |
|
|
56 |
rng = np.random.RandomState(1) |
|
|
57 |
n2 = 30 # number of points in the second distribution |
|
|
58 |
sig = 1 # std of first distribution |
|
|
59 |
|
|
|
60 |
phi2 = np.arange(n2)[:, None] |
|
|
61 |
xt = phi2 + sig * rng.randn(n2, 1) |
|
|
62 |
|
|
|
63 |
return jnp.asarray(xt) |
|
|
64 |
|
|
|
65 |
|
|
|
66 |
@pytest.fixture |
|
|
67 |
def xy() -> Tuple[Geom_t, Geom_t]: |
|
|
68 |
rng = np.random.RandomState(2) |
|
|
69 |
n = 20 # number of points in the first distribution |
|
|
70 |
n2 = 30 # number of points in the second distribution |
|
|
71 |
sig = 1 # std of first distribution |
|
|
72 |
sig2 = 0.1 # std of second distribution |
|
|
73 |
|
|
|
74 |
phi = np.arange(n)[:, None] |
|
|
75 |
phi + sig * rng.randn(n, 1) |
|
|
76 |
ys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * rng.randn(n, 1) |
|
|
77 |
|
|
|
78 |
phi2 = np.arange(n2)[:, None] |
|
|
79 |
phi2 + sig * rng.randn(n2, 1) |
|
|
80 |
yt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * rng.randn(n2, 1) |
|
|
81 |
yt = yt[::-1, :] |
|
|
82 |
|
|
|
83 |
return jnp.asarray(ys), jnp.asarray(yt) |
|
|
84 |
|
|
|
85 |
|
|
|
86 |
@pytest.fixture |
|
|
87 |
def ab() -> Tuple[np.ndarray, np.ndarray]: |
|
|
88 |
rng = np.random.RandomState(42) |
|
|
89 |
return rng.normal(size=(20, 2)), rng.normal(size=(30, 4)) |
|
|
90 |
|
|
|
91 |
|
|
|
92 |
@pytest.fixture |
|
|
93 |
def x_cost(x: Geom_t) -> jnp.ndarray: |
|
|
94 |
return ((x[:, None, :] - x[None, ...]) ** 2).sum(-1) |
|
|
95 |
|
|
|
96 |
|
|
|
97 |
@pytest.fixture |
|
|
98 |
def y_cost(y: Geom_t) -> jnp.ndarray: |
|
|
99 |
return ((y[:, None, :] - y[None, ...]) ** 2).sum(-1) |
|
|
100 |
|
|
|
101 |
|
|
|
102 |
@pytest.fixture |
|
|
103 |
def xy_cost(xy: Geom_t) -> jnp.ndarray: |
|
|
104 |
x, y = xy |
|
|
105 |
return ((x[:, None, :] - y[None, ...]) ** 2).sum(-1) |
|
|
106 |
|
|
|
107 |
|
|
|
108 |
@pytest.fixture |
|
|
109 |
def adata_x(x: Geom_t) -> AnnData: |
|
|
110 |
rng = np.random.RandomState(43) |
|
|
111 |
pc = rng.normal(size=(len(x), 4)) |
|
|
112 |
return AnnData(X=np.asarray(x, dtype=float), obsm={"X_pca": pc}) |
|
|
113 |
|
|
|
114 |
|
|
|
115 |
@pytest.fixture |
|
|
116 |
def adata_y(y: Geom_t) -> AnnData: |
|
|
117 |
rng = np.random.RandomState(44) |
|
|
118 |
pc = rng.normal(size=(len(y), 4)) |
|
|
119 |
return AnnData(X=np.asarray(y, dtype=float), obsm={"X_pca": pc}) |
|
|
120 |
|
|
|
121 |
|
|
|
122 |
def creat_prob(n: int, *, uniform: bool = False, seed: Optional[int] = None) -> Geom_t: |
|
|
123 |
rng = np.random.RandomState(seed) |
|
|
124 |
a = np.ones((n,)) if uniform else np.abs(rng.normal(size=(n,))) |
|
|
125 |
a /= np.sum(a) |
|
|
126 |
return jnp.asarray(a) |
|
|
127 |
|
|
|
128 |
|
|
|
129 |
@pytest.fixture |
|
|
130 |
def adata_time() -> AnnData: |
|
|
131 |
rng = np.random.RandomState(42) |
|
|
132 |
|
|
|
133 |
adatas = [ |
|
|
134 |
AnnData( |
|
|
135 |
X=csr_matrix(rng.normal(size=(96, 60))), |
|
|
136 |
obs={ |
|
|
137 |
"left_marginals_balanced": creat_prob(96, seed=42), |
|
|
138 |
"right_marginals_balanced": creat_prob(96, seed=42), |
|
|
139 |
}, |
|
|
140 |
) |
|
|
141 |
for _ in range(3) |
|
|
142 |
] |
|
|
143 |
adata = ad.concat(adatas, label="time", index_unique="-") |
|
|
144 |
adata.obs["time"] = pd.to_numeric(adata.obs["time"]).astype("category") |
|
|
145 |
adata.obs["batch"] = rng.choice((0, 1, 2), len(adata)) |
|
|
146 |
adata.obs["left_marginals_unbalanced"] = np.ones(len(adata)) |
|
|
147 |
adata.obs["right_marginals_unbalanced"] = np.ones(len(adata)) |
|
|
148 |
adata.obs["celltype"] = rng.choice(["A", "B", "C"], size=len(adata)) |
|
|
149 |
# genes from mouse/human proliferation/apoptosis |
|
|
150 |
genes = ["ANLN", "ANP32E", "ATAD2", "Mcm4", "Smc4", "Gtse1", "ADD1", "AIFM3", "ANKH", "Ercc5", "Serpinb5", "Inhbb"] |
|
|
151 |
# genes which are transcription factors, 3 from drosophila, 2 from human, 1 from mouse |
|
|
152 |
genes += ["Cf2", "Dlip3", "Dref", "KLF12", "ZNF143", "Zic5"] |
|
|
153 |
adata.var.index = ["gene_" + el if i > len(genes) - 1 else genes[i] for i, el in enumerate(adata.var.index)] |
|
|
154 |
adata.obsm["X_umap"] = rng.randn(len(adata), 2) |
|
|
155 |
sc.pp.pca(adata) |
|
|
156 |
return adata |
|
|
157 |
|
|
|
158 |
|
|
|
159 |
@pytest.fixture |
|
|
160 |
def gt_temporal_adata() -> AnnData: |
|
|
161 |
adata = _gt_temporal_adata.copy() |
|
|
162 |
# TODO(michalk8): remove both lines once data has been regenerated |
|
|
163 |
adata.obs["day"] = pd.to_numeric(adata.obs["day"]).astype("category") |
|
|
164 |
adata.obs_names_make_unique() |
|
|
165 |
return adata |
|
|
166 |
|
|
|
167 |
|
|
|
168 |
@pytest.fixture |
|
|
169 |
def adata_space_rotate() -> AnnData: |
|
|
170 |
rng = np.random.RandomState(31) |
|
|
171 |
grid = _make_grid(10) |
|
|
172 |
adatas = _make_adata(grid, n=len(ANGLES), seed=32) |
|
|
173 |
for adata, angle in zip(adatas, ANGLES): |
|
|
174 |
theta = np.deg2rad(angle) |
|
|
175 |
rot = np.array([[cos(theta), -sin(theta)], [sin(theta), cos(theta)]]) |
|
|
176 |
adata.obsm["spatial"] = np.dot(adata.obsm["spatial"], rot) |
|
|
177 |
|
|
|
178 |
adata = ad.concat(adatas, label="batch", index_unique="-") |
|
|
179 |
adata.obs["celltype"] = rng.choice(["A", "B", "C"], size=len(adata)) |
|
|
180 |
adata.uns["spatial"] = {} |
|
|
181 |
sc.pp.pca(adata) |
|
|
182 |
return adata |
|
|
183 |
|
|
|
184 |
|
|
|
185 |
@pytest.fixture |
|
|
186 |
def adata_mapping() -> AnnData: |
|
|
187 |
grid = _make_grid(10) |
|
|
188 |
adataref, adata1, adata2 = _make_adata(grid, n=3, seed=17, cat_key="covariate", num_categories=3) |
|
|
189 |
sc.pp.pca(adataref, n_comps=30) |
|
|
190 |
return ad.concat([adataref, adata1, adata2], label="batch", join="outer", index_unique="-") |
|
|
191 |
|
|
|
192 |
|
|
|
193 |
@pytest.fixture |
|
|
194 |
def adata_translation() -> AnnData: |
|
|
195 |
rng = np.random.RandomState(31) |
|
|
196 |
adatas = [AnnData(X=csr_matrix(rng.normal(size=(100, 60)))) for _ in range(3)] |
|
|
197 |
adata = ad.concat(adatas, label="batch", index_unique="-") |
|
|
198 |
adata.obs["celltype"] = rng.choice(["A", "B", "C"], size=len(adata)) |
|
|
199 |
adata.obs["celltype"] = adata.obs["celltype"].astype("category") |
|
|
200 |
adata.layers["counts"] = adata.X.toarray() |
|
|
201 |
sc.pp.pca(adata) |
|
|
202 |
return adata |
|
|
203 |
|
|
|
204 |
|
|
|
205 |
@pytest.fixture |
|
|
206 |
def adata_translation_split(adata_translation) -> Tuple[AnnData, AnnData]: |
|
|
207 |
rng = np.random.RandomState(15) |
|
|
208 |
adata_src = adata_translation[adata_translation.obs.batch != "0"].copy() |
|
|
209 |
adata_tgt = adata_translation[adata_translation.obs.batch == "0"].copy() |
|
|
210 |
adata_src.obsm["emb_src"] = rng.normal(size=(adata_src.shape[0], 5)) |
|
|
211 |
adata_tgt.obsm["emb_tgt"] = rng.normal(size=(adata_tgt.shape[0], 15)) |
|
|
212 |
return adata_src, adata_tgt |
|
|
213 |
|
|
|
214 |
|
|
|
215 |
@pytest.fixture |
|
|
216 |
def adata_anno( |
|
|
217 |
problem_kind: Literal["temporal", "cross_modality", "alignment", "mapping"], |
|
|
218 |
) -> Union[AnnData, Tuple[AnnData, AnnData]]: |
|
|
219 |
rng = np.random.RandomState(31) |
|
|
220 |
adata_src = AnnData(X=csr_matrix(rng.normal(size=(10, 60)))) |
|
|
221 |
rng_src = rng.choice(["A", "B", "C"], size=5).tolist() |
|
|
222 |
adata_src.obs["celltype1"] = ["C", "C", "A", "B", "B"] + rng_src |
|
|
223 |
adata_src.obs["celltype1"] = adata_src.obs["celltype1"].astype("category") |
|
|
224 |
adata_src.uns["expected_max1"] = ["C", "C", "A", "B", "B"] + rng_src + rng_src |
|
|
225 |
adata_src.uns["expected_sum1"] = ["C", "C", "B", "B", "B"] + rng_src + rng_src |
|
|
226 |
|
|
|
227 |
adata_tgt = AnnData(X=csr_matrix(rng.normal(size=(15, 60)))) |
|
|
228 |
rng_tgt = rng.choice(["A", "B", "C"], size=5).tolist() |
|
|
229 |
adata_tgt.obs["celltype2"] = ["C", "C", "A", "B", "B"] + rng_tgt + rng_tgt |
|
|
230 |
adata_tgt.obs["celltype2"] = adata_tgt.obs["celltype2"].astype("category") |
|
|
231 |
adata_tgt.uns["expected_max2"] = ["C", "C", "A", "B", "B"] + rng_tgt |
|
|
232 |
adata_tgt.uns["expected_sum2"] = ["C", "C", "B", "B", "B"] + rng_tgt |
|
|
233 |
|
|
|
234 |
if problem_kind == "cross_modality": |
|
|
235 |
adata_src.obs["batch"] = "0" |
|
|
236 |
adata_tgt.obs["batch"] = "1" |
|
|
237 |
adata_src.obsm["emb_src"] = rng.normal(size=(adata_src.shape[0], 5)) |
|
|
238 |
adata_tgt.obsm["emb_tgt"] = rng.normal(size=(adata_tgt.shape[0], 15)) |
|
|
239 |
sc.pp.pca(adata_src) |
|
|
240 |
sc.pp.pca(adata_tgt) |
|
|
241 |
return adata_src, adata_tgt |
|
|
242 |
if problem_kind == "mapping": |
|
|
243 |
adata_src.obs["batch"] = "0" |
|
|
244 |
adata_tgt.obs["batch"] = "1" |
|
|
245 |
sc.pp.pca(adata_src) |
|
|
246 |
sc.pp.pca(adata_tgt) |
|
|
247 |
adata_tgt.obsm["spatial"] = rng.normal(size=(adata_tgt.n_obs, 2)) |
|
|
248 |
return adata_src, adata_tgt |
|
|
249 |
if problem_kind == "alignment": |
|
|
250 |
adata_src.obsm["spatial"] = rng.normal(size=(adata_src.n_obs, 2)) |
|
|
251 |
adata_tgt.obsm["spatial"] = rng.normal(size=(adata_tgt.n_obs, 2)) |
|
|
252 |
key = "day" if problem_kind == "temporal" else "batch" |
|
|
253 |
adatas = [adata_src, adata_tgt] |
|
|
254 |
adata = ad.concat(adatas, join="outer", label=key, index_unique="-", uns_merge="unique") |
|
|
255 |
adata.obs[key] = (pd.to_numeric(adata.obs[key]) if key == "day" else adata.obs[key]).astype("category") |
|
|
256 |
adata.layers["counts"] = adata.X.toarray() |
|
|
257 |
sc.pp.pca(adata) |
|
|
258 |
return adata |
|
|
259 |
|
|
|
260 |
|
|
|
261 |
@pytest.fixture |
|
|
262 |
def gt_tm_annotation() -> np.ndarray: |
|
|
263 |
tm = np.zeros((10, 15)) |
|
|
264 |
for i in range(10): |
|
|
265 |
tm[i][i] = 1 |
|
|
266 |
for i in range(10, 15): |
|
|
267 |
tm[i - 5][i] = 1 |
|
|
268 |
for j in range(2, 5): |
|
|
269 |
for i in range(2, 5): |
|
|
270 |
tm[i][j] = 0.3 if i != j else 0.4 |
|
|
271 |
return tm |