[36b44b]: / torchdrug / layers / geometry / graph.py

Download this file

195 lines (158 with data), 8.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import math
import torch
from torch import nn
from torchdrug import core, data
from torchdrug.layers import functional
from torchdrug.core import Registry as R
@R.register("layers.GraphConstruction")
class GraphConstruction(nn.Module, core.Configurable):
"""
Construct a new graph from an existing graph.
See `torchdrug.layers.geometry` for a full list of available node and edge layers.
Parameters:
node_layers (list of nn.Module, optional): modules to construct nodes of the new graph
edge_layers (list of nn.Module, optional): modules to construct edges of the new graph
edge_feature (str, optional): edge features in the new graph.
Available features are ``residue_type``, ``gearnet``.
1. For ``residue_type``, the feature of the edge :math:`e_{ij}` between residue :math:`i` and residue
:math:`j` is the concatenation ``[residue_type(i), residue_type(j)]``.
2. For ``gearnet``, the feature of the edge :math:`e_{ij}` between residue :math:`i` and residue :math:`j`
is the concatenation ``[residue_type(i), residue_type(j), edge_type(e_ij),
sequential_distance(i,j), spatial_distance(i,j)]``.
.. note::
You may customize your own edge features by inheriting this class and define a member function
for your features. Use ``edge_feature="my_feature"`` to call the following feature function.
.. code:: python
def edge_my_feature(self, graph, edge_list, num_relation):
...
return feature # the first dimension must be ``graph.num_edge``
"""
max_seq_dist = 10
def __init__(self, node_layers=None, edge_layers=None, edge_feature="residue_type"):
super(GraphConstruction, self).__init__()
if node_layers is None:
self.node_layers = nn.ModuleList()
else:
self.node_layers = nn.ModuleList(node_layers)
if edge_layers is None:
edge_layers = nn.ModuleList()
else:
edge_layers = nn.ModuleList(edge_layers)
self.edge_layers = edge_layers
self.edge_feature = edge_feature
def edge_residue_type(self, graph, edge_list, num_relation):
node_in, node_out, _ = edge_list.t()
residue_in, residue_out = graph.atom2residue[node_in], graph.atom2residue[node_out]
in_residue_type = graph.residue_type[residue_in]
out_residue_type = graph.residue_type[residue_out]
return torch.cat([
functional.one_hot(in_residue_type, len(data.Protein.residue2id)),
functional.one_hot(out_residue_type, len(data.Protein.residue2id))
], dim=-1)
def edge_gearnet(self, graph, edge_list, num_relation):
node_in, node_out, r = edge_list.t()
residue_in, residue_out = graph.atom2residue[node_in], graph.atom2residue[node_out]
in_residue_type = graph.residue_type[residue_in]
out_residue_type = graph.residue_type[residue_out]
sequential_dist = torch.abs(residue_in - residue_out)
spatial_dist = (graph.node_position[node_in] - graph.node_position[node_out]).norm(dim=-1)
return torch.cat([
functional.one_hot(in_residue_type, len(data.Protein.residue2id)),
functional.one_hot(out_residue_type, len(data.Protein.residue2id)),
functional.one_hot(r, num_relation),
functional.one_hot(sequential_dist.clamp(max=self.max_seq_dist), self.max_seq_dist + 1),
spatial_dist.unsqueeze(-1)
], dim=-1)
def apply_node_layer(self, graph):
for layer in self.node_layers:
graph = layer(graph)
return graph
def apply_edge_layer(self, graph):
if not self.edge_layers:
return graph
edge_list = []
num_edges = []
num_relations = []
for layer in self.edge_layers:
edges, num_relation = layer(graph)
edge_list.append(edges)
num_edges.append(len(edges))
num_relations.append(num_relation)
edge_list = torch.cat(edge_list)
num_edges = torch.tensor(num_edges, device=graph.device)
num_relations = torch.tensor(num_relations, device=graph.device)
num_relation = num_relations.sum()
offsets = (num_relations.cumsum(0) - num_relations).repeat_interleave(num_edges)
edge_list[:, 2] += offsets
# reorder edges into a valid PackedGraph
node_in = edge_list[:, 0]
edge2graph = graph.node2graph[node_in]
order = edge2graph.argsort()
edge_list = edge_list[order]
num_edges = edge2graph.bincount(minlength=graph.batch_size)
offsets = (graph.num_cum_nodes - graph.num_nodes).repeat_interleave(num_edges)
if hasattr(self, "edge_%s" % self.edge_feature):
edge_feature = getattr(self, "edge_%s" % self.edge_feature)(graph, edge_list, num_relation)
elif self.edge_feature is None:
edge_feature = None
else:
raise ValueError("Unknown edge feature `%s`" % self.edge_feature)
data_dict, meta_dict = graph.data_by_meta(include=(
"node", "residue", "node reference", "residue reference", "graph"
))
if isinstance(graph, data.PackedProtein):
data_dict["num_residues"] = graph.num_residues
if isinstance(graph, data.PackedMolecule):
data_dict["bond_type"] = torch.zeros_like(edge_list[:, 2])
return type(graph)(edge_list, num_nodes=graph.num_nodes, num_edges=num_edges, num_relation=num_relation,
view=graph.view, offsets=offsets, edge_feature=edge_feature,
meta_dict=meta_dict, **data_dict)
def forward(self, graph):
"""
Generate a new graph based on the input graph and pre-defined node and edge layers.
Parameters:
graph (Graph): :math:`n` graph(s)
Returns:
graph (Graph): new graph(s)
"""
graph = self.apply_node_layer(graph)
graph = self.apply_edge_layer(graph)
return graph
@R.register("layers.SpatialLineGraph")
class SpatialLineGraph(nn.Module, core.Configurable):
"""
Spatial line graph construction module from `Protein Representation Learning by Geometric Structure Pretraining`_.
.. _Protein Representation Learning by Geometric Structure Pretraining:
https://arxiv.org/pdf/2203.06125.pdf
Parameters:
num_angle_bin (int, optional): number of bins to discretize angles between edges
"""
def __init__(self, num_angle_bin=8):
super(SpatialLineGraph, self).__init__()
self.num_angle_bin = num_angle_bin
def forward(self, graph):
"""
Generate the spatial line graph of the input graph.
The edge types are decided by the angles between two adjacent edges in the input graph.
Parameters:
graph (PackedGraph): :math:`n` graph(s)
Returns:
graph (PackedGraph): the spatial line graph
"""
line_graph = graph.line_graph()
node_in, node_out = graph.edge_list[:, :2].t()
edge_in, edge_out = line_graph.edge_list.t()
# compute the angle ijk
node_i = node_out[edge_out]
node_j = node_in[edge_out]
node_k = node_in[edge_in]
vector1 = graph.node_position[node_i] - graph.node_position[node_j]
vector2 = graph.node_position[node_k] - graph.node_position[node_j]
x = (vector1 * vector2).sum(dim=-1)
y = torch.cross(vector1, vector2).norm(dim=-1)
angle = torch.atan2(y, x)
relation = (angle / math.pi * self.num_angle_bin).long().clamp(max=self.num_angle_bin - 1)
edge_list = torch.cat([line_graph.edge_list, relation.unsqueeze(-1)], dim=-1)
return type(line_graph)(edge_list, num_nodes=line_graph.num_nodes, offsets=line_graph._offsets,
num_edges=line_graph.num_edges, num_relation=self.num_angle_bin,
meta_dict=line_graph.meta_dict, **line_graph.data_dict)