from torch import nn
import torch
import math
class GCL(nn.Module):
def __init__(self, input_nf, output_nf, hidden_nf, normalization_factor, aggregation_method,
edges_in_d=0, nodes_att_dim=0, act_fn=nn.SiLU(), attention=False):
super(GCL, self).__init__()
input_edge = input_nf * 2
self.normalization_factor = normalization_factor
self.aggregation_method = aggregation_method
self.attention = attention
self.edge_mlp = nn.Sequential(
nn.Linear(input_edge + edges_in_d, hidden_nf),
act_fn,
nn.Linear(hidden_nf, hidden_nf),
act_fn)
self.node_mlp = nn.Sequential(
nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf),
act_fn,
nn.Linear(hidden_nf, output_nf))
if self.attention:
self.att_mlp = nn.Sequential(
nn.Linear(hidden_nf, 1),
nn.Sigmoid())
def edge_model(self, source, target, edge_attr, edge_mask):
if edge_attr is None: # Unused.
out = torch.cat([source, target], dim=1)
else:
out = torch.cat([source, target, edge_attr], dim=1)
mij = self.edge_mlp(out)
if self.attention:
att_val = self.att_mlp(mij)
out = mij * att_val
else:
out = mij
if edge_mask is not None:
out = out * edge_mask
return out, mij
def node_model(self, x, edge_index, edge_attr, node_attr):
row, col = edge_index
agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0),
normalization_factor=self.normalization_factor,
aggregation_method=self.aggregation_method)
if node_attr is not None:
agg = torch.cat([x, agg, node_attr], dim=1)
else:
agg = torch.cat([x, agg], dim=1)
out = x + self.node_mlp(agg)
return out, agg
def forward(self, h, edge_index, edge_attr=None, node_attr=None, node_mask=None, edge_mask=None):
row, col = edge_index
edge_feat, mij = self.edge_model(h[row], h[col], edge_attr, edge_mask)
h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
if node_mask is not None:
h = h * node_mask
return h, mij
class EquivariantUpdate(nn.Module):
def __init__(self, hidden_nf, normalization_factor, aggregation_method,
edges_in_d=1, act_fn=nn.SiLU(), tanh=False, coords_range=10.0,
reflection_equiv=True):
super(EquivariantUpdate, self).__init__()
self.tanh = tanh
self.coords_range = coords_range
self.reflection_equiv = reflection_equiv
input_edge = hidden_nf * 2 + edges_in_d
layer = nn.Linear(hidden_nf, 1, bias=False)
torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
self.coord_mlp = nn.Sequential(
nn.Linear(input_edge, hidden_nf),
act_fn,
nn.Linear(hidden_nf, hidden_nf),
act_fn,
layer)
self.cross_product_mlp = nn.Sequential(
nn.Linear(input_edge, hidden_nf),
act_fn,
nn.Linear(hidden_nf, hidden_nf),
act_fn,
layer
) if not self.reflection_equiv else None
self.normalization_factor = normalization_factor
self.aggregation_method = aggregation_method
def coord_model(self, h, coord, edge_index, coord_diff, coord_cross,
edge_attr, edge_mask, update_coords_mask=None):
row, col = edge_index
input_tensor = torch.cat([h[row], h[col], edge_attr], dim=1)
if self.tanh:
trans = coord_diff * torch.tanh(self.coord_mlp(input_tensor)) * self.coords_range
else:
trans = coord_diff * self.coord_mlp(input_tensor)
if not self.reflection_equiv:
phi_cross = self.cross_product_mlp(input_tensor)
if self.tanh:
phi_cross = torch.tanh(phi_cross) * self.coords_range
trans = trans + coord_cross * phi_cross
if edge_mask is not None:
trans = trans * edge_mask
agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0),
normalization_factor=self.normalization_factor,
aggregation_method=self.aggregation_method)
if update_coords_mask is not None:
agg = update_coords_mask * agg
coord = coord + agg
return coord
def forward(self, h, coord, edge_index, coord_diff, coord_cross,
edge_attr=None, node_mask=None, edge_mask=None,
update_coords_mask=None):
coord = self.coord_model(h, coord, edge_index, coord_diff, coord_cross,
edge_attr, edge_mask,
update_coords_mask=update_coords_mask)
if node_mask is not None:
coord = coord * node_mask
return coord
class EquivariantBlock(nn.Module):
def __init__(self, hidden_nf, edge_feat_nf=2, device='cpu', act_fn=nn.SiLU(), n_layers=2, attention=True,
norm_diff=True, tanh=False, coords_range=15, norm_constant=1, sin_embedding=None,
normalization_factor=100, aggregation_method='sum', reflection_equiv=True):
super(EquivariantBlock, self).__init__()
self.hidden_nf = hidden_nf
self.device = device
self.n_layers = n_layers
self.coords_range_layer = float(coords_range)
self.norm_diff = norm_diff
self.norm_constant = norm_constant
self.sin_embedding = sin_embedding
self.normalization_factor = normalization_factor
self.aggregation_method = aggregation_method
self.reflection_equiv = reflection_equiv
for i in range(0, n_layers):
self.add_module("gcl_%d" % i, GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=edge_feat_nf,
act_fn=act_fn, attention=attention,
normalization_factor=self.normalization_factor,
aggregation_method=self.aggregation_method))
self.add_module("gcl_equiv", EquivariantUpdate(hidden_nf, edges_in_d=edge_feat_nf, act_fn=nn.SiLU(), tanh=tanh,
coords_range=self.coords_range_layer,
normalization_factor=self.normalization_factor,
aggregation_method=self.aggregation_method,
reflection_equiv=self.reflection_equiv))
self.to(self.device)
def forward(self, h, x, edge_index, node_mask=None, edge_mask=None,
edge_attr=None, update_coords_mask=None, batch_mask=None):
# Edit Emiel: Remove velocity as input
distances, coord_diff = coord2diff(x, edge_index, self.norm_constant)
if self.reflection_equiv:
coord_cross = None
else:
coord_cross = coord2cross(x, edge_index, batch_mask,
self.norm_constant)
if self.sin_embedding is not None:
distances = self.sin_embedding(distances)
edge_attr = torch.cat([distances, edge_attr], dim=1)
for i in range(0, self.n_layers):
h, _ = self._modules["gcl_%d" % i](h, edge_index, edge_attr=edge_attr,
node_mask=node_mask, edge_mask=edge_mask)
x = self._modules["gcl_equiv"](h, x, edge_index, coord_diff, coord_cross, edge_attr,
node_mask, edge_mask, update_coords_mask=update_coords_mask)
# Important, the bias of the last linear might be non-zero
if node_mask is not None:
h = h * node_mask
return h, x
class EGNN(nn.Module):
def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=3, attention=False,
norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, norm_constant=1, inv_sublayers=2,
sin_embedding=False, normalization_factor=100, aggregation_method='sum', reflection_equiv=True):
super(EGNN, self).__init__()
if out_node_nf is None:
out_node_nf = in_node_nf
self.hidden_nf = hidden_nf
self.device = device
self.n_layers = n_layers
self.coords_range_layer = float(coords_range/n_layers)
self.norm_diff = norm_diff
self.normalization_factor = normalization_factor
self.aggregation_method = aggregation_method
self.reflection_equiv = reflection_equiv
if sin_embedding:
self.sin_embedding = SinusoidsEmbeddingNew()
edge_feat_nf = self.sin_embedding.dim * 2
else:
self.sin_embedding = None
edge_feat_nf = 2
edge_feat_nf = edge_feat_nf + in_edge_nf
self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
for i in range(0, n_layers):
self.add_module("e_block_%d" % i, EquivariantBlock(hidden_nf, edge_feat_nf=edge_feat_nf, device=device,
act_fn=act_fn, n_layers=inv_sublayers,
attention=attention, norm_diff=norm_diff, tanh=tanh,
coords_range=coords_range, norm_constant=norm_constant,
sin_embedding=self.sin_embedding,
normalization_factor=self.normalization_factor,
aggregation_method=self.aggregation_method,
reflection_equiv=self.reflection_equiv))
self.to(self.device)
def forward(self, h, x, edge_index, node_mask=None, edge_mask=None, update_coords_mask=None,
batch_mask=None, edge_attr=None):
# Edit Emiel: Remove velocity as input
edge_feat, _ = coord2diff(x, edge_index)
if self.sin_embedding is not None:
edge_feat = self.sin_embedding(edge_feat)
if edge_attr is not None:
edge_feat = torch.cat([edge_feat, edge_attr], dim=1)
h = self.embedding(h)
for i in range(0, self.n_layers):
h, x = self._modules["e_block_%d" % i](
h, x, edge_index, node_mask=node_mask, edge_mask=edge_mask,
edge_attr=edge_feat, update_coords_mask=update_coords_mask,
batch_mask=batch_mask)
# Important, the bias of the last linear might be non-zero
h = self.embedding_out(h)
if node_mask is not None:
h = h * node_mask
return h, x
class GNN(nn.Module):
def __init__(self, in_node_nf, in_edge_nf, hidden_nf, aggregation_method='sum', device='cpu',
act_fn=nn.SiLU(), n_layers=4, attention=False,
normalization_factor=1, out_node_nf=None):
super(GNN, self).__init__()
if out_node_nf is None:
out_node_nf = in_node_nf
self.hidden_nf = hidden_nf
self.device = device
self.n_layers = n_layers
### Encoder
self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
for i in range(0, n_layers):
self.add_module("gcl_%d" % i, GCL(
self.hidden_nf, self.hidden_nf, self.hidden_nf,
normalization_factor=normalization_factor,
aggregation_method=aggregation_method,
edges_in_d=in_edge_nf, act_fn=act_fn,
attention=attention))
self.to(self.device)
def forward(self, h, edges, edge_attr=None, node_mask=None, edge_mask=None):
# Edit Emiel: Remove velocity as input
h = self.embedding(h)
for i in range(0, self.n_layers):
h, _ = self._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask)
h = self.embedding_out(h)
# Important, the bias of the last linear might be non-zero
if node_mask is not None:
h = h * node_mask
return h
class SinusoidsEmbeddingNew(nn.Module):
def __init__(self, max_res=15., min_res=15. / 2000., div_factor=4):
super().__init__()
self.n_frequencies = int(math.log(max_res / min_res, div_factor)) + 1
self.frequencies = 2 * math.pi * div_factor ** torch.arange(self.n_frequencies)/max_res
self.dim = len(self.frequencies) * 2
def forward(self, x):
x = torch.sqrt(x + 1e-8)
emb = x * self.frequencies[None, :].to(x.device)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb.detach()
def coord2diff(x, edge_index, norm_constant=1):
row, col = edge_index
coord_diff = x[row] - x[col]
radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1)
norm = torch.sqrt(radial + 1e-8)
coord_diff = coord_diff/(norm + norm_constant)
return radial, coord_diff
def coord2cross(x, edge_index, batch_mask, norm_constant=1):
mean = unsorted_segment_sum(x, batch_mask,
num_segments=batch_mask.max() + 1,
normalization_factor=None,
aggregation_method='mean')
row, col = edge_index
cross = torch.cross(x[row]-mean[batch_mask[row]],
x[col]-mean[batch_mask[col]], dim=1)
norm = torch.linalg.norm(cross, dim=1, keepdim=True)
cross = cross / (norm + norm_constant)
return cross
def unsorted_segment_sum(data, segment_ids, num_segments, normalization_factor, aggregation_method: str):
"""Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`.
Normalization: 'sum' or 'mean'.
"""
result_shape = (num_segments, data.size(1))
result = data.new_full(result_shape, 0) # Init empty result tensor.
segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
result.scatter_add_(0, segment_ids, data)
if aggregation_method == 'sum':
result = result / normalization_factor
if aggregation_method == 'mean':
norm = data.new_zeros(result.shape)
norm.scatter_add_(0, segment_ids, data.new_ones(data.shape))
norm[norm == 0] = 1
result = result / norm
return result