--- a
+++ b/tests/backends/ott/test_backend.py
@@ -0,0 +1,408 @@
+from typing import Optional, Tuple, Type, Union
+
+import pytest
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from ott.geometry import costs
+from ott.geometry.geometry import Geometry
+from ott.geometry.low_rank import LRCGeometry
+from ott.geometry.pointcloud import PointCloud
+from ott.problems.linear.linear_problem import LinearProblem
+from ott.problems.quadratic import quadratic_problem
+from ott.problems.quadratic.quadratic_problem import QuadraticProblem
+from ott.solvers.linear import solve as sinkhorn
+from ott.solvers.linear.sinkhorn import Sinkhorn
+from ott.solvers.linear.sinkhorn_lr import LRSinkhorn
+from ott.solvers.quadratic.gromov_wasserstein import GromovWasserstein
+from ott.solvers.quadratic.gromov_wasserstein_lr import LRGromovWasserstein
+
+from moscot._types import ArrayLike, Device_t
+from moscot.backends.ott import GWSolver, SinkhornSolver
+from moscot.backends.ott._utils import InitializerResolver, alpha_to_fused_penalty
+from moscot.base.output import BaseDiscreteSolverOutput
+from moscot.base.solver import O, OTSolver
+from moscot.utils.tagged_array import Tag, TaggedArray
+from tests._utils import ATOL, RTOL, Geom_t
+from tests.plotting.conftest import PlotTester, PlotTesterMeta
+
+
+class TestSinkhorn:
+    @pytest.mark.fast
+    @pytest.mark.parametrize("jit", [False, True])
+    @pytest.mark.parametrize("eps", [None, 1e-2, 1e-1])
+    def test_matches_ott(self, x: Geom_t, eps: Optional[float], jit: bool):
+        fn = jax.jit(sinkhorn) if jit else sinkhorn
+        gt = fn(PointCloud(x, epsilon=eps))
+        solver = SinkhornSolver(jit=jit)
+        assert solver.xy is None
+        assert isinstance(solver.solver, Sinkhorn)
+
+        pred = solver(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(x)) / len(x), xy=(x, x), epsilon=eps)
+
+        assert solver.rank == -1
+        assert not solver.is_low_rank
+        assert isinstance(solver.xy, Geometry)
+        assert pred.rank == -1
+        np.testing.assert_allclose(gt.matrix, pred.transport_matrix, rtol=RTOL, atol=ATOL)
+
+    @pytest.mark.parametrize("rank", [5, 10])
+    @pytest.mark.parametrize("initializer", ["random", "rank2", "k-means"])
+    def test_solver_rank(self, y: Geom_t, rank: Optional[int], initializer: str):
+        eps = 1e-2
+        default_gamma_lr_sinhorn = 500
+        initializer = InitializerResolver.lr_from_str(initializer, rank=rank)
+        lr_sinkhorn = LRSinkhorn(rank=rank, initializer=initializer, gamma=default_gamma_lr_sinhorn)
+        problem = LinearProblem(PointCloud(y, epsilon=eps))
+        gt = lr_sinkhorn(problem)
+
+        solver = SinkhornSolver(rank=rank, initializer=initializer)
+        assert solver.rank == rank
+        assert solver.is_low_rank
+        assert solver.xy is None
+        assert isinstance(solver.solver, LRSinkhorn)
+
+        pred = solver(a=jnp.ones(len(y)) / len(y), b=jnp.ones(len(y)) / len(y), xy=(y, y), epsilon=eps)
+
+        assert isinstance(solver.xy, PointCloud)
+        assert pred.rank == rank
+        np.testing.assert_allclose(solver._problem.geom.cost_matrix, problem.geom.cost_matrix, rtol=RTOL, atol=ATOL)
+        np.testing.assert_allclose(gt.matrix, pred.transport_matrix, rtol=RTOL, atol=ATOL)
+
+    @pytest.mark.parametrize(("rank", "cost_fn"), [(2, costs.Euclidean()), (3, costs.SqPNorm(p=1.5))])
+    def test_geometry_rank(self, x: Geom_t, rank: int, cost_fn: costs.CostFn):
+        eps = 0.05
+        geom = PointCloud(x, epsilon=eps, cost_fn=cost_fn).to_LRCGeometry(rank=rank)
+        problem = LinearProblem(geom)
+        gt = Sinkhorn()(problem)
+
+        solver = SinkhornSolver()
+        assert not solver.is_low_rank
+
+        pred = solver(
+            a=jnp.ones(len(x)) / len(x),
+            b=jnp.ones(len(x)) / len(x),
+            xy=TaggedArray(x, cost=cost_fn),
+            epsilon=eps,
+            cost_matrix_rank=rank,
+        )
+
+        assert isinstance(solver.xy, LRCGeometry)
+        np.testing.assert_allclose(solver._problem.geom.cost_matrix, problem.geom.cost_matrix, rtol=RTOL, atol=ATOL)
+        np.testing.assert_allclose(gt.matrix, pred.transport_matrix, rtol=RTOL, atol=ATOL)
+
+
+class TestGW:
+    @pytest.mark.parametrize("jit", [False, True])
+    @pytest.mark.parametrize("eps", [5e-2, 1e-2, 1e-1])
+    def test_matches_ott(self, x: Geom_t, y: Geom_t, eps: Optional[float], jit: bool):
+        thresh = 1e-2
+        pc_x, pc_y = PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps)
+        prob = quadratic_problem.QuadraticProblem(pc_x, pc_y)
+        sol = GromovWasserstein(epsilon=eps, threshold=thresh, linear_solver=Sinkhorn())
+        solver = jax.jit(sol, static_argnames=["threshold", "epsilon"]) if jit else sol
+        gt = solver(prob)
+
+        solver = GWSolver(jit=jit, epsilon=eps, threshold=thresh)
+        assert isinstance(solver.solver, GromovWasserstein)
+        assert solver.x is None
+        assert solver.y is None
+
+        pred = solver(
+            a=jnp.ones(len(x)) / len(x),
+            b=jnp.ones(len(y)) / len(y),
+            x=x,
+            y=y,
+            tags={"x": "point_cloud", "y": "point_cloud"},
+            alpha=1.0,
+        )
+
+        assert solver.is_fused is False
+        assert solver.rank == -1
+        assert not solver.is_low_rank
+        assert isinstance(solver.x, PointCloud)
+        assert isinstance(solver.y, PointCloud)
+        assert pred.rank == -1
+        np.testing.assert_allclose(gt.matrix, pred.transport_matrix, rtol=RTOL, atol=ATOL)
+
+    @pytest.mark.parametrize("eps", [5e-1, 1])
+    def test_epsilon(self, x_cost: jnp.ndarray, y_cost: jnp.ndarray, eps: Optional[float]) -> None:
+        thresh = 1e-3
+        problem = QuadraticProblem(
+            geom_xx=Geometry(cost_matrix=x_cost, epsilon=eps), geom_yy=Geometry(cost_matrix=y_cost, epsilon=eps)
+        )
+        gt = GromovWasserstein(epsilon=eps, threshold=thresh, linear_solver=Sinkhorn())(problem)
+        solver = GWSolver(epsilon=eps, threshold=thresh)
+
+        pred = solver(
+            a=jnp.ones(len(x_cost)) / len(x_cost),
+            b=jnp.ones(len(y_cost)) / len(y_cost),
+            x=x_cost,
+            y=y_cost,
+            tags={"x": Tag.COST_MATRIX, "y": Tag.COST_MATRIX},
+            alpha=1.0,
+        )
+
+        assert solver.is_fused is False
+        assert pred.rank == -1
+        assert solver.rank == -1
+        assert isinstance(solver.x, Geometry)
+        assert isinstance(solver.y, Geometry)
+        np.testing.assert_allclose(gt.matrix, pred.transport_matrix, rtol=RTOL, atol=ATOL)
+
+    @pytest.mark.parametrize("rank", [-1, 7])
+    def test_solver_rank(self, x: Geom_t, y: Geom_t, rank: int) -> None:
+        thresh, eps = 1e-2, 1e-2
+        if rank > -1:
+            initializer = InitializerResolver.lr_from_str("random", rank=rank)
+            gt = LRGromovWasserstein(epsilon=eps, rank=rank, threshold=thresh, initializer=initializer)(
+                QuadraticProblem(PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps))
+            )
+
+        else:
+            gt = GromovWasserstein(epsilon=eps, threshold=thresh, linear_solver=Sinkhorn(threshold=thresh))(
+                QuadraticProblem(PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps))
+            )
+
+        solver = GWSolver(rank=rank, epsilon=eps, threshold=thresh)
+        pred = solver(
+            a=jnp.ones(len(x)) / len(x),
+            b=jnp.ones(len(y)) / len(y),
+            x=x,
+            y=y,
+            tags={"x": "point_cloud", "y": "point_cloud"},
+            alpha=1.0,
+        )
+
+        assert solver.is_fused is False
+        assert solver.rank == rank
+        assert pred.rank == rank
+        np.testing.assert_allclose(gt.matrix, pred.transport_matrix, rtol=RTOL, atol=ATOL)
+
+
+class TestFGW:
+    @pytest.mark.parametrize("alpha", [0.25, 0.75])
+    @pytest.mark.parametrize("eps", [1e-2, 1e-1, 5e-1])
+    def test_matches_ott(self, x: Geom_t, y: Geom_t, xy: Geom_t, eps: Optional[float], alpha: float) -> None:
+        thresh = 1e-2
+        xx, yy = xy
+
+        ott_solver = GromovWasserstein(epsilon=eps, threshold=thresh, linear_solver=Sinkhorn())
+        problem = quadratic_problem.QuadraticProblem(
+            geom_xx=PointCloud(x, epsilon=eps),
+            geom_yy=PointCloud(y, epsilon=eps),
+            geom_xy=PointCloud(xx, yy, epsilon=eps),
+            fused_penalty=alpha_to_fused_penalty(alpha),
+        )
+        gt = ott_solver(problem)
+
+        solver = GWSolver(epsilon=eps, threshold=thresh)
+        assert isinstance(solver.solver, GromovWasserstein)
+        assert solver.xy is None
+
+        pred = solver(
+            a=jnp.ones(len(x)) / len(x),
+            b=jnp.ones(len(y)) / len(y),
+            x=x,
+            y=y,
+            xy=xy,
+            alpha=alpha,
+            tags={"x": "point_cloud", "y": "point_cloud", "xy": "point_cloud"},
+        )
+
+        assert solver.is_fused is True
+        assert solver.rank == -1
+        assert pred.rank == -1
+        assert isinstance(solver.xy, PointCloud)
+        np.testing.assert_allclose(gt.matrix, pred.transport_matrix, rtol=RTOL, atol=ATOL)
+
+    @pytest.mark.fast
+    @pytest.mark.parametrize("alpha", [0.1, 0.9])
+    def test_alpha(self, x: Geom_t, y: Geom_t, xy: Geom_t, alpha: float) -> None:
+        thresh, eps = 5e-2, 1e-1
+        xx, yy = xy
+
+        ott_solver = GromovWasserstein(epsilon=eps, threshold=thresh, linear_solver=Sinkhorn())
+        problem = quadratic_problem.QuadraticProblem(
+            geom_xx=PointCloud(x, epsilon=eps),
+            geom_yy=PointCloud(y, epsilon=eps),
+            geom_xy=PointCloud(xx, yy, epsilon=eps),
+            fused_penalty=alpha_to_fused_penalty(alpha),
+        )
+        gt = ott_solver(problem)
+
+        solver = GWSolver(epsilon=eps, threshold=thresh)
+        pred = solver(
+            a=jnp.ones(len(x)) / len(x),
+            b=jnp.ones(len(y)) / len(y),
+            x=x,
+            y=y,
+            xy=xy,
+            alpha=alpha,
+            tags={"x": "point_cloud", "y": "point_cloud", "xy": "point_cloud"},
+        )
+
+        assert solver.is_fused is True
+        assert not solver.is_low_rank
+        assert pred.rank == -1
+        assert not pred.is_low_rank
+        np.testing.assert_allclose(gt.matrix, pred.transport_matrix, rtol=RTOL, atol=ATOL)
+
+    @pytest.mark.parametrize("eps", [1e-3, 5e-2])
+    def test_epsilon(
+        self, x_cost: jnp.ndarray, y_cost: jnp.ndarray, xy_cost: jnp.ndarray, eps: Optional[float]
+    ) -> None:
+        thresh, alpha = 5e-1, 0.66
+
+        problem = QuadraticProblem(
+            geom_xx=Geometry(cost_matrix=x_cost, epsilon=eps),
+            geom_yy=Geometry(cost_matrix=y_cost, epsilon=eps),
+            geom_xy=Geometry(cost_matrix=xy_cost, epsilon=eps),
+            fused_penalty=alpha_to_fused_penalty(alpha),
+        )
+        gt = GromovWasserstein(epsilon=eps, threshold=thresh, linear_solver=Sinkhorn())(problem)
+
+        solver = GWSolver(epsilon=eps, threshold=thresh)
+        pred = solver(
+            a=jnp.ones(len(x_cost)) / len(x_cost),
+            b=jnp.ones(len(y_cost)) / len(y_cost),
+            x=x_cost,
+            y=y_cost,
+            xy=xy_cost,
+            alpha=alpha,
+            tags={"x": Tag.COST_MATRIX, "y": Tag.COST_MATRIX, "xy": Tag.COST_MATRIX},
+        )
+
+        assert solver.is_fused is True
+        assert pred.rank == -1
+        np.testing.assert_allclose(gt.matrix, pred.transport_matrix, rtol=RTOL, atol=ATOL)
+
+
+class TestScaleCost:
+    @pytest.mark.parametrize("scale_cost", [1.0, 0.5, "mean", "max_cost", "max_norm", "max_bound"])
+    def test_scale(self, x: Geom_t, scale_cost: Union[float, str]) -> None:
+        eps = 1e-2
+        gt = sinkhorn(PointCloud(x, epsilon=eps, scale_cost=scale_cost))
+
+        solver = SinkhornSolver()
+        pred = solver(
+            a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(x)) / len(x), xy=(x, x), epsilon=eps, scale_cost=scale_cost
+        )
+
+        np.testing.assert_allclose(gt.matrix, pred.transport_matrix, rtol=RTOL, atol=ATOL)
+
+
+class TestSolverOutput:
+    def test_properties(self, x: ArrayLike, y: ArrayLike) -> None:
+        solver = SinkhornSolver()
+
+        out = solver(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), xy=(x, y), epsilon=1e-1)
+        a, b = out.a, out.b
+
+        assert isinstance(a, jnp.ndarray)
+        assert a.shape == (out.shape[0],)
+        assert isinstance(b, jnp.ndarray)
+        assert b.shape == (out.shape[1],)
+
+        assert isinstance(out.converged, bool)
+        assert isinstance(out.cost, float)
+        assert out.cost >= 0
+        assert out.shape == (x.shape[0], y.shape[0])
+
+    @pytest.mark.parametrize("batched", [False, True])
+    @pytest.mark.parametrize("rank", [-1, 5])
+    def test_push(
+        self,
+        x: Geom_t,
+        y: Geom_t,
+        ab: Tuple[ArrayLike, ArrayLike],
+        rank: int,
+        batched: bool,
+    ) -> None:
+        a, _ = ab
+        a, ndim = (a, a.shape[1]) if batched else (a[:, 0], None)
+        solver = SinkhornSolver(rank=rank)
+
+        out = solver(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), xy=(x, y))
+        p = out.push(a, scale_by_marginals=False)
+
+        assert isinstance(out, BaseDiscreteSolverOutput)
+        assert isinstance(p, jnp.ndarray)
+        if batched:
+            assert p.shape == (out.shape[1], ndim)
+        else:
+            assert p.shape == (out.shape[1],)
+
+    @pytest.mark.parametrize("batched", [False, True])
+    @pytest.mark.parametrize("solver_t", [GWSolver])
+    def test_pull(
+        self,
+        x: ArrayLike,
+        y: ArrayLike,
+        xy: ArrayLike,
+        ab: Tuple[ArrayLike, ArrayLike],
+        solver_t: Type[OTSolver[O]],
+        batched: bool,
+    ) -> None:
+        _, b = ab
+        b, ndim = (b, b.shape[1]) if batched else (b[:, 0], None)
+        xx, yy = xy
+        solver = solver_t()
+        out = solver(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, xy=(xx, yy), alpha=0.5)
+        p = out.pull(b, scale_by_marginals=False)
+
+        assert isinstance(out, BaseDiscreteSolverOutput)
+        assert isinstance(p, jnp.ndarray)
+        if batched:
+            assert p.shape == (out.shape[0], ndim)
+        else:
+            assert p.shape == (out.shape[0],)
+
+    @pytest.mark.parametrize("batched", [False, True])
+    @pytest.mark.parametrize("forward", [False, True])
+    def test_scale_by_marginals(self, x: Geom_t, ab: Tuple[ArrayLike, ArrayLike], forward: bool, batched: bool) -> None:
+        solver = SinkhornSolver()
+        a, _ = ab
+        z = a if batched else a[:, 0]
+
+        out = solver(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(x)) / len(x), xy=(x, x))
+        p = (out.push if forward else out.pull)(z, scale_by_marginals=True)
+
+        if batched:
+            np.testing.assert_allclose(p.sum(axis=0), z.sum(axis=0))
+        else:
+            np.testing.assert_allclose(p.sum(), z.sum())
+
+    @pytest.mark.parametrize("device", [None, "cpu", "cpu:0", "cpu:1", "explicit"])
+    def test_to_device(self, x: Geom_t, device: Optional[Device_t]) -> None:
+        # simple integration test
+        solver = SinkhornSolver()
+        if device == "explicit":
+            device = jax.devices()[0]
+            _ = solver(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(x)) / len(x), xy=(x, x), device=device)
+        elif device == "cpu:1":
+            with pytest.raises(IndexError, match=r"Unable to fetch the device with `id=1`."):
+                _ = solver(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(x)) / len(x), xy=(x, x), device=device)
+        else:
+            _ = solver(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(x)) / len(x), xy=(x, x), device=device)
+
+
+class TestOutputPlotting(PlotTester, metaclass=PlotTesterMeta):
+    def test_plot_costs(self, x: Geom_t, y: Geom_t):
+        out = GWSolver()(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, alpha=1.0)
+        out.plot_costs()
+
+    def test_plot_costs_last(self, x: Geom_t, y: Geom_t):
+        out = GWSolver(rank=2)(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, alpha=1.0)
+        out.plot_costs(last=3)
+
+    def test_plot_errors_sink(self, x: Geom_t, y: Geom_t):
+        out = SinkhornSolver()(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), xy=(x, y))
+        out.plot_errors()
+
+    def test_plot_errors_gw(self, x: Geom_t, y: Geom_t):
+        out = GWSolver(store_inner_errors=True)(
+            a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, alpha=1.0
+        )
+        out.plot_errors()