--- a
+++ b/tests/problems/conftest.py
@@ -0,0 +1,264 @@
+import pytest
+
+import numpy as np
+import pandas as pd
+from sklearn.metrics import pairwise_distances
+
+import anndata as ad
+from anndata import AnnData
+
+from tests._utils import Geom_t
+
+
+@pytest.fixture
+def adata_with_cost_matrix(adata_x: Geom_t, adata_y: Geom_t) -> AnnData:
+    adata = ad.concat([adata_x, adata_y], label="batch", index_unique="-")
+    C = pairwise_distances(adata_x.obsm["X_pca"], adata_y.obsm["X_pca"]) ** 2
+    adata.obs["batch"] = pd.to_numeric(adata.obs["batch"])
+    adata.uns[0] = C / C.mean()  # TODO(@MUCDK) make a callback function and replace this part
+    return adata
+
+
+@pytest.fixture
+def adata_time_with_tmap(adata_time: AnnData) -> AnnData:
+    adata = adata_time[adata_time.obs["time"].isin([0, 1])].copy()
+    rng = np.random.RandomState(42)
+    cell_types = ["cell_A", "cell_B", "cell_C", "cell_D"]
+
+    cell_d1 = rng.multinomial(len(adata[adata.obs["time"] == 0]), [1 / len(cell_types)] * len(cell_types))
+    cell_d2 = rng.multinomial(len(adata[adata.obs["time"] == 0]), [1 / len(cell_types)] * len(cell_types))
+    a1 = np.concatenate(
+        [["cell_A"] * cell_d1[0], ["cell_B"] * cell_d1[1], ["cell_C"] * cell_d1[2], ["cell_D"] * cell_d1[3]]
+    ).flatten()
+    a2 = np.concatenate(
+        [["cell_A"] * cell_d2[0], ["cell_B"] * cell_d2[1], ["cell_C"] * cell_d2[2], ["cell_D"] * cell_d2[3]]
+    ).flatten()
+
+    adata.obs["cell_type"] = np.concatenate([a1, a2])
+    adata.obs["cell_type"] = adata.obs["cell_type"].astype("category")
+    cell_numbers_source = dict(adata[adata.obs["time"] == 0].obs["cell_type"].value_counts())
+    cell_numbers_target = dict(adata[adata.obs["time"] == 1].obs["cell_type"].value_counts())
+    trans_matrix = np.abs(rng.randn(len(cell_types), len(cell_types)))
+    trans_matrix = trans_matrix / trans_matrix.sum(axis=1, keepdims=1)
+
+    cell_transition_gt = pd.DataFrame(data=trans_matrix, index=cell_types, columns=cell_types)
+
+    blocks = []
+    for cell_row in cell_types:
+        block_row = []
+        for cell_col in cell_types:
+            sub_trans_matrix = np.abs(rng.randn(cell_numbers_source[cell_row], cell_numbers_target[cell_col]))
+            sub_trans_matrix /= sub_trans_matrix.sum() * (1 / cell_transition_gt.loc[cell_row, cell_col])
+            block_row.append(sub_trans_matrix)
+        blocks.append(block_row)
+    transport_matrix = np.block(blocks)
+    adata.uns["transport_matrix"] = transport_matrix
+    adata.uns["cell_transition_gt"] = cell_transition_gt
+
+    return adata
+
+
+# keys for marginals
+@pytest.fixture(
+    params=[
+        (None, None),
+        ("left_marginals_balanced", "right_marginals_balanced"),
+    ],
+    ids=["default", "balanced"],
+)
+def marginal_keys(request):
+    return request.param
+
+
+sinkhorn_args_1 = {
+    "epsilon": 0.7,
+    "tau_a": 1.0,
+    "tau_b": 1.0,
+    "rank": 7,
+    "initializer": "rank2",
+    "initializer_kwargs": {},
+    "jit": False,
+    "threshold": 2e-3,
+    "lse_mode": True,
+    "norm_error": 2,
+    "inner_iterations": 3,
+    "min_iterations": 4,
+    "max_iterations": 9,
+    "gamma": 9.4,
+    "gamma_rescale": False,
+    "batch_size": None,  # in to_LRC() `batch_size` cannot be passed so we expect None.
+    "scale_cost": "max_cost",
+}
+
+
+sinkhorn_args_2 = {  # no gamma/gamma_rescale as these are LR-specific
+    "epsilon": 0.8,
+    "tau_a": 0.9,
+    "tau_b": 0.8,
+    "rank": -1,
+    "batch_size": 125,
+    "initializer": "gaussian",
+    "initializer_kwargs": {},
+    "jit": True,
+    "threshold": 3e-3,
+    "lse_mode": False,
+    "norm_error": 3,
+    "inner_iterations": 4,
+    "min_iterations": 1,
+    "max_iterations": 2,
+    "scale_cost": "mean",
+}
+
+linear_solver_kwargs1 = {
+    "inner_iterations": 1,
+    "min_iterations": 5,
+    "max_iterations": 7,
+    "lse_mode": False,
+    "threshold": 5e-2,
+    "norm_error": 4,
+}
+
+gw_args_1 = {  # no gamma/gamma_rescale/tolerances/ranks as these are LR-specific
+    "epsilon": 0.5,
+    "tau_a": 0.7,
+    "tau_b": 0.8,
+    "scale_cost": "max_cost",
+    "rank": -1,
+    "batch_size": 122,
+    "initializer": None,
+    "initializer_kwargs": {},
+    "jit": True,
+    "threshold": 3e-2,
+    "min_iterations": 3,
+    "max_iterations": 4,
+    "gw_unbalanced_correction": True,
+    "ranks": 4,
+    "tolerances": 2e-2,
+    "warm_start": False,
+    "linear_solver_kwargs": linear_solver_kwargs1,
+}
+
+linear_solver_kwargs2 = {
+    "inner_iterations": 3,
+    "min_iterations": 7,
+    "max_iterations": 8,
+    "lse_mode": True,
+    "threshold": 4e-2,
+    "norm_error": 3,
+    "gamma": 9.4,
+    "gamma_rescale": False,
+}
+
+gw_args_2 = {
+    "alpha": 0.4,
+    "epsilon": 0.7,
+    "tau_a": 1.0,
+    "tau_b": 1.0,
+    "scale_cost": "max_cost",
+    "rank": 7,
+    "batch_size": 123,
+    "initializer": "rank2",
+    "initializer_kwargs": {},
+    "jit": False,
+    "threshold": 2e-3,
+    "min_iterations": 2,
+    "max_iterations": 3,
+    "gw_unbalanced_correction": False,
+    "ranks": 3,
+    "tolerances": 3e-2,
+    # "linear_solver_kwargs": linear_solver_kwargs2,
+}
+
+gw_args_2 = {**gw_args_2, **linear_solver_kwargs2}
+
+fgw_args_1 = gw_args_1.copy()
+fgw_args_1["alpha"] = 0.6
+
+fgw_args_2 = gw_args_2.copy()
+fgw_args_2["alpha"] = 0.4
+
+gw_solver_args = {
+    "epsilon": "epsilon",
+    "rank": "rank",
+    "threshold": "threshold",
+    "min_iterations": "min_iterations",
+    "max_iterations": "max_iterations",
+    "warm_start": "warm_start",
+    "initializer": "initializer",
+}
+
+gw_lr_solver_args = {
+    "epsilon": "epsilon",
+    "rank": "rank",
+    "threshold": "threshold",
+    "min_iterations": "min_iterations",
+    "max_iterations": "max_iterations",
+    "initializer": "initializer",
+}
+
+gw_linear_solver_args = {
+    "lse_mode": "lse_mode",
+    "inner_iterations": "inner_iterations",
+    "threshold": "threshold",
+    "norm_error": "norm_error",
+    "max_iterations": "max_iterations",
+    "min_iterations": "min_iterations",
+}
+
+gw_lr_linear_solver_args = {
+    "lse_mode": "lse_mode",
+    "inner_iterations": "inner_iterations",
+    "threshold": "threshold",
+    "norm_error": "norm_error",
+    "max_iterations": "max_iterations",
+    "min_iterations": "min_iterations",
+    "gamma": "gamma",
+    "gamma_rescale": "gamma_rescale",
+}
+
+quad_prob_args = {
+    "tau_a": "tau_a",
+    "tau_b": "tau_b",
+    "gw_unbalanced_correction": "gw_unbalanced_correction",
+    "ranks": "ranks",
+    "tolerances": "tolerances",
+}
+
+geometry_args = {"epsilon": "_epsilon_init", "scale_cost": "_scale_cost"}
+
+pointcloud_args = {
+    "batch_size": "_batch_size",
+    "scale_cost": "_scale_cost",
+}
+
+lr_pointcloud_args = {
+    "batch_size": "batch_size",
+    "scale_cost": "_scale_cost",
+}
+
+sinkhorn_solver_args = {  # dictionary with key = moscot arg name, value = ott-jax attribute
+    "lse_mode": "lse_mode",
+    "threshold": "threshold",
+    "norm_error": "norm_error",
+    "inner_iterations": "inner_iterations",
+    "min_iterations": "min_iterations",
+    "max_iterations": "max_iterations",
+    "initializer": "initializer",
+    "initializer_kwargs": "initializer_kwargs",
+}
+
+lr_sinkhorn_solver_args = sinkhorn_solver_args.copy()
+lr_sinkhorn_solver_args["gamma"] = "gamma"
+lr_sinkhorn_solver_args["gamma_rescale"] = "gamma_rescale"
+
+lin_prob_args = {
+    "tau_a": "tau_a",
+    "tau_b": "tau_b",
+}
+
+neurallin_cond_args_1 = {
+    "batch_size": 8,
+    "seed": 0,
+    "iterations": 2,
+    "valid_freq": 4,
+}