Diff of /src/moscot/base/output.py [000000] .. [6ff4a8]

Switch to side-by-side view

--- a
+++ b/src/moscot/base/output.py
@@ -0,0 +1,414 @@
+from __future__ import annotations
+
+import abc
+import copy
+import functools
+from abc import abstractmethod
+from typing import Any, Callable, Iterable, Literal, Optional, Union
+
+import numpy as np
+import scipy.sparse as sp
+from scipy.sparse.linalg import LinearOperator
+
+from moscot._logging import logger
+from moscot._types import ArrayLike, Device_t, DTypeLike
+
+__all__ = ["BaseDiscreteSolverOutput", "MatrixSolverOutput", "BaseNeuralOutput"]
+
+
+class BaseSolverOutput(abc.ABC):
+    """Base class for all solver outputs."""
+
+    @abc.abstractmethod
+    def push(self, x: ArrayLike, **kwargs) -> ArrayLike:
+        """Push the solution based on a condition."""
+
+    @abc.abstractmethod
+    def _apply_forward(self, x: ArrayLike) -> ArrayLike:
+        """Apply the transport matrix in the forward direction."""
+
+    @property
+    @abc.abstractmethod
+    def shape(self) -> tuple[int, int]:
+        """Shape of the problem."""
+
+    @property
+    @abc.abstractmethod
+    def converged(self) -> bool:
+        """Whether the solver converged."""
+
+    @abc.abstractmethod
+    def to(self: BaseSolverOutput, device: Optional[Device_t] = None) -> BaseSolverOutput:
+        """Transfer self to another compute device.
+
+        Parameters
+        ----------
+        device
+            Device where to transfer the solver output. If :obj:`None`, use the default device.
+
+        Returns
+        -------
+        Self transferred to the ``device``.
+        """
+
+    def _format_params(self, fmt: Callable[[Any], str]) -> str:
+        params = {"shape": self.shape}
+        return ", ".join(f"{name}={fmt(val)}" for name, val in params.items())
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}[{self._format_params(repr)}]"
+
+    def __str__(self) -> str:
+        return f"{self.__class__.__name__}[{self._format_params(str)}]"
+
+
+class BaseDiscreteSolverOutput(BaseSolverOutput, abc.ABC):
+    """Base class for all discrete solver outputs."""
+
+    @abc.abstractmethod
+    def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
+        """Apply :attr:`transport_matrix` to an array of shape ``[n, d]`` or ``[m, d]``."""
+
+    @property
+    @abc.abstractmethod
+    def transport_matrix(self) -> ArrayLike:
+        """Transport matrix of shape ``[n, m]``."""
+
+    @property
+    @abc.abstractmethod
+    def cost(self) -> float:
+        """Regularized :term:`OT` cost."""
+
+    @property
+    @abc.abstractmethod
+    def potentials(self) -> Optional[tuple[ArrayLike, ArrayLike]]:
+        """:term:`Dual potentials` :math:`f` and :math:`g`.
+
+        Only valid for the :term:`Sinkhorn` algorithm.
+        """
+
+    @property
+    @abc.abstractmethod
+    def shape(self) -> tuple[int, int]:
+        """Shape of the :attr:`transport_matrix`."""
+
+    @property
+    @abc.abstractmethod
+    def is_linear(self) -> bool:
+        """Whether the output is a solution to a :term:`linear problem`."""
+
+    @property
+    def rank(self) -> int:
+        """Rank of the :attr:`transport_matrix`."""
+        return -1
+
+    @property
+    def is_low_rank(self) -> bool:
+        """Whether the :attr:`transport_matrix` is :term:`low-rank`."""
+        return self.rank > -1
+
+    @abc.abstractmethod
+    def _ones(self, n: int) -> ArrayLike:
+        """Generate vector of 1s of shape ``[n,]``."""
+
+    def push(self, x: ArrayLike, scale_by_marginals: bool = False) -> ArrayLike:
+        """Push mass through the :attr:`transport_matrix`.
+
+        It is equivalent to :math:`T^T x` but without instantiating the transport matrix :math:`T`, if possible.
+
+        Parameters
+        ----------
+        x
+            Array of shape ``[n,]`` or ``[n, d]`` to push.
+        scale_by_marginals
+            Whether to scale by the source marginals :attr:`a`.
+
+        Returns
+        -------
+        Array of shape ``[m,]`` or ``[m, d]``, depending on the shape of ``x``.
+        """
+        if x.ndim not in (1, 2):
+            raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.")
+        if x.shape[0] != self.shape[0]:
+            raise ValueError(f"Expected array to have shape `({self.shape[0]}, ...)`, found `{x.shape}`.")
+        if scale_by_marginals:
+            x = self._scale_by_marginals(x, forward=True)
+        return self._apply(x, forward=True)
+
+    def pull(self, x: ArrayLike, scale_by_marginals: bool = False) -> ArrayLike:
+        """Pull mass through the :attr:`transport_matrix`.
+
+        It is equivalent to :math:`T x` but without instantiating the transport matrix :math:`T`, if possible.
+
+        Parameters
+        ----------
+        x
+            Array of shape ``[m,]`` or ``[m, d]`` to pull.
+        scale_by_marginals
+            Whether to scale by the target marginals :attr:`b`.
+
+        Returns
+        -------
+        Array of shape ``[n,]`` or ``[n, d]``, depending on the shape of ``x``.
+        """
+        if x.ndim not in (1, 2):
+            raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.")
+        if x.shape[0] != self.shape[1]:
+            raise ValueError(f"Expected array to have shape `({self.shape[1]}, ...)`, found `{x.shape}`.")
+        if scale_by_marginals:
+            x = self._scale_by_marginals(x, forward=False)
+        return self._apply(x, forward=False)
+
+    def as_linear_operator(self, scale_by_marginals: bool = False) -> LinearOperator:
+        """Transform :attr:`transport_matrix` into a linear operator.
+
+        Parameters
+        ----------
+        scale_by_marginals
+            Whether to scale by :term:`marginals`.
+
+        Returns
+        -------
+        The :attr:`transport_matrix` as a linear operator.
+        """
+        push = functools.partial(self.push, scale_by_marginals=scale_by_marginals)
+        pull = functools.partial(self.pull, scale_by_marginals=scale_by_marginals)
+        # push: a @ X (rmatvec)
+        # pull: X @ a (matvec)
+        return LinearOperator(shape=self.shape, dtype=self.dtype, matvec=pull, rmatvec=push)
+
+    def chain(self, outputs: Iterable[BaseDiscreteSolverOutput], scale_by_marginals: bool = False) -> LinearOperator:
+        """Chain subsequent applications of :attr:`transport_matrix`.
+
+        Parameters
+        ----------
+        outputs
+            Sequence of transport matrices to chain.
+        scale_by_marginals
+            Whether to scale by :term:`marginals`.
+
+        Returns
+        -------
+        The chained transport matrices as a linear operator.
+        """
+        op = self.as_linear_operator(scale_by_marginals)
+        for out in outputs:
+            op *= out.as_linear_operator(scale_by_marginals)
+
+        return op
+
+    def sparsify(
+        self,
+        mode: Literal["threshold", "percentile", "min_row"],
+        value: Optional[float] = None,
+        batch_size: int = 1024,
+        n_samples: Optional[int] = None,
+        seed: Optional[int] = None,
+    ) -> MatrixSolverOutput:
+        """Sparsify the :attr:`transport_matrix`.
+
+        This function sets all entries of the transport matrix below a certain threshold to :math:`0` and
+        returns a :class:`~moscot.base.output.MatrixSolverOutput` with sparsified transport matrix stored
+        as a :class:`~scipy.sparse.csr_matrix`.
+
+        .. warning::
+            This function only serves for interfacing software which has to instantiate the transport matrix,
+            :mod:`moscot` never uses the sparsified transport matrix.
+
+        Parameters
+        ----------
+        mode
+            How to determine the value below which entries are set to :math:`0`. Valid options are:
+
+            - `'threshold'` - ``value`` is the threshold below which entries are set to :math:`0`.
+            - `'percentile'` - ``value`` is the percentile in :math:`[0, 100]` of the :attr:`transport_matrix`.
+              below which entries are set to :math:`0`.
+            - `'min_row'` - ``value`` is not used, it is chosen such that each row has at least 1 non-zero entry.
+        value
+            Value to use for sparsification.
+        batch_size
+            How many rows to materialize when sparsifying the :attr:`transport_matrix`.
+        n_samples
+            If ``mode = 'percentile'``, determine the number of samples based on which the percentile is computed
+            stochastically. Note this means that a matrix of shape `[n_samples, min(transport_matrix.shape)]`
+            has to be instantiated. If `None`, ``n_samples`` is set to ``batch_size``.
+        seed
+            Random seed needed for sampling if ``mode = 'percentile'``.
+
+        Returns
+        -------
+        Output with sparsified transport matrix.
+        """
+        n, m = self.shape
+        if mode == "threshold":
+            if value is None:
+                raise ValueError("If `mode = 'threshold'`, `threshold` cannot be `None`.")
+            thr = value
+        elif mode == "percentile":
+            if value is None:
+                raise ValueError("If `mode = 'percentile'`, `threshold` cannot be `None`.")
+            rng = np.random.RandomState(seed=seed)
+            n_samples = n_samples if n_samples is not None else batch_size
+            k = min(n_samples, n)
+            x = np.zeros((m, k))
+            rows = rng.choice(m, size=k)
+            x[rows, np.arange(k)] = 1.0
+            res = self.pull(x, scale_by_marginals=False)  # tmap @ indicator_vectors
+            thr = np.percentile(res, value)
+        elif mode == "min_row":
+            thr = np.inf
+            for batch in range(0, m, batch_size):
+                x = np.eye(m, min(batch_size, m - batch), -(min(batch, m)))
+                res = self.pull(x, scale_by_marginals=False)  # tmap @ indicator_vectors
+                thr = min(thr, float(res.max(axis=1).min()))
+        else:
+            raise NotImplementedError(f"Mode `{mode}` is not yet implemented.")
+
+        k, func, fn_stack = (n, self.push, sp.vstack) if n < m else (m, self.pull, sp.hstack)
+        tmaps_sparse: list[sp.csr_matrix] = []
+
+        for batch in range(0, k, batch_size):
+            x = np.eye(k, min(batch_size, k - batch), -(min(batch, k)), dtype=float)
+            res = np.array(func(x, scale_by_marginals=False))
+            res[res < thr] = 0.0
+            tmaps_sparse.append(sp.csr_matrix(res.T if n < m else res))
+
+        return MatrixSolverOutput(
+            transport_matrix=fn_stack(tmaps_sparse), cost=self.cost, converged=self.converged, is_linear=self.is_linear
+        )
+
+    @property
+    def a(self) -> ArrayLike:
+        """:term:`Marginals` of the source distribution.
+
+        If the output of an :term:`unbalanced OT problem`, these are the posterior marginals.
+        """
+        return self.pull(self._ones(self.shape[1]))
+
+    @property
+    def b(self) -> ArrayLike:
+        """:term:`Marginals` of the target distribution.
+
+        If the output of an :term:`unbalanced OT problem`, these are the posterior marginals.
+        """
+        return self.push(self._ones(self.shape[0]))
+
+    @property
+    def dtype(self) -> DTypeLike:
+        """Underlying data type."""
+        return self.a.dtype
+
+    def _format_params(self, fmt: Callable[[Any], str]) -> str:
+        params = {"shape": self.shape, "cost": round(self.cost, 4), "converged": self.converged}
+        return ", ".join(f"{name}={fmt(val)}" for name, val in params.items())
+
+    def _scale_by_marginals(self, x: ArrayLike, *, forward: bool, eps: float = 1e-12) -> ArrayLike:
+        # alt. we could use the public push/pull
+        marginals = self.a if forward else self.b
+        if x.ndim == 2:
+            marginals = marginals[:, None]
+        return x / (marginals + eps)
+
+    def __bool__(self) -> bool:
+        return self.converged
+
+
+class MatrixSolverOutput(BaseDiscreteSolverOutput):
+    """:term:`OT` solution with a materialized transport matrix.
+
+    Parameters
+    ----------
+    transport_matrix
+        Transport matrix of shape ``[n, m]``.
+    cost
+        Cost of an :term:`OT` problem.
+    converged
+        Whether the solution converged.
+    is_linear
+        Whether this is a solution to a :term:`linear problem`.
+    """
+
+    # TODO(michalk8): don't provide defaults?
+    def __init__(
+        self,
+        transport_matrix: Union[ArrayLike, sp.spmatrix],
+        *,
+        cost: float = np.nan,
+        converged: bool = True,
+        is_linear: bool = True,
+    ):
+        super().__init__()
+        self._transport_matrix = transport_matrix
+        self._cost = cost
+        self._converged = converged
+        self._is_linear = is_linear
+
+    def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
+        if forward:
+            return self.transport_matrix.T @ x
+        return self.transport_matrix @ x
+
+    def _apply_forward(self, x: ArrayLike) -> ArrayLike:
+        return self._apply(x, forward=True)
+
+    @property
+    def transport_matrix(self) -> ArrayLike:  # noqa: D102
+        return self._transport_matrix
+
+    @property
+    def shape(self) -> tuple[int, ...]:  # noqa: D102
+        return self.transport_matrix.shape
+
+    def to(  # noqa: D102
+        self, device: Optional[Device_t] = None, dtype: Optional[DTypeLike] = None
+    ) -> BaseDiscreteSolverOutput:
+        if device is not None:
+            logger.warning(f"`{self!r}` does not support the `device` argument, ignoring.")
+        if dtype is None:
+            return self
+
+        obj = copy.copy(self)
+        obj._transport_matrix = obj.transport_matrix.astype(dtype)
+        return obj
+
+    @property
+    def cost(self) -> float:  # noqa: D102
+        return self._cost
+
+    @property
+    def converged(self) -> bool:  # noqa: D102
+        return self._converged
+
+    @property
+    def potentials(self) -> Optional[tuple[ArrayLike, ArrayLike]]:  # noqa: D102
+        return None
+
+    @property
+    def is_linear(self) -> bool:  # noqa: D102
+        return self._is_linear
+
+    def _ones(self, n: int) -> ArrayLike:
+        if isinstance(self.transport_matrix, np.ndarray):
+            return np.ones((n,), dtype=self.transport_matrix.dtype)
+
+        import jax.numpy as jnp
+
+        return jnp.ones((n,), dtype=self.transport_matrix.dtype)
+
+
+class BaseNeuralOutput(BaseSolverOutput, abc.ABC):
+    """Base class for output of."""
+
+    @abstractmethod
+    def project_to_transport_matrix(
+        self,
+        source: Optional[ArrayLike] = None,
+        target: Optional[ArrayLike] = None,
+        condition: Optional[ArrayLike] = None,
+        save_transport_matrix: bool = False,
+        batch_size: int = 1024,
+        k: int = 30,
+        length_scale: Optional[float] = None,
+        seed: int = 42,
+    ) -> sp.csr_matrix:
+        """Project transport matrix."""