|
a |
|
b/equivariant_diffusion/egnn_new.py |
|
|
1 |
from torch import nn |
|
|
2 |
import torch |
|
|
3 |
import math |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
class GCL(nn.Module): |
|
|
7 |
def __init__(self, input_nf, output_nf, hidden_nf, normalization_factor, aggregation_method, |
|
|
8 |
edges_in_d=0, nodes_att_dim=0, act_fn=nn.SiLU(), attention=False): |
|
|
9 |
super(GCL, self).__init__() |
|
|
10 |
input_edge = input_nf * 2 |
|
|
11 |
self.normalization_factor = normalization_factor |
|
|
12 |
self.aggregation_method = aggregation_method |
|
|
13 |
self.attention = attention |
|
|
14 |
|
|
|
15 |
self.edge_mlp = nn.Sequential( |
|
|
16 |
nn.Linear(input_edge + edges_in_d, hidden_nf), |
|
|
17 |
act_fn, |
|
|
18 |
nn.Linear(hidden_nf, hidden_nf), |
|
|
19 |
act_fn) |
|
|
20 |
|
|
|
21 |
self.node_mlp = nn.Sequential( |
|
|
22 |
nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf), |
|
|
23 |
act_fn, |
|
|
24 |
nn.Linear(hidden_nf, output_nf)) |
|
|
25 |
|
|
|
26 |
if self.attention: |
|
|
27 |
self.att_mlp = nn.Sequential( |
|
|
28 |
nn.Linear(hidden_nf, 1), |
|
|
29 |
nn.Sigmoid()) |
|
|
30 |
|
|
|
31 |
def edge_model(self, source, target, edge_attr, edge_mask): |
|
|
32 |
if edge_attr is None: # Unused. |
|
|
33 |
out = torch.cat([source, target], dim=1) |
|
|
34 |
else: |
|
|
35 |
out = torch.cat([source, target, edge_attr], dim=1) |
|
|
36 |
mij = self.edge_mlp(out) |
|
|
37 |
|
|
|
38 |
if self.attention: |
|
|
39 |
att_val = self.att_mlp(mij) |
|
|
40 |
out = mij * att_val |
|
|
41 |
else: |
|
|
42 |
out = mij |
|
|
43 |
|
|
|
44 |
if edge_mask is not None: |
|
|
45 |
out = out * edge_mask |
|
|
46 |
return out, mij |
|
|
47 |
|
|
|
48 |
def node_model(self, x, edge_index, edge_attr, node_attr): |
|
|
49 |
row, col = edge_index |
|
|
50 |
agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0), |
|
|
51 |
normalization_factor=self.normalization_factor, |
|
|
52 |
aggregation_method=self.aggregation_method) |
|
|
53 |
if node_attr is not None: |
|
|
54 |
agg = torch.cat([x, agg, node_attr], dim=1) |
|
|
55 |
else: |
|
|
56 |
agg = torch.cat([x, agg], dim=1) |
|
|
57 |
out = x + self.node_mlp(agg) |
|
|
58 |
return out, agg |
|
|
59 |
|
|
|
60 |
def forward(self, h, edge_index, edge_attr=None, node_attr=None, node_mask=None, edge_mask=None): |
|
|
61 |
row, col = edge_index |
|
|
62 |
edge_feat, mij = self.edge_model(h[row], h[col], edge_attr, edge_mask) |
|
|
63 |
h, agg = self.node_model(h, edge_index, edge_feat, node_attr) |
|
|
64 |
if node_mask is not None: |
|
|
65 |
h = h * node_mask |
|
|
66 |
return h, mij |
|
|
67 |
|
|
|
68 |
|
|
|
69 |
class EquivariantUpdate(nn.Module): |
|
|
70 |
def __init__(self, hidden_nf, normalization_factor, aggregation_method, |
|
|
71 |
edges_in_d=1, act_fn=nn.SiLU(), tanh=False, coords_range=10.0, |
|
|
72 |
reflection_equiv=True): |
|
|
73 |
super(EquivariantUpdate, self).__init__() |
|
|
74 |
self.tanh = tanh |
|
|
75 |
self.coords_range = coords_range |
|
|
76 |
self.reflection_equiv = reflection_equiv |
|
|
77 |
input_edge = hidden_nf * 2 + edges_in_d |
|
|
78 |
layer = nn.Linear(hidden_nf, 1, bias=False) |
|
|
79 |
torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) |
|
|
80 |
self.coord_mlp = nn.Sequential( |
|
|
81 |
nn.Linear(input_edge, hidden_nf), |
|
|
82 |
act_fn, |
|
|
83 |
nn.Linear(hidden_nf, hidden_nf), |
|
|
84 |
act_fn, |
|
|
85 |
layer) |
|
|
86 |
self.cross_product_mlp = nn.Sequential( |
|
|
87 |
nn.Linear(input_edge, hidden_nf), |
|
|
88 |
act_fn, |
|
|
89 |
nn.Linear(hidden_nf, hidden_nf), |
|
|
90 |
act_fn, |
|
|
91 |
layer |
|
|
92 |
) if not self.reflection_equiv else None |
|
|
93 |
self.normalization_factor = normalization_factor |
|
|
94 |
self.aggregation_method = aggregation_method |
|
|
95 |
|
|
|
96 |
def coord_model(self, h, coord, edge_index, coord_diff, coord_cross, |
|
|
97 |
edge_attr, edge_mask, update_coords_mask=None): |
|
|
98 |
row, col = edge_index |
|
|
99 |
input_tensor = torch.cat([h[row], h[col], edge_attr], dim=1) |
|
|
100 |
if self.tanh: |
|
|
101 |
trans = coord_diff * torch.tanh(self.coord_mlp(input_tensor)) * self.coords_range |
|
|
102 |
else: |
|
|
103 |
trans = coord_diff * self.coord_mlp(input_tensor) |
|
|
104 |
|
|
|
105 |
if not self.reflection_equiv: |
|
|
106 |
phi_cross = self.cross_product_mlp(input_tensor) |
|
|
107 |
if self.tanh: |
|
|
108 |
phi_cross = torch.tanh(phi_cross) * self.coords_range |
|
|
109 |
trans = trans + coord_cross * phi_cross |
|
|
110 |
|
|
|
111 |
if edge_mask is not None: |
|
|
112 |
trans = trans * edge_mask |
|
|
113 |
|
|
|
114 |
agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0), |
|
|
115 |
normalization_factor=self.normalization_factor, |
|
|
116 |
aggregation_method=self.aggregation_method) |
|
|
117 |
|
|
|
118 |
if update_coords_mask is not None: |
|
|
119 |
agg = update_coords_mask * agg |
|
|
120 |
|
|
|
121 |
coord = coord + agg |
|
|
122 |
return coord |
|
|
123 |
|
|
|
124 |
def forward(self, h, coord, edge_index, coord_diff, coord_cross, |
|
|
125 |
edge_attr=None, node_mask=None, edge_mask=None, |
|
|
126 |
update_coords_mask=None): |
|
|
127 |
coord = self.coord_model(h, coord, edge_index, coord_diff, coord_cross, |
|
|
128 |
edge_attr, edge_mask, |
|
|
129 |
update_coords_mask=update_coords_mask) |
|
|
130 |
if node_mask is not None: |
|
|
131 |
coord = coord * node_mask |
|
|
132 |
return coord |
|
|
133 |
|
|
|
134 |
|
|
|
135 |
class EquivariantBlock(nn.Module): |
|
|
136 |
def __init__(self, hidden_nf, edge_feat_nf=2, device='cpu', act_fn=nn.SiLU(), n_layers=2, attention=True, |
|
|
137 |
norm_diff=True, tanh=False, coords_range=15, norm_constant=1, sin_embedding=None, |
|
|
138 |
normalization_factor=100, aggregation_method='sum', reflection_equiv=True): |
|
|
139 |
super(EquivariantBlock, self).__init__() |
|
|
140 |
self.hidden_nf = hidden_nf |
|
|
141 |
self.device = device |
|
|
142 |
self.n_layers = n_layers |
|
|
143 |
self.coords_range_layer = float(coords_range) |
|
|
144 |
self.norm_diff = norm_diff |
|
|
145 |
self.norm_constant = norm_constant |
|
|
146 |
self.sin_embedding = sin_embedding |
|
|
147 |
self.normalization_factor = normalization_factor |
|
|
148 |
self.aggregation_method = aggregation_method |
|
|
149 |
self.reflection_equiv = reflection_equiv |
|
|
150 |
|
|
|
151 |
for i in range(0, n_layers): |
|
|
152 |
self.add_module("gcl_%d" % i, GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=edge_feat_nf, |
|
|
153 |
act_fn=act_fn, attention=attention, |
|
|
154 |
normalization_factor=self.normalization_factor, |
|
|
155 |
aggregation_method=self.aggregation_method)) |
|
|
156 |
self.add_module("gcl_equiv", EquivariantUpdate(hidden_nf, edges_in_d=edge_feat_nf, act_fn=nn.SiLU(), tanh=tanh, |
|
|
157 |
coords_range=self.coords_range_layer, |
|
|
158 |
normalization_factor=self.normalization_factor, |
|
|
159 |
aggregation_method=self.aggregation_method, |
|
|
160 |
reflection_equiv=self.reflection_equiv)) |
|
|
161 |
self.to(self.device) |
|
|
162 |
|
|
|
163 |
def forward(self, h, x, edge_index, node_mask=None, edge_mask=None, |
|
|
164 |
edge_attr=None, update_coords_mask=None, batch_mask=None): |
|
|
165 |
# Edit Emiel: Remove velocity as input |
|
|
166 |
distances, coord_diff = coord2diff(x, edge_index, self.norm_constant) |
|
|
167 |
if self.reflection_equiv: |
|
|
168 |
coord_cross = None |
|
|
169 |
else: |
|
|
170 |
coord_cross = coord2cross(x, edge_index, batch_mask, |
|
|
171 |
self.norm_constant) |
|
|
172 |
if self.sin_embedding is not None: |
|
|
173 |
distances = self.sin_embedding(distances) |
|
|
174 |
edge_attr = torch.cat([distances, edge_attr], dim=1) |
|
|
175 |
for i in range(0, self.n_layers): |
|
|
176 |
h, _ = self._modules["gcl_%d" % i](h, edge_index, edge_attr=edge_attr, |
|
|
177 |
node_mask=node_mask, edge_mask=edge_mask) |
|
|
178 |
x = self._modules["gcl_equiv"](h, x, edge_index, coord_diff, coord_cross, edge_attr, |
|
|
179 |
node_mask, edge_mask, update_coords_mask=update_coords_mask) |
|
|
180 |
|
|
|
181 |
# Important, the bias of the last linear might be non-zero |
|
|
182 |
if node_mask is not None: |
|
|
183 |
h = h * node_mask |
|
|
184 |
return h, x |
|
|
185 |
|
|
|
186 |
|
|
|
187 |
class EGNN(nn.Module): |
|
|
188 |
def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=3, attention=False, |
|
|
189 |
norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, norm_constant=1, inv_sublayers=2, |
|
|
190 |
sin_embedding=False, normalization_factor=100, aggregation_method='sum', reflection_equiv=True): |
|
|
191 |
super(EGNN, self).__init__() |
|
|
192 |
if out_node_nf is None: |
|
|
193 |
out_node_nf = in_node_nf |
|
|
194 |
self.hidden_nf = hidden_nf |
|
|
195 |
self.device = device |
|
|
196 |
self.n_layers = n_layers |
|
|
197 |
self.coords_range_layer = float(coords_range/n_layers) |
|
|
198 |
self.norm_diff = norm_diff |
|
|
199 |
self.normalization_factor = normalization_factor |
|
|
200 |
self.aggregation_method = aggregation_method |
|
|
201 |
self.reflection_equiv = reflection_equiv |
|
|
202 |
|
|
|
203 |
if sin_embedding: |
|
|
204 |
self.sin_embedding = SinusoidsEmbeddingNew() |
|
|
205 |
edge_feat_nf = self.sin_embedding.dim * 2 |
|
|
206 |
else: |
|
|
207 |
self.sin_embedding = None |
|
|
208 |
edge_feat_nf = 2 |
|
|
209 |
|
|
|
210 |
edge_feat_nf = edge_feat_nf + in_edge_nf |
|
|
211 |
|
|
|
212 |
self.embedding = nn.Linear(in_node_nf, self.hidden_nf) |
|
|
213 |
self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf) |
|
|
214 |
for i in range(0, n_layers): |
|
|
215 |
self.add_module("e_block_%d" % i, EquivariantBlock(hidden_nf, edge_feat_nf=edge_feat_nf, device=device, |
|
|
216 |
act_fn=act_fn, n_layers=inv_sublayers, |
|
|
217 |
attention=attention, norm_diff=norm_diff, tanh=tanh, |
|
|
218 |
coords_range=coords_range, norm_constant=norm_constant, |
|
|
219 |
sin_embedding=self.sin_embedding, |
|
|
220 |
normalization_factor=self.normalization_factor, |
|
|
221 |
aggregation_method=self.aggregation_method, |
|
|
222 |
reflection_equiv=self.reflection_equiv)) |
|
|
223 |
self.to(self.device) |
|
|
224 |
|
|
|
225 |
def forward(self, h, x, edge_index, node_mask=None, edge_mask=None, update_coords_mask=None, |
|
|
226 |
batch_mask=None, edge_attr=None): |
|
|
227 |
# Edit Emiel: Remove velocity as input |
|
|
228 |
edge_feat, _ = coord2diff(x, edge_index) |
|
|
229 |
if self.sin_embedding is not None: |
|
|
230 |
edge_feat = self.sin_embedding(edge_feat) |
|
|
231 |
if edge_attr is not None: |
|
|
232 |
edge_feat = torch.cat([edge_feat, edge_attr], dim=1) |
|
|
233 |
h = self.embedding(h) |
|
|
234 |
for i in range(0, self.n_layers): |
|
|
235 |
h, x = self._modules["e_block_%d" % i]( |
|
|
236 |
h, x, edge_index, node_mask=node_mask, edge_mask=edge_mask, |
|
|
237 |
edge_attr=edge_feat, update_coords_mask=update_coords_mask, |
|
|
238 |
batch_mask=batch_mask) |
|
|
239 |
|
|
|
240 |
# Important, the bias of the last linear might be non-zero |
|
|
241 |
h = self.embedding_out(h) |
|
|
242 |
if node_mask is not None: |
|
|
243 |
h = h * node_mask |
|
|
244 |
return h, x |
|
|
245 |
|
|
|
246 |
|
|
|
247 |
class GNN(nn.Module): |
|
|
248 |
def __init__(self, in_node_nf, in_edge_nf, hidden_nf, aggregation_method='sum', device='cpu', |
|
|
249 |
act_fn=nn.SiLU(), n_layers=4, attention=False, |
|
|
250 |
normalization_factor=1, out_node_nf=None): |
|
|
251 |
super(GNN, self).__init__() |
|
|
252 |
if out_node_nf is None: |
|
|
253 |
out_node_nf = in_node_nf |
|
|
254 |
self.hidden_nf = hidden_nf |
|
|
255 |
self.device = device |
|
|
256 |
self.n_layers = n_layers |
|
|
257 |
### Encoder |
|
|
258 |
self.embedding = nn.Linear(in_node_nf, self.hidden_nf) |
|
|
259 |
self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf) |
|
|
260 |
for i in range(0, n_layers): |
|
|
261 |
self.add_module("gcl_%d" % i, GCL( |
|
|
262 |
self.hidden_nf, self.hidden_nf, self.hidden_nf, |
|
|
263 |
normalization_factor=normalization_factor, |
|
|
264 |
aggregation_method=aggregation_method, |
|
|
265 |
edges_in_d=in_edge_nf, act_fn=act_fn, |
|
|
266 |
attention=attention)) |
|
|
267 |
self.to(self.device) |
|
|
268 |
|
|
|
269 |
def forward(self, h, edges, edge_attr=None, node_mask=None, edge_mask=None): |
|
|
270 |
# Edit Emiel: Remove velocity as input |
|
|
271 |
h = self.embedding(h) |
|
|
272 |
for i in range(0, self.n_layers): |
|
|
273 |
h, _ = self._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask) |
|
|
274 |
h = self.embedding_out(h) |
|
|
275 |
|
|
|
276 |
# Important, the bias of the last linear might be non-zero |
|
|
277 |
if node_mask is not None: |
|
|
278 |
h = h * node_mask |
|
|
279 |
return h |
|
|
280 |
|
|
|
281 |
|
|
|
282 |
class SinusoidsEmbeddingNew(nn.Module): |
|
|
283 |
def __init__(self, max_res=15., min_res=15. / 2000., div_factor=4): |
|
|
284 |
super().__init__() |
|
|
285 |
self.n_frequencies = int(math.log(max_res / min_res, div_factor)) + 1 |
|
|
286 |
self.frequencies = 2 * math.pi * div_factor ** torch.arange(self.n_frequencies)/max_res |
|
|
287 |
self.dim = len(self.frequencies) * 2 |
|
|
288 |
|
|
|
289 |
def forward(self, x): |
|
|
290 |
x = torch.sqrt(x + 1e-8) |
|
|
291 |
emb = x * self.frequencies[None, :].to(x.device) |
|
|
292 |
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) |
|
|
293 |
return emb.detach() |
|
|
294 |
|
|
|
295 |
|
|
|
296 |
def coord2diff(x, edge_index, norm_constant=1): |
|
|
297 |
row, col = edge_index |
|
|
298 |
coord_diff = x[row] - x[col] |
|
|
299 |
radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1) |
|
|
300 |
norm = torch.sqrt(radial + 1e-8) |
|
|
301 |
coord_diff = coord_diff/(norm + norm_constant) |
|
|
302 |
return radial, coord_diff |
|
|
303 |
|
|
|
304 |
|
|
|
305 |
def coord2cross(x, edge_index, batch_mask, norm_constant=1): |
|
|
306 |
|
|
|
307 |
mean = unsorted_segment_sum(x, batch_mask, |
|
|
308 |
num_segments=batch_mask.max() + 1, |
|
|
309 |
normalization_factor=None, |
|
|
310 |
aggregation_method='mean') |
|
|
311 |
row, col = edge_index |
|
|
312 |
cross = torch.cross(x[row]-mean[batch_mask[row]], |
|
|
313 |
x[col]-mean[batch_mask[col]], dim=1) |
|
|
314 |
norm = torch.linalg.norm(cross, dim=1, keepdim=True) |
|
|
315 |
cross = cross / (norm + norm_constant) |
|
|
316 |
return cross |
|
|
317 |
|
|
|
318 |
|
|
|
319 |
def unsorted_segment_sum(data, segment_ids, num_segments, normalization_factor, aggregation_method: str): |
|
|
320 |
"""Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`. |
|
|
321 |
Normalization: 'sum' or 'mean'. |
|
|
322 |
""" |
|
|
323 |
result_shape = (num_segments, data.size(1)) |
|
|
324 |
result = data.new_full(result_shape, 0) # Init empty result tensor. |
|
|
325 |
segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) |
|
|
326 |
result.scatter_add_(0, segment_ids, data) |
|
|
327 |
if aggregation_method == 'sum': |
|
|
328 |
result = result / normalization_factor |
|
|
329 |
|
|
|
330 |
if aggregation_method == 'mean': |
|
|
331 |
norm = data.new_zeros(result.shape) |
|
|
332 |
norm.scatter_add_(0, segment_ids, data.new_ones(data.shape)) |
|
|
333 |
norm[norm == 0] = 1 |
|
|
334 |
result = result / norm |
|
|
335 |
return result |