a b/doc/source/notes/model.rst
1
Customize Models & Tasks
2
========================
3
4
TorchDrug provides many popular model architectures for graph representation
5
learning. However, you may still find yourself in need of some more customized
6
architectures.
7
8
Here we illustrate the steps for writing customized models based on the example
9
of `variational graph auto encoder`_ (VGAE). VGAE learns latent node representations
10
with a graph convolutional network (GCN) encoder and an inner product decoder.
11
They are jointly trained with a reconstruction loss and evaluated on the link
12
prediction task.
13
14
.. _variational graph auto encoder: https://arxiv.org/pdf/1611.07308.pdf
15
16
As a convention, we separate representation models and task-specific designs for
17
better reusability.
18
19
Node Representation Model
20
-------------------------
21
22
In VGAE, the node representation model is a variational graph convolutional network
23
(VGCN). This can be implemented via standard graph convolution layers, plus a
24
variational regularization loss. We define our model as a subclass of `nn.Module`
25
and :class:`core.Configurable <torchdrug.core.Configurable>`.
26
27
.. code:: python
28
29
    import torch
30
    from torch import nn
31
    from torch.nn import functional as F
32
    from torch.utils import data as torch_data
33
34
    from torchdrug import core, layers, datasets, metrics
35
    from torchdrug.core import Registry as R
36
37
    @R.register("models.VGCN")
38
    class VariationalGraphConvolutionalNetwork(nn.Module, core.Configurable):
39
40
        def __init__(self, input_dim, hidden_dims, beta=0, batch_norm=False,
41
                     activation="relu"):
42
            super(VariationalGraphConvolutionalNetwork, self).__init__()
43
            self.input_dim = input_dim
44
            self.output_dim = hidden_dims[-1]
45
            self.dims = [input_dim] + list(hidden_dims)
46
            self.beta = beta
47
48
            self.layers = nn.ModuleList()
49
            for i in range(len(self.dims) - 2):
50
                self.layers.append(
51
                    layers.GraphConv(self.dims[i], self.dims[i + 1], None,
52
                                     batch_norm, activation)
53
                )
54
            self.layers.append(
55
                layers.GraphConv(self.dims[-2], self.dims[-1] * 2, None, False, None)
56
            )
57
58
The definition is similar to most other ``torch`` models, except two points.
59
First, the decoration line ``@R.register("models.VGCN")`` registers the model in
60
the library with the name ``models.VGCN``. This enables the model to be dumped
61
into string format and reconstructed later. Second, ``self.input_dim`` and
62
``self.output_dim`` are set to inform other models that connect to it.
63
64
Then we implement the forward function. The forward function takes 4 arguments,
65
graph(s), node input feature(s), the global loss and the global metric. The advanatage
66
of these global variables is that they enable implementation of losses in a
67
distributed, module-centric manner.
68
69
We compute the variational regularization loss, and add it to the global loss and the
70
global metric.
71
72
.. code::
73
74
        def reparameterize(self, mu, log_sigma):
75
            if self.training:
76
                z = mu + torch.rand_like(mu) * log_sigma.exp()
77
            else:
78
                z = mu
79
            return z
80
81
        def forward(self, graph, input, all_loss=None, metric=None):
82
            x = input
83
            for layer in self.layers:
84
                x = layer(graph, x)
85
            mu, log_sigma = x.chunk(2, dim=-1)
86
            node_feature = self.reparameterize(mu, log_sigma)
87
88
            if all_loss is not None and self.beta > 0:
89
                loss = 0.5 * (mu ** 2 + log_sigma.exp() ** 2 - 2 * log_sigma - 1)
90
                loss = loss.sum(dim=-1).mean()
91
                all_loss += loss * self.beta
92
                metric["variational regularization loss"] = loss
93
94
            return {
95
                "node_feature": node_feature
96
            }
97
98
Here we explicitly return a dict to indicate the type of our representations. The
99
dict may also contain other representations, such edge representations or graph
100
representations.
101
102
Link Prediction Task
103
--------------------
104
105
Here we show how to implement the link prediction task for VGAE.
106
107
Generally, a task in TorchDrug contains 4 functions, ``predict()``, ``target()``,
108
``forward`` and ``evaluate()``. Such interfaces empower us to seamlessly switch
109
between different devices, such as CPUs, GPUs or even the distributed setting.
110
111
Among the above functions, ``predict()`` and ``target()`` compute the prediction and
112
the ground truth for a batch respectively. ``forward()`` compute the training loss,
113
while ``evaluate()`` compute the evaluation metrics.
114
115
Optionally, one can also implement ``preprocess()`` function, which performs
116
arbitrary operations based on the dataset.
117
118
In the case of VGAE, we first compute the undirected training graph in
119
``preprocess()``. In ``predict()``, we perform negative sampling, and predict
120
the logits for both positive and negative edges. In ``target()``, we return
121
the ground truth label for edges. ``evaluate()`` computes the area under ROC curve
122
for the predictions.
123
124
.. code:: python
125
126
    @R.register("tasks.LinkPrediction")
127
    class LinkPrediction(tasks.Task, core.Configurable):
128
129
        def __init__(self, model):
130
            super(LinkPrediction, self).__init__()
131
            self.model = model
132
133
        def preprocess(self, train_set, valid_set, test_set):
134
            dataset = train_set.dataset
135
            graph = dataset.graph
136
            train_graph = dataset.graph.edge_mask(train_set.indices)
137
138
            # flip the edges to make the graph undirected
139
            edge_list = train_graph.edge_list.repeat(2, 1)
140
            edge_list[train_graph.num_edge:, :2] = edge_list[train_graph.num_edge:, :2] \
141
                                                   .flip(1)
142
            index = torch.arange(train_graph.num_edge, device=self.device) \
143
                    .repeat(2, 1).t().flatten()
144
            data_dict, meta_dict = train_graph.data_mask(edge_index=index)
145
            train_graph = type(train_graph)(
146
                edge_list, edge_weight=train_graph.edge_weight[index],
147
                num_node=train_graph.num_node, num_edge=train_graph.num_edge * 2,
148
                meta_dict=meta_dict, **data_dict
149
            )
150
151
            self.register_buffer("train_graph", train_graph)
152
            self.num_node = dataset.num_node
153
154
        def forward(self, batch):
155
            all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
156
            metric = {}
157
158
            pred = self.predict(batch, all_loss, metric)
159
            target = self.target(batch)
160
            metric.update(self.evaluate(pred, target))
161
162
            loss = F.binary_cross_entropy_with_logits(pred, target)
163
            metric["bce loss"] = loss
164
165
            all_loss += loss
166
167
            return all_loss, metric
168
169
        def predict(self, batch, all_loss=None, metric=None):
170
            neg_batch = torch.randint(self.num_node, batch.shape, device=self.device)
171
            batch = torch.cat([batch, neg_batch])
172
            node_in, node_out = batch.t()
173
174
            output = self.model(self.train_graph, self.train_graph.node_feature.float(),
175
                                all_loss, metric)
176
            node_feature = output["node_feature"]
177
            pred = torch.einsum("bd, bd -> b",
178
                                node_feature[node_in], node_feature[node_out])
179
            return pred
180
181
        def target(self, batch):
182
            batch_size = len(batch)
183
            target = torch.zeros(batch_size * 2, device=self.device)
184
            target[:batch_size] = 1
185
            return target
186
187
        def evaluate(self, pred, target):
188
            roc = metrics.area_under_roc(pred, target)
189
            return {
190
                "AUROC": roc
191
            }
192
193
Put Them Together
194
-----------------
195
196
Let's put all the ingredients together. Since the original Cora is a node
197
classification dataset, we apply a wrapper to make it compatible with link
198
prediction.
199
200
.. code:: python
201
202
    class CoraLinkPrediction(datasets.Cora):
203
204
        def __getitem__(self, index):
205
            return self.graph.edge_list[index]
206
207
        def __len__(self):
208
            return self.graph.num_edge
209
210
    dataset = CoraLinkPrediction("~/node-datasets/")
211
    lengths = [int(0.8 * len(dataset)), int(0.1 * len(dataset))]
212
    lengths += [len(dataset) - sum(lengths)]
213
    train_set, valid_set, test_set = torch_data.random_split(dataset, lengths)
214
215
    model = VariationalGraphConvolutionalNetwork(dataset.node_feature_dim, [128, 16],
216
                                                 beta=1e-3, batch_norm=True)
217
    task = LinkPrediction(model)
218
219
    optimizer = torch.optim.Adam(task.parameters(), lr=1e-2)
220
    solver = core.Engine(task, train_set, valid_set, test_set, optimizer, gpus=[0],
221
                         batch_size=len(train_set))
222
    solver.train(num_epoch=200)
223
    solver.evaluate("valid")
224
225
The result may look like
226
227
.. code:: bash
228
229
    AUROC: 0.898589