|
a |
|
b/mowgli/models.py |
|
|
1 |
from typing import Callable, List |
|
|
2 |
|
|
|
3 |
import mudata as md |
|
|
4 |
import numpy as np |
|
|
5 |
import torch |
|
|
6 |
import torch.nn.functional as F |
|
|
7 |
from mowgli import utils |
|
|
8 |
from sklearn.decomposition import PCA |
|
|
9 |
from torch import optim |
|
|
10 |
from tqdm import tqdm |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
class MowgliModel: |
|
|
14 |
"""The Mowgli model, which performs integrative NMF with an Optimal Transport loss. |
|
|
15 |
|
|
|
16 |
Args: |
|
|
17 |
latent_dim (int, optional): |
|
|
18 |
The latent dimension of the model. Defaults to 15. |
|
|
19 |
highly_variable (bool, optional): |
|
|
20 |
Whether to use highly variable features. Defaults to True. |
|
|
21 |
For now, only True is supported. |
|
|
22 |
use_mod_weight (bool, optional): |
|
|
23 |
Whether to use a different weight for each modality and each |
|
|
24 |
cell. If `True`, the weights are expected in the `mod_weight` |
|
|
25 |
obs field of each modality. Defaults to False. |
|
|
26 |
h_regularization (float, optional): |
|
|
27 |
The entropy parameter for the dictionary. Defaults to 0.01 for RNA |
|
|
28 |
and ADT and 0.1 for ATAC. If needed, other modalities should be |
|
|
29 |
specified by the user. We advise setting values between 0.001 |
|
|
30 |
(biological signal driven by very few features) and 1.0 (very |
|
|
31 |
diffuse biological signals). |
|
|
32 |
w_regularization (float, optional): |
|
|
33 |
The entropy parameter for the embedding. As with `h_regularization`, |
|
|
34 |
small values mean sparse vectors. Defaults to 1e-3. |
|
|
35 |
eps (float, optional): |
|
|
36 |
The entropy parameter for epsilon transport. Large values |
|
|
37 |
decrease importance of individual genes. Defaults to 5e-2. |
|
|
38 |
cost (str, optional): |
|
|
39 |
The function used to compute an emprical ground cost. All |
|
|
40 |
metrics from Scipy's `cdist` are allowed. Defaults to 'cosine'. |
|
|
41 |
pca_cost (bool, optional): |
|
|
42 |
If True, the emprical ground cost will be computed on PCA |
|
|
43 |
embeddings rather than raw data. Defaults to False. |
|
|
44 |
cost_path (dict, optional): |
|
|
45 |
Will look for an existing cost as a `.npy` file at this |
|
|
46 |
path. If not found, the cost will be computed then saved |
|
|
47 |
there. Defaults to None. |
|
|
48 |
""" |
|
|
49 |
|
|
|
50 |
def __init__( |
|
|
51 |
self, |
|
|
52 |
latent_dim: int = 50, |
|
|
53 |
highly_variable: bool = True, |
|
|
54 |
use_mod_weight: bool = False, |
|
|
55 |
h_regularization: float = { |
|
|
56 |
"rna": 1e-2, |
|
|
57 |
"adt": 1e-2, |
|
|
58 |
"prot": 1e-2, |
|
|
59 |
"atac": 1e-1, |
|
|
60 |
}, |
|
|
61 |
w_regularization: float = 1e-3, |
|
|
62 |
eps: float = 5e-2, |
|
|
63 |
cost: str = "cosine", |
|
|
64 |
pca_cost: bool = False, |
|
|
65 |
cost_path: dict = None, |
|
|
66 |
): |
|
|
67 |
|
|
|
68 |
# Check that the user-defined parameters are valid. |
|
|
69 |
assert latent_dim > 0 |
|
|
70 |
assert w_regularization > 0 |
|
|
71 |
assert eps > 0 |
|
|
72 |
assert highly_variable is True |
|
|
73 |
# TODO: Actually implement the use of highly_variable |
|
|
74 |
|
|
|
75 |
if isinstance(h_regularization, dict): |
|
|
76 |
for mod in h_regularization: |
|
|
77 |
assert h_regularization[mod] > 0 |
|
|
78 |
else: |
|
|
79 |
assert h_regularization > 0 |
|
|
80 |
|
|
|
81 |
# Save arguments as attributes. |
|
|
82 |
self.latent_dim = latent_dim |
|
|
83 |
self.h_regularization = h_regularization |
|
|
84 |
self.w_regularization = w_regularization |
|
|
85 |
self.eps = eps |
|
|
86 |
self.use_mod_weight = use_mod_weight |
|
|
87 |
self.cost = cost |
|
|
88 |
self.cost_path = cost_path |
|
|
89 |
self.pca_cost = pca_cost |
|
|
90 |
|
|
|
91 |
# Create new attributes. |
|
|
92 |
self.mod_weight = {} |
|
|
93 |
|
|
|
94 |
# Initialize the loss and statistics histories. |
|
|
95 |
self.losses_w, self.losses_h, self.losses = [], [], [] |
|
|
96 |
|
|
|
97 |
# Initialize the dictionaries containing matrices for each omics. |
|
|
98 |
self.A, self.H, self.G, self.K = {}, {}, {}, {} |
|
|
99 |
|
|
|
100 |
def init_parameters( |
|
|
101 |
self, |
|
|
102 |
mdata: md.MuData, |
|
|
103 |
dtype: torch.dtype, |
|
|
104 |
device: torch.device, |
|
|
105 |
force_recompute: bool = False, |
|
|
106 |
normalize_rows: bool = False, |
|
|
107 |
) -> None: |
|
|
108 |
"""Initialize parameters based on input data. |
|
|
109 |
|
|
|
110 |
Args: |
|
|
111 |
mdata (md.MuData): |
|
|
112 |
The input MuData object. |
|
|
113 |
dtype (torch.dtype): |
|
|
114 |
The dtype to work with. |
|
|
115 |
device (torch.device): |
|
|
116 |
The device to work on. |
|
|
117 |
force_recompute (bool, optional): |
|
|
118 |
Whether to recompute the ground cost. Defaults to False. |
|
|
119 |
""" |
|
|
120 |
|
|
|
121 |
# Set some attributes. |
|
|
122 |
self.mod = mdata.mod |
|
|
123 |
self.n_mod = mdata.n_mod |
|
|
124 |
self.n_obs = mdata.n_obs |
|
|
125 |
self.n_var = {} |
|
|
126 |
|
|
|
127 |
if not isinstance(self.h_regularization, dict): |
|
|
128 |
self.h_regularization = {mod: self.h_regularization for mod in self.mod} |
|
|
129 |
|
|
|
130 |
# For each modality, |
|
|
131 |
for mod in self.mod: |
|
|
132 |
|
|
|
133 |
# Define the modality weights. |
|
|
134 |
if self.use_mod_weight: |
|
|
135 |
mod_weight = mdata.obs[mod + ":mod_weight"].to_numpy() |
|
|
136 |
mod_weight = torch.Tensor(mod_weight).reshape(1, -1) |
|
|
137 |
mod_weight = mod_weight.to(dtype=dtype, device=device) |
|
|
138 |
self.mod_weight[mod] = mod_weight |
|
|
139 |
else: |
|
|
140 |
self.mod_weight[mod] = torch.ones( |
|
|
141 |
1, self.n_obs, dtype=dtype, device=device |
|
|
142 |
) |
|
|
143 |
|
|
|
144 |
# Select the highly variable features. |
|
|
145 |
keep_idx = mdata[mod].var["highly_variable"].to_numpy() |
|
|
146 |
|
|
|
147 |
# Make the reference dataset. |
|
|
148 |
self.A[mod] = utils.reference_dataset(mdata[mod].X, dtype, device, keep_idx) |
|
|
149 |
self.n_var[mod] = self.A[mod].shape[0] |
|
|
150 |
|
|
|
151 |
# Normalize the reference dataset, and add a small value |
|
|
152 |
# for numerical stability. |
|
|
153 |
self.A[mod] += 1e-6 |
|
|
154 |
if normalize_rows: |
|
|
155 |
mean_row_sum = self.A[mod].sum(1).mean() |
|
|
156 |
self.A[mod] /= self.A[mod].sum(1).reshape(-1, 1) * mean_row_sum |
|
|
157 |
self.A[mod] /= self.A[mod].sum(0) |
|
|
158 |
|
|
|
159 |
# Determine which cost function to use. |
|
|
160 |
cost = self.cost if isinstance(self.cost, str) else self.cost[mod] |
|
|
161 |
try: |
|
|
162 |
cost_path = self.cost_path[mod] |
|
|
163 |
except Exception: |
|
|
164 |
cost_path = None |
|
|
165 |
|
|
|
166 |
# Define the features that the ground cost will be computed on. |
|
|
167 |
features = 1e-6 + self.A[mod].cpu().numpy() |
|
|
168 |
if self.pca_cost: |
|
|
169 |
pca = PCA(n_components=self.latent_dim) |
|
|
170 |
features = pca.fit_transform(features) |
|
|
171 |
|
|
|
172 |
# Compute ground cost, using the specified cost function. |
|
|
173 |
self.K[mod] = utils.compute_ground_cost( |
|
|
174 |
features, cost, self.eps, force_recompute, cost_path, dtype, device |
|
|
175 |
) |
|
|
176 |
|
|
|
177 |
# Initialize the matrices `H`, which should be normalized. |
|
|
178 |
self.H[mod] = torch.rand( |
|
|
179 |
self.n_var[mod], self.latent_dim, device=device, dtype=dtype |
|
|
180 |
) |
|
|
181 |
self.H[mod] = utils.normalize_tensor(self.H[mod]) |
|
|
182 |
|
|
|
183 |
# Initialize the dual variable `G` |
|
|
184 |
self.G[mod] = torch.zeros_like(self.A[mod], requires_grad=True) |
|
|
185 |
|
|
|
186 |
# Initialize the shared factor `W`, which should be normalized. |
|
|
187 |
self.W = torch.rand(self.latent_dim, self.n_obs, device=device, dtype=dtype) |
|
|
188 |
self.W = utils.normalize_tensor(self.W) |
|
|
189 |
|
|
|
190 |
# Clean up. |
|
|
191 |
del keep_idx, features |
|
|
192 |
|
|
|
193 |
def train( |
|
|
194 |
self, |
|
|
195 |
mdata: md.MuData, |
|
|
196 |
max_iter_inner: int = 1_000, |
|
|
197 |
max_iter: int = 100, |
|
|
198 |
device: torch.device = "cpu", |
|
|
199 |
dtype: torch.dtype = torch.double, |
|
|
200 |
lr: float = 1, |
|
|
201 |
optim_name: str = "lbfgs", |
|
|
202 |
tol_inner: float = 1e-12, |
|
|
203 |
tol_outer: float = 1e-4, |
|
|
204 |
normalize_rows: bool = False, |
|
|
205 |
) -> None: |
|
|
206 |
"""Train the Mowgli model on an input MuData object. |
|
|
207 |
|
|
|
208 |
Args: |
|
|
209 |
mdata (md.MuData): |
|
|
210 |
The input MuData object. |
|
|
211 |
max_iter_inner (int, optional): |
|
|
212 |
How many iterations for the inner optimization loop |
|
|
213 |
(optimizing H, or W). Defaults to 1_000. |
|
|
214 |
max_iter (int, optional): |
|
|
215 |
How many interations for the outer optimization loop (how |
|
|
216 |
many successive optimizations of H and W). Defaults to 100. |
|
|
217 |
device (torch.device, optional): |
|
|
218 |
The device to work on. Defaults to 'cpu'. |
|
|
219 |
dtype (torch.dtype, optional): |
|
|
220 |
The dtype to work with. Defaults to torch.double. |
|
|
221 |
lr (float, optional): |
|
|
222 |
The learning rate for the optimizer. The default is set |
|
|
223 |
for LBFGS and should be changed otherwise. Defaults to 1. |
|
|
224 |
optim_name (str, optional): |
|
|
225 |
The optimizer to use (`lbfgs`, `sgd` or `adam`). LBFGS |
|
|
226 |
is advised, but requires more memory. Defaults to "lbfgs". |
|
|
227 |
tol_inner (float, optional): |
|
|
228 |
The tolerance for the inner iterations before early stopping. |
|
|
229 |
Defaults to 1e-12. |
|
|
230 |
tol_outer (float, optional): |
|
|
231 |
The tolerance for the outer iterations before early stopping. |
|
|
232 |
Defaults to 1e-4. |
|
|
233 |
""" |
|
|
234 |
|
|
|
235 |
# First, initialize the different parameters. |
|
|
236 |
self.init_parameters( |
|
|
237 |
mdata, |
|
|
238 |
dtype=dtype, |
|
|
239 |
device=device, |
|
|
240 |
normalize_rows=normalize_rows, |
|
|
241 |
) |
|
|
242 |
|
|
|
243 |
# This is needed to save things in uns if it doesn't exist. |
|
|
244 |
if mdata.uns is None: |
|
|
245 |
mdata.uns = {} |
|
|
246 |
|
|
|
247 |
self.lr = lr |
|
|
248 |
self.optim_name = optim_name |
|
|
249 |
|
|
|
250 |
# Initialize the loss histories. |
|
|
251 |
self.losses_w, self.losses_h, self.losses = [], [], [] |
|
|
252 |
|
|
|
253 |
# Set up the progress bar. |
|
|
254 |
pbar = tqdm(total=2 * max_iter, position=0, leave=True) |
|
|
255 |
|
|
|
256 |
# This is the main loop, with at most `max_iter` iterations. |
|
|
257 |
try: |
|
|
258 |
for _ in range(max_iter): |
|
|
259 |
|
|
|
260 |
# Perform the `W` optimization step. |
|
|
261 |
self.optimize( |
|
|
262 |
loss_fn=self.loss_fn_w, |
|
|
263 |
max_iter=max_iter_inner, |
|
|
264 |
tol=tol_inner, |
|
|
265 |
history=self.losses_h, |
|
|
266 |
pbar=pbar, |
|
|
267 |
device=device, |
|
|
268 |
) |
|
|
269 |
|
|
|
270 |
# Update the shared factor `W`. |
|
|
271 |
htgw = 0 |
|
|
272 |
for mod in self.mod: |
|
|
273 |
htgw += self.H[mod].T @ (self.mod_weight[mod] * self.G[mod]) |
|
|
274 |
coef = np.log(self.latent_dim) / (self.n_mod * self.w_regularization) |
|
|
275 |
|
|
|
276 |
self.W = F.softmin(coef * htgw.detach(), dim=0) |
|
|
277 |
|
|
|
278 |
# Clean up. |
|
|
279 |
del htgw |
|
|
280 |
|
|
|
281 |
# Update the progress bar. |
|
|
282 |
pbar.update(1) |
|
|
283 |
|
|
|
284 |
# Save the total dual loss and statistics. |
|
|
285 |
self.losses.append(self.total_dual_loss().cpu().detach()) |
|
|
286 |
|
|
|
287 |
# Perform the `H` optimization step. |
|
|
288 |
self.optimize( |
|
|
289 |
loss_fn=self.loss_fn_h, |
|
|
290 |
device=device, |
|
|
291 |
max_iter=max_iter_inner, |
|
|
292 |
tol=tol_inner, |
|
|
293 |
history=self.losses_h, |
|
|
294 |
pbar=pbar, |
|
|
295 |
) |
|
|
296 |
|
|
|
297 |
# Update the omic specific factors `H[mod]`. |
|
|
298 |
for mod in self.mod: |
|
|
299 |
coef = self.latent_dim * np.log(self.n_var[mod]) |
|
|
300 |
coef /= self.n_obs * self.h_regularization[mod] |
|
|
301 |
|
|
|
302 |
self.H[mod] = self.mod_weight[mod] * self.G[mod].detach() |
|
|
303 |
self.H[mod] = self.H[mod] @ self.W.T |
|
|
304 |
self.H[mod] = F.softmin(coef * self.H[mod], dim=0) |
|
|
305 |
|
|
|
306 |
# Update the progress bar. |
|
|
307 |
pbar.update(1) |
|
|
308 |
|
|
|
309 |
# Save the total dual loss and statistics. |
|
|
310 |
self.losses.append(self.total_dual_loss().cpu().detach()) |
|
|
311 |
|
|
|
312 |
# Early stopping |
|
|
313 |
if utils.early_stop(self.losses, tol_outer, nonincreasing=True): |
|
|
314 |
break |
|
|
315 |
|
|
|
316 |
except KeyboardInterrupt: |
|
|
317 |
print("Training interrupted.") |
|
|
318 |
|
|
|
319 |
# Add H and W to the MuData object. |
|
|
320 |
for mod in self.mod: |
|
|
321 |
mdata[mod].uns["H_OT"] = self.H[mod].cpu().numpy() |
|
|
322 |
mdata.obsm["W_OT"] = self.W.T.cpu().numpy() |
|
|
323 |
|
|
|
324 |
def build_optimizer( |
|
|
325 |
self, params, lr: float, optim_name: str |
|
|
326 |
) -> torch.optim.Optimizer: |
|
|
327 |
"""Generates the optimizer. The PyTorch LBGS implementation is |
|
|
328 |
parametrized following the discussion in https://discuss.pytorch.org/ |
|
|
329 |
t/unclear-purpose-of-max-iter-kwarg-in-the-lbfgs-optimizer/65695. |
|
|
330 |
|
|
|
331 |
Args: |
|
|
332 |
params (Iterable of Tensors): |
|
|
333 |
The parameters to be optimized. |
|
|
334 |
lr (float): |
|
|
335 |
Learning rate of the optimizer. |
|
|
336 |
optim_name (str): |
|
|
337 |
Name of the optimizer, among `'lbfgs'`, `'sgd'`, `'adam'` |
|
|
338 |
|
|
|
339 |
Returns: |
|
|
340 |
torch.optim.Optimizer: The optimizer. |
|
|
341 |
""" |
|
|
342 |
if optim_name == "lbfgs": |
|
|
343 |
return optim.LBFGS( |
|
|
344 |
params, |
|
|
345 |
lr=lr, |
|
|
346 |
history_size=5, |
|
|
347 |
max_iter=1, |
|
|
348 |
line_search_fn="strong_wolfe", |
|
|
349 |
) |
|
|
350 |
elif optim_name == "sgd": |
|
|
351 |
return optim.SGD(params, lr=lr) |
|
|
352 |
elif optim_name == "adam": |
|
|
353 |
return optim.Adam(params, lr=lr) |
|
|
354 |
|
|
|
355 |
def optimize( |
|
|
356 |
self, |
|
|
357 |
loss_fn: Callable, |
|
|
358 |
max_iter: int, |
|
|
359 |
history: List, |
|
|
360 |
tol: float, |
|
|
361 |
pbar, |
|
|
362 |
device: str, |
|
|
363 |
) -> None: |
|
|
364 |
"""Optimize a given function. |
|
|
365 |
|
|
|
366 |
Args: |
|
|
367 |
loss_fn (Callable): The function to optimize. |
|
|
368 |
max_iter (int): The maximum number of iterations. |
|
|
369 |
history (List): A list to append the losses to. |
|
|
370 |
tol (float): The tolerance before early stopping. |
|
|
371 |
pbar (A tqdm progress bar): The progress bar. |
|
|
372 |
device (str): The device to work on. |
|
|
373 |
""" |
|
|
374 |
|
|
|
375 |
# Build the optimizer. |
|
|
376 |
optimizer = self.build_optimizer( |
|
|
377 |
[self.G[mod] for mod in self.G], lr=self.lr, optim_name=self.optim_name |
|
|
378 |
) |
|
|
379 |
|
|
|
380 |
# This value will be initially be displayed in the progress bar |
|
|
381 |
if len(self.losses) > 0: |
|
|
382 |
total_loss = self.losses[-1].cpu().numpy() |
|
|
383 |
else: |
|
|
384 |
total_loss = "?" |
|
|
385 |
|
|
|
386 |
# This is the main optimization loop. |
|
|
387 |
for i in range(max_iter): |
|
|
388 |
|
|
|
389 |
# Define the closure function required by the optimizer. |
|
|
390 |
def closure(): |
|
|
391 |
optimizer.zero_grad() |
|
|
392 |
loss = loss_fn() |
|
|
393 |
loss.backward() |
|
|
394 |
return loss.detach() |
|
|
395 |
|
|
|
396 |
# Perform an optimization step. |
|
|
397 |
optimizer.step(closure) |
|
|
398 |
|
|
|
399 |
# Every x steps, update the progress bar. |
|
|
400 |
if i % 10 == 0: |
|
|
401 |
|
|
|
402 |
# Add a value to the loss history. |
|
|
403 |
history.append(loss_fn().cpu().detach()) |
|
|
404 |
gpu_mem_alloc = torch.cuda.memory_allocated(device=device) |
|
|
405 |
|
|
|
406 |
# Populate the progress bar. |
|
|
407 |
pbar.set_postfix( |
|
|
408 |
{ |
|
|
409 |
"loss": total_loss, |
|
|
410 |
"loss_inner": history[-1].cpu().numpy(), |
|
|
411 |
"inner_steps": i, |
|
|
412 |
"gpu_memory_allocated": gpu_mem_alloc, |
|
|
413 |
} |
|
|
414 |
) |
|
|
415 |
|
|
|
416 |
# Attempt early stopping. |
|
|
417 |
if utils.early_stop(history, tol): |
|
|
418 |
break |
|
|
419 |
|
|
|
420 |
@torch.no_grad() |
|
|
421 |
def total_dual_loss(self) -> torch.Tensor: |
|
|
422 |
"""Compute the total dual loss. This is only used by the user and for, |
|
|
423 |
early stopping, not by the optimization algorithm. |
|
|
424 |
|
|
|
425 |
Returns: |
|
|
426 |
torch.Tensor: The loss |
|
|
427 |
""" |
|
|
428 |
|
|
|
429 |
# Initialize the loss to zero. |
|
|
430 |
loss = 0 |
|
|
431 |
|
|
|
432 |
# Recover the modalities (omics). |
|
|
433 |
modalities = self.mod |
|
|
434 |
|
|
|
435 |
# For each modality, |
|
|
436 |
for mod in modalities: |
|
|
437 |
|
|
|
438 |
# Add the OT dual loss. |
|
|
439 |
loss -= ( |
|
|
440 |
utils.ot_dual_loss( |
|
|
441 |
self.A[mod], |
|
|
442 |
self.G[mod], |
|
|
443 |
self.K[mod], |
|
|
444 |
self.eps, |
|
|
445 |
self.mod_weight[mod], |
|
|
446 |
) |
|
|
447 |
/ self.n_obs |
|
|
448 |
) |
|
|
449 |
|
|
|
450 |
# Add the Lagrange multiplier term. |
|
|
451 |
lagrange = self.H[mod] @ self.W |
|
|
452 |
lagrange *= self.mod_weight[mod] * self.G[mod] |
|
|
453 |
lagrange = lagrange.sum() |
|
|
454 |
loss += lagrange / self.n_obs |
|
|
455 |
|
|
|
456 |
# Add the `H[mod]` entropy term. |
|
|
457 |
coef = self.h_regularization[mod] / ( |
|
|
458 |
self.latent_dim * np.log(self.n_var[mod]) |
|
|
459 |
) |
|
|
460 |
loss -= coef * utils.entropy(self.H[mod], min_one=True) |
|
|
461 |
|
|
|
462 |
# Add the `W` entropy term. |
|
|
463 |
coef = ( |
|
|
464 |
self.n_mod * self.w_regularization / (self.n_obs * np.log(self.latent_dim)) |
|
|
465 |
) |
|
|
466 |
loss -= coef * utils.entropy(self.W, min_one=True) |
|
|
467 |
|
|
|
468 |
# Return the full loss. |
|
|
469 |
return loss |
|
|
470 |
|
|
|
471 |
def loss_fn_h(self) -> torch.Tensor: |
|
|
472 |
"""Computes the loss for the update of `H`. |
|
|
473 |
|
|
|
474 |
Returns: |
|
|
475 |
torch.Tensor: The loss. |
|
|
476 |
""" |
|
|
477 |
loss_h = 0 |
|
|
478 |
for mod in self.mod: |
|
|
479 |
|
|
|
480 |
# OT dual loss term |
|
|
481 |
loss_h += ( |
|
|
482 |
utils.ot_dual_loss( |
|
|
483 |
self.A[mod], |
|
|
484 |
self.G[mod], |
|
|
485 |
self.K[mod], |
|
|
486 |
self.eps, |
|
|
487 |
self.mod_weight[mod], |
|
|
488 |
) |
|
|
489 |
/ self.n_obs |
|
|
490 |
) |
|
|
491 |
|
|
|
492 |
# Entropy dual loss term |
|
|
493 |
coef = self.h_regularization[mod] / ( |
|
|
494 |
self.latent_dim * np.log(self.n_var[mod]) |
|
|
495 |
) |
|
|
496 |
gwt = self.mod_weight[mod] * self.G[mod] @ self.W.T |
|
|
497 |
gwt /= self.n_obs * coef |
|
|
498 |
loss_h -= coef * utils.entropy_dual_loss(-gwt) |
|
|
499 |
|
|
|
500 |
# Clean up. |
|
|
501 |
del gwt |
|
|
502 |
|
|
|
503 |
# Return the loss. |
|
|
504 |
return loss_h |
|
|
505 |
|
|
|
506 |
def loss_fn_w(self) -> torch.Tensor: |
|
|
507 |
"""Return the loss for the optimization of W |
|
|
508 |
|
|
|
509 |
Returns: |
|
|
510 |
torch.Tensor: The loss |
|
|
511 |
""" |
|
|
512 |
loss_w, htgw = 0, 0 |
|
|
513 |
|
|
|
514 |
for mod in self.mod: |
|
|
515 |
|
|
|
516 |
# For the entropy dual loss term. |
|
|
517 |
htgw += self.H[mod].T @ (self.mod_weight[mod] * self.G[mod]) |
|
|
518 |
|
|
|
519 |
# OT dual loss term. |
|
|
520 |
loss_w += ( |
|
|
521 |
utils.ot_dual_loss( |
|
|
522 |
self.A[mod], |
|
|
523 |
self.G[mod], |
|
|
524 |
self.K[mod], |
|
|
525 |
self.eps, |
|
|
526 |
self.mod_weight[mod], |
|
|
527 |
) |
|
|
528 |
/ self.n_obs |
|
|
529 |
) |
|
|
530 |
|
|
|
531 |
# Entropy dual loss term. |
|
|
532 |
coef = self.n_mod * self.w_regularization |
|
|
533 |
coef /= self.n_obs * np.log(self.latent_dim) |
|
|
534 |
htgw /= coef * self.n_obs |
|
|
535 |
loss_w -= coef * utils.entropy_dual_loss(-htgw) |
|
|
536 |
|
|
|
537 |
# Clean up. |
|
|
538 |
del htgw |
|
|
539 |
|
|
|
540 |
# Return the loss. |
|
|
541 |
return loss_w |