Switch to unified view

a b/tests/problems/conftest.py
1
import pytest
2
3
import numpy as np
4
import pandas as pd
5
from sklearn.metrics import pairwise_distances
6
7
import anndata as ad
8
from anndata import AnnData
9
10
from tests._utils import Geom_t
11
12
13
@pytest.fixture
14
def adata_with_cost_matrix(adata_x: Geom_t, adata_y: Geom_t) -> AnnData:
15
    adata = ad.concat([adata_x, adata_y], label="batch", index_unique="-")
16
    C = pairwise_distances(adata_x.obsm["X_pca"], adata_y.obsm["X_pca"]) ** 2
17
    adata.obs["batch"] = pd.to_numeric(adata.obs["batch"])
18
    adata.uns[0] = C / C.mean()  # TODO(@MUCDK) make a callback function and replace this part
19
    return adata
20
21
22
@pytest.fixture
23
def adata_time_with_tmap(adata_time: AnnData) -> AnnData:
24
    adata = adata_time[adata_time.obs["time"].isin([0, 1])].copy()
25
    rng = np.random.RandomState(42)
26
    cell_types = ["cell_A", "cell_B", "cell_C", "cell_D"]
27
28
    cell_d1 = rng.multinomial(len(adata[adata.obs["time"] == 0]), [1 / len(cell_types)] * len(cell_types))
29
    cell_d2 = rng.multinomial(len(adata[adata.obs["time"] == 0]), [1 / len(cell_types)] * len(cell_types))
30
    a1 = np.concatenate(
31
        [["cell_A"] * cell_d1[0], ["cell_B"] * cell_d1[1], ["cell_C"] * cell_d1[2], ["cell_D"] * cell_d1[3]]
32
    ).flatten()
33
    a2 = np.concatenate(
34
        [["cell_A"] * cell_d2[0], ["cell_B"] * cell_d2[1], ["cell_C"] * cell_d2[2], ["cell_D"] * cell_d2[3]]
35
    ).flatten()
36
37
    adata.obs["cell_type"] = np.concatenate([a1, a2])
38
    adata.obs["cell_type"] = adata.obs["cell_type"].astype("category")
39
    cell_numbers_source = dict(adata[adata.obs["time"] == 0].obs["cell_type"].value_counts())
40
    cell_numbers_target = dict(adata[adata.obs["time"] == 1].obs["cell_type"].value_counts())
41
    trans_matrix = np.abs(rng.randn(len(cell_types), len(cell_types)))
42
    trans_matrix = trans_matrix / trans_matrix.sum(axis=1, keepdims=1)
43
44
    cell_transition_gt = pd.DataFrame(data=trans_matrix, index=cell_types, columns=cell_types)
45
46
    blocks = []
47
    for cell_row in cell_types:
48
        block_row = []
49
        for cell_col in cell_types:
50
            sub_trans_matrix = np.abs(rng.randn(cell_numbers_source[cell_row], cell_numbers_target[cell_col]))
51
            sub_trans_matrix /= sub_trans_matrix.sum() * (1 / cell_transition_gt.loc[cell_row, cell_col])
52
            block_row.append(sub_trans_matrix)
53
        blocks.append(block_row)
54
    transport_matrix = np.block(blocks)
55
    adata.uns["transport_matrix"] = transport_matrix
56
    adata.uns["cell_transition_gt"] = cell_transition_gt
57
58
    return adata
59
60
61
# keys for marginals
62
@pytest.fixture(
63
    params=[
64
        (None, None),
65
        ("left_marginals_balanced", "right_marginals_balanced"),
66
    ],
67
    ids=["default", "balanced"],
68
)
69
def marginal_keys(request):
70
    return request.param
71
72
73
sinkhorn_args_1 = {
74
    "epsilon": 0.7,
75
    "tau_a": 1.0,
76
    "tau_b": 1.0,
77
    "rank": 7,
78
    "initializer": "rank2",
79
    "initializer_kwargs": {},
80
    "jit": False,
81
    "threshold": 2e-3,
82
    "lse_mode": True,
83
    "norm_error": 2,
84
    "inner_iterations": 3,
85
    "min_iterations": 4,
86
    "max_iterations": 9,
87
    "gamma": 9.4,
88
    "gamma_rescale": False,
89
    "batch_size": None,  # in to_LRC() `batch_size` cannot be passed so we expect None.
90
    "scale_cost": "max_cost",
91
}
92
93
94
sinkhorn_args_2 = {  # no gamma/gamma_rescale as these are LR-specific
95
    "epsilon": 0.8,
96
    "tau_a": 0.9,
97
    "tau_b": 0.8,
98
    "rank": -1,
99
    "batch_size": 125,
100
    "initializer": "gaussian",
101
    "initializer_kwargs": {},
102
    "jit": True,
103
    "threshold": 3e-3,
104
    "lse_mode": False,
105
    "norm_error": 3,
106
    "inner_iterations": 4,
107
    "min_iterations": 1,
108
    "max_iterations": 2,
109
    "scale_cost": "mean",
110
}
111
112
linear_solver_kwargs1 = {
113
    "inner_iterations": 1,
114
    "min_iterations": 5,
115
    "max_iterations": 7,
116
    "lse_mode": False,
117
    "threshold": 5e-2,
118
    "norm_error": 4,
119
}
120
121
gw_args_1 = {  # no gamma/gamma_rescale/tolerances/ranks as these are LR-specific
122
    "epsilon": 0.5,
123
    "tau_a": 0.7,
124
    "tau_b": 0.8,
125
    "scale_cost": "max_cost",
126
    "rank": -1,
127
    "batch_size": 122,
128
    "initializer": None,
129
    "initializer_kwargs": {},
130
    "jit": True,
131
    "threshold": 3e-2,
132
    "min_iterations": 3,
133
    "max_iterations": 4,
134
    "gw_unbalanced_correction": True,
135
    "ranks": 4,
136
    "tolerances": 2e-2,
137
    "warm_start": False,
138
    "linear_solver_kwargs": linear_solver_kwargs1,
139
}
140
141
linear_solver_kwargs2 = {
142
    "inner_iterations": 3,
143
    "min_iterations": 7,
144
    "max_iterations": 8,
145
    "lse_mode": True,
146
    "threshold": 4e-2,
147
    "norm_error": 3,
148
    "gamma": 9.4,
149
    "gamma_rescale": False,
150
}
151
152
gw_args_2 = {
153
    "alpha": 0.4,
154
    "epsilon": 0.7,
155
    "tau_a": 1.0,
156
    "tau_b": 1.0,
157
    "scale_cost": "max_cost",
158
    "rank": 7,
159
    "batch_size": 123,
160
    "initializer": "rank2",
161
    "initializer_kwargs": {},
162
    "jit": False,
163
    "threshold": 2e-3,
164
    "min_iterations": 2,
165
    "max_iterations": 3,
166
    "gw_unbalanced_correction": False,
167
    "ranks": 3,
168
    "tolerances": 3e-2,
169
    # "linear_solver_kwargs": linear_solver_kwargs2,
170
}
171
172
gw_args_2 = {**gw_args_2, **linear_solver_kwargs2}
173
174
fgw_args_1 = gw_args_1.copy()
175
fgw_args_1["alpha"] = 0.6
176
177
fgw_args_2 = gw_args_2.copy()
178
fgw_args_2["alpha"] = 0.4
179
180
gw_solver_args = {
181
    "epsilon": "epsilon",
182
    "rank": "rank",
183
    "threshold": "threshold",
184
    "min_iterations": "min_iterations",
185
    "max_iterations": "max_iterations",
186
    "warm_start": "warm_start",
187
    "initializer": "initializer",
188
}
189
190
gw_lr_solver_args = {
191
    "epsilon": "epsilon",
192
    "rank": "rank",
193
    "threshold": "threshold",
194
    "min_iterations": "min_iterations",
195
    "max_iterations": "max_iterations",
196
    "initializer": "initializer",
197
}
198
199
gw_linear_solver_args = {
200
    "lse_mode": "lse_mode",
201
    "inner_iterations": "inner_iterations",
202
    "threshold": "threshold",
203
    "norm_error": "norm_error",
204
    "max_iterations": "max_iterations",
205
    "min_iterations": "min_iterations",
206
}
207
208
gw_lr_linear_solver_args = {
209
    "lse_mode": "lse_mode",
210
    "inner_iterations": "inner_iterations",
211
    "threshold": "threshold",
212
    "norm_error": "norm_error",
213
    "max_iterations": "max_iterations",
214
    "min_iterations": "min_iterations",
215
    "gamma": "gamma",
216
    "gamma_rescale": "gamma_rescale",
217
}
218
219
quad_prob_args = {
220
    "tau_a": "tau_a",
221
    "tau_b": "tau_b",
222
    "gw_unbalanced_correction": "gw_unbalanced_correction",
223
    "ranks": "ranks",
224
    "tolerances": "tolerances",
225
}
226
227
geometry_args = {"epsilon": "_epsilon_init", "scale_cost": "_scale_cost"}
228
229
pointcloud_args = {
230
    "batch_size": "_batch_size",
231
    "scale_cost": "_scale_cost",
232
}
233
234
lr_pointcloud_args = {
235
    "batch_size": "batch_size",
236
    "scale_cost": "_scale_cost",
237
}
238
239
sinkhorn_solver_args = {  # dictionary with key = moscot arg name, value = ott-jax attribute
240
    "lse_mode": "lse_mode",
241
    "threshold": "threshold",
242
    "norm_error": "norm_error",
243
    "inner_iterations": "inner_iterations",
244
    "min_iterations": "min_iterations",
245
    "max_iterations": "max_iterations",
246
    "initializer": "initializer",
247
    "initializer_kwargs": "initializer_kwargs",
248
}
249
250
lr_sinkhorn_solver_args = sinkhorn_solver_args.copy()
251
lr_sinkhorn_solver_args["gamma"] = "gamma"
252
lr_sinkhorn_solver_args["gamma_rescale"] = "gamma_rescale"
253
254
lin_prob_args = {
255
    "tau_a": "tau_a",
256
    "tau_b": "tau_b",
257
}
258
259
neurallin_cond_args_1 = {
260
    "batch_size": 8,
261
    "seed": 0,
262
    "iterations": 2,
263
    "valid_freq": 4,
264
}