[6ff4a8]: / tests / problems / cross_modality / test_translation_problem.py

Download this file

181 lines (158 with data), 7.8 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from contextlib import nullcontext
from typing import Any, Callable, Literal, Mapping, Optional, Tuple
import pytest
import numpy as np
from anndata import AnnData
from moscot.backends.ott._utils import alpha_to_fused_penalty
from moscot.base.output import BaseDiscreteSolverOutput
from moscot.problems.cross_modality import TranslationProblem
from tests.problems.conftest import (
fgw_args_1,
fgw_args_2,
geometry_args,
gw_linear_solver_args,
gw_lr_linear_solver_args,
gw_lr_solver_args,
gw_solver_args,
pointcloud_args,
quad_prob_args,
)
class TestTranslationProblem:
@pytest.mark.fast
@pytest.mark.parametrize("src_attr", ["emb_src", {"attr": "obsm", "key": "emb_src"}])
@pytest.mark.parametrize("tgt_attr", ["emb_tgt", {"attr": "obsm", "key": "emb_tgt"}])
@pytest.mark.parametrize("joint_attr", [None, "X_pca", {"attr": "obsm", "key": "X_pca"}])
def test_prepare_dummy_policy(
self,
adata_translation_split: Tuple[AnnData, AnnData],
src_attr: Mapping[str, str],
tgt_attr: Mapping[str, str],
joint_attr: Optional[Mapping[str, str]],
):
adata_src, adata_tgt = adata_translation_split
n_obs = adata_tgt.shape[0]
tp = TranslationProblem(adata_src, adata_tgt)
assert tp.problems == {}
assert tp.solutions == {}
prob_key = ("src", "tgt")
tp = tp.prepare(src_attr=src_attr, tgt_attr=tgt_attr, joint_attr=joint_attr)
assert len(tp) == 1
assert isinstance(tp[prob_key], tp._base_problem_type)
assert tp[prob_key].shape == (2 * n_obs, n_obs)
np.testing.assert_array_equal(tp._policy._cat, prob_key)
@pytest.mark.fast
@pytest.mark.parametrize("src_attr", ["emb_src", {"attr": "obsm", "key": "emb_src"}])
@pytest.mark.parametrize("tgt_attr", ["emb_tgt", {"attr": "obsm", "key": "emb_tgt"}])
@pytest.mark.parametrize("joint_attr", [None, "X_pca", {"attr": "obsm", "key": "X_pca"}])
def test_prepare_external_star_policy(
self,
adata_translation_split: Tuple[AnnData, AnnData],
src_attr: Mapping[str, str],
tgt_attr: Mapping[str, str],
joint_attr: Optional[Mapping[str, str]],
):
adata_src, adata_tgt = adata_translation_split
expected_keys = {(i, "ref") for i in adata_src.obs["batch"]}
n_obs = adata_tgt.shape[0]
x_n_var = adata_src.obsm["emb_src"].shape[1]
y_n_var = adata_tgt.obsm["emb_tgt"].shape[1]
xy_n_vars = adata_src.X.shape[1] if joint_attr == "default" else adata_src.obsm["X_pca"].shape[1]
tp = TranslationProblem(adata_src, adata_tgt)
assert tp.problems == {}
assert tp.solutions == {}
tp = tp.prepare(batch_key="batch", src_attr=src_attr, tgt_attr=tgt_attr, joint_attr=joint_attr)
assert len(tp) == len(expected_keys)
for prob_key in expected_keys:
assert isinstance(tp[prob_key], tp._base_problem_type)
assert tp[prob_key].shape == (n_obs, n_obs)
assert tp[prob_key].x.data_src.shape == (n_obs, x_n_var)
assert tp[prob_key].y.data_src.shape == (n_obs, y_n_var)
if joint_attr is not None:
assert tp[prob_key].xy.data_src.shape == tp[prob_key].xy.data_tgt.shape == (n_obs, xy_n_vars)
@pytest.mark.parametrize(
("epsilon", "alpha", "rank", "initializer", "joint_attr", "expect_fail"),
[
(1e-2, 0.9, -1, None, {"attr": "obsm", "key": "X_pca"}, False),
(2, 0.5, -1, "random", None, True),
(2, 0.5, -1, "random", {"attr": "X"}, False),
(2, 1.0, -1, "rank2", None, False),
(2, 1.0, -1, "rank2", {"attr": "obsm", "key": "X_pca"}, True),
(2, 0.1, -1, None, {"attr": "obsm", "key": "X_pca"}, False),
(2, 1.0, -1, None, None, False),
(1.3, 1.0, -1, "random", None, False),
],
)
@pytest.mark.parametrize("src_attr", ["emb_src", {"attr": "obsm", "key": "emb_src"}])
@pytest.mark.parametrize("tgt_attr", ["emb_tgt", {"attr": "obsm", "key": "emb_tgt"}])
def test_solve_balanced(
self,
adata_translation_split: Tuple[AnnData, AnnData],
epsilon: float,
alpha: float,
rank: int,
src_attr: Mapping[str, str],
tgt_attr: Mapping[str, str],
initializer: Optional[Literal["random", "rank2"]],
joint_attr: Optional[Mapping[str, str]],
expect_fail: bool,
):
adata_src, adata_tgt = adata_translation_split
kwargs = {}
expected_keys = {(i, "ref") for i in adata_src.obs["batch"]}
if rank > -1:
kwargs["initializer"] = initializer
tp = TranslationProblem(adata_src, adata_tgt)
tp = tp.prepare(batch_key="batch", src_attr=src_attr, tgt_attr=tgt_attr, joint_attr=joint_attr)
context = pytest.raises(ValueError, match="alpha") if expect_fail else nullcontext()
with context:
tp = tp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs)
for key, subsol in tp.solutions.items():
assert isinstance(subsol, BaseDiscreteSolverOutput)
assert key in expected_keys
assert tp[key].solution.rank == rank
for key, sol in tp.solutions.items():
np.testing.assert_array_equal(np.isfinite(sol.transport_matrix), True, err_msg=f"{key}")
@pytest.mark.parametrize("args_to_check", [fgw_args_1, fgw_args_2])
def test_pass_arguments(self, adata_translation_split: Tuple[AnnData, AnnData], args_to_check: Mapping[str, Any]):
adata_src, adata_tgt = adata_translation_split
tp = TranslationProblem(adata_src, adata_tgt)
key = ("1", "ref")
tp = tp.prepare(
batch_key="batch",
src_attr={"attr": "obsm", "key": "emb_src"},
tgt_attr={"attr": "obsm", "key": "emb_tgt"},
joint_attr="X_pca",
)
tp = tp.solve(**args_to_check)
solver = tp[key].solver.solver
args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args
for arg, val in args.items():
if arg == "initializer":
assert isinstance(getattr(solver, val), Callable)
sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver
lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args
tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check
for arg, val in lin_solver_args.items():
el = (
getattr(sinkhorn_solver, val)[0]
if isinstance(getattr(sinkhorn_solver, val), tuple)
else getattr(sinkhorn_solver, val)
)
assert el == tmp_dict[arg], arg
quad_prob = tp[key]._solver._problem
for arg, val in quad_prob_args.items():
assert getattr(quad_prob, val) == args_to_check[arg], arg
assert quad_prob.fused_penalty == alpha_to_fused_penalty(args_to_check["alpha"])
geom = quad_prob.geom_xx
for arg, val in geometry_args.items():
assert hasattr(geom, val)
el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val)
if arg == "epsilon":
eps_processed = getattr(geom, val)
assert eps_processed == args_to_check[arg], arg
else:
assert getattr(geom, val) == args_to_check[arg], arg
assert el == args_to_check[arg]
geom = quad_prob.geom_xy
for arg, val in pointcloud_args.items():
assert getattr(geom, val) == args_to_check[arg], arg