import math
import numpy as np
import torch
import torch.nn.functional as F
from torch_scatter import scatter_add, scatter_mean
import utils
from equivariant_diffusion.en_diffusion import EnVariationalDiffusion
class ConditionalDDPM(EnVariationalDiffusion):
"""
Conditional Diffusion Module.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert not self.dynamics.update_pocket_coords
def kl_prior(self, xh_lig, mask_lig, num_nodes):
"""Computes the KL between q(z1 | x) and the prior p(z1) = Normal(0, 1).
This is essentially a lot of work for something that is in practice
negligible in the loss. However, you compute it so that you see it when
you've made a mistake in your noise schedule.
"""
batch_size = len(num_nodes)
# Compute the last alpha value, alpha_T.
ones = torch.ones((batch_size, 1), device=xh_lig.device)
gamma_T = self.gamma(ones)
alpha_T = self.alpha(gamma_T, xh_lig)
# Compute means.
mu_T_lig = alpha_T[mask_lig] * xh_lig
mu_T_lig_x, mu_T_lig_h = \
mu_T_lig[:, :self.n_dims], mu_T_lig[:, self.n_dims:]
# Compute standard deviations (only batch axis for x-part, inflated for h-part).
sigma_T_x = self.sigma(gamma_T, mu_T_lig_x).squeeze()
sigma_T_h = self.sigma(gamma_T, mu_T_lig_h).squeeze()
# Compute KL for h-part.
zeros = torch.zeros_like(mu_T_lig_h)
ones = torch.ones_like(sigma_T_h)
mu_norm2 = self.sum_except_batch((mu_T_lig_h - zeros) ** 2, mask_lig)
kl_distance_h = self.gaussian_KL(mu_norm2, sigma_T_h, ones, d=1)
# Compute KL for x-part.
zeros = torch.zeros_like(mu_T_lig_x)
ones = torch.ones_like(sigma_T_x)
mu_norm2 = self.sum_except_batch((mu_T_lig_x - zeros) ** 2, mask_lig)
subspace_d = self.subspace_dimensionality(num_nodes)
kl_distance_x = self.gaussian_KL(mu_norm2, sigma_T_x, ones, subspace_d)
return kl_distance_x + kl_distance_h
def log_pxh_given_z0_without_constants(self, ligand, z_0_lig, eps_lig,
net_out_lig, gamma_0, epsilon=1e-10):
# Discrete properties are predicted directly from z_t.
z_h_lig = z_0_lig[:, self.n_dims:]
# Take only part over x.
eps_lig_x = eps_lig[:, :self.n_dims]
net_lig_x = net_out_lig[:, :self.n_dims]
# Compute sigma_0 and rescale to the integer scale of the data.
sigma_0 = self.sigma(gamma_0, target_tensor=z_0_lig)
sigma_0_cat = sigma_0 * self.norm_values[1]
# Computes the error for the distribution
# N(x | 1 / alpha_0 z_0 + sigma_0/alpha_0 eps_0, sigma_0 / alpha_0),
# the weighting in the epsilon parametrization is exactly '1'.
squared_error = (eps_lig_x - net_lig_x) ** 2
if self.vnode_idx is not None:
# coordinates of virtual atoms should not contribute to the error
squared_error[ligand['one_hot'][:, self.vnode_idx].bool(), :self.n_dims] = 0
log_p_x_given_z0_without_constants_ligand = -0.5 * (
self.sum_except_batch(squared_error, ligand['mask'])
)
# Compute delta indicator masks.
# un-normalize
ligand_onehot = ligand['one_hot'] * self.norm_values[1] + self.norm_biases[1]
estimated_ligand_onehot = z_h_lig * self.norm_values[1] + self.norm_biases[1]
# Centered h_cat around 1, since onehot encoded.
centered_ligand_onehot = estimated_ligand_onehot - 1
# Compute integrals from 0.5 to 1.5 of the normal distribution
# N(mean=z_h_cat, stdev=sigma_0_cat)
log_ph_cat_proportional_ligand = torch.log(
self.cdf_standard_gaussian((centered_ligand_onehot + 0.5) / sigma_0_cat[ligand['mask']])
- self.cdf_standard_gaussian((centered_ligand_onehot - 0.5) / sigma_0_cat[ligand['mask']])
+ epsilon
)
# Normalize the distribution over the categories.
log_Z = torch.logsumexp(log_ph_cat_proportional_ligand, dim=1,
keepdim=True)
log_probabilities_ligand = log_ph_cat_proportional_ligand - log_Z
# Select the log_prob of the current category using the onehot
# representation.
log_ph_given_z0_ligand = self.sum_except_batch(
log_probabilities_ligand * ligand_onehot, ligand['mask'])
return log_p_x_given_z0_without_constants_ligand, log_ph_given_z0_ligand
def sample_p_xh_given_z0(self, z0_lig, xh0_pocket, lig_mask, pocket_mask,
batch_size, fix_noise=False):
"""Samples x ~ p(x|z0)."""
t_zeros = torch.zeros(size=(batch_size, 1), device=z0_lig.device)
gamma_0 = self.gamma(t_zeros)
# Computes sqrt(sigma_0^2 / alpha_0^2)
sigma_x = self.SNR(-0.5 * gamma_0)
net_out_lig, _ = self.dynamics(
z0_lig, xh0_pocket, t_zeros, lig_mask, pocket_mask)
# Compute mu for p(zs | zt).
mu_x_lig = self.compute_x_pred(net_out_lig, z0_lig, gamma_0, lig_mask)
xh_lig, xh0_pocket = self.sample_normal_zero_com(
mu_x_lig, xh0_pocket, sigma_x, lig_mask, pocket_mask, fix_noise)
x_lig, h_lig = self.unnormalize(
xh_lig[:, :self.n_dims], z0_lig[:, self.n_dims:])
x_pocket, h_pocket = self.unnormalize(
xh0_pocket[:, :self.n_dims], xh0_pocket[:, self.n_dims:])
h_lig = F.one_hot(torch.argmax(h_lig, dim=1), self.atom_nf)
# h_pocket = F.one_hot(torch.argmax(h_pocket, dim=1), self.residue_nf)
return x_lig, h_lig, x_pocket, h_pocket
def sample_normal(self, *args):
raise NotImplementedError("Has been replaced by sample_normal_zero_com()")
def sample_normal_zero_com(self, mu_lig, xh0_pocket, sigma, lig_mask,
pocket_mask, fix_noise=False):
"""Samples from a Normal distribution."""
if fix_noise:
# bs = 1 if fix_noise else mu.size(0)
raise NotImplementedError("fix_noise option isn't implemented yet")
eps_lig = self.sample_gaussian(
size=(len(lig_mask), self.n_dims + self.atom_nf),
device=lig_mask.device)
out_lig = mu_lig + sigma[lig_mask] * eps_lig
# project to COM-free subspace
xh_pocket = xh0_pocket.detach().clone()
out_lig[:, :self.n_dims], xh_pocket[:, :self.n_dims] = \
self.remove_mean_batch(out_lig[:, :self.n_dims],
xh0_pocket[:, :self.n_dims],
lig_mask, pocket_mask)
return out_lig, xh_pocket
def noised_representation(self, xh_lig, xh0_pocket, lig_mask, pocket_mask,
gamma_t):
# Compute alpha_t and sigma_t from gamma.
alpha_t = self.alpha(gamma_t, xh_lig)
sigma_t = self.sigma(gamma_t, xh_lig)
# Sample zt ~ Normal(alpha_t x, sigma_t)
eps_lig = self.sample_gaussian(
size=(len(lig_mask), self.n_dims + self.atom_nf),
device=lig_mask.device)
# Sample z_t given x, h for timestep t, from q(z_t | x, h)
z_t_lig = alpha_t[lig_mask] * xh_lig + sigma_t[lig_mask] * eps_lig
# project to COM-free subspace
xh_pocket = xh0_pocket.detach().clone()
z_t_lig[:, :self.n_dims], xh_pocket[:, :self.n_dims] = \
self.remove_mean_batch(z_t_lig[:, :self.n_dims],
xh_pocket[:, :self.n_dims],
lig_mask, pocket_mask)
return z_t_lig, xh_pocket, eps_lig
def log_pN(self, N_lig, N_pocket):
"""
Prior on the sample size for computing
log p(x,h,N) = log p(x,h|N) + log p(N), where log p(x,h|N) is the
model's output
Args:
N: array of sample sizes
Returns:
log p(N)
"""
log_pN = self.size_distribution.log_prob_n1_given_n2(N_lig, N_pocket)
return log_pN
def delta_log_px(self, num_nodes):
return -self.subspace_dimensionality(num_nodes) * \
np.log(self.norm_values[0])
def forward(self, ligand, pocket, return_info=False):
"""
Computes the loss and NLL terms
"""
# Normalize data, take into account volume change in x.
ligand, pocket = self.normalize(ligand, pocket)
# Likelihood change due to normalization
# if self.vnode_idx is not None:
# delta_log_px = self.delta_log_px(ligand['size'] - ligand['num_virtual_atoms'] + pocket['size'])
# else:
delta_log_px = self.delta_log_px(ligand['size'])
# Sample a timestep t for each example in batch
# At evaluation time, loss_0 will be computed separately to decrease
# variance in the estimator (costs two forward passes)
lowest_t = 0 if self.training else 1
t_int = torch.randint(
lowest_t, self.T + 1, size=(ligand['size'].size(0), 1),
device=ligand['x'].device).float()
s_int = t_int - 1 # previous timestep
# Masks: important to compute log p(x | z0).
t_is_zero = (t_int == 0).float()
t_is_not_zero = 1 - t_is_zero
# Normalize t to [0, 1]. Note that the negative
# step of s will never be used, since then p(x | z0) is computed.
s = s_int / self.T
t = t_int / self.T
# Compute gamma_s and gamma_t via the network.
gamma_s = self.inflate_batch_array(self.gamma(s), ligand['x'])
gamma_t = self.inflate_batch_array(self.gamma(t), ligand['x'])
# Concatenate x, and h[categorical].
xh0_lig = torch.cat([ligand['x'], ligand['one_hot']], dim=1)
xh0_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1)
# Center the input nodes
xh0_lig[:, :self.n_dims], xh0_pocket[:, :self.n_dims] = \
self.remove_mean_batch(xh0_lig[:, :self.n_dims],
xh0_pocket[:, :self.n_dims],
ligand['mask'], pocket['mask'])
# Find noised representation
z_t_lig, xh_pocket, eps_t_lig = \
self.noised_representation(xh0_lig, xh0_pocket, ligand['mask'],
pocket['mask'], gamma_t)
# Neural net prediction.
net_out_lig, _ = self.dynamics(
z_t_lig, xh_pocket, t, ligand['mask'], pocket['mask'])
# For LJ loss term
# xh_lig_hat does not need to be zero-centered as it is only used for
# computing relative distances
xh_lig_hat = self.xh_given_zt_and_epsilon(z_t_lig, net_out_lig, gamma_t,
ligand['mask'])
# Compute the L2 error.
squared_error = (eps_t_lig - net_out_lig) ** 2
if self.vnode_idx is not None:
# coordinates of virtual atoms should not contribute to the error
squared_error[ligand['one_hot'][:, self.vnode_idx].bool(), :self.n_dims] = 0
error_t_lig = self.sum_except_batch(squared_error, ligand['mask'])
# Compute weighting with SNR: (1 - SNR(s-t)) for epsilon parametrization
SNR_weight = (1 - self.SNR(gamma_s - gamma_t)).squeeze(1)
assert error_t_lig.size() == SNR_weight.size()
# The _constants_ depending on sigma_0 from the
# cross entropy term E_q(z0 | x) [log p(x | z0)].
neg_log_constants = -self.log_constants_p_x_given_z0(
n_nodes=ligand['size'], device=error_t_lig.device)
# The KL between q(zT | x) and p(zT) = Normal(0, 1).
# Should be close to zero.
kl_prior = self.kl_prior(xh0_lig, ligand['mask'], ligand['size'])
if self.training:
# Computes the L_0 term (even if gamma_t is not actually gamma_0)
# and this will later be selected via masking.
log_p_x_given_z0_without_constants_ligand, log_ph_given_z0 = \
self.log_pxh_given_z0_without_constants(
ligand, z_t_lig, eps_t_lig, net_out_lig, gamma_t)
loss_0_x_ligand = -log_p_x_given_z0_without_constants_ligand * \
t_is_zero.squeeze()
loss_0_h = -log_ph_given_z0 * t_is_zero.squeeze()
# apply t_is_zero mask
error_t_lig = error_t_lig * t_is_not_zero.squeeze()
else:
# Compute noise values for t = 0.
t_zeros = torch.zeros_like(s)
gamma_0 = self.inflate_batch_array(self.gamma(t_zeros), ligand['x'])
# Sample z_0 given x, h for timestep t, from q(z_t | x, h)
z_0_lig, xh_pocket, eps_0_lig = \
self.noised_representation(xh0_lig, xh0_pocket, ligand['mask'],
pocket['mask'], gamma_0)
net_out_0_lig, _ = self.dynamics(
z_0_lig, xh_pocket, t_zeros, ligand['mask'], pocket['mask'])
log_p_x_given_z0_without_constants_ligand, log_ph_given_z0 = \
self.log_pxh_given_z0_without_constants(
ligand, z_0_lig, eps_0_lig, net_out_0_lig, gamma_0)
loss_0_x_ligand = -log_p_x_given_z0_without_constants_ligand
loss_0_h = -log_ph_given_z0
# sample size prior
log_pN = self.log_pN(ligand['size'], pocket['size'])
info = {
'eps_hat_lig_x': scatter_mean(
net_out_lig[:, :self.n_dims].abs().mean(1), ligand['mask'],
dim=0).mean(),
'eps_hat_lig_h': scatter_mean(
net_out_lig[:, self.n_dims:].abs().mean(1), ligand['mask'],
dim=0).mean(),
}
loss_terms = (delta_log_px, error_t_lig, torch.tensor(0.0), SNR_weight,
loss_0_x_ligand, torch.tensor(0.0), loss_0_h,
neg_log_constants, kl_prior, log_pN,
t_int.squeeze(), xh_lig_hat)
return (*loss_terms, info) if return_info else loss_terms
def partially_noised_ligand(self, ligand, pocket, noising_steps):
"""
Partially noises a ligand to be later denoised.
"""
# Inflate timestep into an array
t_int = torch.ones(size=(ligand['size'].size(0), 1),
device=ligand['x'].device).float() * noising_steps
# Normalize t to [0, 1].
t = t_int / self.T
# Compute gamma_s and gamma_t via the network.
gamma_t = self.inflate_batch_array(self.gamma(t), ligand['x'])
# Concatenate x, and h[categorical].
xh0_lig = torch.cat([ligand['x'], ligand['one_hot']], dim=1)
xh0_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1)
# Center the input nodes
xh0_lig[:, :self.n_dims], xh0_pocket[:, :self.n_dims] = \
self.remove_mean_batch(xh0_lig[:, :self.n_dims],
xh0_pocket[:, :self.n_dims],
ligand['mask'], pocket['mask'])
# Find noised representation
z_t_lig, xh_pocket, eps_t_lig = \
self.noised_representation(xh0_lig, xh0_pocket, ligand['mask'],
pocket['mask'], gamma_t)
return z_t_lig, xh_pocket, eps_t_lig
def diversify(self, ligand, pocket, noising_steps):
"""
Diversifies a set of ligands via noise-denoising
"""
# Normalize data, take into account volume change in x.
ligand, pocket = self.normalize(ligand, pocket)
z_lig, xh_pocket, _ = self.partially_noised_ligand(ligand, pocket, noising_steps)
timesteps = self.T
n_samples = len(pocket['size'])
device = pocket['x'].device
# xh0_pocket is the original pocket while xh_pocket might be a
# translated version of it
xh0_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1)
lig_mask = ligand['mask']
self.assert_mean_zero_with_mask(z_lig[:, :self.n_dims], lig_mask)
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
for s in reversed(range(0, noising_steps)):
s_array = torch.full((n_samples, 1), fill_value=s,
device=z_lig.device)
t_array = s_array + 1
s_array = s_array / timesteps
t_array = t_array / timesteps
z_lig, xh_pocket = self.sample_p_zs_given_zt(
s_array, t_array, z_lig.detach(), xh_pocket.detach(), lig_mask, pocket['mask'])
# Finally sample p(x, h | z_0).
x_lig, h_lig, x_pocket, h_pocket = self.sample_p_xh_given_z0(
z_lig, xh_pocket, lig_mask, pocket['mask'], n_samples)
self.assert_mean_zero_with_mask(x_lig, lig_mask)
# Overwrite last frame with the resulting x and h.
out_lig = torch.cat([x_lig, h_lig], dim=1)
out_pocket = torch.cat([x_pocket, h_pocket], dim=1)
# remove frame dimension if only the final molecule is returned
return out_lig, out_pocket, lig_mask, pocket['mask']
def xh_given_zt_and_epsilon(self, z_t, epsilon, gamma_t, batch_mask):
""" Equation (7) in the EDM paper """
alpha_t = self.alpha(gamma_t, z_t)
sigma_t = self.sigma(gamma_t, z_t)
xh = z_t / alpha_t[batch_mask] - epsilon * sigma_t[batch_mask] / \
alpha_t[batch_mask]
return xh
def sample_p_zt_given_zs(self, zs_lig, xh0_pocket, ligand_mask, pocket_mask,
gamma_t, gamma_s, fix_noise=False):
sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = \
self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, zs_lig)
mu_lig = alpha_t_given_s[ligand_mask] * zs_lig
zt_lig, xh0_pocket = self.sample_normal_zero_com(
mu_lig, xh0_pocket, sigma_t_given_s, ligand_mask, pocket_mask,
fix_noise)
return zt_lig, xh0_pocket
def sample_p_zs_given_zt(self, s, t, zt_lig, xh0_pocket, ligand_mask,
pocket_mask, fix_noise=False):
"""Samples from zs ~ p(zs | zt). Only used during sampling."""
gamma_s = self.gamma(s)
gamma_t = self.gamma(t)
sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = \
self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, zt_lig)
sigma_s = self.sigma(gamma_s, target_tensor=zt_lig)
sigma_t = self.sigma(gamma_t, target_tensor=zt_lig)
# Neural net prediction.
eps_t_lig, _ = self.dynamics(
zt_lig, xh0_pocket, t, ligand_mask, pocket_mask)
# Compute mu for p(zs | zt).
# Note: mu_{t->s} = 1 / alpha_{t|s} z_t - sigma_{t|s}^2 / sigma_t / alpha_{t|s} epsilon
# follows from the definition of mu_{t->s} and Equ. (7) in the EDM paper
mu_lig = zt_lig / alpha_t_given_s[ligand_mask] - \
(sigma2_t_given_s / alpha_t_given_s / sigma_t)[ligand_mask] * \
eps_t_lig
# Compute sigma for p(zs | zt).
sigma = sigma_t_given_s * sigma_s / sigma_t
# Sample zs given the parameters derived from zt.
zs_lig, xh0_pocket = self.sample_normal_zero_com(
mu_lig, xh0_pocket, sigma, ligand_mask, pocket_mask, fix_noise)
self.assert_mean_zero_with_mask(zt_lig[:, :self.n_dims], ligand_mask)
return zs_lig, xh0_pocket
def sample_combined_position_feature_noise(self, lig_indices, xh0_pocket,
pocket_indices):
"""
Samples mean-centered normal noise for z_x, and standard normal noise
for z_h.
"""
raise NotImplementedError("Use sample_normal_zero_com() instead.")
def sample(self, *args):
raise NotImplementedError("Conditional model does not support sampling "
"without given pocket.")
@torch.no_grad()
def sample_given_pocket(self, pocket, num_nodes_lig, return_frames=1,
timesteps=None):
"""
Draw samples from the generative model. Optionally, return intermediate
states for visualization purposes.
"""
timesteps = self.T if timesteps is None else timesteps
assert 0 < return_frames <= timesteps
assert timesteps % return_frames == 0
n_samples = len(pocket['size'])
device = pocket['x'].device
_, pocket = self.normalize(pocket=pocket)
# xh0_pocket is the original pocket while xh_pocket might be a
# translated version of it
xh0_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1)
lig_mask = utils.num_nodes_to_batch_mask(
n_samples, num_nodes_lig, device)
# Sample from Normal distribution in the pocket center
mu_lig_x = scatter_mean(pocket['x'], pocket['mask'], dim=0)
mu_lig_h = torch.zeros((n_samples, self.atom_nf), device=device)
mu_lig = torch.cat((mu_lig_x, mu_lig_h), dim=1)[lig_mask]
sigma = torch.ones_like(pocket['size']).unsqueeze(1)
z_lig, xh_pocket = self.sample_normal_zero_com(
mu_lig, xh0_pocket, sigma, lig_mask, pocket['mask'])
self.assert_mean_zero_with_mask(z_lig[:, :self.n_dims], lig_mask)
out_lig = torch.zeros((return_frames,) + z_lig.size(),
device=z_lig.device)
out_pocket = torch.zeros((return_frames,) + xh_pocket.size(),
device=device)
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
for s in reversed(range(0, timesteps)):
s_array = torch.full((n_samples, 1), fill_value=s,
device=z_lig.device)
t_array = s_array + 1
s_array = s_array / timesteps
t_array = t_array / timesteps
z_lig, xh_pocket = self.sample_p_zs_given_zt(
s_array, t_array, z_lig, xh_pocket, lig_mask, pocket['mask'])
# save frame
if (s * return_frames) % timesteps == 0:
idx = (s * return_frames) // timesteps
out_lig[idx], out_pocket[idx] = \
self.unnormalize_z(z_lig, xh_pocket)
# Finally sample p(x, h | z_0).
x_lig, h_lig, x_pocket, h_pocket = self.sample_p_xh_given_z0(
z_lig, xh_pocket, lig_mask, pocket['mask'], n_samples)
self.assert_mean_zero_with_mask(x_lig, lig_mask)
# Correct CoM drift for examples without intermediate states
if return_frames == 1:
max_cog = scatter_add(x_lig, lig_mask, dim=0).abs().max().item()
if max_cog > 5e-2:
print(f'Warning CoG drift with error {max_cog:.3f}. Projecting '
f'the positions down.')
x_lig, x_pocket = self.remove_mean_batch(
x_lig, x_pocket, lig_mask, pocket['mask'])
# Overwrite last frame with the resulting x and h.
out_lig[0] = torch.cat([x_lig, h_lig], dim=1)
out_pocket[0] = torch.cat([x_pocket, h_pocket], dim=1)
# remove frame dimension if only the final molecule is returned
return out_lig.squeeze(0), out_pocket.squeeze(0), lig_mask, \
pocket['mask']
@torch.no_grad()
def inpaint(self, ligand, pocket, lig_fixed, resamplings=1, return_frames=1,
timesteps=None, center='ligand'):
"""
Draw samples from the generative model while fixing parts of the input.
Optionally, return intermediate states for visualization purposes.
Inspired by Algorithm 1 in:
Lugmayr, Andreas, et al.
"Repaint: Inpainting using denoising diffusion probabilistic models."
Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern
Recognition. 2022.
"""
timesteps = self.T if timesteps is None else timesteps
assert 0 < return_frames <= timesteps
assert timesteps % return_frames == 0
if len(lig_fixed.size()) == 1:
lig_fixed = lig_fixed.unsqueeze(1)
n_samples = len(ligand['size'])
device = pocket['x'].device
# Normalize
ligand, pocket = self.normalize(ligand, pocket)
# xh0_pocket is the original pocket while xh_pocket might be a
# translated version of it
xh0_pocket = torch.cat([pocket['x'], pocket['one_hot']], dim=1)
com_pocket_0 = scatter_mean(pocket['x'], pocket['mask'], dim=0)
xh0_ligand = torch.cat([ligand['x'], ligand['one_hot']], dim=1)
xh_ligand = xh0_ligand.clone()
# Center initial system, subtract COM of known parts
if center == 'ligand':
mean_known = scatter_mean(ligand['x'][lig_fixed.bool().view(-1)],
ligand['mask'][lig_fixed.bool().view(-1)],
dim=0)
elif center == 'pocket':
mean_known = scatter_mean(pocket['x'], pocket['mask'], dim=0)
else:
raise NotImplementedError(
f"Centering option {center} not implemented")
# Sample from Normal distribution in the ligand center
mu_lig_x = mean_known
mu_lig_h = torch.zeros((n_samples, self.atom_nf), device=device)
mu_lig = torch.cat((mu_lig_x, mu_lig_h), dim=1)[ligand['mask']]
sigma = torch.ones_like(pocket['size']).unsqueeze(1)
z_lig, xh_pocket = self.sample_normal_zero_com(
mu_lig, xh0_pocket, sigma, ligand['mask'], pocket['mask'])
# Output tensors
out_lig = torch.zeros((return_frames,) + z_lig.size(),
device=z_lig.device)
out_pocket = torch.zeros((return_frames,) + xh_pocket.size(),
device=device)
# Iteratively sample with resampling iterations
for s in reversed(range(0, timesteps)):
# resampling iterations
for u in range(resamplings):
# Denoise one time step: t -> s
s_array = torch.full((n_samples, 1), fill_value=s,
device=device)
t_array = s_array + 1
s_array = s_array / timesteps
t_array = t_array / timesteps
gamma_t = self.gamma(t_array)
gamma_s = self.gamma(s_array)
# sample inpainted part
z_lig_unknown, xh_pocket = self.sample_p_zs_given_zt(
s_array, t_array, z_lig, xh_pocket, ligand['mask'],
pocket['mask'])
# sample known nodes from the input
com_pocket = scatter_mean(xh_pocket[:, :self.n_dims],
pocket['mask'], dim=0)
xh_ligand[:, :self.n_dims] = \
ligand['x'] + (com_pocket - com_pocket_0)[ligand['mask']]
z_lig_known, xh_pocket, _ = self.noised_representation(
xh_ligand, xh_pocket, ligand['mask'], pocket['mask'],
gamma_s)
# move center of mass of the noised part to the center of mass
# of the corresponding denoised part before combining them
# -> the resulting system should be COM-free
com_noised = scatter_mean(
z_lig_known[lig_fixed.bool().view(-1)][:, :self.n_dims],
ligand['mask'][lig_fixed.bool().view(-1)], dim=0)
com_denoised = scatter_mean(
z_lig_unknown[lig_fixed.bool().view(-1)][:, :self.n_dims],
ligand['mask'][lig_fixed.bool().view(-1)], dim=0)
dx = com_denoised - com_noised
z_lig_known[:, :self.n_dims] = z_lig_known[:, :self.n_dims] + dx[ligand['mask']]
xh_pocket[:, :self.n_dims] = xh_pocket[:, :self.n_dims] + dx[pocket['mask']]
# combine
z_lig = z_lig_known * lig_fixed + z_lig_unknown * (
1 - lig_fixed)
if u < resamplings - 1:
# Noise the sample
z_lig, xh_pocket = self.sample_p_zt_given_zs(
z_lig, xh_pocket, ligand['mask'], pocket['mask'],
gamma_t, gamma_s)
# save frame at the end of a resampling cycle
if u == resamplings - 1:
if (s * return_frames) % timesteps == 0:
idx = (s * return_frames) // timesteps
out_lig[idx], out_pocket[idx] = \
self.unnormalize_z(z_lig, xh_pocket)
# Finally sample p(x, h | z_0).
x_lig, h_lig, x_pocket, h_pocket = self.sample_p_xh_given_z0(
z_lig, xh_pocket, ligand['mask'], pocket['mask'], n_samples)
# Overwrite last frame with the resulting x and h.
out_lig[0] = torch.cat([x_lig, h_lig], dim=1)
out_pocket[0] = torch.cat([x_pocket, h_pocket], dim=1)
# remove frame dimension if only the final molecule is returned
return out_lig.squeeze(0), out_pocket.squeeze(0), ligand['mask'], \
pocket['mask']
@classmethod
def remove_mean_batch(cls, x_lig, x_pocket, lig_indices, pocket_indices):
# Just subtract the center of mass of the sampled part
mean = scatter_mean(x_lig, lig_indices, dim=0)
x_lig = x_lig - mean[lig_indices]
x_pocket = x_pocket - mean[pocket_indices]
return x_lig, x_pocket
# ------------------------------------------------------------------------------
# The same model without subspace-trick
# ------------------------------------------------------------------------------
class SimpleConditionalDDPM(ConditionalDDPM):
"""
Simpler conditional diffusion module without subspace-trick.
- rotational equivariance is guaranteed by construction
- translationally equivariant likelihood is achieved by first mapping
samples to a space where the context is COM-free and evaluating the
likelihood there
- molecule generation is equivariant because we can first sample in the
space where the context is COM-free and translate the whole system back to
the original position of the context later
"""
def subspace_dimensionality(self, input_size):
""" Override because we don't use the linear subspace anymore. """
return input_size * self.n_dims
@classmethod
def remove_mean_batch(cls, x_lig, x_pocket, lig_indices, pocket_indices):
""" Hacky way of removing the centering steps without changing too much
code. """
return x_lig, x_pocket
@staticmethod
def assert_mean_zero_with_mask(x, node_mask, eps=1e-10):
return
def forward(self, ligand, pocket, return_info=False):
# Subtract pocket center of mass
pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0)
ligand['x'] = ligand['x'] - pocket_com[ligand['mask']]
pocket['x'] = pocket['x'] - pocket_com[pocket['mask']]
return super(SimpleConditionalDDPM, self).forward(
ligand, pocket, return_info)
@torch.no_grad()
def sample_given_pocket(self, pocket, num_nodes_lig, return_frames=1,
timesteps=None):
# Subtract pocket center of mass
pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0)
pocket['x'] = pocket['x'] - pocket_com[pocket['mask']]
return super(SimpleConditionalDDPM, self).sample_given_pocket(
pocket, num_nodes_lig, return_frames, timesteps)