|
a |
|
b/doc/source/notes/graph.rst |
|
|
1 |
Graph Data Structures |
|
|
2 |
===================== |
|
|
3 |
|
|
|
4 |
At the core of TorchDrug, we provides several data structures to enable common |
|
|
5 |
operations in graph representation learning. |
|
|
6 |
|
|
|
7 |
Create a Graph |
|
|
8 |
-------------- |
|
|
9 |
|
|
|
10 |
To begin with, let's create a graph. |
|
|
11 |
|
|
|
12 |
.. code:: python |
|
|
13 |
|
|
|
14 |
import torch |
|
|
15 |
from torchdrug import data |
|
|
16 |
|
|
|
17 |
edge_list = [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 0]] |
|
|
18 |
graph = data.Graph(edge_list, num_node=6) |
|
|
19 |
graph.visualize() |
|
|
20 |
|
|
|
21 |
This will plot a ring graph like the following. |
|
|
22 |
|
|
|
23 |
.. image:: ../../../asset/graph/graph.png |
|
|
24 |
:align: center |
|
|
25 |
:width: 33% |
|
|
26 |
|
|
|
27 |
Internally, the graph is stored as a sparse edge list to save memory footprint. For |
|
|
28 |
an intuitive comparison, a `scale-free graph`_ may have 1 million nodes and 10 million |
|
|
29 |
edges. The dense version takes about 4TB, while the sparse version only requires 120MB. |
|
|
30 |
|
|
|
31 |
.. _scale-free graph: |
|
|
32 |
https://en.wikipedia.org/wiki/Scale-free_network |
|
|
33 |
|
|
|
34 |
Here are some commonly used properties of the graph. |
|
|
35 |
|
|
|
36 |
.. code:: python |
|
|
37 |
|
|
|
38 |
print(graph.num_node) |
|
|
39 |
print(graph.num_edge) |
|
|
40 |
print(graph.edge_list) |
|
|
41 |
print(graph.edge_weight) |
|
|
42 |
|
|
|
43 |
In some scenarios, the graph may also have type information on its edges. For example, |
|
|
44 |
molecules have bond types like ``single bound``, while knowledge graphs have relations |
|
|
45 |
like ``consists of``. To construct such a relational graph, we can pass the edge type |
|
|
46 |
as a third variable in the edge list. |
|
|
47 |
|
|
|
48 |
.. code:: python |
|
|
49 |
|
|
|
50 |
triplet_list = [[0, 1, 0], [1, 2, 1], [2, 3, 0], [3, 4, 1], [4, 5, 0], [5, 0, 1]] |
|
|
51 |
graph = data.Graph(triplet_list, num_node=6, num_relation=2) |
|
|
52 |
graph.visualize() |
|
|
53 |
|
|
|
54 |
.. image:: ../../../asset/graph/relational_graph.png |
|
|
55 |
:align: center |
|
|
56 |
:width: 33% |
|
|
57 |
|
|
|
58 |
Alternatively, we can also use adjacency matrices to create the above graphs. |
|
|
59 |
|
|
|
60 |
The normal graph uses a 2D adjacency matrix :math:`A`, where non-zero :math:`A_{i,j}` |
|
|
61 |
corresponds to an edge from node :math:`i` to node :math:`j`. The relational graph |
|
|
62 |
uses a 3D adjacency matrix :math:`A`, where non-zero :math:`A_{i,j,k}` denotes an |
|
|
63 |
edge from node :math:`i` to node :math:`j` with edge type :math:`k`. |
|
|
64 |
|
|
|
65 |
.. code:: python |
|
|
66 |
|
|
|
67 |
adjacency = torch.zeros(6, 6) |
|
|
68 |
adjacency[edge_list] = 1 |
|
|
69 |
graph = data.Graph.from_dense(adjacency) |
|
|
70 |
|
|
|
71 |
adjacency = torch.zeros(6, 6, 2) |
|
|
72 |
adjacency[triplet_list] = 1 |
|
|
73 |
graph = data.Graph.from_dense(adjacency) |
|
|
74 |
|
|
|
75 |
For molecule graphs, TorchDrug supports creating instances from `SMILES`_ strings. |
|
|
76 |
For example, the following code creates a benzene molecule. |
|
|
77 |
|
|
|
78 |
.. _SMILES: |
|
|
79 |
https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system |
|
|
80 |
|
|
|
81 |
.. code:: python |
|
|
82 |
|
|
|
83 |
mol = data.Molecule.from_smiles("C1=CC=CC=C1") |
|
|
84 |
mol.visualize() |
|
|
85 |
|
|
|
86 |
.. image:: ../../../asset/graph/benzene.png |
|
|
87 |
:align: center |
|
|
88 |
:width: 33% |
|
|
89 |
|
|
|
90 |
Once the graph is created, we can transfer it between CPU and GPUs, just like |
|
|
91 |
:class:`torch.Tensor`. |
|
|
92 |
|
|
|
93 |
.. code:: python |
|
|
94 |
|
|
|
95 |
graph = graph.cuda() |
|
|
96 |
print(graph.device) |
|
|
97 |
|
|
|
98 |
graph = graph.cpu() |
|
|
99 |
print(graph.device) |
|
|
100 |
|
|
|
101 |
Graph Attributes |
|
|
102 |
---------------- |
|
|
103 |
|
|
|
104 |
A common practice in graph representation learning is to add some graph features as |
|
|
105 |
the input of neural networks. Typically, there are three types of features, node-level, |
|
|
106 |
edge-level and graph-level features. In TorchDrug, these features are stored as |
|
|
107 |
node/edge/graph attributes in the data structure, and are automatically processed |
|
|
108 |
during any graph operation. |
|
|
109 |
|
|
|
110 |
Here we specify some features during the construction of the molecule graph. |
|
|
111 |
|
|
|
112 |
.. code:: python |
|
|
113 |
|
|
|
114 |
mol = data.Molecule.from_smiles("C1=CC=CC=C1", atom_feature="default", |
|
|
115 |
bond_feature="default", mol_feature="ecfp") |
|
|
116 |
print(mol.node_feature.shape) |
|
|
117 |
print(mol.edge_feature.shape) |
|
|
118 |
print(mol.graph_feature.shape) |
|
|
119 |
|
|
|
120 |
There are a bunch of popular feature functions provided in :mod:`torchdrug.data.feature`. |
|
|
121 |
We may also want to define our own attributes. This only requires to wrap the |
|
|
122 |
assignment lines with a context manager. The following example defines edge importance |
|
|
123 |
as the reciprocal of node degrees. |
|
|
124 |
|
|
|
125 |
.. code:: python |
|
|
126 |
|
|
|
127 |
node_in, node_out = mol.edge_list.t()[:2] |
|
|
128 |
with mol.edge(): |
|
|
129 |
mol.edge_importance = 1 / graph.degree_in[node_in] + 1 / graph.degree_out[node_out] |
|
|
130 |
|
|
|
131 |
We can use ``mol.node()`` and ``mol.graph()`` for node- and graph-level attributes |
|
|
132 |
respectively. Attributes may also be a reference to node/edge/graph indexes. See |
|
|
133 |
:doc:`reference` for more details. |
|
|
134 |
|
|
|
135 |
Note in order to support batching and masking, attributes should always have the same |
|
|
136 |
length as their corresponding components. This means the size of the first dimension of |
|
|
137 |
the tensor should be either ``num_node``, ``num_edge`` or ``1``. |
|
|
138 |
|
|
|
139 |
Batch Graph |
|
|
140 |
----------- |
|
|
141 |
|
|
|
142 |
Modern deep learning frameworks employs batched operations to accelerate computation. |
|
|
143 |
In TorchDrug, we can easily batch same kind of graphs with **arbitary sizes**. Here |
|
|
144 |
is an example of creating a batch of 4 graphs. |
|
|
145 |
|
|
|
146 |
.. code:: python |
|
|
147 |
|
|
|
148 |
graphs = [graph, graph, graph, graph] |
|
|
149 |
batch = data.Graph.pack(graphs) |
|
|
150 |
batch.visualize(num_row=1) |
|
|
151 |
|
|
|
152 |
.. image:: ../../../asset/graph/batch.png |
|
|
153 |
|
|
|
154 |
This returns a :class:`PackedGraph <torchdrug.data.PackedGraph>` instance with |
|
|
155 |
all attributes automatically batched. The essential trick behind this operation is |
|
|
156 |
based on a property of graphs. A batch of :math:`n` graphs is equivalent to a large |
|
|
157 |
graph with :math:`n` connected components. The equivalent adjacency matrix for a |
|
|
158 |
batch is |
|
|
159 |
|
|
|
160 |
.. math:: |
|
|
161 |
|
|
|
162 |
A = |
|
|
163 |
\begin{bmatrix} |
|
|
164 |
A_1 & \cdots & 0 \\ |
|
|
165 |
\vdots & \ddots & \vdots \\ |
|
|
166 |
0 & \cdots & A_n |
|
|
167 |
\end{bmatrix} |
|
|
168 |
|
|
|
169 |
where :math:`A_i` is the adjacency of :math:`i`-th graph. |
|
|
170 |
|
|
|
171 |
To get a single graph from the batch, use the conventional index or |
|
|
172 |
:meth:`PackedGraph.unpack <torchdrug.data.PackedGraph.unpack>`. |
|
|
173 |
|
|
|
174 |
.. code:: python |
|
|
175 |
|
|
|
176 |
graph = batch[1] |
|
|
177 |
graphs = batch.unpack() |
|
|
178 |
|
|
|
179 |
One advantage of such batching mechanism is that it does not distinguish a single |
|
|
180 |
graph and a batch. In other words, we only need to implement single graph operations, |
|
|
181 |
and they can be directly applied as batched operations. This reduces the pain of |
|
|
182 |
writing batched operations. |
|
|
183 |
|
|
|
184 |
Subgraph and Masking |
|
|
185 |
-------------------- |
|
|
186 |
|
|
|
187 |
The graph data structure also provides a bunch of slicing operations to create subgraphs |
|
|
188 |
or masked graphs in a sparse manner. Some typical operations include |
|
|
189 |
|
|
|
190 |
.. code:: python |
|
|
191 |
|
|
|
192 |
g1 = graph.subgraph([1, 2, 3, 4]) |
|
|
193 |
g1.visualize() |
|
|
194 |
|
|
|
195 |
g2 = graph.node_mask([1, 2, 3, 4]) |
|
|
196 |
g2.visualize() |
|
|
197 |
|
|
|
198 |
g3 = graph.edge_mask([0, 1, 5]) |
|
|
199 |
g3.visualize() |
|
|
200 |
|
|
|
201 |
g4 = g3.compact() |
|
|
202 |
g4.visualize() |
|
|
203 |
|
|
|
204 |
.. image:: ../../../asset/graph/subgraph.png |
|
|
205 |
:width: 24% |
|
|
206 |
.. image:: ../../../asset/graph/node_mask.png |
|
|
207 |
:width: 24% |
|
|
208 |
.. image:: ../../../asset/graph/edge_mask.png |
|
|
209 |
:width: 24% |
|
|
210 |
.. image:: ../../../asset/graph/compact.png |
|
|
211 |
:width: 24% |
|
|
212 |
|
|
|
213 |
All the above operations accept either integer node indexes or binary node masks. |
|
|
214 |
:meth:`subgraph() <torchdrug.data.Graph.subgraph>` extracts a subgraph based on |
|
|
215 |
the given nodes. The node ids are re-mapped to produce a compact index. |
|
|
216 |
:meth:`node_mask() <torchdrug.data.Graph.node_mask>` keeps edges among the given |
|
|
217 |
nodes. :meth:`edge_mask() <torchdrug.data.Graph.edge_mask>` keeps edges of the |
|
|
218 |
given edge indexes. :meth:`compact() <torchdrug.data.Graph.compact>` removes all |
|
|
219 |
isolated nodes. |
|
|
220 |
|
|
|
221 |
The same operations can also be applied to batches. In this case, we need to convert |
|
|
222 |
the index of a single graph into the index in a batch. |
|
|
223 |
|
|
|
224 |
.. code:: python |
|
|
225 |
|
|
|
226 |
graph_ids = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) |
|
|
227 |
node_ids = torch.tensor([1, 2, 3, 4, 0, 1, 2, 3, 4, 5]) |
|
|
228 |
node_ids += batch.num_cum_nodes[graph_ids] - batch.num_nodes[graph_ids] |
|
|
229 |
batch = batch.node_mask(node_ids) |
|
|
230 |
batch.visualize(num_row=1) |
|
|
231 |
|
|
|
232 |
.. image:: ../../../asset/graph/batch_node_mask.png |
|
|
233 |
|
|
|
234 |
We can also pick a subset of graphs in a batch. |
|
|
235 |
|
|
|
236 |
.. code:: python |
|
|
237 |
|
|
|
238 |
batch = batch[[0, 1]] |
|
|
239 |
batch.visualize() |
|
|
240 |
|
|
|
241 |
.. image:: ../../../asset/graph/subbatch.png |
|
|
242 |
:align: center |
|
|
243 |
:width: 66% |