a b/doc/source/notes/layer.rst
1
Graph Neural Network Layers
2
===========================
3
4
Modern graph neural networks encode graph structures with message passing layers
5
and readout layers. In some cases, graph-to-node broadcast may also be needed. All
6
these operations can be easily implemented with TorchDrug.
7
8
+------------------------+-----------------------+-------------------------+
9
| |fig. message passing| | |fig. readout|        | |fig. broadcast|        |
10
| Message passing        | Node-to-graph readout | Graph-to-node broadcast |
11
+------------------------+-----------------------+-------------------------+
12
13
.. |fig. message passing| image:: ../../../asset/graph/message_passing.png
14
.. |fig. readout| image:: ../../../asset/graph/readout.png
15
.. |fig. broadcast| image:: ../../../asset/graph/broadcast.png
16
17
Message Passing Layers
18
----------------------
19
20
A message passing layer can be described as 3 steps, a message generation step, an
21
aggregation step and a combination step. The :math:`t`-th message passing layer
22
performs the following computation
23
24
.. math::
25
26
    m_{i,j}^{(t+1)} &= Message^{(t)}(h_i^{(t)}, h_j^{(t)}) \\
27
    u_i^{(t+1)} &= Aggregate^{(t)}(\{m_{i,j}^{(t+1)} \mid j \in N(i)\}) \\
28
    h_i^{(t+1)} &= Combine^{(t)}(h_i^{(t)}, u_i^{(t+1)})
29
30
where :math:`h_i^{(t)}` denotes node representations, :math:`m_{i,j}^{(t)}` denotes
31
messages from node :math:`j` to node :math:`i` and :math:`u_i^{(t)}` is the
32
aggregated messages, i.e., updates.
33
34
In TorchDrug, these steps are abstracted as three methods, namely
35
:meth:`message(graph, input) <>`, :meth:`aggregate(graph, message) <>` and
36
:meth:`combine(input, update) <>`.
37
38
Here we show an example of a custom message passing for `PageRank`_ algorithm.
39
Representing the PageRank value as :math:`h_i^{(t)}`, one PageRank iteration is
40
equivalent to the following functions.
41
42
.. _PageRank: https://en.wikipedia.org/wiki/PageRank
43
44
.. math::
45
46
    Message^{(t)}(h_i^{(t)}, h_j^{(t)}) &= \frac{h_j^{(t)}}{degree\_in_j} \\
47
    Aggregate^{(t)}(\{m_{i,j} \mid j \in N(i)\}) &= \sum_{j \in N(i)} m_{i,j} \\
48
    Combine^{(t)}(h_i^{(t)}, u_i^{(t+1)}) &= u_i^{(t+1)}
49
50
We use the convention that :math:`degree\_in_j` represents the degree of node
51
:math:`j` as the source node of any edge. The corresponding implementation is
52
53
.. code:: python
54
55
    from torch_scatter import scatter_add
56
    from torchdrug import layers
57
58
    class PageRankIteration(layers.MessagePassingBase):
59
60
        def message(self, graph, input):
61
            node_in = graph.edge_list[:, 0]
62
            message = input[node_in] / graph.degree_in[node_in].unsqueeze(-1)
63
            return message
64
65
        def aggregate(self, graph, message):
66
            node_out = graph.edge_list[:, 1]
67
            update = scatter_add(node_out, message, dim=0, dim_size=graph.num_node)
68
            return update
69
70
        def combine(self, input, update):
71
            output = update
72
            return output
73
74
Let's elaborate the functions one by one. In :meth:`message`, we pick the source
75
nodes of all edges, and compute the messages by dividing the source nodes' hidden
76
states with their source degrees.
77
78
In :meth:`aggregate`, we collect the messages by their target nodes. This is
79
implemented by :func:`scatter_add` operation from `PyTorch Scatter`_. We specify
80
:attr:`dim_size` to be ``graph.num_node``, since there might be isolated nodes in
81
the graph and :func:`scatter_add` cannot figure it out from ``node_in``.
82
83
The :meth:`combine` function trivially returns node updates as new node hidden
84
states.
85
86
.. _PyTorch Scatter:
87
    https://pytorch-scatter.readthedocs.io
88
89
Readout and Broadcast Layers
90
----------------------------
91
92
A readout layer collects all node representations in a graph to form a graph
93
representation. Reversely, a broadcast layer sends the graph representation to every
94
node in the graph. For a batch of graphs, these operations can be viewed as message
95
passing on a bipartite graph -- one side are original nodes, and the other side are
96
"graph" nodes.
97
98
TorchDrug provides effcient primitives to support this kind of message passing.
99
Specifically, :attr:`node2graph <torchdrug.data.PackedGraph.node2graph>` maps
100
node IDs to graph IDs, and :attr:`edge2graph <torchdrug.data.PackedGraph.edge2graph>`
101
maps edge IDs to graph IDs.
102
103
In this example, we will use the above primitives to compute the variance of node
104
representations as a graph representation. First, we readout the mean of node
105
representations. Second, we broadcast the mean representation to each node to compute
106
the difference. Finally, we readout the mean of the squared difference as the variance.
107
108
.. code:: python
109
110
    from torch import nn
111
    from torch_scatter import scatter_mean
112
113
    class Variance(nn.Module):
114
115
        def forward(self, graph, input):
116
            mean = scatter_mean(input, graph.node2graph, dim=0, dim_size=graph.batch_size)
117
            diff = input - mean[graph.node2graph]
118
            var = scatter_mean(diff * diff, graph.node2graph, dim=0, dim_size=graph.batch_size)
119
            return var
120
121
Notice that :attr:`node2graph <torchdrug.data.PackedGraph.node2graph>` is used
122
for both readout and broadcast. When used in a scatter function, it serves as
123
readout. When used in a conventional indexing, it is equivalent to broadcast.