Switch to side-by-side view

--- a
+++ b/doc/source/notes/layer.rst
@@ -0,0 +1,123 @@
+Graph Neural Network Layers
+===========================
+
+Modern graph neural networks encode graph structures with message passing layers
+and readout layers. In some cases, graph-to-node broadcast may also be needed. All
+these operations can be easily implemented with TorchDrug.
+
++------------------------+-----------------------+-------------------------+
+| |fig. message passing| | |fig. readout|        | |fig. broadcast|        |
+| Message passing        | Node-to-graph readout | Graph-to-node broadcast |
++------------------------+-----------------------+-------------------------+
+
+.. |fig. message passing| image:: ../../../asset/graph/message_passing.png
+.. |fig. readout| image:: ../../../asset/graph/readout.png
+.. |fig. broadcast| image:: ../../../asset/graph/broadcast.png
+
+Message Passing Layers
+----------------------
+
+A message passing layer can be described as 3 steps, a message generation step, an
+aggregation step and a combination step. The :math:`t`-th message passing layer
+performs the following computation
+
+.. math::
+
+    m_{i,j}^{(t+1)} &= Message^{(t)}(h_i^{(t)}, h_j^{(t)}) \\
+    u_i^{(t+1)} &= Aggregate^{(t)}(\{m_{i,j}^{(t+1)} \mid j \in N(i)\}) \\
+    h_i^{(t+1)} &= Combine^{(t)}(h_i^{(t)}, u_i^{(t+1)})
+
+where :math:`h_i^{(t)}` denotes node representations, :math:`m_{i,j}^{(t)}` denotes
+messages from node :math:`j` to node :math:`i` and :math:`u_i^{(t)}` is the
+aggregated messages, i.e., updates.
+
+In TorchDrug, these steps are abstracted as three methods, namely
+:meth:`message(graph, input) <>`, :meth:`aggregate(graph, message) <>` and
+:meth:`combine(input, update) <>`.
+
+Here we show an example of a custom message passing for `PageRank`_ algorithm.
+Representing the PageRank value as :math:`h_i^{(t)}`, one PageRank iteration is
+equivalent to the following functions.
+
+.. _PageRank: https://en.wikipedia.org/wiki/PageRank
+
+.. math::
+
+    Message^{(t)}(h_i^{(t)}, h_j^{(t)}) &= \frac{h_j^{(t)}}{degree\_in_j} \\
+    Aggregate^{(t)}(\{m_{i,j} \mid j \in N(i)\}) &= \sum_{j \in N(i)} m_{i,j} \\
+    Combine^{(t)}(h_i^{(t)}, u_i^{(t+1)}) &= u_i^{(t+1)}
+
+We use the convention that :math:`degree\_in_j` represents the degree of node
+:math:`j` as the source node of any edge. The corresponding implementation is
+
+.. code:: python
+
+    from torch_scatter import scatter_add
+    from torchdrug import layers
+
+    class PageRankIteration(layers.MessagePassingBase):
+
+        def message(self, graph, input):
+            node_in = graph.edge_list[:, 0]
+            message = input[node_in] / graph.degree_in[node_in].unsqueeze(-1)
+            return message
+
+        def aggregate(self, graph, message):
+            node_out = graph.edge_list[:, 1]
+            update = scatter_add(node_out, message, dim=0, dim_size=graph.num_node)
+            return update
+
+        def combine(self, input, update):
+            output = update
+            return output
+
+Let's elaborate the functions one by one. In :meth:`message`, we pick the source
+nodes of all edges, and compute the messages by dividing the source nodes' hidden
+states with their source degrees.
+
+In :meth:`aggregate`, we collect the messages by their target nodes. This is
+implemented by :func:`scatter_add` operation from `PyTorch Scatter`_. We specify
+:attr:`dim_size` to be ``graph.num_node``, since there might be isolated nodes in
+the graph and :func:`scatter_add` cannot figure it out from ``node_in``.
+
+The :meth:`combine` function trivially returns node updates as new node hidden
+states.
+
+.. _PyTorch Scatter:
+    https://pytorch-scatter.readthedocs.io
+
+Readout and Broadcast Layers
+----------------------------
+
+A readout layer collects all node representations in a graph to form a graph
+representation. Reversely, a broadcast layer sends the graph representation to every
+node in the graph. For a batch of graphs, these operations can be viewed as message
+passing on a bipartite graph -- one side are original nodes, and the other side are
+"graph" nodes.
+
+TorchDrug provides effcient primitives to support this kind of message passing.
+Specifically, :attr:`node2graph <torchdrug.data.PackedGraph.node2graph>` maps
+node IDs to graph IDs, and :attr:`edge2graph <torchdrug.data.PackedGraph.edge2graph>`
+maps edge IDs to graph IDs.
+
+In this example, we will use the above primitives to compute the variance of node
+representations as a graph representation. First, we readout the mean of node
+representations. Second, we broadcast the mean representation to each node to compute
+the difference. Finally, we readout the mean of the squared difference as the variance.
+
+.. code:: python
+
+    from torch import nn
+    from torch_scatter import scatter_mean
+
+    class Variance(nn.Module):
+
+        def forward(self, graph, input):
+            mean = scatter_mean(input, graph.node2graph, dim=0, dim_size=graph.batch_size)
+            diff = input - mean[graph.node2graph]
+            var = scatter_mean(diff * diff, graph.node2graph, dim=0, dim_size=graph.batch_size)
+            return var
+
+Notice that :attr:`node2graph <torchdrug.data.PackedGraph.node2graph>` is used
+for both readout and broadcast. When used in a scatter function, it serves as
+readout. When used in a conventional indexing, it is equivalent to broadcast.
\ No newline at end of file