Switch to side-by-side view

--- a
+++ b/doc/source/notes/variadic.rst
@@ -0,0 +1,200 @@
+Batch Irregular Structures
+==========================
+
+Unlike images, text and audio, graphs usually have irregular structures, which
+makes them hard to batch in tensor frameworks. Many existing implementations use
+padding to convert graphs into dense grid structures, which costs much unnecessary
+computation and memory.
+
+In TorchDrug, we develop a more intuitive and efficient solution based on
+variadic functions. The variadic functions can directly operate on sparse irregular
+inputs or outputs.
+
+Variadic Input
+--------------
+
+Here we show how to apply functions to variadic inputs.
+
+Generally, a batch of :math:`n` variadic tensors can be represented by a value
+tensor and a size tensor. The value tensor is a concatenation of all variadic
+tensors along the variadic axis, while the size tensor indicates how big each
+variadic tensor is.
+
+Let's first create a batch of 1D variadic samples.
+
+.. code:: python
+
+    import torch
+
+    samples = []
+    for size in range(2, 6):
+        samples.append(torch.randint(6, (size,)))
+    value = torch.cat(samples)
+    size = torch.tensor([len(s) for s in samples])
+
+.. image:: ../../../asset/tensor/variadic_tensor.png
+    :align: center
+    :width: 60%
+
+We apply variadic functions to compute the sum, max and top-k values for each
+sample.
+
+.. code:: python
+
+    from torchdrug.layers import functional
+
+    sum = functional.variadic_sum(value, size)
+    max = functional.variadic_max(value, size)[0]
+    top3_value, top3_index = functional.variadic_topk(value, size, k=3)
+
+Note :meth:`variadic_topk <torchdrug.layers.functional.variadic_topk>` accepts
+samples smaller than :math:`k`. In this case, it will fill the output with the
+smallest element from that sample.
+
+.. image:: ../../../asset/tensor/variadic_func_result.png
+    :align: center
+    :width: 88%
+
+Mathematically, these functions can be viewed as performing the operation over
+each sample with a for loop. For example, the variadic sum is equivalent to the
+following logic.
+
+.. code::
+
+    sums = []
+    for sample in samples:
+        sums.append(sample.sum())
+    sum = torch.cat(sums)
+
+.. note::
+
+    In spite of the same logic, variadic functions is much faster than for loops
+    on GPUs (typically :math:`\text{batch size}\times` faster). Use variadic functions
+    instead of for loops whenever possible.
+
+Many operations in graph representation learning can be implemented by variadic
+functions. For example,
+
+1. Infer graph-level representations from node-/edge-level representations.
+2. Perform classification over nodes/edges.
+
+Here we demonstrate how to perform classification over nodes. We create a toy
+task, where the model needs to predict the heaviest atom of each molecule. Note
+that node attributes form variadic tensors with ``num_nodes`` from the same graph.
+Therefore, we can use :meth:`variadic_max <torchdrug.layers.functional.variadic_max>`
+to get our ground truth.
+
+.. code:: python
+
+    from torchdrug import data, models, metrics
+
+    smiles_list = ["CC(=C)C#N", "CCNC(=S)NCC", "BrC1=CC=C(Br)C=C1"]
+    graph = data.PackedMolecule.from_smiles(smiles_list)
+    target = functional.variadic_max(graph.atom_type, graph.num_nodes)[1]
+
+Naturally, the prediction over nodes also forms a variadic tensor with ``num_nodes``.
+
+.. code:: python
+
+    model = models.GCN(input_dim=graph.node_feature.shape[-1], hidden_dims=[128, 128, 1])
+    feature = model(graph, graph.node_feature.float())
+    pred = feature["node_feature"].squeeze(-1)
+
+    pred_prob, pred_index = functional.variadic_max(pred, graph.num_nodes)
+    loss = functional.variadic_cross_entropy(pred, target, graph.num_nodes)
+    accuracy = metrics.variadic_accuracy(pred, target, graph.num_nodes)
+
+.. seealso::
+    :func:`variadic_sum <torchdrug.layers.functional.variadic_sum>`,
+    :func:`variadic_mean <torchdrug.layers.functional.variadic_mean>`,
+    :func:`variadic_max <torchdrug.layers.functional.variadic_max>`,
+    :func:`variadic_arange <torchdrug.layers.functional.variadic_arange>`,
+    :func:`variadic_sort <torchdrug.layers.functional.variadic_sort>`,
+    :func:`variadic_topk <torchdrug.layers.functional.variadic_topk>`,
+    :func:`variadic_randperm <torchdrug.layers.functional.variadic_randperm>`,
+    :func:`variadic_sample <torchdrug.layers.functional.variadic_sample>`,
+    :func:`variadic_meshgrid <torchdrug.layers.functional.variadic_meshgrid`,
+    :func:`variadic_softmax <torchdrug.layers.functional.variadic_softmax>`,
+    :func:`variadic_log_softmax <torchdrug.layers.functional.variadic_log_softmax>`,
+    :func:`variadic_cross_entropy <torchdrug.layers.functional.variadic_cross_entropy>`,
+    :func:`variadic_accuracy <torchdrug.metrics.variadic_accuracy>`
+
+Variadic Output
+---------------
+
+In some cases, we also need to write functions that produce variadic outputs. A
+typical example is autoregressive generation, where we need to generate all
+node/edge prefixes of a graph. When this operation is batched, we need to output
+variadic numbers of graphs for different input graphs.
+
+Here we show how to generate edge prefixes for a batch of graphs in TorchDrug.
+First, let's prepare a batch of two graphs.
+
+.. code:: python
+
+    edge_list = [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5]]
+    graph1 = data.Graph(edge_list, num_node=6)
+    edge_list = [[0, 1], [1, 2], [2, 3], [3, 0], [0, 2], [1, 3]]
+    graph2 = data.Graph(edge_list, num_node=4)
+    graph = data.Graph.pack([graph1, graph2])
+    with graph.graph():
+        graph.id = torch.arange(2)
+
+.. image:: ../../../asset/graph/autoregressive_input.png
+    :align: center
+    :width: 66%
+
+The generation of edge prefixes consists 3 steps.
+
+1. Construct an extended batch with enough copies for each graph.
+2. Apply an edge mask over the batch.
+3. Remove excess or invalid graphs.
+
+The first step can be implemented through
+:meth:`Graph.repeat <torchdrug.data.Graph.repeat>`. For the following steps, we
+define an auxiliary function ``all_prefix_slice``. This function takes in a size
+tensor and desired prefix lengths, and outputs :math:`n*l` prefix slices for the
+extended batch, where :math:`n` is the batch size and :math:`l` is the number of
+prefix lengths.
+
+.. code:: python
+
+    def all_prefix_slice(size, lengths=None):
+        cum_sizes = sizes.cumsum(0)
+        starts = cum_sizes - sizes
+        if lengths is None:
+            max_size = sizes.max().item()
+            lengths = torch.arange(0, max_size, 1, device=sizes.device)
+
+        pack_offsets = torch.arange(len(lengths), device=sizes.device) * num_cum_xs[-1]
+        starts = starts.unsqueeze(0) + pack_offsets.unsqueeze(-1)
+        valid = lengths.unsqueeze(-1) <= sizes.unsqueeze(0)
+        lengths = torch.min(lengths.unsqueeze(-1), sizes.unsqueeze(0)).clamp(0)
+        ends = starts + lengths
+
+        starts = starts.flatten()
+        ends = ends.flatten()
+        valid = valid.flatten()
+
+        return starts, ends, valid
+
+    lengths = torch.arange(1, graph.num_edges.max() + 1)
+    num_length = len(lengths)
+    starts, ends, valid = all_prefix_slice(graph.num_edges, lengths)
+
+The slices are visualized as follows. Two colors correspond to two input graphs.
+
+.. image:: ../../../asset/tensor/autoregressive_slice.png
+    :align: center
+    :width: 55%
+
+.. code:: python
+
+    graph = graph.repeat(num_length) # step 1
+    mask = functional.multi_slice_mask(starts, ends)
+    graph = graph.edge_mask(mask) # step 2
+    graph = graph[valid] # step 3
+
+The output batch is
+
+.. image:: ../../../asset/graph/autoregressive_output.png
\ No newline at end of file