--- a +++ b/tests/solvers/test_base_solver.py @@ -0,0 +1,87 @@ +from typing import Tuple + +import pytest + +import jax.numpy as jnp +import jax.random +import numpy as np +import scipy.sparse as sp + +from moscot.base.output import MatrixSolverOutput +from tests._utils import ATOL, RTOL, MockSolverOutput + + +class TestBaseDiscreteSolverOutput: + @pytest.mark.parametrize("batch_size", [1, 4]) + @pytest.mark.parametrize("threshold", [0.0, 1e-1, 1.0]) + @pytest.mark.parametrize("shape", [(7, 2), (91, 103)]) + def test_sparsify_threshold(self, batch_size: int, threshold: float, shape: Tuple[int, int]) -> None: + rng = np.random.RandomState(42) + tmap = np.abs(rng.rand(shape[0], shape[1])) + output = MockSolverOutput(tmap) + mso = output.sparsify(mode="threshold", value=threshold, batch_size=batch_size) + assert isinstance(mso, MatrixSolverOutput) + res = mso.transport_matrix + assert isinstance(res, sp.csr_matrix) + assert res.shape == shape + np.testing.assert_array_equal(res.data >= 0.0, True) + vec_pull = np.abs(rng.randn(shape[1], 1)) + pull1 = mso.pull(vec_pull) + pull2 = output.pull(vec_pull) + assert isinstance(pull1, np.ndarray) + + if threshold == 0.0: + np.testing.assert_allclose(res.toarray(), tmap, rtol=RTOL, atol=ATOL) + np.testing.assert_array_less(0.5, np.corrcoef(pull1.squeeze(), pull2.squeeze())[0, 1]) + elif threshold == 1e-1: + data = res.data + np.testing.assert_equal(np.sum((data > threshold) + (data == 0)), len(data)) + np.testing.assert_array_less(0.5, np.corrcoef(pull1.squeeze(), pull2.squeeze())[0, 1]) + elif threshold == 1.0: + assert res.nnz == 0 + else: + raise ValueError(f"Threshold {threshold} not expected.") + + @pytest.mark.parametrize("batch_size", [1, 4]) + @pytest.mark.parametrize("shape", [(7, 2), (91, 103)]) + def test_sparsify_minrow(self, batch_size: int, shape: Tuple[int, int]) -> None: + rng = np.random.RandomState(42) + tmap = np.abs(rng.rand(shape[0], shape[1])) + 1e-3 # make sure it's not 0 + output = MockSolverOutput(tmap) + mso = output.sparsify(mode="min_row", batch_size=batch_size) + assert isinstance(mso, MatrixSolverOutput) + res = mso.transport_matrix + assert isinstance(res, sp.csr_matrix) + assert res.shape == shape + np.testing.assert_array_equal(res.data >= 0.0, True) + np.testing.assert_array_equal(np.sum(res.toarray(), axis=1) > 0.0, True) + vec_pull = np.abs(rng.randn(shape[1], 1)) + pull1 = mso.pull(vec_pull) + pull2 = output.pull(vec_pull) + assert isinstance(pull1, np.ndarray) + np.testing.assert_array_less(0.5, np.corrcoef(pull1.squeeze(), pull2.squeeze())[0, 1]) + + @pytest.mark.parametrize("batch_size", [1, 4]) + @pytest.mark.parametrize("threshold", [0, 10, 100]) + @pytest.mark.parametrize("shape", [(7, 2), (91, 103)]) + def test_sparsify_percentile(self, batch_size: int, threshold: float, shape: Tuple[int, int]) -> None: + rng = np.random.RandomState(42) + tmap = jnp.abs(jax.random.normal(jax.random.PRNGKey(0), shape=shape)) + 1e-3 + output = MockSolverOutput(tmap) + mso = output.sparsify(mode="percentile", value=threshold, batch_size=batch_size, n_samples=shape[0], seed=42) + assert isinstance(mso, MatrixSolverOutput) + res = mso.transport_matrix + assert isinstance(res, sp.csr_matrix) + assert res.shape == shape + np.testing.assert_array_equal(res.data >= 0.0, True) + n, m = shape + if threshold == 0: + assert np.sum(tmap != res.toarray()) < n * m * 0.1 # this only holds with probability < 1 + if threshold == 100: + assert res.nnz < n * m * 0.9 # this only holds with probability < 1 + vec_pull = np.abs(rng.randn(shape[1], 1)) + pull1 = mso.pull(vec_pull) + pull2 = output.pull(vec_pull) + assert isinstance(pull1, np.ndarray) + if threshold < 100: + np.testing.assert_array_less(0.5, np.corrcoef(pull1.squeeze(), pull2.squeeze())[0, 1])