|
a |
|
b/tests/problems/conftest.py |
|
|
1 |
import pytest |
|
|
2 |
|
|
|
3 |
import numpy as np |
|
|
4 |
import pandas as pd |
|
|
5 |
from sklearn.metrics import pairwise_distances |
|
|
6 |
|
|
|
7 |
import anndata as ad |
|
|
8 |
from anndata import AnnData |
|
|
9 |
|
|
|
10 |
from tests._utils import Geom_t |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
@pytest.fixture |
|
|
14 |
def adata_with_cost_matrix(adata_x: Geom_t, adata_y: Geom_t) -> AnnData: |
|
|
15 |
adata = ad.concat([adata_x, adata_y], label="batch", index_unique="-") |
|
|
16 |
C = pairwise_distances(adata_x.obsm["X_pca"], adata_y.obsm["X_pca"]) ** 2 |
|
|
17 |
adata.obs["batch"] = pd.to_numeric(adata.obs["batch"]) |
|
|
18 |
adata.uns[0] = C / C.mean() # TODO(@MUCDK) make a callback function and replace this part |
|
|
19 |
return adata |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
@pytest.fixture |
|
|
23 |
def adata_time_with_tmap(adata_time: AnnData) -> AnnData: |
|
|
24 |
adata = adata_time[adata_time.obs["time"].isin([0, 1])].copy() |
|
|
25 |
rng = np.random.RandomState(42) |
|
|
26 |
cell_types = ["cell_A", "cell_B", "cell_C", "cell_D"] |
|
|
27 |
|
|
|
28 |
cell_d1 = rng.multinomial(len(adata[adata.obs["time"] == 0]), [1 / len(cell_types)] * len(cell_types)) |
|
|
29 |
cell_d2 = rng.multinomial(len(adata[adata.obs["time"] == 0]), [1 / len(cell_types)] * len(cell_types)) |
|
|
30 |
a1 = np.concatenate( |
|
|
31 |
[["cell_A"] * cell_d1[0], ["cell_B"] * cell_d1[1], ["cell_C"] * cell_d1[2], ["cell_D"] * cell_d1[3]] |
|
|
32 |
).flatten() |
|
|
33 |
a2 = np.concatenate( |
|
|
34 |
[["cell_A"] * cell_d2[0], ["cell_B"] * cell_d2[1], ["cell_C"] * cell_d2[2], ["cell_D"] * cell_d2[3]] |
|
|
35 |
).flatten() |
|
|
36 |
|
|
|
37 |
adata.obs["cell_type"] = np.concatenate([a1, a2]) |
|
|
38 |
adata.obs["cell_type"] = adata.obs["cell_type"].astype("category") |
|
|
39 |
cell_numbers_source = dict(adata[adata.obs["time"] == 0].obs["cell_type"].value_counts()) |
|
|
40 |
cell_numbers_target = dict(adata[adata.obs["time"] == 1].obs["cell_type"].value_counts()) |
|
|
41 |
trans_matrix = np.abs(rng.randn(len(cell_types), len(cell_types))) |
|
|
42 |
trans_matrix = trans_matrix / trans_matrix.sum(axis=1, keepdims=1) |
|
|
43 |
|
|
|
44 |
cell_transition_gt = pd.DataFrame(data=trans_matrix, index=cell_types, columns=cell_types) |
|
|
45 |
|
|
|
46 |
blocks = [] |
|
|
47 |
for cell_row in cell_types: |
|
|
48 |
block_row = [] |
|
|
49 |
for cell_col in cell_types: |
|
|
50 |
sub_trans_matrix = np.abs(rng.randn(cell_numbers_source[cell_row], cell_numbers_target[cell_col])) |
|
|
51 |
sub_trans_matrix /= sub_trans_matrix.sum() * (1 / cell_transition_gt.loc[cell_row, cell_col]) |
|
|
52 |
block_row.append(sub_trans_matrix) |
|
|
53 |
blocks.append(block_row) |
|
|
54 |
transport_matrix = np.block(blocks) |
|
|
55 |
adata.uns["transport_matrix"] = transport_matrix |
|
|
56 |
adata.uns["cell_transition_gt"] = cell_transition_gt |
|
|
57 |
|
|
|
58 |
return adata |
|
|
59 |
|
|
|
60 |
|
|
|
61 |
# keys for marginals |
|
|
62 |
@pytest.fixture( |
|
|
63 |
params=[ |
|
|
64 |
(None, None), |
|
|
65 |
("left_marginals_balanced", "right_marginals_balanced"), |
|
|
66 |
], |
|
|
67 |
ids=["default", "balanced"], |
|
|
68 |
) |
|
|
69 |
def marginal_keys(request): |
|
|
70 |
return request.param |
|
|
71 |
|
|
|
72 |
|
|
|
73 |
sinkhorn_args_1 = { |
|
|
74 |
"epsilon": 0.7, |
|
|
75 |
"tau_a": 1.0, |
|
|
76 |
"tau_b": 1.0, |
|
|
77 |
"rank": 7, |
|
|
78 |
"initializer": "rank2", |
|
|
79 |
"initializer_kwargs": {}, |
|
|
80 |
"jit": False, |
|
|
81 |
"threshold": 2e-3, |
|
|
82 |
"lse_mode": True, |
|
|
83 |
"norm_error": 2, |
|
|
84 |
"inner_iterations": 3, |
|
|
85 |
"min_iterations": 4, |
|
|
86 |
"max_iterations": 9, |
|
|
87 |
"gamma": 9.4, |
|
|
88 |
"gamma_rescale": False, |
|
|
89 |
"batch_size": None, # in to_LRC() `batch_size` cannot be passed so we expect None. |
|
|
90 |
"scale_cost": "max_cost", |
|
|
91 |
} |
|
|
92 |
|
|
|
93 |
|
|
|
94 |
sinkhorn_args_2 = { # no gamma/gamma_rescale as these are LR-specific |
|
|
95 |
"epsilon": 0.8, |
|
|
96 |
"tau_a": 0.9, |
|
|
97 |
"tau_b": 0.8, |
|
|
98 |
"rank": -1, |
|
|
99 |
"batch_size": 125, |
|
|
100 |
"initializer": "gaussian", |
|
|
101 |
"initializer_kwargs": {}, |
|
|
102 |
"jit": True, |
|
|
103 |
"threshold": 3e-3, |
|
|
104 |
"lse_mode": False, |
|
|
105 |
"norm_error": 3, |
|
|
106 |
"inner_iterations": 4, |
|
|
107 |
"min_iterations": 1, |
|
|
108 |
"max_iterations": 2, |
|
|
109 |
"scale_cost": "mean", |
|
|
110 |
} |
|
|
111 |
|
|
|
112 |
linear_solver_kwargs1 = { |
|
|
113 |
"inner_iterations": 1, |
|
|
114 |
"min_iterations": 5, |
|
|
115 |
"max_iterations": 7, |
|
|
116 |
"lse_mode": False, |
|
|
117 |
"threshold": 5e-2, |
|
|
118 |
"norm_error": 4, |
|
|
119 |
} |
|
|
120 |
|
|
|
121 |
gw_args_1 = { # no gamma/gamma_rescale/tolerances/ranks as these are LR-specific |
|
|
122 |
"epsilon": 0.5, |
|
|
123 |
"tau_a": 0.7, |
|
|
124 |
"tau_b": 0.8, |
|
|
125 |
"scale_cost": "max_cost", |
|
|
126 |
"rank": -1, |
|
|
127 |
"batch_size": 122, |
|
|
128 |
"initializer": None, |
|
|
129 |
"initializer_kwargs": {}, |
|
|
130 |
"jit": True, |
|
|
131 |
"threshold": 3e-2, |
|
|
132 |
"min_iterations": 3, |
|
|
133 |
"max_iterations": 4, |
|
|
134 |
"gw_unbalanced_correction": True, |
|
|
135 |
"ranks": 4, |
|
|
136 |
"tolerances": 2e-2, |
|
|
137 |
"warm_start": False, |
|
|
138 |
"linear_solver_kwargs": linear_solver_kwargs1, |
|
|
139 |
} |
|
|
140 |
|
|
|
141 |
linear_solver_kwargs2 = { |
|
|
142 |
"inner_iterations": 3, |
|
|
143 |
"min_iterations": 7, |
|
|
144 |
"max_iterations": 8, |
|
|
145 |
"lse_mode": True, |
|
|
146 |
"threshold": 4e-2, |
|
|
147 |
"norm_error": 3, |
|
|
148 |
"gamma": 9.4, |
|
|
149 |
"gamma_rescale": False, |
|
|
150 |
} |
|
|
151 |
|
|
|
152 |
gw_args_2 = { |
|
|
153 |
"alpha": 0.4, |
|
|
154 |
"epsilon": 0.7, |
|
|
155 |
"tau_a": 1.0, |
|
|
156 |
"tau_b": 1.0, |
|
|
157 |
"scale_cost": "max_cost", |
|
|
158 |
"rank": 7, |
|
|
159 |
"batch_size": 123, |
|
|
160 |
"initializer": "rank2", |
|
|
161 |
"initializer_kwargs": {}, |
|
|
162 |
"jit": False, |
|
|
163 |
"threshold": 2e-3, |
|
|
164 |
"min_iterations": 2, |
|
|
165 |
"max_iterations": 3, |
|
|
166 |
"gw_unbalanced_correction": False, |
|
|
167 |
"ranks": 3, |
|
|
168 |
"tolerances": 3e-2, |
|
|
169 |
# "linear_solver_kwargs": linear_solver_kwargs2, |
|
|
170 |
} |
|
|
171 |
|
|
|
172 |
gw_args_2 = {**gw_args_2, **linear_solver_kwargs2} |
|
|
173 |
|
|
|
174 |
fgw_args_1 = gw_args_1.copy() |
|
|
175 |
fgw_args_1["alpha"] = 0.6 |
|
|
176 |
|
|
|
177 |
fgw_args_2 = gw_args_2.copy() |
|
|
178 |
fgw_args_2["alpha"] = 0.4 |
|
|
179 |
|
|
|
180 |
gw_solver_args = { |
|
|
181 |
"epsilon": "epsilon", |
|
|
182 |
"rank": "rank", |
|
|
183 |
"threshold": "threshold", |
|
|
184 |
"min_iterations": "min_iterations", |
|
|
185 |
"max_iterations": "max_iterations", |
|
|
186 |
"warm_start": "warm_start", |
|
|
187 |
"initializer": "initializer", |
|
|
188 |
} |
|
|
189 |
|
|
|
190 |
gw_lr_solver_args = { |
|
|
191 |
"epsilon": "epsilon", |
|
|
192 |
"rank": "rank", |
|
|
193 |
"threshold": "threshold", |
|
|
194 |
"min_iterations": "min_iterations", |
|
|
195 |
"max_iterations": "max_iterations", |
|
|
196 |
"initializer": "initializer", |
|
|
197 |
} |
|
|
198 |
|
|
|
199 |
gw_linear_solver_args = { |
|
|
200 |
"lse_mode": "lse_mode", |
|
|
201 |
"inner_iterations": "inner_iterations", |
|
|
202 |
"threshold": "threshold", |
|
|
203 |
"norm_error": "norm_error", |
|
|
204 |
"max_iterations": "max_iterations", |
|
|
205 |
"min_iterations": "min_iterations", |
|
|
206 |
} |
|
|
207 |
|
|
|
208 |
gw_lr_linear_solver_args = { |
|
|
209 |
"lse_mode": "lse_mode", |
|
|
210 |
"inner_iterations": "inner_iterations", |
|
|
211 |
"threshold": "threshold", |
|
|
212 |
"norm_error": "norm_error", |
|
|
213 |
"max_iterations": "max_iterations", |
|
|
214 |
"min_iterations": "min_iterations", |
|
|
215 |
"gamma": "gamma", |
|
|
216 |
"gamma_rescale": "gamma_rescale", |
|
|
217 |
} |
|
|
218 |
|
|
|
219 |
quad_prob_args = { |
|
|
220 |
"tau_a": "tau_a", |
|
|
221 |
"tau_b": "tau_b", |
|
|
222 |
"gw_unbalanced_correction": "gw_unbalanced_correction", |
|
|
223 |
"ranks": "ranks", |
|
|
224 |
"tolerances": "tolerances", |
|
|
225 |
} |
|
|
226 |
|
|
|
227 |
geometry_args = {"epsilon": "_epsilon_init", "scale_cost": "_scale_cost"} |
|
|
228 |
|
|
|
229 |
pointcloud_args = { |
|
|
230 |
"batch_size": "_batch_size", |
|
|
231 |
"scale_cost": "_scale_cost", |
|
|
232 |
} |
|
|
233 |
|
|
|
234 |
lr_pointcloud_args = { |
|
|
235 |
"batch_size": "batch_size", |
|
|
236 |
"scale_cost": "_scale_cost", |
|
|
237 |
} |
|
|
238 |
|
|
|
239 |
sinkhorn_solver_args = { # dictionary with key = moscot arg name, value = ott-jax attribute |
|
|
240 |
"lse_mode": "lse_mode", |
|
|
241 |
"threshold": "threshold", |
|
|
242 |
"norm_error": "norm_error", |
|
|
243 |
"inner_iterations": "inner_iterations", |
|
|
244 |
"min_iterations": "min_iterations", |
|
|
245 |
"max_iterations": "max_iterations", |
|
|
246 |
"initializer": "initializer", |
|
|
247 |
"initializer_kwargs": "initializer_kwargs", |
|
|
248 |
} |
|
|
249 |
|
|
|
250 |
lr_sinkhorn_solver_args = sinkhorn_solver_args.copy() |
|
|
251 |
lr_sinkhorn_solver_args["gamma"] = "gamma" |
|
|
252 |
lr_sinkhorn_solver_args["gamma_rescale"] = "gamma_rescale" |
|
|
253 |
|
|
|
254 |
lin_prob_args = { |
|
|
255 |
"tau_a": "tau_a", |
|
|
256 |
"tau_b": "tau_b", |
|
|
257 |
} |
|
|
258 |
|
|
|
259 |
neurallin_cond_args_1 = { |
|
|
260 |
"batch_size": 8, |
|
|
261 |
"seed": 0, |
|
|
262 |
"iterations": 2, |
|
|
263 |
"valid_freq": 4, |
|
|
264 |
} |