Diff of /mowgli/utils.py [000000] .. [061d85]

Switch to side-by-side view

--- a
+++ b/mowgli/utils.py
@@ -0,0 +1,199 @@
+from typing import Iterable, List
+import torch
+from scipy.spatial.distance import cdist
+import numpy as np
+
+
+def reference_dataset(
+    X, dtype: torch.dtype, device: torch.device, keep_idx: Iterable
+) -> torch.Tensor:
+    """Select features, transpose dataset and convert to Tensor.
+
+    Args:
+        X (array-like): The input data
+        dtype (torch.dtype): The dtype to create
+        device (torch.device): The device to create on
+        keep_idx (Iterable): The variables to keep.
+
+    Returns:
+        torch.Tensor: The reference dataset A.
+    """
+
+    # Keep only the highly variable features.
+    A = X[:, keep_idx].T
+
+    # Check that the dataset is positive.
+    assert (A >= 0).all()
+
+    # If the dataset is sparse, make it dense.
+    try:
+        A = A.todense()
+    except:
+        pass
+
+    # Send the matrix `A` to PyTorch.
+    return torch.from_numpy(A).to(device=device, dtype=dtype).contiguous()
+
+
+def compute_ground_cost(
+    features,
+    cost: str,
+    eps: float,
+    force_recompute: bool,
+    cost_path: str,
+    dtype: torch.dtype,
+    device: torch.device,
+) -> torch.Tensor:
+    """Compute the ground cost (not lazily!)
+
+    Args:
+        features (array-like): A array with the features to compute the cost on.
+        cost (str): The function to compute the cost. Scipy distances are allowed.
+        force_recompute (bool): Recompute even is there is already a cost matrix saved at the provided path.
+        cost_path (str): Where to look for or where to save the cost.
+        dtype (torch.dtype): The dtype for the output.
+        device (torch.device): The device for the ouput.
+
+    Returns:
+        torch.Tensor: The ground cost
+    """
+
+    # Initialize the `recomputed variable`.
+    recomputed = False
+
+    # If we force recomputing, then compute the ground cost.
+    if force_recompute:
+        K = cdist(features, features, metric=cost)
+        recomputed = True
+
+    # If the cost is not yet computed, try to load it or compute it.
+    if not recomputed:
+        try:
+            K = np.load(cost_path)
+        except:
+            if cost == "ones":
+                K = 1 - np.eye(features.shape[0])
+            else:
+                K = cdist(features, features, metric=cost)
+            recomputed = True
+
+    # If we did recompute the cost, save it.
+    if recomputed and cost_path:
+        np.save(cost_path, K)
+
+    K = torch.from_numpy(K).to(device=device, dtype=dtype)
+    K /= eps * K.max()
+
+    # Compute the kernel K.
+    K = torch.exp(-K).to(device=device, dtype=dtype)
+
+    return K
+
+
+def normalize_tensor(X: torch.Tensor) -> torch.Tensor:
+    """Normalize a tensor along columns
+
+    Args:
+        X (torch.Tensor): The tensor to normalize.
+
+    Returns:
+        torch.Tensor: The normalized tensor.
+    """
+    return X / X.sum(0)
+
+
+def entropy(
+    X: torch.Tensor, min_one: bool = False, rescale: bool = False
+) -> torch.Tensor:
+    """Entropy function, :math:`E(X) = \langle X, \log X - 1 \rangle`.
+
+    Args:
+        X (torch.Tensor):
+            The parameter to compute the entropy of.
+        min_one (bool, optional):
+            Whether to inclue the :math:`-1` in the formula. Defaults to False.
+        rescale (bool, optional):
+            Rescale so that the value is between 0 and 1 (when min_one=False).
+
+    Returns:
+        torch.Tensor: The entropy of X.
+    """
+    offset = 1 if min_one else 0
+    scale = X.shape[1] * np.log(X.shape[0]) if rescale else 1
+    return -torch.sum(X * (torch.nan_to_num(X.log()) - offset)) / scale
+
+
+def entropy_dual_loss(Y: torch.Tensor) -> torch.Tensor:
+    """Compute the Legendre dual of the entropy.
+
+    Args:
+        Y (torch.Tensor): The input parameter.
+
+    Returns:
+        torch.Tensor: The loss.
+    """
+    return -torch.logsumexp(Y, dim=0).sum()
+
+
+def ot_dual_loss(
+    A: dict, G: dict, K: dict, eps: float, mod_weights: torch.Tensor, dim=(0, 1)
+) -> torch.Tensor:
+    """Compute the Legendre dual of the entropic OT loss.
+
+    Args:
+        A (dict): The input data.
+        G (dict): The dual variable.
+        K (dict): The kernel.
+        eps (float): The entropic regularization.
+        mod_weights (torch.Tensor): The weights per cell and modality.
+        dim (tuple, optional): How to sum the loss. Defaults to (0, 1).
+
+    Returns:
+        torch.Tensor: The loss
+    """
+
+    log_fG = G / eps
+
+    # Compute the non stabilized product.
+    scale = log_fG.max(0).values
+    prod = torch.log(K @ torch.exp(log_fG - scale)) + scale
+
+    # Compute the dot product with A.
+    return eps * torch.sum(mod_weights * A * prod, dim=dim)
+
+
+def early_stop(history: List, tol: float, nonincreasing: bool = False) -> bool:
+    """Based on a history and a tolerance, whether to stop early or not.
+
+    Args:
+        history (List):
+            The loss history.
+        tol (float):
+            The tolerance before early stopping.
+        nonincreasing (bool, optional):
+            When False, throws an error if the loss goes up. Defaults to False.
+
+    Raises:
+        ValueError: When the loss goes up.
+
+    Returns:
+        bool: Whether to stop early.
+    """
+    # If we have a nan or infinite, die.
+    if len(history) > 0 and not torch.isfinite(history[-1]):
+        raise ValueError("Error: Loss is not finite!")
+
+    # If the history is too short, continue.
+    if len(history) < 3:
+        return False
+
+    # If the next value is worse, stop (not normal!).
+    if nonincreasing and (history[-1] - history[-3]) > tol:
+        return True
+
+    # If the next value is close enough, stop.
+    if abs(history[-1] - history[-2]) < tol:
+        return True
+
+    # Otherwise, keep on going.
+    return False