|
a |
|
b/equivariant_diffusion/dynamics.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import torch.nn.functional as F |
|
|
4 |
from equivariant_diffusion.egnn_new import EGNN, GNN |
|
|
5 |
from equivariant_diffusion.en_diffusion import EnVariationalDiffusion |
|
|
6 |
remove_mean_batch = EnVariationalDiffusion.remove_mean_batch |
|
|
7 |
import numpy as np |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
class EGNNDynamics(nn.Module): |
|
|
11 |
def __init__(self, atom_nf, residue_nf, |
|
|
12 |
n_dims, joint_nf=16, hidden_nf=64, device='cpu', |
|
|
13 |
act_fn=torch.nn.SiLU(), n_layers=4, attention=False, |
|
|
14 |
condition_time=True, tanh=False, mode='egnn_dynamics', |
|
|
15 |
norm_constant=0, inv_sublayers=2, sin_embedding=False, |
|
|
16 |
normalization_factor=100, aggregation_method='sum', |
|
|
17 |
update_pocket_coords=True, edge_cutoff_ligand=None, |
|
|
18 |
edge_cutoff_pocket=None, edge_cutoff_interaction=None, |
|
|
19 |
reflection_equivariant=True, edge_embedding_dim=None): |
|
|
20 |
super().__init__() |
|
|
21 |
self.mode = mode |
|
|
22 |
self.edge_cutoff_l = edge_cutoff_ligand |
|
|
23 |
self.edge_cutoff_p = edge_cutoff_pocket |
|
|
24 |
self.edge_cutoff_i = edge_cutoff_interaction |
|
|
25 |
self.edge_nf = edge_embedding_dim |
|
|
26 |
|
|
|
27 |
self.atom_encoder = nn.Sequential( |
|
|
28 |
nn.Linear(atom_nf, 2 * atom_nf), |
|
|
29 |
act_fn, |
|
|
30 |
nn.Linear(2 * atom_nf, joint_nf) |
|
|
31 |
) |
|
|
32 |
|
|
|
33 |
self.atom_decoder = nn.Sequential( |
|
|
34 |
nn.Linear(joint_nf, 2 * atom_nf), |
|
|
35 |
act_fn, |
|
|
36 |
nn.Linear(2 * atom_nf, atom_nf) |
|
|
37 |
) |
|
|
38 |
|
|
|
39 |
self.residue_encoder = nn.Sequential( |
|
|
40 |
nn.Linear(residue_nf, 2 * residue_nf), |
|
|
41 |
act_fn, |
|
|
42 |
nn.Linear(2 * residue_nf, joint_nf) |
|
|
43 |
) |
|
|
44 |
|
|
|
45 |
self.residue_decoder = nn.Sequential( |
|
|
46 |
nn.Linear(joint_nf, 2 * residue_nf), |
|
|
47 |
act_fn, |
|
|
48 |
nn.Linear(2 * residue_nf, residue_nf) |
|
|
49 |
) |
|
|
50 |
|
|
|
51 |
self.edge_embedding = nn.Embedding(3, self.edge_nf) \ |
|
|
52 |
if self.edge_nf is not None else None |
|
|
53 |
self.edge_nf = 0 if self.edge_nf is None else self.edge_nf |
|
|
54 |
|
|
|
55 |
if condition_time: |
|
|
56 |
dynamics_node_nf = joint_nf + 1 |
|
|
57 |
else: |
|
|
58 |
print('Warning: dynamics model is _not_ conditioned on time.') |
|
|
59 |
dynamics_node_nf = joint_nf |
|
|
60 |
|
|
|
61 |
if mode == 'egnn_dynamics': |
|
|
62 |
self.egnn = EGNN( |
|
|
63 |
in_node_nf=dynamics_node_nf, in_edge_nf=self.edge_nf, |
|
|
64 |
hidden_nf=hidden_nf, device=device, act_fn=act_fn, |
|
|
65 |
n_layers=n_layers, attention=attention, tanh=tanh, |
|
|
66 |
norm_constant=norm_constant, |
|
|
67 |
inv_sublayers=inv_sublayers, sin_embedding=sin_embedding, |
|
|
68 |
normalization_factor=normalization_factor, |
|
|
69 |
aggregation_method=aggregation_method, |
|
|
70 |
reflection_equiv=reflection_equivariant |
|
|
71 |
) |
|
|
72 |
self.node_nf = dynamics_node_nf |
|
|
73 |
self.update_pocket_coords = update_pocket_coords |
|
|
74 |
|
|
|
75 |
elif mode == 'gnn_dynamics': |
|
|
76 |
self.gnn = GNN( |
|
|
77 |
in_node_nf=dynamics_node_nf + n_dims, in_edge_nf=self.edge_nf, |
|
|
78 |
hidden_nf=hidden_nf, out_node_nf=n_dims + dynamics_node_nf, |
|
|
79 |
device=device, act_fn=act_fn, n_layers=n_layers, |
|
|
80 |
attention=attention, normalization_factor=normalization_factor, |
|
|
81 |
aggregation_method=aggregation_method) |
|
|
82 |
|
|
|
83 |
self.device = device |
|
|
84 |
self.n_dims = n_dims |
|
|
85 |
self.condition_time = condition_time |
|
|
86 |
|
|
|
87 |
def forward(self, xh_atoms, xh_residues, t, mask_atoms, mask_residues): |
|
|
88 |
|
|
|
89 |
x_atoms = xh_atoms[:, :self.n_dims].clone() |
|
|
90 |
h_atoms = xh_atoms[:, self.n_dims:].clone() |
|
|
91 |
|
|
|
92 |
x_residues = xh_residues[:, :self.n_dims].clone() |
|
|
93 |
h_residues = xh_residues[:, self.n_dims:].clone() |
|
|
94 |
|
|
|
95 |
# embed atom features and residue features in a shared space |
|
|
96 |
h_atoms = self.atom_encoder(h_atoms) |
|
|
97 |
h_residues = self.residue_encoder(h_residues) |
|
|
98 |
|
|
|
99 |
# combine the two node types |
|
|
100 |
x = torch.cat((x_atoms, x_residues), dim=0) |
|
|
101 |
h = torch.cat((h_atoms, h_residues), dim=0) |
|
|
102 |
mask = torch.cat([mask_atoms, mask_residues]) |
|
|
103 |
|
|
|
104 |
if self.condition_time: |
|
|
105 |
if np.prod(t.size()) == 1: |
|
|
106 |
# t is the same for all elements in batch. |
|
|
107 |
h_time = torch.empty_like(h[:, 0:1]).fill_(t.item()) |
|
|
108 |
else: |
|
|
109 |
# t is different over the batch dimension. |
|
|
110 |
h_time = t[mask] |
|
|
111 |
h = torch.cat([h, h_time], dim=1) |
|
|
112 |
|
|
|
113 |
# get edges of a complete graph |
|
|
114 |
edges = self.get_edges(mask_atoms, mask_residues, x_atoms, x_residues) |
|
|
115 |
assert torch.all(mask[edges[0]] == mask[edges[1]]) |
|
|
116 |
|
|
|
117 |
# Get edge types |
|
|
118 |
if self.edge_nf > 0: |
|
|
119 |
# 0: ligand-pocket, 1: ligand-ligand, 2: pocket-pocket |
|
|
120 |
edge_types = torch.zeros(edges.size(1), dtype=int, device=edges.device) |
|
|
121 |
edge_types[(edges[0] < len(mask_atoms)) & (edges[1] < len(mask_atoms))] = 1 |
|
|
122 |
edge_types[(edges[0] >= len(mask_atoms)) & (edges[1] >= len(mask_atoms))] = 2 |
|
|
123 |
|
|
|
124 |
# Learnable embedding |
|
|
125 |
edge_types = self.edge_embedding(edge_types) |
|
|
126 |
else: |
|
|
127 |
edge_types = None |
|
|
128 |
|
|
|
129 |
if self.mode == 'egnn_dynamics': |
|
|
130 |
update_coords_mask = None if self.update_pocket_coords \ |
|
|
131 |
else torch.cat((torch.ones_like(mask_atoms), |
|
|
132 |
torch.zeros_like(mask_residues))).unsqueeze(1) |
|
|
133 |
h_final, x_final = self.egnn(h, x, edges, |
|
|
134 |
update_coords_mask=update_coords_mask, |
|
|
135 |
batch_mask=mask, edge_attr=edge_types) |
|
|
136 |
vel = (x_final - x) |
|
|
137 |
|
|
|
138 |
elif self.mode == 'gnn_dynamics': |
|
|
139 |
xh = torch.cat([x, h], dim=1) |
|
|
140 |
output = self.gnn(xh, edges, node_mask=None, edge_attr=edge_types) |
|
|
141 |
vel = output[:, :3] |
|
|
142 |
h_final = output[:, 3:] |
|
|
143 |
|
|
|
144 |
else: |
|
|
145 |
raise Exception("Wrong mode %s" % self.mode) |
|
|
146 |
|
|
|
147 |
if self.condition_time: |
|
|
148 |
# Slice off last dimension which represented time. |
|
|
149 |
h_final = h_final[:, :-1] |
|
|
150 |
|
|
|
151 |
# decode atom and residue features |
|
|
152 |
h_final_atoms = self.atom_decoder(h_final[:len(mask_atoms)]) |
|
|
153 |
h_final_residues = self.residue_decoder(h_final[len(mask_atoms):]) |
|
|
154 |
|
|
|
155 |
if torch.any(torch.isnan(vel)): |
|
|
156 |
if self.training: |
|
|
157 |
vel[torch.isnan(vel)] = 0.0 |
|
|
158 |
else: |
|
|
159 |
raise ValueError("NaN detected in EGNN output") |
|
|
160 |
|
|
|
161 |
if self.update_pocket_coords: |
|
|
162 |
# in case of unconditional joint distribution, include this as in |
|
|
163 |
# the original code |
|
|
164 |
vel = remove_mean_batch(vel, mask) |
|
|
165 |
|
|
|
166 |
return torch.cat([vel[:len(mask_atoms)], h_final_atoms], dim=-1), \ |
|
|
167 |
torch.cat([vel[len(mask_atoms):], h_final_residues], dim=-1) |
|
|
168 |
|
|
|
169 |
def get_edges(self, batch_mask_ligand, batch_mask_pocket, x_ligand, x_pocket): |
|
|
170 |
adj_ligand = batch_mask_ligand[:, None] == batch_mask_ligand[None, :] |
|
|
171 |
adj_pocket = batch_mask_pocket[:, None] == batch_mask_pocket[None, :] |
|
|
172 |
adj_cross = batch_mask_ligand[:, None] == batch_mask_pocket[None, :] |
|
|
173 |
|
|
|
174 |
if self.edge_cutoff_l is not None: |
|
|
175 |
adj_ligand = adj_ligand & (torch.cdist(x_ligand, x_ligand) <= self.edge_cutoff_l) |
|
|
176 |
|
|
|
177 |
if self.edge_cutoff_p is not None: |
|
|
178 |
adj_pocket = adj_pocket & (torch.cdist(x_pocket, x_pocket) <= self.edge_cutoff_p) |
|
|
179 |
|
|
|
180 |
if self.edge_cutoff_i is not None: |
|
|
181 |
adj_cross = adj_cross & (torch.cdist(x_ligand, x_pocket) <= self.edge_cutoff_i) |
|
|
182 |
|
|
|
183 |
adj = torch.cat((torch.cat((adj_ligand, adj_cross), dim=1), |
|
|
184 |
torch.cat((adj_cross.T, adj_pocket), dim=1)), dim=0) |
|
|
185 |
edges = torch.stack(torch.where(adj), dim=0) |
|
|
186 |
|
|
|
187 |
return edges |