a b/mowgli/utils.py
1
from typing import Iterable, List
2
import torch
3
from scipy.spatial.distance import cdist
4
import numpy as np
5
6
7
def reference_dataset(
8
    X, dtype: torch.dtype, device: torch.device, keep_idx: Iterable
9
) -> torch.Tensor:
10
    """Select features, transpose dataset and convert to Tensor.
11
12
    Args:
13
        X (array-like): The input data
14
        dtype (torch.dtype): The dtype to create
15
        device (torch.device): The device to create on
16
        keep_idx (Iterable): The variables to keep.
17
18
    Returns:
19
        torch.Tensor: The reference dataset A.
20
    """
21
22
    # Keep only the highly variable features.
23
    A = X[:, keep_idx].T
24
25
    # Check that the dataset is positive.
26
    assert (A >= 0).all()
27
28
    # If the dataset is sparse, make it dense.
29
    try:
30
        A = A.todense()
31
    except:
32
        pass
33
34
    # Send the matrix `A` to PyTorch.
35
    return torch.from_numpy(A).to(device=device, dtype=dtype).contiguous()
36
37
38
def compute_ground_cost(
39
    features,
40
    cost: str,
41
    eps: float,
42
    force_recompute: bool,
43
    cost_path: str,
44
    dtype: torch.dtype,
45
    device: torch.device,
46
) -> torch.Tensor:
47
    """Compute the ground cost (not lazily!)
48
49
    Args:
50
        features (array-like): A array with the features to compute the cost on.
51
        cost (str): The function to compute the cost. Scipy distances are allowed.
52
        force_recompute (bool): Recompute even is there is already a cost matrix saved at the provided path.
53
        cost_path (str): Where to look for or where to save the cost.
54
        dtype (torch.dtype): The dtype for the output.
55
        device (torch.device): The device for the ouput.
56
57
    Returns:
58
        torch.Tensor: The ground cost
59
    """
60
61
    # Initialize the `recomputed variable`.
62
    recomputed = False
63
64
    # If we force recomputing, then compute the ground cost.
65
    if force_recompute:
66
        K = cdist(features, features, metric=cost)
67
        recomputed = True
68
69
    # If the cost is not yet computed, try to load it or compute it.
70
    if not recomputed:
71
        try:
72
            K = np.load(cost_path)
73
        except:
74
            if cost == "ones":
75
                K = 1 - np.eye(features.shape[0])
76
            else:
77
                K = cdist(features, features, metric=cost)
78
            recomputed = True
79
80
    # If we did recompute the cost, save it.
81
    if recomputed and cost_path:
82
        np.save(cost_path, K)
83
84
    K = torch.from_numpy(K).to(device=device, dtype=dtype)
85
    K /= eps * K.max()
86
87
    # Compute the kernel K.
88
    K = torch.exp(-K).to(device=device, dtype=dtype)
89
90
    return K
91
92
93
def normalize_tensor(X: torch.Tensor) -> torch.Tensor:
94
    """Normalize a tensor along columns
95
96
    Args:
97
        X (torch.Tensor): The tensor to normalize.
98
99
    Returns:
100
        torch.Tensor: The normalized tensor.
101
    """
102
    return X / X.sum(0)
103
104
105
def entropy(
106
    X: torch.Tensor, min_one: bool = False, rescale: bool = False
107
) -> torch.Tensor:
108
    """Entropy function, :math:`E(X) = \langle X, \log X - 1 \rangle`.
109
110
    Args:
111
        X (torch.Tensor):
112
            The parameter to compute the entropy of.
113
        min_one (bool, optional):
114
            Whether to inclue the :math:`-1` in the formula. Defaults to False.
115
        rescale (bool, optional):
116
            Rescale so that the value is between 0 and 1 (when min_one=False).
117
118
    Returns:
119
        torch.Tensor: The entropy of X.
120
    """
121
    offset = 1 if min_one else 0
122
    scale = X.shape[1] * np.log(X.shape[0]) if rescale else 1
123
    return -torch.sum(X * (torch.nan_to_num(X.log()) - offset)) / scale
124
125
126
def entropy_dual_loss(Y: torch.Tensor) -> torch.Tensor:
127
    """Compute the Legendre dual of the entropy.
128
129
    Args:
130
        Y (torch.Tensor): The input parameter.
131
132
    Returns:
133
        torch.Tensor: The loss.
134
    """
135
    return -torch.logsumexp(Y, dim=0).sum()
136
137
138
def ot_dual_loss(
139
    A: dict, G: dict, K: dict, eps: float, mod_weights: torch.Tensor, dim=(0, 1)
140
) -> torch.Tensor:
141
    """Compute the Legendre dual of the entropic OT loss.
142
143
    Args:
144
        A (dict): The input data.
145
        G (dict): The dual variable.
146
        K (dict): The kernel.
147
        eps (float): The entropic regularization.
148
        mod_weights (torch.Tensor): The weights per cell and modality.
149
        dim (tuple, optional): How to sum the loss. Defaults to (0, 1).
150
151
    Returns:
152
        torch.Tensor: The loss
153
    """
154
155
    log_fG = G / eps
156
157
    # Compute the non stabilized product.
158
    scale = log_fG.max(0).values
159
    prod = torch.log(K @ torch.exp(log_fG - scale)) + scale
160
161
    # Compute the dot product with A.
162
    return eps * torch.sum(mod_weights * A * prod, dim=dim)
163
164
165
def early_stop(history: List, tol: float, nonincreasing: bool = False) -> bool:
166
    """Based on a history and a tolerance, whether to stop early or not.
167
168
    Args:
169
        history (List):
170
            The loss history.
171
        tol (float):
172
            The tolerance before early stopping.
173
        nonincreasing (bool, optional):
174
            When False, throws an error if the loss goes up. Defaults to False.
175
176
    Raises:
177
        ValueError: When the loss goes up.
178
179
    Returns:
180
        bool: Whether to stop early.
181
    """
182
    # If we have a nan or infinite, die.
183
    if len(history) > 0 and not torch.isfinite(history[-1]):
184
        raise ValueError("Error: Loss is not finite!")
185
186
    # If the history is too short, continue.
187
    if len(history) < 3:
188
        return False
189
190
    # If the next value is worse, stop (not normal!).
191
    if nonincreasing and (history[-1] - history[-3]) > tol:
192
        return True
193
194
    # If the next value is close enough, stop.
195
    if abs(history[-1] - history[-2]) < tol:
196
        return True
197
198
    # Otherwise, keep on going.
199
    return False