a b/doc/source/notes/variadic.rst
1
Batch Irregular Structures
2
==========================
3
4
Unlike images, text and audio, graphs usually have irregular structures, which
5
makes them hard to batch in tensor frameworks. Many existing implementations use
6
padding to convert graphs into dense grid structures, which costs much unnecessary
7
computation and memory.
8
9
In TorchDrug, we develop a more intuitive and efficient solution based on
10
variadic functions. The variadic functions can directly operate on sparse irregular
11
inputs or outputs.
12
13
Variadic Input
14
--------------
15
16
Here we show how to apply functions to variadic inputs.
17
18
Generally, a batch of :math:`n` variadic tensors can be represented by a value
19
tensor and a size tensor. The value tensor is a concatenation of all variadic
20
tensors along the variadic axis, while the size tensor indicates how big each
21
variadic tensor is.
22
23
Let's first create a batch of 1D variadic samples.
24
25
.. code:: python
26
27
    import torch
28
29
    samples = []
30
    for size in range(2, 6):
31
        samples.append(torch.randint(6, (size,)))
32
    value = torch.cat(samples)
33
    size = torch.tensor([len(s) for s in samples])
34
35
.. image:: ../../../asset/tensor/variadic_tensor.png
36
    :align: center
37
    :width: 60%
38
39
We apply variadic functions to compute the sum, max and top-k values for each
40
sample.
41
42
.. code:: python
43
44
    from torchdrug.layers import functional
45
46
    sum = functional.variadic_sum(value, size)
47
    max = functional.variadic_max(value, size)[0]
48
    top3_value, top3_index = functional.variadic_topk(value, size, k=3)
49
50
Note :meth:`variadic_topk <torchdrug.layers.functional.variadic_topk>` accepts
51
samples smaller than :math:`k`. In this case, it will fill the output with the
52
smallest element from that sample.
53
54
.. image:: ../../../asset/tensor/variadic_func_result.png
55
    :align: center
56
    :width: 88%
57
58
Mathematically, these functions can be viewed as performing the operation over
59
each sample with a for loop. For example, the variadic sum is equivalent to the
60
following logic.
61
62
.. code::
63
64
    sums = []
65
    for sample in samples:
66
        sums.append(sample.sum())
67
    sum = torch.cat(sums)
68
69
.. note::
70
71
    In spite of the same logic, variadic functions is much faster than for loops
72
    on GPUs (typically :math:`\text{batch size}\times` faster). Use variadic functions
73
    instead of for loops whenever possible.
74
75
Many operations in graph representation learning can be implemented by variadic
76
functions. For example,
77
78
1. Infer graph-level representations from node-/edge-level representations.
79
2. Perform classification over nodes/edges.
80
81
Here we demonstrate how to perform classification over nodes. We create a toy
82
task, where the model needs to predict the heaviest atom of each molecule. Note
83
that node attributes form variadic tensors with ``num_nodes`` from the same graph.
84
Therefore, we can use :meth:`variadic_max <torchdrug.layers.functional.variadic_max>`
85
to get our ground truth.
86
87
.. code:: python
88
89
    from torchdrug import data, models, metrics
90
91
    smiles_list = ["CC(=C)C#N", "CCNC(=S)NCC", "BrC1=CC=C(Br)C=C1"]
92
    graph = data.PackedMolecule.from_smiles(smiles_list)
93
    target = functional.variadic_max(graph.atom_type, graph.num_nodes)[1]
94
95
Naturally, the prediction over nodes also forms a variadic tensor with ``num_nodes``.
96
97
.. code:: python
98
99
    model = models.GCN(input_dim=graph.node_feature.shape[-1], hidden_dims=[128, 128, 1])
100
    feature = model(graph, graph.node_feature.float())
101
    pred = feature["node_feature"].squeeze(-1)
102
103
    pred_prob, pred_index = functional.variadic_max(pred, graph.num_nodes)
104
    loss = functional.variadic_cross_entropy(pred, target, graph.num_nodes)
105
    accuracy = metrics.variadic_accuracy(pred, target, graph.num_nodes)
106
107
.. seealso::
108
    :func:`variadic_sum <torchdrug.layers.functional.variadic_sum>`,
109
    :func:`variadic_mean <torchdrug.layers.functional.variadic_mean>`,
110
    :func:`variadic_max <torchdrug.layers.functional.variadic_max>`,
111
    :func:`variadic_arange <torchdrug.layers.functional.variadic_arange>`,
112
    :func:`variadic_sort <torchdrug.layers.functional.variadic_sort>`,
113
    :func:`variadic_topk <torchdrug.layers.functional.variadic_topk>`,
114
    :func:`variadic_randperm <torchdrug.layers.functional.variadic_randperm>`,
115
    :func:`variadic_sample <torchdrug.layers.functional.variadic_sample>`,
116
    :func:`variadic_meshgrid <torchdrug.layers.functional.variadic_meshgrid`,
117
    :func:`variadic_softmax <torchdrug.layers.functional.variadic_softmax>`,
118
    :func:`variadic_log_softmax <torchdrug.layers.functional.variadic_log_softmax>`,
119
    :func:`variadic_cross_entropy <torchdrug.layers.functional.variadic_cross_entropy>`,
120
    :func:`variadic_accuracy <torchdrug.metrics.variadic_accuracy>`
121
122
Variadic Output
123
---------------
124
125
In some cases, we also need to write functions that produce variadic outputs. A
126
typical example is autoregressive generation, where we need to generate all
127
node/edge prefixes of a graph. When this operation is batched, we need to output
128
variadic numbers of graphs for different input graphs.
129
130
Here we show how to generate edge prefixes for a batch of graphs in TorchDrug.
131
First, let's prepare a batch of two graphs.
132
133
.. code:: python
134
135
    edge_list = [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5]]
136
    graph1 = data.Graph(edge_list, num_node=6)
137
    edge_list = [[0, 1], [1, 2], [2, 3], [3, 0], [0, 2], [1, 3]]
138
    graph2 = data.Graph(edge_list, num_node=4)
139
    graph = data.Graph.pack([graph1, graph2])
140
    with graph.graph():
141
        graph.id = torch.arange(2)
142
143
.. image:: ../../../asset/graph/autoregressive_input.png
144
    :align: center
145
    :width: 66%
146
147
The generation of edge prefixes consists 3 steps.
148
149
1. Construct an extended batch with enough copies for each graph.
150
2. Apply an edge mask over the batch.
151
3. Remove excess or invalid graphs.
152
153
The first step can be implemented through
154
:meth:`Graph.repeat <torchdrug.data.Graph.repeat>`. For the following steps, we
155
define an auxiliary function ``all_prefix_slice``. This function takes in a size
156
tensor and desired prefix lengths, and outputs :math:`n*l` prefix slices for the
157
extended batch, where :math:`n` is the batch size and :math:`l` is the number of
158
prefix lengths.
159
160
.. code:: python
161
162
    def all_prefix_slice(size, lengths=None):
163
        cum_sizes = sizes.cumsum(0)
164
        starts = cum_sizes - sizes
165
        if lengths is None:
166
            max_size = sizes.max().item()
167
            lengths = torch.arange(0, max_size, 1, device=sizes.device)
168
169
        pack_offsets = torch.arange(len(lengths), device=sizes.device) * num_cum_xs[-1]
170
        starts = starts.unsqueeze(0) + pack_offsets.unsqueeze(-1)
171
        valid = lengths.unsqueeze(-1) <= sizes.unsqueeze(0)
172
        lengths = torch.min(lengths.unsqueeze(-1), sizes.unsqueeze(0)).clamp(0)
173
        ends = starts + lengths
174
175
        starts = starts.flatten()
176
        ends = ends.flatten()
177
        valid = valid.flatten()
178
179
        return starts, ends, valid
180
181
    lengths = torch.arange(1, graph.num_edges.max() + 1)
182
    num_length = len(lengths)
183
    starts, ends, valid = all_prefix_slice(graph.num_edges, lengths)
184
185
The slices are visualized as follows. Two colors correspond to two input graphs.
186
187
.. image:: ../../../asset/tensor/autoregressive_slice.png
188
    :align: center
189
    :width: 55%
190
191
.. code:: python
192
193
    graph = graph.repeat(num_length) # step 1
194
    mask = functional.multi_slice_mask(starts, ends)
195
    graph = graph.edge_mask(mask) # step 2
196
    graph = graph[valid] # step 3
197
198
The output batch is
199
200
.. image:: ../../../asset/graph/autoregressive_output.png