a b/equivariant_diffusion/conditional_model.py
1
import math
2
3
import numpy as np
4
import torch
5
import torch.nn.functional as F
6
from torch_scatter import scatter_add, scatter_mean
7
8
import utils
9
from equivariant_diffusion.en_diffusion import EnVariationalDiffusion
10
11
12
class ConditionalDDPM(EnVariationalDiffusion):
13
    """
14
    Conditional Diffusion Module.
15
    """
16
    def __init__(self, *args, **kwargs):
17
        super().__init__(*args, **kwargs)
18
        assert not self.dynamics.update_pocket_coords
19
20
    def kl_prior(self, xh_lig, mask_lig, num_nodes):
21
        """Computes the KL between q(z1 | x) and the prior p(z1) = Normal(0, 1).
22
23
        This is essentially a lot of work for something that is in practice
24
        negligible in the loss. However, you compute it so that you see it when
25
        you've made a mistake in your noise schedule.
26
        """
27
        batch_size = len(num_nodes)
28
29
        # Compute the last alpha value, alpha_T.
30
        ones = torch.ones((batch_size, 1), device=xh_lig.device)
31
        gamma_T = self.gamma(ones)
32
        alpha_T = self.alpha(gamma_T, xh_lig)
33
34
        # Compute means.
35
        mu_T_lig = alpha_T[mask_lig] * xh_lig
36
        mu_T_lig_x, mu_T_lig_h = \
37
            mu_T_lig[:, :self.n_dims], mu_T_lig[:, self.n_dims:]
38
39
        # Compute standard deviations (only batch axis for x-part, inflated for h-part).
40
        sigma_T_x = self.sigma(gamma_T, mu_T_lig_x).squeeze()
41
        sigma_T_h = self.sigma(gamma_T, mu_T_lig_h).squeeze()
42
43
        # Compute KL for h-part.
44
        zeros = torch.zeros_like(mu_T_lig_h)
45
        ones = torch.ones_like(sigma_T_h)
46
        mu_norm2 = self.sum_except_batch((mu_T_lig_h - zeros) ** 2, mask_lig)
47
        kl_distance_h = self.gaussian_KL(mu_norm2, sigma_T_h, ones, d=1)
48
49
        # Compute KL for x-part.
50
        zeros = torch.zeros_like(mu_T_lig_x)
51
        ones = torch.ones_like(sigma_T_x)
52
        mu_norm2 = self.sum_except_batch((mu_T_lig_x - zeros) ** 2, mask_lig)
53
        subspace_d = self.subspace_dimensionality(num_nodes)
54
        kl_distance_x = self.gaussian_KL(mu_norm2, sigma_T_x, ones, subspace_d)
55
56
        return kl_distance_x + kl_distance_h
57
58
    def log_pxh_given_z0_without_constants(self, ligand, z_0_lig, eps_lig,
59
                                           net_out_lig, gamma_0, epsilon=1e-10):
60
61
        # Discrete properties are predicted directly from z_t.
62
        z_h_lig = z_0_lig[:, self.n_dims:]
63
64
        # Take only part over x.
65
        eps_lig_x = eps_lig[:, :self.n_dims]
66
        net_lig_x = net_out_lig[:, :self.n_dims]
67
68
        # Compute sigma_0 and rescale to the integer scale of the data.
69
        sigma_0 = self.sigma(gamma_0, target_tensor=z_0_lig)
70
        sigma_0_cat = sigma_0 * self.norm_values[1]
71
72
        # Computes the error for the distribution
73
        # N(x | 1 / alpha_0 z_0 + sigma_0/alpha_0 eps_0, sigma_0 / alpha_0),
74
        # the weighting in the epsilon parametrization is exactly '1'.
75
        squared_error = (eps_lig_x - net_lig_x) ** 2
76
        if self.vnode_idx is not None:
77
            # coordinates of virtual atoms should not contribute to the error
78
            squared_error[ligand['one_hot'][:, self.vnode_idx].bool(), :self.n_dims] = 0
79
        log_p_x_given_z0_without_constants_ligand = -0.5 * (
80
            self.sum_except_batch(squared_error, ligand['mask'])
81
        )
82
83
        # Compute delta indicator masks.
84
        # un-normalize
85
        ligand_onehot = ligand['one_hot'] * self.norm_values[1] + self.norm_biases[1]
86
87
        estimated_ligand_onehot = z_h_lig * self.norm_values[1] + self.norm_biases[1]
88
89
        # Centered h_cat around 1, since onehot encoded.
90
        centered_ligand_onehot = estimated_ligand_onehot - 1
91
92
        # Compute integrals from 0.5 to 1.5 of the normal distribution
93
        # N(mean=z_h_cat, stdev=sigma_0_cat)
94
        log_ph_cat_proportional_ligand = torch.log(
95
            self.cdf_standard_gaussian((centered_ligand_onehot + 0.5) / sigma_0_cat[ligand['mask']])
96
            - self.cdf_standard_gaussian((centered_ligand_onehot - 0.5) / sigma_0_cat[ligand['mask']])
97
            + epsilon
98
        )
99
100
        # Normalize the distribution over the categories.
101
        log_Z = torch.logsumexp(log_ph_cat_proportional_ligand, dim=1,
102
                                keepdim=True)
103
        log_probabilities_ligand = log_ph_cat_proportional_ligand - log_Z
104
105
        # Select the log_prob of the current category using the onehot
106
        # representation.
107
        log_ph_given_z0_ligand = self.sum_except_batch(
108
            log_probabilities_ligand * ligand_onehot, ligand['mask'])
109
110
        return log_p_x_given_z0_without_constants_ligand, log_ph_given_z0_ligand
111
112
    def sample_p_xh_given_z0(self, z0_lig, xh0_pocket, lig_mask, pocket_mask,
113
                             batch_size, fix_noise=False):
114
        """Samples x ~ p(x|z0)."""
115
        t_zeros = torch.zeros(size=(batch_size, 1), device=z0_lig.device)
116
        gamma_0 = self.gamma(t_zeros)
117
        # Computes sqrt(sigma_0^2 / alpha_0^2)
118
        sigma_x = self.SNR(-0.5 * gamma_0)
119
        net_out_lig, _ = self.dynamics(
120
            z0_lig, xh0_pocket, t_zeros, lig_mask, pocket_mask)
121
122
        # Compute mu for p(zs | zt).
123
        mu_x_lig = self.compute_x_pred(net_out_lig, z0_lig, gamma_0, lig_mask)
124
        xh_lig, xh0_pocket = self.sample_normal_zero_com(
125
            mu_x_lig, xh0_pocket, sigma_x, lig_mask, pocket_mask, fix_noise)
126
127
        x_lig, h_lig = self.unnormalize(
128
            xh_lig[:, :self.n_dims], z0_lig[:, self.n_dims:])
129
        x_pocket, h_pocket = self.unnormalize(
130
            xh0_pocket[:, :self.n_dims], xh0_pocket[:, self.n_dims:])
131
132
        h_lig = F.one_hot(torch.argmax(h_lig, dim=1), self.atom_nf)
133
        # h_pocket = F.one_hot(torch.argmax(h_pocket, dim=1), self.residue_nf)
134
135
        return x_lig, h_lig, x_pocket, h_pocket
136
137
    def sample_normal(self, *args):
138
        raise NotImplementedError("Has been replaced by sample_normal_zero_com()")
139
140
    def sample_normal_zero_com(self, mu_lig, xh0_pocket, sigma, lig_mask,
141
                               pocket_mask, fix_noise=False):
142
        """Samples from a Normal distribution."""
143
        if fix_noise:
144
            # bs = 1 if fix_noise else mu.size(0)
145
            raise NotImplementedError("fix_noise option isn't implemented yet")
146
147
        eps_lig = self.sample_gaussian(
148
            size=(len(lig_mask), self.n_dims + self.atom_nf),
149
            device=lig_mask.device)
150
151
        out_lig = mu_lig + sigma[lig_mask] * eps_lig
152
153
        # project to COM-free subspace
154
        xh_pocket = xh0_pocket.detach().clone()
155
        out_lig[:, :self.n_dims], xh_pocket[:, :self.n_dims] = \
156
            self.remove_mean_batch(out_lig[:, :self.n_dims],
157
                                   xh0_pocket[:, :self.n_dims],
158
                                   lig_mask, pocket_mask)
159
160
        return out_lig, xh_pocket
161
162
    def noised_representation(self, xh_lig, xh0_pocket, lig_mask, pocket_mask,
163
                              gamma_t):
164
        # Compute alpha_t and sigma_t from gamma.
165
        alpha_t = self.alpha(gamma_t, xh_lig)
166
        sigma_t = self.sigma(gamma_t, xh_lig)
167
168
        # Sample zt ~ Normal(alpha_t x, sigma_t)
169
        eps_lig = self.sample_gaussian(
170
            size=(len(lig_mask), self.n_dims + self.atom_nf),
171
            device=lig_mask.device)
172
173
        # Sample z_t given x, h for timestep t, from q(z_t | x, h)
174
        z_t_lig = alpha_t[lig_mask] * xh_lig + sigma_t[lig_mask] * eps_lig
175
176
        # project to COM-free subspace
177
        xh_pocket = xh0_pocket.detach().clone()
178
        z_t_lig[:, :self.n_dims], xh_pocket[:, :self.n_dims] = \
179
            self.remove_mean_batch(z_t_lig[:, :self.n_dims],
180
                                   xh_pocket[:, :self.n_dims],
181
                                   lig_mask, pocket_mask)
182
183
        return z_t_lig, xh_pocket, eps_lig
184
185
    def log_pN(self, N_lig, N_pocket):
186
        """
187
        Prior on the sample size for computing
188
        log p(x,h,N) = log p(x,h|N) + log p(N), where log p(x,h|N) is the
189
        model's output
190
        Args:
191
            N: array of sample sizes
192
        Returns:
193
            log p(N)
194
        """
195
        log_pN = self.size_distribution.log_prob_n1_given_n2(N_lig, N_pocket)
196
        return log_pN
197
198
    def delta_log_px(self, num_nodes):
199
        return -self.subspace_dimensionality(num_nodes) * \
200
               np.log(self.norm_values[0])
201
202
    def forward(self, ligand, pocket, return_info=False):
203
        """
204
        Computes the loss and NLL terms
205
        """
206
        # Normalize data, take into account volume change in x.
207
        ligand, pocket = self.normalize(ligand, pocket)
208
209
        # Likelihood change due to normalization
210
        # if self.vnode_idx is not None:
211
        #     delta_log_px = self.delta_log_px(ligand['size'] - ligand['num_virtual_atoms'] + pocket['size'])
212
        # else:
213
        delta_log_px = self.delta_log_px(ligand['size'])
214
215
        # Sample a timestep t for each example in batch
216
        # At evaluation time, loss_0 will be computed separately to decrease
217
        # variance in the estimator (costs two forward passes)
218
        lowest_t = 0 if self.training else 1
219
        t_int = torch.randint(
220
            lowest_t, self.T + 1, size=(ligand['size'].size(0), 1),
221
            device=ligand['x'].device).float()
222
        s_int = t_int - 1  # previous timestep
223
224
        # Masks: important to compute log p(x | z0).
225
        t_is_zero = (t_int == 0).float()
226
        t_is_not_zero = 1 - t_is_zero
227
228
        # Normalize t to [0, 1]. Note that the negative
229
        # step of s will never be used, since then p(x | z0) is computed.
230
        s = s_int / self.T
231
        t = t_int / self.T
232
233
        # Compute gamma_s and gamma_t via the network.
234
        gamma_s = self.inflate_batch_array(self.gamma(s), ligand['x'])
235
        gamma_t = self.inflate_batch_array(self.gamma(t), ligand['x'])
236
237
        # Concatenate x, and h[categorical].
238
        xh0_lig = torch.cat([ligand['x'], ligand['one_hot']], dim=1)
239
        xh0_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1)
240
241
        # Center the input nodes
242
        xh0_lig[:, :self.n_dims], xh0_pocket[:, :self.n_dims] = \
243
            self.remove_mean_batch(xh0_lig[:, :self.n_dims],
244
                                   xh0_pocket[:, :self.n_dims],
245
                                   ligand['mask'], pocket['mask'])
246
247
        # Find noised representation
248
        z_t_lig, xh_pocket, eps_t_lig = \
249
            self.noised_representation(xh0_lig, xh0_pocket, ligand['mask'],
250
                                       pocket['mask'], gamma_t)
251
252
        # Neural net prediction.
253
        net_out_lig, _ = self.dynamics(
254
            z_t_lig, xh_pocket, t, ligand['mask'], pocket['mask'])
255
256
        # For LJ loss term
257
        # xh_lig_hat does not need to be zero-centered as it is only used for
258
        # computing relative distances
259
        xh_lig_hat = self.xh_given_zt_and_epsilon(z_t_lig, net_out_lig, gamma_t,
260
                                                  ligand['mask'])
261
262
        # Compute the L2 error.
263
        squared_error = (eps_t_lig - net_out_lig) ** 2
264
        if self.vnode_idx is not None:
265
            # coordinates of virtual atoms should not contribute to the error
266
            squared_error[ligand['one_hot'][:, self.vnode_idx].bool(), :self.n_dims] = 0
267
        error_t_lig = self.sum_except_batch(squared_error, ligand['mask'])
268
269
        # Compute weighting with SNR: (1 - SNR(s-t)) for epsilon parametrization
270
        SNR_weight = (1 - self.SNR(gamma_s - gamma_t)).squeeze(1)
271
        assert error_t_lig.size() == SNR_weight.size()
272
273
        # The _constants_ depending on sigma_0 from the
274
        # cross entropy term E_q(z0 | x) [log p(x | z0)].
275
        neg_log_constants = -self.log_constants_p_x_given_z0(
276
            n_nodes=ligand['size'], device=error_t_lig.device)
277
278
        # The KL between q(zT | x) and p(zT) = Normal(0, 1).
279
        # Should be close to zero.
280
        kl_prior = self.kl_prior(xh0_lig, ligand['mask'], ligand['size'])
281
282
        if self.training:
283
            # Computes the L_0 term (even if gamma_t is not actually gamma_0)
284
            # and this will later be selected via masking.
285
            log_p_x_given_z0_without_constants_ligand, log_ph_given_z0 = \
286
                self.log_pxh_given_z0_without_constants(
287
                    ligand, z_t_lig, eps_t_lig, net_out_lig, gamma_t)
288
289
            loss_0_x_ligand = -log_p_x_given_z0_without_constants_ligand * \
290
                              t_is_zero.squeeze()
291
            loss_0_h = -log_ph_given_z0 * t_is_zero.squeeze()
292
293
            # apply t_is_zero mask
294
            error_t_lig = error_t_lig * t_is_not_zero.squeeze()
295
296
        else:
297
            # Compute noise values for t = 0.
298
            t_zeros = torch.zeros_like(s)
299
            gamma_0 = self.inflate_batch_array(self.gamma(t_zeros), ligand['x'])
300
301
            # Sample z_0 given x, h for timestep t, from q(z_t | x, h)
302
            z_0_lig, xh_pocket, eps_0_lig = \
303
                self.noised_representation(xh0_lig, xh0_pocket, ligand['mask'],
304
                                           pocket['mask'], gamma_0)
305
306
            net_out_0_lig, _ = self.dynamics(
307
                z_0_lig, xh_pocket, t_zeros, ligand['mask'], pocket['mask'])
308
309
            log_p_x_given_z0_without_constants_ligand, log_ph_given_z0 = \
310
                self.log_pxh_given_z0_without_constants(
311
                    ligand, z_0_lig, eps_0_lig, net_out_0_lig, gamma_0)
312
            loss_0_x_ligand = -log_p_x_given_z0_without_constants_ligand
313
            loss_0_h = -log_ph_given_z0
314
315
        # sample size prior
316
        log_pN = self.log_pN(ligand['size'], pocket['size'])
317
318
        info = {
319
            'eps_hat_lig_x': scatter_mean(
320
                net_out_lig[:, :self.n_dims].abs().mean(1), ligand['mask'],
321
                dim=0).mean(),
322
            'eps_hat_lig_h': scatter_mean(
323
                net_out_lig[:, self.n_dims:].abs().mean(1), ligand['mask'],
324
                dim=0).mean(),
325
        }
326
        loss_terms = (delta_log_px, error_t_lig, torch.tensor(0.0), SNR_weight,
327
                      loss_0_x_ligand, torch.tensor(0.0), loss_0_h,
328
                      neg_log_constants, kl_prior, log_pN,
329
                      t_int.squeeze(), xh_lig_hat)
330
        return (*loss_terms, info) if return_info else loss_terms
331
    
332
    def partially_noised_ligand(self, ligand, pocket, noising_steps):
333
        """
334
        Partially noises a ligand to be later denoised.
335
        """
336
337
        # Inflate timestep into an array
338
        t_int = torch.ones(size=(ligand['size'].size(0), 1),
339
            device=ligand['x'].device).float() * noising_steps
340
341
        # Normalize t to [0, 1].
342
        t = t_int / self.T
343
344
        # Compute gamma_s and gamma_t via the network.
345
        gamma_t = self.inflate_batch_array(self.gamma(t), ligand['x'])
346
347
        # Concatenate x, and h[categorical].
348
        xh0_lig = torch.cat([ligand['x'], ligand['one_hot']], dim=1)
349
        xh0_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1)
350
351
        # Center the input nodes
352
        xh0_lig[:, :self.n_dims], xh0_pocket[:, :self.n_dims] = \
353
            self.remove_mean_batch(xh0_lig[:, :self.n_dims],
354
                                   xh0_pocket[:, :self.n_dims],
355
                                   ligand['mask'], pocket['mask'])
356
357
        # Find noised representation
358
        z_t_lig, xh_pocket, eps_t_lig = \
359
            self.noised_representation(xh0_lig, xh0_pocket, ligand['mask'],
360
                                       pocket['mask'], gamma_t)
361
            
362
        return z_t_lig, xh_pocket, eps_t_lig
363
364
    def diversify(self, ligand, pocket, noising_steps):
365
        """
366
        Diversifies a set of ligands via noise-denoising
367
        """
368
369
        # Normalize data, take into account volume change in x.
370
        ligand, pocket = self.normalize(ligand, pocket)
371
372
        z_lig, xh_pocket, _ = self.partially_noised_ligand(ligand, pocket, noising_steps)
373
374
        timesteps = self.T
375
        n_samples = len(pocket['size'])
376
        device = pocket['x'].device
377
378
        # xh0_pocket is the original pocket while xh_pocket might be a
379
        # translated version of it
380
        xh0_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1)
381
382
        lig_mask = ligand['mask']
383
384
        self.assert_mean_zero_with_mask(z_lig[:, :self.n_dims], lig_mask)
385
386
        # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
387
388
        for s in reversed(range(0, noising_steps)):
389
            s_array = torch.full((n_samples, 1), fill_value=s,
390
                                 device=z_lig.device)
391
            t_array = s_array + 1
392
            s_array = s_array / timesteps
393
            t_array = t_array / timesteps
394
395
            z_lig, xh_pocket = self.sample_p_zs_given_zt(
396
                s_array, t_array, z_lig.detach(), xh_pocket.detach(), lig_mask, pocket['mask'])
397
398
        # Finally sample p(x, h | z_0).
399
        x_lig, h_lig, x_pocket, h_pocket = self.sample_p_xh_given_z0(
400
            z_lig, xh_pocket, lig_mask, pocket['mask'], n_samples)
401
402
        self.assert_mean_zero_with_mask(x_lig, lig_mask)
403
404
        # Overwrite last frame with the resulting x and h.
405
        out_lig = torch.cat([x_lig, h_lig], dim=1)
406
        out_pocket = torch.cat([x_pocket, h_pocket], dim=1)
407
408
        # remove frame dimension if only the final molecule is returned
409
        return out_lig, out_pocket, lig_mask, pocket['mask']
410
411
412
    def xh_given_zt_and_epsilon(self, z_t, epsilon, gamma_t, batch_mask):
413
        """ Equation (7) in the EDM paper """
414
        alpha_t = self.alpha(gamma_t, z_t)
415
        sigma_t = self.sigma(gamma_t, z_t)
416
        xh = z_t / alpha_t[batch_mask] - epsilon * sigma_t[batch_mask] / \
417
             alpha_t[batch_mask]
418
        return xh
419
420
    def sample_p_zt_given_zs(self, zs_lig, xh0_pocket, ligand_mask, pocket_mask,
421
                             gamma_t, gamma_s, fix_noise=False):
422
        sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = \
423
            self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, zs_lig)
424
425
        mu_lig = alpha_t_given_s[ligand_mask] * zs_lig
426
        zt_lig, xh0_pocket = self.sample_normal_zero_com(
427
            mu_lig, xh0_pocket, sigma_t_given_s, ligand_mask, pocket_mask,
428
            fix_noise)
429
430
        return zt_lig, xh0_pocket
431
432
    def sample_p_zs_given_zt(self, s, t, zt_lig, xh0_pocket, ligand_mask,
433
                             pocket_mask, fix_noise=False):
434
        """Samples from zs ~ p(zs | zt). Only used during sampling."""
435
        gamma_s = self.gamma(s)
436
        gamma_t = self.gamma(t)
437
438
        sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = \
439
            self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, zt_lig)
440
441
        sigma_s = self.sigma(gamma_s, target_tensor=zt_lig)
442
        sigma_t = self.sigma(gamma_t, target_tensor=zt_lig)
443
444
        # Neural net prediction.
445
        eps_t_lig, _ = self.dynamics(
446
            zt_lig, xh0_pocket, t, ligand_mask, pocket_mask)
447
448
        # Compute mu for p(zs | zt).
449
        # Note: mu_{t->s} = 1 / alpha_{t|s} z_t - sigma_{t|s}^2 / sigma_t / alpha_{t|s} epsilon
450
        # follows from the definition of mu_{t->s} and Equ. (7) in the EDM paper
451
        mu_lig = zt_lig / alpha_t_given_s[ligand_mask] - \
452
                 (sigma2_t_given_s / alpha_t_given_s / sigma_t)[ligand_mask] * \
453
                 eps_t_lig
454
455
        # Compute sigma for p(zs | zt).
456
        sigma = sigma_t_given_s * sigma_s / sigma_t
457
458
        # Sample zs given the parameters derived from zt.
459
        zs_lig, xh0_pocket = self.sample_normal_zero_com(
460
            mu_lig, xh0_pocket, sigma, ligand_mask, pocket_mask, fix_noise)
461
462
        self.assert_mean_zero_with_mask(zt_lig[:, :self.n_dims], ligand_mask)
463
464
        return zs_lig, xh0_pocket
465
466
    def sample_combined_position_feature_noise(self, lig_indices, xh0_pocket,
467
                                               pocket_indices):
468
        """
469
        Samples mean-centered normal noise for z_x, and standard normal noise
470
        for z_h.
471
        """
472
        raise NotImplementedError("Use sample_normal_zero_com() instead.")
473
474
    def sample(self, *args):
475
        raise NotImplementedError("Conditional model does not support sampling "
476
                                  "without given pocket.")
477
478
    @torch.no_grad()
479
    def sample_given_pocket(self, pocket, num_nodes_lig, return_frames=1,
480
                            timesteps=None):
481
        """
482
        Draw samples from the generative model. Optionally, return intermediate
483
        states for visualization purposes.
484
        """
485
        timesteps = self.T if timesteps is None else timesteps
486
        assert 0 < return_frames <= timesteps
487
        assert timesteps % return_frames == 0
488
489
        n_samples = len(pocket['size'])
490
        device = pocket['x'].device
491
492
        _, pocket = self.normalize(pocket=pocket)
493
494
        # xh0_pocket is the original pocket while xh_pocket might be a
495
        # translated version of it
496
        xh0_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1)
497
498
        lig_mask = utils.num_nodes_to_batch_mask(
499
            n_samples, num_nodes_lig, device)
500
501
        # Sample from Normal distribution in the pocket center
502
        mu_lig_x = scatter_mean(pocket['x'], pocket['mask'], dim=0)
503
        mu_lig_h = torch.zeros((n_samples, self.atom_nf), device=device)
504
        mu_lig = torch.cat((mu_lig_x, mu_lig_h), dim=1)[lig_mask]
505
        sigma = torch.ones_like(pocket['size']).unsqueeze(1)
506
507
        z_lig, xh_pocket = self.sample_normal_zero_com(
508
            mu_lig, xh0_pocket, sigma, lig_mask, pocket['mask'])
509
510
        self.assert_mean_zero_with_mask(z_lig[:, :self.n_dims], lig_mask)
511
512
        out_lig = torch.zeros((return_frames,) + z_lig.size(),
513
                              device=z_lig.device)
514
        out_pocket = torch.zeros((return_frames,) + xh_pocket.size(),
515
                                 device=device)
516
517
        # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
518
        for s in reversed(range(0, timesteps)):
519
            s_array = torch.full((n_samples, 1), fill_value=s,
520
                                 device=z_lig.device)
521
            t_array = s_array + 1
522
            s_array = s_array / timesteps
523
            t_array = t_array / timesteps
524
525
            z_lig, xh_pocket = self.sample_p_zs_given_zt(
526
                s_array, t_array, z_lig, xh_pocket, lig_mask, pocket['mask'])
527
528
            # save frame
529
            if (s * return_frames) % timesteps == 0:
530
                idx = (s * return_frames) // timesteps
531
                out_lig[idx], out_pocket[idx] = \
532
                    self.unnormalize_z(z_lig, xh_pocket)
533
534
        # Finally sample p(x, h | z_0).
535
        x_lig, h_lig, x_pocket, h_pocket = self.sample_p_xh_given_z0(
536
            z_lig, xh_pocket, lig_mask, pocket['mask'], n_samples)
537
538
        self.assert_mean_zero_with_mask(x_lig, lig_mask)
539
540
        # Correct CoM drift for examples without intermediate states
541
        if return_frames == 1:
542
            max_cog = scatter_add(x_lig, lig_mask, dim=0).abs().max().item()
543
            if max_cog > 5e-2:
544
                print(f'Warning CoG drift with error {max_cog:.3f}. Projecting '
545
                      f'the positions down.')
546
                x_lig, x_pocket = self.remove_mean_batch(
547
                    x_lig, x_pocket, lig_mask, pocket['mask'])
548
549
        # Overwrite last frame with the resulting x and h.
550
        out_lig[0] = torch.cat([x_lig, h_lig], dim=1)
551
        out_pocket[0] = torch.cat([x_pocket, h_pocket], dim=1)
552
553
        # remove frame dimension if only the final molecule is returned
554
        return out_lig.squeeze(0), out_pocket.squeeze(0), lig_mask, \
555
               pocket['mask']
556
557
    @torch.no_grad()
558
    def inpaint(self, ligand, pocket, lig_fixed, resamplings=1, return_frames=1,
559
                timesteps=None, center='ligand'):
560
        """
561
        Draw samples from the generative model while fixing parts of the input.
562
        Optionally, return intermediate states for visualization purposes.
563
        Inspired by Algorithm 1 in:
564
        Lugmayr, Andreas, et al.
565
        "Repaint: Inpainting using denoising diffusion probabilistic models."
566
        Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern
567
        Recognition. 2022.
568
        """
569
        timesteps = self.T if timesteps is None else timesteps
570
        assert 0 < return_frames <= timesteps
571
        assert timesteps % return_frames == 0
572
573
        if len(lig_fixed.size()) == 1:
574
            lig_fixed = lig_fixed.unsqueeze(1)
575
576
        n_samples = len(ligand['size'])
577
        device = pocket['x'].device
578
579
        # Normalize
580
        ligand, pocket = self.normalize(ligand, pocket)
581
582
        # xh0_pocket is the original pocket while xh_pocket might be a
583
        # translated version of it
584
        xh0_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1)
585
        com_pocket_0 = scatter_mean(pocket['x'], pocket['mask'], dim=0)
586
        xh0_ligand = torch.cat([ligand['x'], ligand['one_hot']], dim=1)
587
        xh_ligand = xh0_ligand.clone()
588
589
        # Center initial system, subtract COM of known parts
590
        if center == 'ligand':
591
            mean_known = scatter_mean(ligand['x'][lig_fixed.bool().view(-1)],
592
                                      ligand['mask'][lig_fixed.bool().view(-1)],
593
                                      dim=0)
594
        elif center == 'pocket':
595
            mean_known = scatter_mean(pocket['x'], pocket['mask'], dim=0)
596
        else:
597
            raise NotImplementedError(
598
                f"Centering option {center} not implemented")
599
600
        # Sample from Normal distribution in the ligand center
601
        mu_lig_x = mean_known
602
        mu_lig_h = torch.zeros((n_samples, self.atom_nf), device=device)
603
        mu_lig = torch.cat((mu_lig_x, mu_lig_h), dim=1)[ligand['mask']]
604
        sigma = torch.ones_like(pocket['size']).unsqueeze(1)
605
606
        z_lig, xh_pocket = self.sample_normal_zero_com(
607
            mu_lig, xh0_pocket, sigma, ligand['mask'], pocket['mask'])
608
609
        # Output tensors
610
        out_lig = torch.zeros((return_frames,) + z_lig.size(),
611
                              device=z_lig.device)
612
        out_pocket = torch.zeros((return_frames,) + xh_pocket.size(),
613
                                 device=device)
614
615
        # Iteratively sample with resampling iterations
616
        for s in reversed(range(0, timesteps)):
617
618
            # resampling iterations
619
            for u in range(resamplings):
620
621
                # Denoise one time step: t -> s
622
                s_array = torch.full((n_samples, 1), fill_value=s,
623
                                     device=device)
624
                t_array = s_array + 1
625
                s_array = s_array / timesteps
626
                t_array = t_array / timesteps
627
628
                gamma_t = self.gamma(t_array)
629
                gamma_s = self.gamma(s_array)
630
631
                # sample inpainted part
632
                z_lig_unknown, xh_pocket = self.sample_p_zs_given_zt(
633
                    s_array, t_array, z_lig, xh_pocket, ligand['mask'],
634
                    pocket['mask'])
635
636
                # sample known nodes from the input
637
                com_pocket = scatter_mean(xh_pocket[:, :self.n_dims],
638
                                          pocket['mask'], dim=0)
639
                xh_ligand[:, :self.n_dims] = \
640
                    ligand['x'] + (com_pocket - com_pocket_0)[ligand['mask']]
641
                z_lig_known, xh_pocket, _ = self.noised_representation(
642
                    xh_ligand, xh_pocket, ligand['mask'], pocket['mask'],
643
                    gamma_s)
644
645
                # move center of mass of the noised part to the center of mass
646
                # of the corresponding denoised part before combining them
647
                # -> the resulting system should be COM-free
648
                com_noised = scatter_mean(
649
                    z_lig_known[lig_fixed.bool().view(-1)][:, :self.n_dims],
650
                    ligand['mask'][lig_fixed.bool().view(-1)], dim=0)
651
                com_denoised = scatter_mean(
652
                    z_lig_unknown[lig_fixed.bool().view(-1)][:, :self.n_dims],
653
                    ligand['mask'][lig_fixed.bool().view(-1)], dim=0)
654
                dx = com_denoised - com_noised
655
                z_lig_known[:, :self.n_dims] = z_lig_known[:, :self.n_dims] + dx[ligand['mask']]
656
                xh_pocket[:, :self.n_dims] = xh_pocket[:, :self.n_dims] + dx[pocket['mask']]
657
658
                # combine
659
                z_lig = z_lig_known * lig_fixed + z_lig_unknown * (
660
                            1 - lig_fixed)
661
662
                if u < resamplings - 1:
663
                    # Noise the sample
664
                    z_lig, xh_pocket = self.sample_p_zt_given_zs(
665
                        z_lig, xh_pocket, ligand['mask'], pocket['mask'],
666
                        gamma_t, gamma_s)
667
668
                # save frame at the end of a resampling cycle
669
                if u == resamplings - 1:
670
                    if (s * return_frames) % timesteps == 0:
671
                        idx = (s * return_frames) // timesteps
672
673
                        out_lig[idx], out_pocket[idx] = \
674
                            self.unnormalize_z(z_lig, xh_pocket)
675
676
        # Finally sample p(x, h | z_0).
677
        x_lig, h_lig, x_pocket, h_pocket = self.sample_p_xh_given_z0(
678
            z_lig, xh_pocket, ligand['mask'], pocket['mask'], n_samples)
679
680
        # Overwrite last frame with the resulting x and h.
681
        out_lig[0] = torch.cat([x_lig, h_lig], dim=1)
682
        out_pocket[0] = torch.cat([x_pocket, h_pocket], dim=1)
683
684
        # remove frame dimension if only the final molecule is returned
685
        return out_lig.squeeze(0), out_pocket.squeeze(0), ligand['mask'], \
686
               pocket['mask']
687
688
    @classmethod
689
    def remove_mean_batch(cls, x_lig, x_pocket, lig_indices, pocket_indices):
690
691
        # Just subtract the center of mass of the sampled part
692
        mean = scatter_mean(x_lig, lig_indices, dim=0)
693
694
        x_lig = x_lig - mean[lig_indices]
695
        x_pocket = x_pocket - mean[pocket_indices]
696
        return x_lig, x_pocket
697
698
699
# ------------------------------------------------------------------------------
700
# The same model without subspace-trick
701
# ------------------------------------------------------------------------------
702
class SimpleConditionalDDPM(ConditionalDDPM):
703
    """
704
    Simpler conditional diffusion module without subspace-trick.
705
    - rotational equivariance is guaranteed by construction
706
    - translationally equivariant likelihood is achieved by first mapping
707
      samples to a space where the context is COM-free and evaluating the
708
      likelihood there
709
    - molecule generation is equivariant because we can first sample in the
710
      space where the context is COM-free and translate the whole system back to
711
      the original position of the context later
712
    """
713
    def subspace_dimensionality(self, input_size):
714
        """ Override because we don't use the linear subspace anymore. """
715
        return input_size * self.n_dims
716
717
    @classmethod
718
    def remove_mean_batch(cls, x_lig, x_pocket, lig_indices, pocket_indices):
719
        """ Hacky way of removing the centering steps without changing too much
720
        code. """
721
        return x_lig, x_pocket
722
723
    @staticmethod
724
    def assert_mean_zero_with_mask(x, node_mask, eps=1e-10):
725
        return
726
727
    def forward(self, ligand, pocket, return_info=False):
728
729
        # Subtract pocket center of mass
730
        pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0)
731
        ligand['x'] = ligand['x'] - pocket_com[ligand['mask']]
732
        pocket['x'] = pocket['x'] - pocket_com[pocket['mask']]
733
734
        return super(SimpleConditionalDDPM, self).forward(
735
            ligand, pocket, return_info)
736
737
    @torch.no_grad()
738
    def sample_given_pocket(self, pocket, num_nodes_lig, return_frames=1,
739
                            timesteps=None):
740
741
        # Subtract pocket center of mass
742
        pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0)
743
        pocket['x'] = pocket['x'] - pocket_com[pocket['mask']]
744
745
        return super(SimpleConditionalDDPM, self).sample_given_pocket(
746
            pocket, num_nodes_lig, return_frames, timesteps)