|
a |
|
b/doc/source/notes/variadic.rst |
|
|
1 |
Batch Irregular Structures |
|
|
2 |
========================== |
|
|
3 |
|
|
|
4 |
Unlike images, text and audio, graphs usually have irregular structures, which |
|
|
5 |
makes them hard to batch in tensor frameworks. Many existing implementations use |
|
|
6 |
padding to convert graphs into dense grid structures, which costs much unnecessary |
|
|
7 |
computation and memory. |
|
|
8 |
|
|
|
9 |
In TorchDrug, we develop a more intuitive and efficient solution based on |
|
|
10 |
variadic functions. The variadic functions can directly operate on sparse irregular |
|
|
11 |
inputs or outputs. |
|
|
12 |
|
|
|
13 |
Variadic Input |
|
|
14 |
-------------- |
|
|
15 |
|
|
|
16 |
Here we show how to apply functions to variadic inputs. |
|
|
17 |
|
|
|
18 |
Generally, a batch of :math:`n` variadic tensors can be represented by a value |
|
|
19 |
tensor and a size tensor. The value tensor is a concatenation of all variadic |
|
|
20 |
tensors along the variadic axis, while the size tensor indicates how big each |
|
|
21 |
variadic tensor is. |
|
|
22 |
|
|
|
23 |
Let's first create a batch of 1D variadic samples. |
|
|
24 |
|
|
|
25 |
.. code:: python |
|
|
26 |
|
|
|
27 |
import torch |
|
|
28 |
|
|
|
29 |
samples = [] |
|
|
30 |
for size in range(2, 6): |
|
|
31 |
samples.append(torch.randint(6, (size,))) |
|
|
32 |
value = torch.cat(samples) |
|
|
33 |
size = torch.tensor([len(s) for s in samples]) |
|
|
34 |
|
|
|
35 |
.. image:: ../../../asset/tensor/variadic_tensor.png |
|
|
36 |
:align: center |
|
|
37 |
:width: 60% |
|
|
38 |
|
|
|
39 |
We apply variadic functions to compute the sum, max and top-k values for each |
|
|
40 |
sample. |
|
|
41 |
|
|
|
42 |
.. code:: python |
|
|
43 |
|
|
|
44 |
from torchdrug.layers import functional |
|
|
45 |
|
|
|
46 |
sum = functional.variadic_sum(value, size) |
|
|
47 |
max = functional.variadic_max(value, size)[0] |
|
|
48 |
top3_value, top3_index = functional.variadic_topk(value, size, k=3) |
|
|
49 |
|
|
|
50 |
Note :meth:`variadic_topk <torchdrug.layers.functional.variadic_topk>` accepts |
|
|
51 |
samples smaller than :math:`k`. In this case, it will fill the output with the |
|
|
52 |
smallest element from that sample. |
|
|
53 |
|
|
|
54 |
.. image:: ../../../asset/tensor/variadic_func_result.png |
|
|
55 |
:align: center |
|
|
56 |
:width: 88% |
|
|
57 |
|
|
|
58 |
Mathematically, these functions can be viewed as performing the operation over |
|
|
59 |
each sample with a for loop. For example, the variadic sum is equivalent to the |
|
|
60 |
following logic. |
|
|
61 |
|
|
|
62 |
.. code:: |
|
|
63 |
|
|
|
64 |
sums = [] |
|
|
65 |
for sample in samples: |
|
|
66 |
sums.append(sample.sum()) |
|
|
67 |
sum = torch.cat(sums) |
|
|
68 |
|
|
|
69 |
.. note:: |
|
|
70 |
|
|
|
71 |
In spite of the same logic, variadic functions is much faster than for loops |
|
|
72 |
on GPUs (typically :math:`\text{batch size}\times` faster). Use variadic functions |
|
|
73 |
instead of for loops whenever possible. |
|
|
74 |
|
|
|
75 |
Many operations in graph representation learning can be implemented by variadic |
|
|
76 |
functions. For example, |
|
|
77 |
|
|
|
78 |
1. Infer graph-level representations from node-/edge-level representations. |
|
|
79 |
2. Perform classification over nodes/edges. |
|
|
80 |
|
|
|
81 |
Here we demonstrate how to perform classification over nodes. We create a toy |
|
|
82 |
task, where the model needs to predict the heaviest atom of each molecule. Note |
|
|
83 |
that node attributes form variadic tensors with ``num_nodes`` from the same graph. |
|
|
84 |
Therefore, we can use :meth:`variadic_max <torchdrug.layers.functional.variadic_max>` |
|
|
85 |
to get our ground truth. |
|
|
86 |
|
|
|
87 |
.. code:: python |
|
|
88 |
|
|
|
89 |
from torchdrug import data, models, metrics |
|
|
90 |
|
|
|
91 |
smiles_list = ["CC(=C)C#N", "CCNC(=S)NCC", "BrC1=CC=C(Br)C=C1"] |
|
|
92 |
graph = data.PackedMolecule.from_smiles(smiles_list) |
|
|
93 |
target = functional.variadic_max(graph.atom_type, graph.num_nodes)[1] |
|
|
94 |
|
|
|
95 |
Naturally, the prediction over nodes also forms a variadic tensor with ``num_nodes``. |
|
|
96 |
|
|
|
97 |
.. code:: python |
|
|
98 |
|
|
|
99 |
model = models.GCN(input_dim=graph.node_feature.shape[-1], hidden_dims=[128, 128, 1]) |
|
|
100 |
feature = model(graph, graph.node_feature.float()) |
|
|
101 |
pred = feature["node_feature"].squeeze(-1) |
|
|
102 |
|
|
|
103 |
pred_prob, pred_index = functional.variadic_max(pred, graph.num_nodes) |
|
|
104 |
loss = functional.variadic_cross_entropy(pred, target, graph.num_nodes) |
|
|
105 |
accuracy = metrics.variadic_accuracy(pred, target, graph.num_nodes) |
|
|
106 |
|
|
|
107 |
.. seealso:: |
|
|
108 |
:func:`variadic_sum <torchdrug.layers.functional.variadic_sum>`, |
|
|
109 |
:func:`variadic_mean <torchdrug.layers.functional.variadic_mean>`, |
|
|
110 |
:func:`variadic_max <torchdrug.layers.functional.variadic_max>`, |
|
|
111 |
:func:`variadic_arange <torchdrug.layers.functional.variadic_arange>`, |
|
|
112 |
:func:`variadic_sort <torchdrug.layers.functional.variadic_sort>`, |
|
|
113 |
:func:`variadic_topk <torchdrug.layers.functional.variadic_topk>`, |
|
|
114 |
:func:`variadic_randperm <torchdrug.layers.functional.variadic_randperm>`, |
|
|
115 |
:func:`variadic_sample <torchdrug.layers.functional.variadic_sample>`, |
|
|
116 |
:func:`variadic_meshgrid <torchdrug.layers.functional.variadic_meshgrid`, |
|
|
117 |
:func:`variadic_softmax <torchdrug.layers.functional.variadic_softmax>`, |
|
|
118 |
:func:`variadic_log_softmax <torchdrug.layers.functional.variadic_log_softmax>`, |
|
|
119 |
:func:`variadic_cross_entropy <torchdrug.layers.functional.variadic_cross_entropy>`, |
|
|
120 |
:func:`variadic_accuracy <torchdrug.metrics.variadic_accuracy>` |
|
|
121 |
|
|
|
122 |
Variadic Output |
|
|
123 |
--------------- |
|
|
124 |
|
|
|
125 |
In some cases, we also need to write functions that produce variadic outputs. A |
|
|
126 |
typical example is autoregressive generation, where we need to generate all |
|
|
127 |
node/edge prefixes of a graph. When this operation is batched, we need to output |
|
|
128 |
variadic numbers of graphs for different input graphs. |
|
|
129 |
|
|
|
130 |
Here we show how to generate edge prefixes for a batch of graphs in TorchDrug. |
|
|
131 |
First, let's prepare a batch of two graphs. |
|
|
132 |
|
|
|
133 |
.. code:: python |
|
|
134 |
|
|
|
135 |
edge_list = [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5]] |
|
|
136 |
graph1 = data.Graph(edge_list, num_node=6) |
|
|
137 |
edge_list = [[0, 1], [1, 2], [2, 3], [3, 0], [0, 2], [1, 3]] |
|
|
138 |
graph2 = data.Graph(edge_list, num_node=4) |
|
|
139 |
graph = data.Graph.pack([graph1, graph2]) |
|
|
140 |
with graph.graph(): |
|
|
141 |
graph.id = torch.arange(2) |
|
|
142 |
|
|
|
143 |
.. image:: ../../../asset/graph/autoregressive_input.png |
|
|
144 |
:align: center |
|
|
145 |
:width: 66% |
|
|
146 |
|
|
|
147 |
The generation of edge prefixes consists 3 steps. |
|
|
148 |
|
|
|
149 |
1. Construct an extended batch with enough copies for each graph. |
|
|
150 |
2. Apply an edge mask over the batch. |
|
|
151 |
3. Remove excess or invalid graphs. |
|
|
152 |
|
|
|
153 |
The first step can be implemented through |
|
|
154 |
:meth:`Graph.repeat <torchdrug.data.Graph.repeat>`. For the following steps, we |
|
|
155 |
define an auxiliary function ``all_prefix_slice``. This function takes in a size |
|
|
156 |
tensor and desired prefix lengths, and outputs :math:`n*l` prefix slices for the |
|
|
157 |
extended batch, where :math:`n` is the batch size and :math:`l` is the number of |
|
|
158 |
prefix lengths. |
|
|
159 |
|
|
|
160 |
.. code:: python |
|
|
161 |
|
|
|
162 |
def all_prefix_slice(size, lengths=None): |
|
|
163 |
cum_sizes = sizes.cumsum(0) |
|
|
164 |
starts = cum_sizes - sizes |
|
|
165 |
if lengths is None: |
|
|
166 |
max_size = sizes.max().item() |
|
|
167 |
lengths = torch.arange(0, max_size, 1, device=sizes.device) |
|
|
168 |
|
|
|
169 |
pack_offsets = torch.arange(len(lengths), device=sizes.device) * num_cum_xs[-1] |
|
|
170 |
starts = starts.unsqueeze(0) + pack_offsets.unsqueeze(-1) |
|
|
171 |
valid = lengths.unsqueeze(-1) <= sizes.unsqueeze(0) |
|
|
172 |
lengths = torch.min(lengths.unsqueeze(-1), sizes.unsqueeze(0)).clamp(0) |
|
|
173 |
ends = starts + lengths |
|
|
174 |
|
|
|
175 |
starts = starts.flatten() |
|
|
176 |
ends = ends.flatten() |
|
|
177 |
valid = valid.flatten() |
|
|
178 |
|
|
|
179 |
return starts, ends, valid |
|
|
180 |
|
|
|
181 |
lengths = torch.arange(1, graph.num_edges.max() + 1) |
|
|
182 |
num_length = len(lengths) |
|
|
183 |
starts, ends, valid = all_prefix_slice(graph.num_edges, lengths) |
|
|
184 |
|
|
|
185 |
The slices are visualized as follows. Two colors correspond to two input graphs. |
|
|
186 |
|
|
|
187 |
.. image:: ../../../asset/tensor/autoregressive_slice.png |
|
|
188 |
:align: center |
|
|
189 |
:width: 55% |
|
|
190 |
|
|
|
191 |
.. code:: python |
|
|
192 |
|
|
|
193 |
graph = graph.repeat(num_length) # step 1 |
|
|
194 |
mask = functional.multi_slice_mask(starts, ends) |
|
|
195 |
graph = graph.edge_mask(mask) # step 2 |
|
|
196 |
graph = graph[valid] # step 3 |
|
|
197 |
|
|
|
198 |
The output batch is |
|
|
199 |
|
|
|
200 |
.. image:: ../../../asset/graph/autoregressive_output.png |