Switch to unified view

a b/doc/source/notes/graph.rst
1
Graph Data Structures
2
=====================
3
4
At the core of TorchDrug, we provides several data structures to enable common
5
operations in graph representation learning.
6
7
Create a Graph
8
--------------
9
10
To begin with, let's create a graph.
11
12
.. code:: python
13
14
    import torch
15
    from torchdrug import data
16
17
    edge_list = [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 0]]
18
    graph = data.Graph(edge_list, num_node=6)
19
    graph.visualize()
20
21
This will plot a ring graph like the following.
22
23
.. image:: ../../../asset/graph/graph.png
24
    :align: center
25
    :width: 33%
26
27
Internally, the graph is stored as a sparse edge list to save memory footprint. For
28
an intuitive comparison, a `scale-free graph`_ may have 1 million nodes and 10 million
29
edges. The dense version takes about 4TB, while the sparse version only requires 120MB.
30
31
.. _scale-free graph:
32
    https://en.wikipedia.org/wiki/Scale-free_network
33
34
Here are some commonly used properties of the graph.
35
36
.. code:: python
37
38
    print(graph.num_node)
39
    print(graph.num_edge)
40
    print(graph.edge_list)
41
    print(graph.edge_weight)
42
43
In some scenarios, the graph may also have type information on its edges. For example,
44
molecules have bond types like ``single bound``, while knowledge graphs have relations
45
like ``consists of``. To construct such a relational graph, we can pass the edge type
46
as a third variable in the edge list.
47
48
.. code:: python
49
50
    triplet_list = [[0, 1, 0], [1, 2, 1], [2, 3, 0], [3, 4, 1], [4, 5, 0], [5, 0, 1]]
51
    graph = data.Graph(triplet_list, num_node=6, num_relation=2)
52
    graph.visualize()
53
54
.. image:: ../../../asset/graph/relational_graph.png
55
    :align: center
56
    :width: 33%
57
58
Alternatively, we can also use adjacency matrices to create the above graphs.
59
60
The normal graph uses a 2D adjacency matrix :math:`A`, where non-zero :math:`A_{i,j}`
61
corresponds to an edge from node :math:`i` to node :math:`j`. The relational graph
62
uses a 3D adjacency matrix :math:`A`, where non-zero :math:`A_{i,j,k}` denotes an
63
edge from node :math:`i` to node :math:`j` with edge type :math:`k`.
64
65
.. code:: python
66
67
    adjacency = torch.zeros(6, 6)
68
    adjacency[edge_list] = 1
69
    graph = data.Graph.from_dense(adjacency)
70
71
    adjacency = torch.zeros(6, 6, 2)
72
    adjacency[triplet_list] = 1
73
    graph = data.Graph.from_dense(adjacency)
74
75
For molecule graphs, TorchDrug supports creating instances from `SMILES`_ strings.
76
For example, the following code creates a benzene molecule.
77
78
.. _SMILES:
79
    https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system
80
81
.. code:: python
82
83
    mol = data.Molecule.from_smiles("C1=CC=CC=C1")
84
    mol.visualize()
85
86
.. image:: ../../../asset/graph/benzene.png
87
    :align: center
88
    :width: 33%
89
90
Once the graph is created, we can transfer it between CPU and GPUs, just like
91
:class:`torch.Tensor`.
92
93
.. code:: python
94
95
    graph = graph.cuda()
96
    print(graph.device)
97
98
    graph = graph.cpu()
99
    print(graph.device)
100
101
Graph Attributes
102
----------------
103
104
A common practice in graph representation learning is to add some graph features as
105
the input of neural networks. Typically, there are three types of features, node-level,
106
edge-level and graph-level features. In TorchDrug, these features are stored as
107
node/edge/graph attributes in the data structure, and are automatically processed
108
during any graph operation.
109
110
Here we specify some features during the construction of the molecule graph.
111
112
.. code:: python
113
114
    mol = data.Molecule.from_smiles("C1=CC=CC=C1", atom_feature="default",
115
                                    bond_feature="default", mol_feature="ecfp")
116
    print(mol.node_feature.shape)
117
    print(mol.edge_feature.shape)
118
    print(mol.graph_feature.shape)
119
120
There are a bunch of popular feature functions provided in :mod:`torchdrug.data.feature`.
121
We may also want to define our own attributes. This only requires to wrap the
122
assignment lines with a context manager. The following example defines edge importance
123
as the reciprocal of node degrees.
124
125
.. code:: python
126
127
    node_in, node_out = mol.edge_list.t()[:2]
128
    with mol.edge():
129
        mol.edge_importance = 1 / graph.degree_in[node_in] + 1 / graph.degree_out[node_out]
130
131
We can use ``mol.node()`` and ``mol.graph()`` for node- and graph-level attributes
132
respectively. Attributes may also be a reference to node/edge/graph indexes. See
133
:doc:`reference` for more details.
134
135
Note in order to support batching and masking, attributes should always have the same
136
length as their corresponding components. This means the size of the first dimension of
137
the tensor should be either ``num_node``, ``num_edge`` or ``1``.
138
139
Batch Graph
140
-----------
141
142
Modern deep learning frameworks employs batched operations to accelerate computation.
143
In TorchDrug, we can easily batch same kind of graphs with **arbitary sizes**. Here
144
is an example of creating a batch of 4 graphs.
145
146
.. code:: python
147
148
    graphs = [graph, graph, graph, graph]
149
    batch = data.Graph.pack(graphs)
150
    batch.visualize(num_row=1)
151
152
.. image:: ../../../asset/graph/batch.png
153
154
This returns a :class:`PackedGraph <torchdrug.data.PackedGraph>` instance with
155
all attributes automatically batched. The essential trick behind this operation is
156
based on a property of graphs. A batch of :math:`n` graphs is equivalent to a large
157
graph with :math:`n` connected components. The equivalent adjacency matrix for a
158
batch is
159
160
.. math::
161
162
    A =
163
    \begin{bmatrix}
164
        A_1    & \cdots & 0      \\
165
        \vdots & \ddots & \vdots \\
166
        0      & \cdots & A_n
167
    \end{bmatrix}
168
169
where :math:`A_i` is the adjacency of :math:`i`-th graph.
170
171
To get a single graph from the batch, use the conventional index or
172
:meth:`PackedGraph.unpack <torchdrug.data.PackedGraph.unpack>`.
173
174
.. code:: python
175
176
    graph = batch[1]
177
    graphs = batch.unpack()
178
179
One advantage of such batching mechanism is that it does not distinguish a single
180
graph and a batch. In other words, we only need to implement single graph operations,
181
and they can be directly applied as batched operations. This reduces the pain of
182
writing batched operations.
183
184
Subgraph and Masking
185
--------------------
186
187
The graph data structure also provides a bunch of slicing operations to create subgraphs
188
or masked graphs in a sparse manner. Some typical operations include
189
190
.. code:: python
191
192
    g1 = graph.subgraph([1, 2, 3, 4])
193
    g1.visualize()
194
195
    g2 = graph.node_mask([1, 2, 3, 4])
196
    g2.visualize()
197
198
    g3 = graph.edge_mask([0, 1, 5])
199
    g3.visualize()
200
201
    g4 = g3.compact()
202
    g4.visualize()
203
204
.. image:: ../../../asset/graph/subgraph.png
205
    :width: 24%
206
.. image:: ../../../asset/graph/node_mask.png
207
    :width: 24%
208
.. image:: ../../../asset/graph/edge_mask.png
209
    :width: 24%
210
.. image:: ../../../asset/graph/compact.png
211
    :width: 24%
212
213
All the above operations accept either integer node indexes or binary node masks.
214
:meth:`subgraph() <torchdrug.data.Graph.subgraph>` extracts a subgraph based on
215
the given nodes. The node ids are re-mapped to produce a compact index.
216
:meth:`node_mask() <torchdrug.data.Graph.node_mask>` keeps edges among the given
217
nodes. :meth:`edge_mask() <torchdrug.data.Graph.edge_mask>` keeps edges of the
218
given edge indexes. :meth:`compact() <torchdrug.data.Graph.compact>` removes all
219
isolated nodes.
220
221
The same operations can also be applied to batches. In this case, we need to convert
222
the index of a single graph into the index in a batch.
223
224
.. code:: python
225
226
    graph_ids = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
227
    node_ids = torch.tensor([1, 2, 3, 4, 0, 1, 2, 3, 4, 5])
228
    node_ids += batch.num_cum_nodes[graph_ids] - batch.num_nodes[graph_ids]
229
    batch = batch.node_mask(node_ids)
230
    batch.visualize(num_row=1)
231
232
.. image:: ../../../asset/graph/batch_node_mask.png
233
234
We can also pick a subset of graphs in a batch.
235
236
.. code:: python
237
238
    batch = batch[[0, 1]]
239
    batch.visualize()
240
241
.. image:: ../../../asset/graph/subbatch.png
242
    :align: center
243
    :width: 66%