a b/equivariant_diffusion/egnn_new.py
1
from torch import nn
2
import torch
3
import math
4
5
6
class GCL(nn.Module):
7
    def __init__(self, input_nf, output_nf, hidden_nf, normalization_factor, aggregation_method,
8
                 edges_in_d=0, nodes_att_dim=0, act_fn=nn.SiLU(), attention=False):
9
        super(GCL, self).__init__()
10
        input_edge = input_nf * 2
11
        self.normalization_factor = normalization_factor
12
        self.aggregation_method = aggregation_method
13
        self.attention = attention
14
15
        self.edge_mlp = nn.Sequential(
16
            nn.Linear(input_edge + edges_in_d, hidden_nf),
17
            act_fn,
18
            nn.Linear(hidden_nf, hidden_nf),
19
            act_fn)
20
21
        self.node_mlp = nn.Sequential(
22
            nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf),
23
            act_fn,
24
            nn.Linear(hidden_nf, output_nf))
25
26
        if self.attention:
27
            self.att_mlp = nn.Sequential(
28
                nn.Linear(hidden_nf, 1),
29
                nn.Sigmoid())
30
31
    def edge_model(self, source, target, edge_attr, edge_mask):
32
        if edge_attr is None:  # Unused.
33
            out = torch.cat([source, target], dim=1)
34
        else:
35
            out = torch.cat([source, target, edge_attr], dim=1)
36
        mij = self.edge_mlp(out)
37
38
        if self.attention:
39
            att_val = self.att_mlp(mij)
40
            out = mij * att_val
41
        else:
42
            out = mij
43
44
        if edge_mask is not None:
45
            out = out * edge_mask
46
        return out, mij
47
48
    def node_model(self, x, edge_index, edge_attr, node_attr):
49
        row, col = edge_index
50
        agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0),
51
                                   normalization_factor=self.normalization_factor,
52
                                   aggregation_method=self.aggregation_method)
53
        if node_attr is not None:
54
            agg = torch.cat([x, agg, node_attr], dim=1)
55
        else:
56
            agg = torch.cat([x, agg], dim=1)
57
        out = x + self.node_mlp(agg)
58
        return out, agg
59
60
    def forward(self, h, edge_index, edge_attr=None, node_attr=None, node_mask=None, edge_mask=None):
61
        row, col = edge_index
62
        edge_feat, mij = self.edge_model(h[row], h[col], edge_attr, edge_mask)
63
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
64
        if node_mask is not None:
65
            h = h * node_mask
66
        return h, mij
67
68
69
class EquivariantUpdate(nn.Module):
70
    def __init__(self, hidden_nf, normalization_factor, aggregation_method,
71
                 edges_in_d=1, act_fn=nn.SiLU(), tanh=False, coords_range=10.0,
72
                 reflection_equiv=True):
73
        super(EquivariantUpdate, self).__init__()
74
        self.tanh = tanh
75
        self.coords_range = coords_range
76
        self.reflection_equiv = reflection_equiv
77
        input_edge = hidden_nf * 2 + edges_in_d
78
        layer = nn.Linear(hidden_nf, 1, bias=False)
79
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
80
        self.coord_mlp = nn.Sequential(
81
            nn.Linear(input_edge, hidden_nf),
82
            act_fn,
83
            nn.Linear(hidden_nf, hidden_nf),
84
            act_fn,
85
            layer)
86
        self.cross_product_mlp = nn.Sequential(
87
            nn.Linear(input_edge, hidden_nf),
88
            act_fn,
89
            nn.Linear(hidden_nf, hidden_nf),
90
            act_fn,
91
            layer
92
        ) if not self.reflection_equiv else None
93
        self.normalization_factor = normalization_factor
94
        self.aggregation_method = aggregation_method
95
96
    def coord_model(self, h, coord, edge_index, coord_diff, coord_cross,
97
                    edge_attr, edge_mask, update_coords_mask=None):
98
        row, col = edge_index
99
        input_tensor = torch.cat([h[row], h[col], edge_attr], dim=1)
100
        if self.tanh:
101
            trans = coord_diff * torch.tanh(self.coord_mlp(input_tensor)) * self.coords_range
102
        else:
103
            trans = coord_diff * self.coord_mlp(input_tensor)
104
105
        if not self.reflection_equiv:
106
            phi_cross = self.cross_product_mlp(input_tensor)
107
            if self.tanh:
108
                phi_cross = torch.tanh(phi_cross) * self.coords_range
109
            trans = trans + coord_cross * phi_cross
110
111
        if edge_mask is not None:
112
            trans = trans * edge_mask
113
114
        agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0),
115
                                   normalization_factor=self.normalization_factor,
116
                                   aggregation_method=self.aggregation_method)
117
118
        if update_coords_mask is not None:
119
            agg = update_coords_mask * agg
120
121
        coord = coord + agg
122
        return coord
123
124
    def forward(self, h, coord, edge_index, coord_diff, coord_cross,
125
                edge_attr=None, node_mask=None, edge_mask=None,
126
                update_coords_mask=None):
127
        coord = self.coord_model(h, coord, edge_index, coord_diff, coord_cross,
128
                                 edge_attr, edge_mask,
129
                                 update_coords_mask=update_coords_mask)
130
        if node_mask is not None:
131
            coord = coord * node_mask
132
        return coord
133
134
135
class EquivariantBlock(nn.Module):
136
    def __init__(self, hidden_nf, edge_feat_nf=2, device='cpu', act_fn=nn.SiLU(), n_layers=2, attention=True,
137
                 norm_diff=True, tanh=False, coords_range=15, norm_constant=1, sin_embedding=None,
138
                 normalization_factor=100, aggregation_method='sum', reflection_equiv=True):
139
        super(EquivariantBlock, self).__init__()
140
        self.hidden_nf = hidden_nf
141
        self.device = device
142
        self.n_layers = n_layers
143
        self.coords_range_layer = float(coords_range)
144
        self.norm_diff = norm_diff
145
        self.norm_constant = norm_constant
146
        self.sin_embedding = sin_embedding
147
        self.normalization_factor = normalization_factor
148
        self.aggregation_method = aggregation_method
149
        self.reflection_equiv = reflection_equiv
150
151
        for i in range(0, n_layers):
152
            self.add_module("gcl_%d" % i, GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=edge_feat_nf,
153
                                              act_fn=act_fn, attention=attention,
154
                                              normalization_factor=self.normalization_factor,
155
                                              aggregation_method=self.aggregation_method))
156
        self.add_module("gcl_equiv", EquivariantUpdate(hidden_nf, edges_in_d=edge_feat_nf, act_fn=nn.SiLU(), tanh=tanh,
157
                                                       coords_range=self.coords_range_layer,
158
                                                       normalization_factor=self.normalization_factor,
159
                                                       aggregation_method=self.aggregation_method,
160
                                                       reflection_equiv=self.reflection_equiv))
161
        self.to(self.device)
162
163
    def forward(self, h, x, edge_index, node_mask=None, edge_mask=None,
164
                edge_attr=None, update_coords_mask=None, batch_mask=None):
165
        # Edit Emiel: Remove velocity as input
166
        distances, coord_diff = coord2diff(x, edge_index, self.norm_constant)
167
        if self.reflection_equiv:
168
            coord_cross = None
169
        else:
170
            coord_cross = coord2cross(x, edge_index, batch_mask,
171
                                      self.norm_constant)
172
        if self.sin_embedding is not None:
173
            distances = self.sin_embedding(distances)
174
        edge_attr = torch.cat([distances, edge_attr], dim=1)
175
        for i in range(0, self.n_layers):
176
            h, _ = self._modules["gcl_%d" % i](h, edge_index, edge_attr=edge_attr,
177
                                               node_mask=node_mask, edge_mask=edge_mask)
178
        x = self._modules["gcl_equiv"](h, x, edge_index, coord_diff, coord_cross, edge_attr,
179
                                       node_mask, edge_mask, update_coords_mask=update_coords_mask)
180
181
        # Important, the bias of the last linear might be non-zero
182
        if node_mask is not None:
183
            h = h * node_mask
184
        return h, x
185
186
187
class EGNN(nn.Module):
188
    def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=3, attention=False,
189
                 norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, norm_constant=1, inv_sublayers=2,
190
                 sin_embedding=False, normalization_factor=100, aggregation_method='sum', reflection_equiv=True):
191
        super(EGNN, self).__init__()
192
        if out_node_nf is None:
193
            out_node_nf = in_node_nf
194
        self.hidden_nf = hidden_nf
195
        self.device = device
196
        self.n_layers = n_layers
197
        self.coords_range_layer = float(coords_range/n_layers)
198
        self.norm_diff = norm_diff
199
        self.normalization_factor = normalization_factor
200
        self.aggregation_method = aggregation_method
201
        self.reflection_equiv = reflection_equiv
202
203
        if sin_embedding:
204
            self.sin_embedding = SinusoidsEmbeddingNew()
205
            edge_feat_nf = self.sin_embedding.dim * 2
206
        else:
207
            self.sin_embedding = None
208
            edge_feat_nf = 2
209
210
        edge_feat_nf = edge_feat_nf + in_edge_nf
211
212
        self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
213
        self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
214
        for i in range(0, n_layers):
215
            self.add_module("e_block_%d" % i, EquivariantBlock(hidden_nf, edge_feat_nf=edge_feat_nf, device=device,
216
                                                               act_fn=act_fn, n_layers=inv_sublayers,
217
                                                               attention=attention, norm_diff=norm_diff, tanh=tanh,
218
                                                               coords_range=coords_range, norm_constant=norm_constant,
219
                                                               sin_embedding=self.sin_embedding,
220
                                                               normalization_factor=self.normalization_factor,
221
                                                               aggregation_method=self.aggregation_method,
222
                                                               reflection_equiv=self.reflection_equiv))
223
        self.to(self.device)
224
225
    def forward(self, h, x, edge_index, node_mask=None, edge_mask=None, update_coords_mask=None,
226
                batch_mask=None, edge_attr=None):
227
        # Edit Emiel: Remove velocity as input
228
        edge_feat, _ = coord2diff(x, edge_index)
229
        if self.sin_embedding is not None:
230
            edge_feat = self.sin_embedding(edge_feat)
231
        if edge_attr is not None:
232
            edge_feat = torch.cat([edge_feat, edge_attr], dim=1)
233
        h = self.embedding(h)
234
        for i in range(0, self.n_layers):
235
            h, x = self._modules["e_block_%d" % i](
236
                h, x, edge_index, node_mask=node_mask, edge_mask=edge_mask,
237
                edge_attr=edge_feat, update_coords_mask=update_coords_mask,
238
                batch_mask=batch_mask)
239
240
        # Important, the bias of the last linear might be non-zero
241
        h = self.embedding_out(h)
242
        if node_mask is not None:
243
            h = h * node_mask
244
        return h, x
245
246
247
class GNN(nn.Module):
248
    def __init__(self, in_node_nf, in_edge_nf, hidden_nf, aggregation_method='sum', device='cpu',
249
                 act_fn=nn.SiLU(), n_layers=4, attention=False,
250
                 normalization_factor=1, out_node_nf=None):
251
        super(GNN, self).__init__()
252
        if out_node_nf is None:
253
            out_node_nf = in_node_nf
254
        self.hidden_nf = hidden_nf
255
        self.device = device
256
        self.n_layers = n_layers
257
        ### Encoder
258
        self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
259
        self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
260
        for i in range(0, n_layers):
261
            self.add_module("gcl_%d" % i, GCL(
262
                self.hidden_nf, self.hidden_nf, self.hidden_nf,
263
                normalization_factor=normalization_factor,
264
                aggregation_method=aggregation_method,
265
                edges_in_d=in_edge_nf, act_fn=act_fn,
266
                attention=attention))
267
        self.to(self.device)
268
269
    def forward(self, h, edges, edge_attr=None, node_mask=None, edge_mask=None):
270
        # Edit Emiel: Remove velocity as input
271
        h = self.embedding(h)
272
        for i in range(0, self.n_layers):
273
            h, _ = self._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask)
274
        h = self.embedding_out(h)
275
276
        # Important, the bias of the last linear might be non-zero
277
        if node_mask is not None:
278
            h = h * node_mask
279
        return h
280
281
282
class SinusoidsEmbeddingNew(nn.Module):
283
    def __init__(self, max_res=15., min_res=15. / 2000., div_factor=4):
284
        super().__init__()
285
        self.n_frequencies = int(math.log(max_res / min_res, div_factor)) + 1
286
        self.frequencies = 2 * math.pi * div_factor ** torch.arange(self.n_frequencies)/max_res
287
        self.dim = len(self.frequencies) * 2
288
289
    def forward(self, x):
290
        x = torch.sqrt(x + 1e-8)
291
        emb = x * self.frequencies[None, :].to(x.device)
292
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
293
        return emb.detach()
294
295
296
def coord2diff(x, edge_index, norm_constant=1):
297
    row, col = edge_index
298
    coord_diff = x[row] - x[col]
299
    radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1)
300
    norm = torch.sqrt(radial + 1e-8)
301
    coord_diff = coord_diff/(norm + norm_constant)
302
    return radial, coord_diff
303
304
305
def coord2cross(x, edge_index, batch_mask, norm_constant=1):
306
307
    mean = unsorted_segment_sum(x, batch_mask,
308
                                num_segments=batch_mask.max() + 1,
309
                                normalization_factor=None,
310
                                aggregation_method='mean')
311
    row, col = edge_index
312
    cross = torch.cross(x[row]-mean[batch_mask[row]],
313
                        x[col]-mean[batch_mask[col]], dim=1)
314
    norm = torch.linalg.norm(cross, dim=1, keepdim=True)
315
    cross = cross / (norm + norm_constant)
316
    return cross
317
318
319
def unsorted_segment_sum(data, segment_ids, num_segments, normalization_factor, aggregation_method: str):
320
    """Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`.
321
        Normalization: 'sum' or 'mean'.
322
    """
323
    result_shape = (num_segments, data.size(1))
324
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
325
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
326
    result.scatter_add_(0, segment_ids, data)
327
    if aggregation_method == 'sum':
328
        result = result / normalization_factor
329
330
    if aggregation_method == 'mean':
331
        norm = data.new_zeros(result.shape)
332
        norm.scatter_add_(0, segment_ids, data.new_ones(data.shape))
333
        norm[norm == 0] = 1
334
        result = result / norm
335
    return result