a b/mowgli/models.py
1
from typing import Callable, List
2
3
import mudata as md
4
import numpy as np
5
import torch
6
import torch.nn.functional as F
7
from mowgli import utils
8
from sklearn.decomposition import PCA
9
from torch import optim
10
from tqdm import tqdm
11
12
13
class MowgliModel:
14
    """The Mowgli model, which performs integrative NMF with an Optimal Transport loss.
15
16
    Args:
17
        latent_dim (int, optional):
18
            The latent dimension of the model. Defaults to 15.
19
        highly_variable (bool, optional):
20
            Whether to use highly variable features. Defaults to True.
21
            For now, only True is supported.
22
        use_mod_weight (bool, optional):
23
            Whether to use a different weight for each modality and each
24
            cell. If `True`, the weights are expected in the `mod_weight`
25
            obs field of each modality. Defaults to False.
26
        h_regularization (float, optional):
27
            The entropy parameter for the dictionary. Defaults to 0.01 for RNA
28
            and ADT and 0.1 for ATAC. If needed, other modalities should be
29
            specified by the user. We advise setting values between 0.001
30
            (biological signal driven by very few features) and 1.0 (very
31
            diffuse biological signals).
32
        w_regularization (float, optional):
33
            The entropy parameter for the embedding. As with `h_regularization`,
34
            small values mean sparse vectors. Defaults to 1e-3.
35
        eps (float, optional):
36
            The entropy parameter for epsilon transport. Large values
37
            decrease importance of individual genes. Defaults to 5e-2.
38
        cost (str, optional):
39
            The function used to compute an emprical ground cost. All
40
            metrics from Scipy's `cdist` are allowed. Defaults to 'cosine'.
41
        pca_cost (bool, optional):
42
            If True, the emprical ground cost will be computed on PCA
43
            embeddings rather than raw data. Defaults to False.
44
        cost_path (dict, optional):
45
            Will look for an existing cost as a `.npy` file at this
46
            path. If not found, the cost will be computed then saved
47
            there. Defaults to None.
48
    """
49
50
    def __init__(
51
        self,
52
        latent_dim: int = 50,
53
        highly_variable: bool = True,
54
        use_mod_weight: bool = False,
55
        h_regularization: float = {
56
            "rna": 1e-2,
57
            "adt": 1e-2,
58
            "prot": 1e-2,
59
            "atac": 1e-1,
60
        },
61
        w_regularization: float = 1e-3,
62
        eps: float = 5e-2,
63
        cost: str = "cosine",
64
        pca_cost: bool = False,
65
        cost_path: dict = None,
66
    ):
67
68
        # Check that the user-defined parameters are valid.
69
        assert latent_dim > 0
70
        assert w_regularization > 0
71
        assert eps > 0
72
        assert highly_variable is True
73
        # TODO: Actually implement the use of highly_variable
74
75
        if isinstance(h_regularization, dict):
76
            for mod in h_regularization:
77
                assert h_regularization[mod] > 0
78
        else:
79
            assert h_regularization > 0
80
81
        # Save arguments as attributes.
82
        self.latent_dim = latent_dim
83
        self.h_regularization = h_regularization
84
        self.w_regularization = w_regularization
85
        self.eps = eps
86
        self.use_mod_weight = use_mod_weight
87
        self.cost = cost
88
        self.cost_path = cost_path
89
        self.pca_cost = pca_cost
90
91
        # Create new attributes.
92
        self.mod_weight = {}
93
94
        # Initialize the loss and statistics histories.
95
        self.losses_w, self.losses_h, self.losses = [], [], []
96
97
        # Initialize the dictionaries containing matrices for each omics.
98
        self.A, self.H, self.G, self.K = {}, {}, {}, {}
99
100
    def init_parameters(
101
        self,
102
        mdata: md.MuData,
103
        dtype: torch.dtype,
104
        device: torch.device,
105
        force_recompute: bool = False,
106
        normalize_rows: bool = False,
107
    ) -> None:
108
        """Initialize parameters based on input data.
109
110
        Args:
111
            mdata (md.MuData):
112
                The input MuData object.
113
            dtype (torch.dtype):
114
                The dtype to work with.
115
            device (torch.device):
116
                The device to work on.
117
            force_recompute (bool, optional):
118
                Whether to recompute the ground cost. Defaults to False.
119
        """
120
121
        # Set some attributes.
122
        self.mod = mdata.mod
123
        self.n_mod = mdata.n_mod
124
        self.n_obs = mdata.n_obs
125
        self.n_var = {}
126
127
        if not isinstance(self.h_regularization, dict):
128
            self.h_regularization = {mod: self.h_regularization for mod in self.mod}
129
130
        # For each modality,
131
        for mod in self.mod:
132
133
            # Define the modality weights.
134
            if self.use_mod_weight:
135
                mod_weight = mdata.obs[mod + ":mod_weight"].to_numpy()
136
                mod_weight = torch.Tensor(mod_weight).reshape(1, -1)
137
                mod_weight = mod_weight.to(dtype=dtype, device=device)
138
                self.mod_weight[mod] = mod_weight
139
            else:
140
                self.mod_weight[mod] = torch.ones(
141
                    1, self.n_obs, dtype=dtype, device=device
142
                )
143
144
            # Select the highly variable features.
145
            keep_idx = mdata[mod].var["highly_variable"].to_numpy()
146
147
            # Make the reference dataset.
148
            self.A[mod] = utils.reference_dataset(mdata[mod].X, dtype, device, keep_idx)
149
            self.n_var[mod] = self.A[mod].shape[0]
150
151
            # Normalize the reference dataset, and add a small value
152
            # for numerical stability.
153
            self.A[mod] += 1e-6
154
            if normalize_rows:
155
                mean_row_sum = self.A[mod].sum(1).mean()
156
                self.A[mod] /= self.A[mod].sum(1).reshape(-1, 1) * mean_row_sum
157
            self.A[mod] /= self.A[mod].sum(0)
158
159
            # Determine which cost function to use.
160
            cost = self.cost if isinstance(self.cost, str) else self.cost[mod]
161
            try:
162
                cost_path = self.cost_path[mod]
163
            except Exception:
164
                cost_path = None
165
166
            # Define the features that the ground cost will be computed on.
167
            features = 1e-6 + self.A[mod].cpu().numpy()
168
            if self.pca_cost:
169
                pca = PCA(n_components=self.latent_dim)
170
                features = pca.fit_transform(features)
171
172
            # Compute ground cost, using the specified cost function.
173
            self.K[mod] = utils.compute_ground_cost(
174
                features, cost, self.eps, force_recompute, cost_path, dtype, device
175
            )
176
177
            # Initialize the matrices `H`, which should be normalized.
178
            self.H[mod] = torch.rand(
179
                self.n_var[mod], self.latent_dim, device=device, dtype=dtype
180
            )
181
            self.H[mod] = utils.normalize_tensor(self.H[mod])
182
183
            # Initialize the dual variable `G`
184
            self.G[mod] = torch.zeros_like(self.A[mod], requires_grad=True)
185
186
        # Initialize the shared factor `W`, which should be normalized.
187
        self.W = torch.rand(self.latent_dim, self.n_obs, device=device, dtype=dtype)
188
        self.W = utils.normalize_tensor(self.W)
189
190
        # Clean up.
191
        del keep_idx, features
192
193
    def train(
194
        self,
195
        mdata: md.MuData,
196
        max_iter_inner: int = 1_000,
197
        max_iter: int = 100,
198
        device: torch.device = "cpu",
199
        dtype: torch.dtype = torch.double,
200
        lr: float = 1,
201
        optim_name: str = "lbfgs",
202
        tol_inner: float = 1e-12,
203
        tol_outer: float = 1e-4,
204
        normalize_rows: bool = False,
205
    ) -> None:
206
        """Train the Mowgli model on an input MuData object.
207
208
        Args:
209
            mdata (md.MuData):
210
                The input MuData object.
211
            max_iter_inner (int, optional):
212
                How many iterations for the inner optimization loop
213
                (optimizing H, or W). Defaults to 1_000.
214
            max_iter (int, optional):
215
                How many interations for the outer optimization loop (how
216
                many successive optimizations of H and W). Defaults to 100.
217
            device (torch.device, optional):
218
                The device to work on. Defaults to 'cpu'.
219
            dtype (torch.dtype, optional):
220
                The dtype to work with. Defaults to torch.double.
221
            lr (float, optional):
222
                The learning rate for the optimizer. The default is set
223
                for LBFGS and should be changed otherwise. Defaults to 1.
224
            optim_name (str, optional):
225
                The optimizer to use (`lbfgs`, `sgd` or `adam`). LBFGS
226
                is advised, but requires more memory. Defaults to "lbfgs".
227
            tol_inner (float, optional):
228
                The tolerance for the inner iterations before early stopping.
229
                Defaults to 1e-12.
230
            tol_outer (float, optional):
231
                The tolerance for the outer iterations before early stopping.
232
                Defaults to 1e-4.
233
        """
234
235
        # First, initialize the different parameters.
236
        self.init_parameters(
237
            mdata,
238
            dtype=dtype,
239
            device=device,
240
            normalize_rows=normalize_rows,
241
        )
242
243
        # This is needed to save things in uns if it doesn't exist.
244
        if mdata.uns is None:
245
            mdata.uns = {}
246
247
        self.lr = lr
248
        self.optim_name = optim_name
249
250
        # Initialize the loss histories.
251
        self.losses_w, self.losses_h, self.losses = [], [], []
252
253
        # Set up the progress bar.
254
        pbar = tqdm(total=2 * max_iter, position=0, leave=True)
255
256
        # This is the main loop, with at most `max_iter` iterations.
257
        try:
258
            for _ in range(max_iter):
259
260
                # Perform the `W` optimization step.
261
                self.optimize(
262
                    loss_fn=self.loss_fn_w,
263
                    max_iter=max_iter_inner,
264
                    tol=tol_inner,
265
                    history=self.losses_h,
266
                    pbar=pbar,
267
                    device=device,
268
                )
269
270
                # Update the shared factor `W`.
271
                htgw = 0
272
                for mod in self.mod:
273
                    htgw += self.H[mod].T @ (self.mod_weight[mod] * self.G[mod])
274
                coef = np.log(self.latent_dim) / (self.n_mod * self.w_regularization)
275
276
                self.W = F.softmin(coef * htgw.detach(), dim=0)
277
278
                # Clean up.
279
                del htgw
280
281
                # Update the progress bar.
282
                pbar.update(1)
283
284
                # Save the total dual loss and statistics.
285
                self.losses.append(self.total_dual_loss().cpu().detach())
286
287
                # Perform the `H` optimization step.
288
                self.optimize(
289
                    loss_fn=self.loss_fn_h,
290
                    device=device,
291
                    max_iter=max_iter_inner,
292
                    tol=tol_inner,
293
                    history=self.losses_h,
294
                    pbar=pbar,
295
                )
296
297
                # Update the omic specific factors `H[mod]`.
298
                for mod in self.mod:
299
                    coef = self.latent_dim * np.log(self.n_var[mod])
300
                    coef /= self.n_obs * self.h_regularization[mod]
301
302
                    self.H[mod] = self.mod_weight[mod] * self.G[mod].detach()
303
                    self.H[mod] = self.H[mod] @ self.W.T
304
                    self.H[mod] = F.softmin(coef * self.H[mod], dim=0)
305
306
                # Update the progress bar.
307
                pbar.update(1)
308
309
                # Save the total dual loss and statistics.
310
                self.losses.append(self.total_dual_loss().cpu().detach())
311
312
                # Early stopping
313
                if utils.early_stop(self.losses, tol_outer, nonincreasing=True):
314
                    break
315
316
        except KeyboardInterrupt:
317
            print("Training interrupted.")
318
319
        # Add H and W to the MuData object.
320
        for mod in self.mod:
321
            mdata[mod].uns["H_OT"] = self.H[mod].cpu().numpy()
322
        mdata.obsm["W_OT"] = self.W.T.cpu().numpy()
323
324
    def build_optimizer(
325
        self, params, lr: float, optim_name: str
326
    ) -> torch.optim.Optimizer:
327
        """Generates the optimizer. The PyTorch LBGS implementation is
328
        parametrized following the discussion in https://discuss.pytorch.org/
329
        t/unclear-purpose-of-max-iter-kwarg-in-the-lbfgs-optimizer/65695.
330
331
        Args:
332
            params (Iterable of Tensors):
333
                The parameters to be optimized.
334
            lr (float):
335
                Learning rate of the optimizer.
336
            optim_name (str):
337
                Name of the optimizer, among `'lbfgs'`, `'sgd'`, `'adam'`
338
339
        Returns:
340
            torch.optim.Optimizer: The optimizer.
341
        """
342
        if optim_name == "lbfgs":
343
            return optim.LBFGS(
344
                params,
345
                lr=lr,
346
                history_size=5,
347
                max_iter=1,
348
                line_search_fn="strong_wolfe",
349
            )
350
        elif optim_name == "sgd":
351
            return optim.SGD(params, lr=lr)
352
        elif optim_name == "adam":
353
            return optim.Adam(params, lr=lr)
354
355
    def optimize(
356
        self,
357
        loss_fn: Callable,
358
        max_iter: int,
359
        history: List,
360
        tol: float,
361
        pbar,
362
        device: str,
363
    ) -> None:
364
        """Optimize a given function.
365
366
        Args:
367
            loss_fn (Callable): The function to optimize.
368
            max_iter (int): The maximum number of iterations.
369
            history (List): A list to append the losses to.
370
            tol (float): The tolerance before early stopping.
371
            pbar (A tqdm progress bar): The progress bar.
372
            device (str): The device to work on.
373
        """
374
375
        # Build the optimizer.
376
        optimizer = self.build_optimizer(
377
            [self.G[mod] for mod in self.G], lr=self.lr, optim_name=self.optim_name
378
        )
379
380
        # This value will be initially be displayed in the progress bar
381
        if len(self.losses) > 0:
382
            total_loss = self.losses[-1].cpu().numpy()
383
        else:
384
            total_loss = "?"
385
386
        # This is the main optimization loop.
387
        for i in range(max_iter):
388
389
            # Define the closure function required by the optimizer.
390
            def closure():
391
                optimizer.zero_grad()
392
                loss = loss_fn()
393
                loss.backward()
394
                return loss.detach()
395
396
            # Perform an optimization step.
397
            optimizer.step(closure)
398
399
            # Every x steps, update the progress bar.
400
            if i % 10 == 0:
401
402
                # Add a value to the loss history.
403
                history.append(loss_fn().cpu().detach())
404
                gpu_mem_alloc = torch.cuda.memory_allocated(device=device)
405
406
                # Populate the progress bar.
407
                pbar.set_postfix(
408
                    {
409
                        "loss": total_loss,
410
                        "loss_inner": history[-1].cpu().numpy(),
411
                        "inner_steps": i,
412
                        "gpu_memory_allocated": gpu_mem_alloc,
413
                    }
414
                )
415
416
                # Attempt early stopping.
417
                if utils.early_stop(history, tol):
418
                    break
419
420
    @torch.no_grad()
421
    def total_dual_loss(self) -> torch.Tensor:
422
        """Compute the total dual loss. This is only used by the user and for,
423
        early stopping, not by the optimization algorithm.
424
425
        Returns:
426
            torch.Tensor: The loss
427
        """
428
429
        # Initialize the loss to zero.
430
        loss = 0
431
432
        # Recover the modalities (omics).
433
        modalities = self.mod
434
435
        # For each modality,
436
        for mod in modalities:
437
438
            # Add the OT dual loss.
439
            loss -= (
440
                utils.ot_dual_loss(
441
                    self.A[mod],
442
                    self.G[mod],
443
                    self.K[mod],
444
                    self.eps,
445
                    self.mod_weight[mod],
446
                )
447
                / self.n_obs
448
            )
449
450
            # Add the Lagrange multiplier term.
451
            lagrange = self.H[mod] @ self.W
452
            lagrange *= self.mod_weight[mod] * self.G[mod]
453
            lagrange = lagrange.sum()
454
            loss += lagrange / self.n_obs
455
456
            # Add the `H[mod]` entropy term.
457
            coef = self.h_regularization[mod] / (
458
                self.latent_dim * np.log(self.n_var[mod])
459
            )
460
            loss -= coef * utils.entropy(self.H[mod], min_one=True)
461
462
        # Add the `W` entropy term.
463
        coef = (
464
            self.n_mod * self.w_regularization / (self.n_obs * np.log(self.latent_dim))
465
        )
466
        loss -= coef * utils.entropy(self.W, min_one=True)
467
468
        # Return the full loss.
469
        return loss
470
471
    def loss_fn_h(self) -> torch.Tensor:
472
        """Computes the loss for the update of `H`.
473
474
        Returns:
475
            torch.Tensor: The loss.
476
        """
477
        loss_h = 0
478
        for mod in self.mod:
479
480
            # OT dual loss term
481
            loss_h += (
482
                utils.ot_dual_loss(
483
                    self.A[mod],
484
                    self.G[mod],
485
                    self.K[mod],
486
                    self.eps,
487
                    self.mod_weight[mod],
488
                )
489
                / self.n_obs
490
            )
491
492
            # Entropy dual loss term
493
            coef = self.h_regularization[mod] / (
494
                self.latent_dim * np.log(self.n_var[mod])
495
            )
496
            gwt = self.mod_weight[mod] * self.G[mod] @ self.W.T
497
            gwt /= self.n_obs * coef
498
            loss_h -= coef * utils.entropy_dual_loss(-gwt)
499
500
            # Clean up.
501
            del gwt
502
503
        # Return the loss.
504
        return loss_h
505
506
    def loss_fn_w(self) -> torch.Tensor:
507
        """Return the loss for the optimization of W
508
509
        Returns:
510
            torch.Tensor: The loss
511
        """
512
        loss_w, htgw = 0, 0
513
514
        for mod in self.mod:
515
516
            # For the entropy dual loss term.
517
            htgw += self.H[mod].T @ (self.mod_weight[mod] * self.G[mod])
518
519
            # OT dual loss term.
520
            loss_w += (
521
                utils.ot_dual_loss(
522
                    self.A[mod],
523
                    self.G[mod],
524
                    self.K[mod],
525
                    self.eps,
526
                    self.mod_weight[mod],
527
                )
528
                / self.n_obs
529
            )
530
531
        # Entropy dual loss term.
532
        coef = self.n_mod * self.w_regularization
533
        coef /= self.n_obs * np.log(self.latent_dim)
534
        htgw /= coef * self.n_obs
535
        loss_w -= coef * utils.entropy_dual_loss(-htgw)
536
537
        # Clean up.
538
        del htgw
539
540
        # Return the loss.
541
        return loss_w