|
a |
|
b/doc/source/notes/layer.rst |
|
|
1 |
Graph Neural Network Layers |
|
|
2 |
=========================== |
|
|
3 |
|
|
|
4 |
Modern graph neural networks encode graph structures with message passing layers |
|
|
5 |
and readout layers. In some cases, graph-to-node broadcast may also be needed. All |
|
|
6 |
these operations can be easily implemented with TorchDrug. |
|
|
7 |
|
|
|
8 |
+------------------------+-----------------------+-------------------------+ |
|
|
9 |
| |fig. message passing| | |fig. readout| | |fig. broadcast| | |
|
|
10 |
| Message passing | Node-to-graph readout | Graph-to-node broadcast | |
|
|
11 |
+------------------------+-----------------------+-------------------------+ |
|
|
12 |
|
|
|
13 |
.. |fig. message passing| image:: ../../../asset/graph/message_passing.png |
|
|
14 |
.. |fig. readout| image:: ../../../asset/graph/readout.png |
|
|
15 |
.. |fig. broadcast| image:: ../../../asset/graph/broadcast.png |
|
|
16 |
|
|
|
17 |
Message Passing Layers |
|
|
18 |
---------------------- |
|
|
19 |
|
|
|
20 |
A message passing layer can be described as 3 steps, a message generation step, an |
|
|
21 |
aggregation step and a combination step. The :math:`t`-th message passing layer |
|
|
22 |
performs the following computation |
|
|
23 |
|
|
|
24 |
.. math:: |
|
|
25 |
|
|
|
26 |
m_{i,j}^{(t+1)} &= Message^{(t)}(h_i^{(t)}, h_j^{(t)}) \\ |
|
|
27 |
u_i^{(t+1)} &= Aggregate^{(t)}(\{m_{i,j}^{(t+1)} \mid j \in N(i)\}) \\ |
|
|
28 |
h_i^{(t+1)} &= Combine^{(t)}(h_i^{(t)}, u_i^{(t+1)}) |
|
|
29 |
|
|
|
30 |
where :math:`h_i^{(t)}` denotes node representations, :math:`m_{i,j}^{(t)}` denotes |
|
|
31 |
messages from node :math:`j` to node :math:`i` and :math:`u_i^{(t)}` is the |
|
|
32 |
aggregated messages, i.e., updates. |
|
|
33 |
|
|
|
34 |
In TorchDrug, these steps are abstracted as three methods, namely |
|
|
35 |
:meth:`message(graph, input) <>`, :meth:`aggregate(graph, message) <>` and |
|
|
36 |
:meth:`combine(input, update) <>`. |
|
|
37 |
|
|
|
38 |
Here we show an example of a custom message passing for `PageRank`_ algorithm. |
|
|
39 |
Representing the PageRank value as :math:`h_i^{(t)}`, one PageRank iteration is |
|
|
40 |
equivalent to the following functions. |
|
|
41 |
|
|
|
42 |
.. _PageRank: https://en.wikipedia.org/wiki/PageRank |
|
|
43 |
|
|
|
44 |
.. math:: |
|
|
45 |
|
|
|
46 |
Message^{(t)}(h_i^{(t)}, h_j^{(t)}) &= \frac{h_j^{(t)}}{degree\_in_j} \\ |
|
|
47 |
Aggregate^{(t)}(\{m_{i,j} \mid j \in N(i)\}) &= \sum_{j \in N(i)} m_{i,j} \\ |
|
|
48 |
Combine^{(t)}(h_i^{(t)}, u_i^{(t+1)}) &= u_i^{(t+1)} |
|
|
49 |
|
|
|
50 |
We use the convention that :math:`degree\_in_j` represents the degree of node |
|
|
51 |
:math:`j` as the source node of any edge. The corresponding implementation is |
|
|
52 |
|
|
|
53 |
.. code:: python |
|
|
54 |
|
|
|
55 |
from torch_scatter import scatter_add |
|
|
56 |
from torchdrug import layers |
|
|
57 |
|
|
|
58 |
class PageRankIteration(layers.MessagePassingBase): |
|
|
59 |
|
|
|
60 |
def message(self, graph, input): |
|
|
61 |
node_in = graph.edge_list[:, 0] |
|
|
62 |
message = input[node_in] / graph.degree_in[node_in].unsqueeze(-1) |
|
|
63 |
return message |
|
|
64 |
|
|
|
65 |
def aggregate(self, graph, message): |
|
|
66 |
node_out = graph.edge_list[:, 1] |
|
|
67 |
update = scatter_add(node_out, message, dim=0, dim_size=graph.num_node) |
|
|
68 |
return update |
|
|
69 |
|
|
|
70 |
def combine(self, input, update): |
|
|
71 |
output = update |
|
|
72 |
return output |
|
|
73 |
|
|
|
74 |
Let's elaborate the functions one by one. In :meth:`message`, we pick the source |
|
|
75 |
nodes of all edges, and compute the messages by dividing the source nodes' hidden |
|
|
76 |
states with their source degrees. |
|
|
77 |
|
|
|
78 |
In :meth:`aggregate`, we collect the messages by their target nodes. This is |
|
|
79 |
implemented by :func:`scatter_add` operation from `PyTorch Scatter`_. We specify |
|
|
80 |
:attr:`dim_size` to be ``graph.num_node``, since there might be isolated nodes in |
|
|
81 |
the graph and :func:`scatter_add` cannot figure it out from ``node_in``. |
|
|
82 |
|
|
|
83 |
The :meth:`combine` function trivially returns node updates as new node hidden |
|
|
84 |
states. |
|
|
85 |
|
|
|
86 |
.. _PyTorch Scatter: |
|
|
87 |
https://pytorch-scatter.readthedocs.io |
|
|
88 |
|
|
|
89 |
Readout and Broadcast Layers |
|
|
90 |
---------------------------- |
|
|
91 |
|
|
|
92 |
A readout layer collects all node representations in a graph to form a graph |
|
|
93 |
representation. Reversely, a broadcast layer sends the graph representation to every |
|
|
94 |
node in the graph. For a batch of graphs, these operations can be viewed as message |
|
|
95 |
passing on a bipartite graph -- one side are original nodes, and the other side are |
|
|
96 |
"graph" nodes. |
|
|
97 |
|
|
|
98 |
TorchDrug provides effcient primitives to support this kind of message passing. |
|
|
99 |
Specifically, :attr:`node2graph <torchdrug.data.PackedGraph.node2graph>` maps |
|
|
100 |
node IDs to graph IDs, and :attr:`edge2graph <torchdrug.data.PackedGraph.edge2graph>` |
|
|
101 |
maps edge IDs to graph IDs. |
|
|
102 |
|
|
|
103 |
In this example, we will use the above primitives to compute the variance of node |
|
|
104 |
representations as a graph representation. First, we readout the mean of node |
|
|
105 |
representations. Second, we broadcast the mean representation to each node to compute |
|
|
106 |
the difference. Finally, we readout the mean of the squared difference as the variance. |
|
|
107 |
|
|
|
108 |
.. code:: python |
|
|
109 |
|
|
|
110 |
from torch import nn |
|
|
111 |
from torch_scatter import scatter_mean |
|
|
112 |
|
|
|
113 |
class Variance(nn.Module): |
|
|
114 |
|
|
|
115 |
def forward(self, graph, input): |
|
|
116 |
mean = scatter_mean(input, graph.node2graph, dim=0, dim_size=graph.batch_size) |
|
|
117 |
diff = input - mean[graph.node2graph] |
|
|
118 |
var = scatter_mean(diff * diff, graph.node2graph, dim=0, dim_size=graph.batch_size) |
|
|
119 |
return var |
|
|
120 |
|
|
|
121 |
Notice that :attr:`node2graph <torchdrug.data.PackedGraph.node2graph>` is used |
|
|
122 |
for both readout and broadcast. When used in a scatter function, it serves as |
|
|
123 |
readout. When used in a conventional indexing, it is equivalent to broadcast. |