--- a +++ b/mowgli/utils.py @@ -0,0 +1,199 @@ +from typing import Iterable, List +import torch +from scipy.spatial.distance import cdist +import numpy as np + + +def reference_dataset( + X, dtype: torch.dtype, device: torch.device, keep_idx: Iterable +) -> torch.Tensor: + """Select features, transpose dataset and convert to Tensor. + + Args: + X (array-like): The input data + dtype (torch.dtype): The dtype to create + device (torch.device): The device to create on + keep_idx (Iterable): The variables to keep. + + Returns: + torch.Tensor: The reference dataset A. + """ + + # Keep only the highly variable features. + A = X[:, keep_idx].T + + # Check that the dataset is positive. + assert (A >= 0).all() + + # If the dataset is sparse, make it dense. + try: + A = A.todense() + except: + pass + + # Send the matrix `A` to PyTorch. + return torch.from_numpy(A).to(device=device, dtype=dtype).contiguous() + + +def compute_ground_cost( + features, + cost: str, + eps: float, + force_recompute: bool, + cost_path: str, + dtype: torch.dtype, + device: torch.device, +) -> torch.Tensor: + """Compute the ground cost (not lazily!) + + Args: + features (array-like): A array with the features to compute the cost on. + cost (str): The function to compute the cost. Scipy distances are allowed. + force_recompute (bool): Recompute even is there is already a cost matrix saved at the provided path. + cost_path (str): Where to look for or where to save the cost. + dtype (torch.dtype): The dtype for the output. + device (torch.device): The device for the ouput. + + Returns: + torch.Tensor: The ground cost + """ + + # Initialize the `recomputed variable`. + recomputed = False + + # If we force recomputing, then compute the ground cost. + if force_recompute: + K = cdist(features, features, metric=cost) + recomputed = True + + # If the cost is not yet computed, try to load it or compute it. + if not recomputed: + try: + K = np.load(cost_path) + except: + if cost == "ones": + K = 1 - np.eye(features.shape[0]) + else: + K = cdist(features, features, metric=cost) + recomputed = True + + # If we did recompute the cost, save it. + if recomputed and cost_path: + np.save(cost_path, K) + + K = torch.from_numpy(K).to(device=device, dtype=dtype) + K /= eps * K.max() + + # Compute the kernel K. + K = torch.exp(-K).to(device=device, dtype=dtype) + + return K + + +def normalize_tensor(X: torch.Tensor) -> torch.Tensor: + """Normalize a tensor along columns + + Args: + X (torch.Tensor): The tensor to normalize. + + Returns: + torch.Tensor: The normalized tensor. + """ + return X / X.sum(0) + + +def entropy( + X: torch.Tensor, min_one: bool = False, rescale: bool = False +) -> torch.Tensor: + """Entropy function, :math:`E(X) = \langle X, \log X - 1 \rangle`. + + Args: + X (torch.Tensor): + The parameter to compute the entropy of. + min_one (bool, optional): + Whether to inclue the :math:`-1` in the formula. Defaults to False. + rescale (bool, optional): + Rescale so that the value is between 0 and 1 (when min_one=False). + + Returns: + torch.Tensor: The entropy of X. + """ + offset = 1 if min_one else 0 + scale = X.shape[1] * np.log(X.shape[0]) if rescale else 1 + return -torch.sum(X * (torch.nan_to_num(X.log()) - offset)) / scale + + +def entropy_dual_loss(Y: torch.Tensor) -> torch.Tensor: + """Compute the Legendre dual of the entropy. + + Args: + Y (torch.Tensor): The input parameter. + + Returns: + torch.Tensor: The loss. + """ + return -torch.logsumexp(Y, dim=0).sum() + + +def ot_dual_loss( + A: dict, G: dict, K: dict, eps: float, mod_weights: torch.Tensor, dim=(0, 1) +) -> torch.Tensor: + """Compute the Legendre dual of the entropic OT loss. + + Args: + A (dict): The input data. + G (dict): The dual variable. + K (dict): The kernel. + eps (float): The entropic regularization. + mod_weights (torch.Tensor): The weights per cell and modality. + dim (tuple, optional): How to sum the loss. Defaults to (0, 1). + + Returns: + torch.Tensor: The loss + """ + + log_fG = G / eps + + # Compute the non stabilized product. + scale = log_fG.max(0).values + prod = torch.log(K @ torch.exp(log_fG - scale)) + scale + + # Compute the dot product with A. + return eps * torch.sum(mod_weights * A * prod, dim=dim) + + +def early_stop(history: List, tol: float, nonincreasing: bool = False) -> bool: + """Based on a history and a tolerance, whether to stop early or not. + + Args: + history (List): + The loss history. + tol (float): + The tolerance before early stopping. + nonincreasing (bool, optional): + When False, throws an error if the loss goes up. Defaults to False. + + Raises: + ValueError: When the loss goes up. + + Returns: + bool: Whether to stop early. + """ + # If we have a nan or infinite, die. + if len(history) > 0 and not torch.isfinite(history[-1]): + raise ValueError("Error: Loss is not finite!") + + # If the history is too short, continue. + if len(history) < 3: + return False + + # If the next value is worse, stop (not normal!). + if nonincreasing and (history[-1] - history[-3]) > tol: + return True + + # If the next value is close enough, stop. + if abs(history[-1] - history[-2]) < tol: + return True + + # Otherwise, keep on going. + return False