|
a |
|
b/doc/source/notes/model.rst |
|
|
1 |
Customize Models & Tasks |
|
|
2 |
======================== |
|
|
3 |
|
|
|
4 |
TorchDrug provides many popular model architectures for graph representation |
|
|
5 |
learning. However, you may still find yourself in need of some more customized |
|
|
6 |
architectures. |
|
|
7 |
|
|
|
8 |
Here we illustrate the steps for writing customized models based on the example |
|
|
9 |
of `variational graph auto encoder`_ (VGAE). VGAE learns latent node representations |
|
|
10 |
with a graph convolutional network (GCN) encoder and an inner product decoder. |
|
|
11 |
They are jointly trained with a reconstruction loss and evaluated on the link |
|
|
12 |
prediction task. |
|
|
13 |
|
|
|
14 |
.. _variational graph auto encoder: https://arxiv.org/pdf/1611.07308.pdf |
|
|
15 |
|
|
|
16 |
As a convention, we separate representation models and task-specific designs for |
|
|
17 |
better reusability. |
|
|
18 |
|
|
|
19 |
Node Representation Model |
|
|
20 |
------------------------- |
|
|
21 |
|
|
|
22 |
In VGAE, the node representation model is a variational graph convolutional network |
|
|
23 |
(VGCN). This can be implemented via standard graph convolution layers, plus a |
|
|
24 |
variational regularization loss. We define our model as a subclass of `nn.Module` |
|
|
25 |
and :class:`core.Configurable <torchdrug.core.Configurable>`. |
|
|
26 |
|
|
|
27 |
.. code:: python |
|
|
28 |
|
|
|
29 |
import torch |
|
|
30 |
from torch import nn |
|
|
31 |
from torch.nn import functional as F |
|
|
32 |
from torch.utils import data as torch_data |
|
|
33 |
|
|
|
34 |
from torchdrug import core, layers, datasets, metrics |
|
|
35 |
from torchdrug.core import Registry as R |
|
|
36 |
|
|
|
37 |
@R.register("models.VGCN") |
|
|
38 |
class VariationalGraphConvolutionalNetwork(nn.Module, core.Configurable): |
|
|
39 |
|
|
|
40 |
def __init__(self, input_dim, hidden_dims, beta=0, batch_norm=False, |
|
|
41 |
activation="relu"): |
|
|
42 |
super(VariationalGraphConvolutionalNetwork, self).__init__() |
|
|
43 |
self.input_dim = input_dim |
|
|
44 |
self.output_dim = hidden_dims[-1] |
|
|
45 |
self.dims = [input_dim] + list(hidden_dims) |
|
|
46 |
self.beta = beta |
|
|
47 |
|
|
|
48 |
self.layers = nn.ModuleList() |
|
|
49 |
for i in range(len(self.dims) - 2): |
|
|
50 |
self.layers.append( |
|
|
51 |
layers.GraphConv(self.dims[i], self.dims[i + 1], None, |
|
|
52 |
batch_norm, activation) |
|
|
53 |
) |
|
|
54 |
self.layers.append( |
|
|
55 |
layers.GraphConv(self.dims[-2], self.dims[-1] * 2, None, False, None) |
|
|
56 |
) |
|
|
57 |
|
|
|
58 |
The definition is similar to most other ``torch`` models, except two points. |
|
|
59 |
First, the decoration line ``@R.register("models.VGCN")`` registers the model in |
|
|
60 |
the library with the name ``models.VGCN``. This enables the model to be dumped |
|
|
61 |
into string format and reconstructed later. Second, ``self.input_dim`` and |
|
|
62 |
``self.output_dim`` are set to inform other models that connect to it. |
|
|
63 |
|
|
|
64 |
Then we implement the forward function. The forward function takes 4 arguments, |
|
|
65 |
graph(s), node input feature(s), the global loss and the global metric. The advanatage |
|
|
66 |
of these global variables is that they enable implementation of losses in a |
|
|
67 |
distributed, module-centric manner. |
|
|
68 |
|
|
|
69 |
We compute the variational regularization loss, and add it to the global loss and the |
|
|
70 |
global metric. |
|
|
71 |
|
|
|
72 |
.. code:: |
|
|
73 |
|
|
|
74 |
def reparameterize(self, mu, log_sigma): |
|
|
75 |
if self.training: |
|
|
76 |
z = mu + torch.rand_like(mu) * log_sigma.exp() |
|
|
77 |
else: |
|
|
78 |
z = mu |
|
|
79 |
return z |
|
|
80 |
|
|
|
81 |
def forward(self, graph, input, all_loss=None, metric=None): |
|
|
82 |
x = input |
|
|
83 |
for layer in self.layers: |
|
|
84 |
x = layer(graph, x) |
|
|
85 |
mu, log_sigma = x.chunk(2, dim=-1) |
|
|
86 |
node_feature = self.reparameterize(mu, log_sigma) |
|
|
87 |
|
|
|
88 |
if all_loss is not None and self.beta > 0: |
|
|
89 |
loss = 0.5 * (mu ** 2 + log_sigma.exp() ** 2 - 2 * log_sigma - 1) |
|
|
90 |
loss = loss.sum(dim=-1).mean() |
|
|
91 |
all_loss += loss * self.beta |
|
|
92 |
metric["variational regularization loss"] = loss |
|
|
93 |
|
|
|
94 |
return { |
|
|
95 |
"node_feature": node_feature |
|
|
96 |
} |
|
|
97 |
|
|
|
98 |
Here we explicitly return a dict to indicate the type of our representations. The |
|
|
99 |
dict may also contain other representations, such edge representations or graph |
|
|
100 |
representations. |
|
|
101 |
|
|
|
102 |
Link Prediction Task |
|
|
103 |
-------------------- |
|
|
104 |
|
|
|
105 |
Here we show how to implement the link prediction task for VGAE. |
|
|
106 |
|
|
|
107 |
Generally, a task in TorchDrug contains 4 functions, ``predict()``, ``target()``, |
|
|
108 |
``forward`` and ``evaluate()``. Such interfaces empower us to seamlessly switch |
|
|
109 |
between different devices, such as CPUs, GPUs or even the distributed setting. |
|
|
110 |
|
|
|
111 |
Among the above functions, ``predict()`` and ``target()`` compute the prediction and |
|
|
112 |
the ground truth for a batch respectively. ``forward()`` compute the training loss, |
|
|
113 |
while ``evaluate()`` compute the evaluation metrics. |
|
|
114 |
|
|
|
115 |
Optionally, one can also implement ``preprocess()`` function, which performs |
|
|
116 |
arbitrary operations based on the dataset. |
|
|
117 |
|
|
|
118 |
In the case of VGAE, we first compute the undirected training graph in |
|
|
119 |
``preprocess()``. In ``predict()``, we perform negative sampling, and predict |
|
|
120 |
the logits for both positive and negative edges. In ``target()``, we return |
|
|
121 |
the ground truth label for edges. ``evaluate()`` computes the area under ROC curve |
|
|
122 |
for the predictions. |
|
|
123 |
|
|
|
124 |
.. code:: python |
|
|
125 |
|
|
|
126 |
@R.register("tasks.LinkPrediction") |
|
|
127 |
class LinkPrediction(tasks.Task, core.Configurable): |
|
|
128 |
|
|
|
129 |
def __init__(self, model): |
|
|
130 |
super(LinkPrediction, self).__init__() |
|
|
131 |
self.model = model |
|
|
132 |
|
|
|
133 |
def preprocess(self, train_set, valid_set, test_set): |
|
|
134 |
dataset = train_set.dataset |
|
|
135 |
graph = dataset.graph |
|
|
136 |
train_graph = dataset.graph.edge_mask(train_set.indices) |
|
|
137 |
|
|
|
138 |
# flip the edges to make the graph undirected |
|
|
139 |
edge_list = train_graph.edge_list.repeat(2, 1) |
|
|
140 |
edge_list[train_graph.num_edge:, :2] = edge_list[train_graph.num_edge:, :2] \ |
|
|
141 |
.flip(1) |
|
|
142 |
index = torch.arange(train_graph.num_edge, device=self.device) \ |
|
|
143 |
.repeat(2, 1).t().flatten() |
|
|
144 |
data_dict, meta_dict = train_graph.data_mask(edge_index=index) |
|
|
145 |
train_graph = type(train_graph)( |
|
|
146 |
edge_list, edge_weight=train_graph.edge_weight[index], |
|
|
147 |
num_node=train_graph.num_node, num_edge=train_graph.num_edge * 2, |
|
|
148 |
meta_dict=meta_dict, **data_dict |
|
|
149 |
) |
|
|
150 |
|
|
|
151 |
self.register_buffer("train_graph", train_graph) |
|
|
152 |
self.num_node = dataset.num_node |
|
|
153 |
|
|
|
154 |
def forward(self, batch): |
|
|
155 |
all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) |
|
|
156 |
metric = {} |
|
|
157 |
|
|
|
158 |
pred = self.predict(batch, all_loss, metric) |
|
|
159 |
target = self.target(batch) |
|
|
160 |
metric.update(self.evaluate(pred, target)) |
|
|
161 |
|
|
|
162 |
loss = F.binary_cross_entropy_with_logits(pred, target) |
|
|
163 |
metric["bce loss"] = loss |
|
|
164 |
|
|
|
165 |
all_loss += loss |
|
|
166 |
|
|
|
167 |
return all_loss, metric |
|
|
168 |
|
|
|
169 |
def predict(self, batch, all_loss=None, metric=None): |
|
|
170 |
neg_batch = torch.randint(self.num_node, batch.shape, device=self.device) |
|
|
171 |
batch = torch.cat([batch, neg_batch]) |
|
|
172 |
node_in, node_out = batch.t() |
|
|
173 |
|
|
|
174 |
output = self.model(self.train_graph, self.train_graph.node_feature.float(), |
|
|
175 |
all_loss, metric) |
|
|
176 |
node_feature = output["node_feature"] |
|
|
177 |
pred = torch.einsum("bd, bd -> b", |
|
|
178 |
node_feature[node_in], node_feature[node_out]) |
|
|
179 |
return pred |
|
|
180 |
|
|
|
181 |
def target(self, batch): |
|
|
182 |
batch_size = len(batch) |
|
|
183 |
target = torch.zeros(batch_size * 2, device=self.device) |
|
|
184 |
target[:batch_size] = 1 |
|
|
185 |
return target |
|
|
186 |
|
|
|
187 |
def evaluate(self, pred, target): |
|
|
188 |
roc = metrics.area_under_roc(pred, target) |
|
|
189 |
return { |
|
|
190 |
"AUROC": roc |
|
|
191 |
} |
|
|
192 |
|
|
|
193 |
Put Them Together |
|
|
194 |
----------------- |
|
|
195 |
|
|
|
196 |
Let's put all the ingredients together. Since the original Cora is a node |
|
|
197 |
classification dataset, we apply a wrapper to make it compatible with link |
|
|
198 |
prediction. |
|
|
199 |
|
|
|
200 |
.. code:: python |
|
|
201 |
|
|
|
202 |
class CoraLinkPrediction(datasets.Cora): |
|
|
203 |
|
|
|
204 |
def __getitem__(self, index): |
|
|
205 |
return self.graph.edge_list[index] |
|
|
206 |
|
|
|
207 |
def __len__(self): |
|
|
208 |
return self.graph.num_edge |
|
|
209 |
|
|
|
210 |
dataset = CoraLinkPrediction("~/node-datasets/") |
|
|
211 |
lengths = [int(0.8 * len(dataset)), int(0.1 * len(dataset))] |
|
|
212 |
lengths += [len(dataset) - sum(lengths)] |
|
|
213 |
train_set, valid_set, test_set = torch_data.random_split(dataset, lengths) |
|
|
214 |
|
|
|
215 |
model = VariationalGraphConvolutionalNetwork(dataset.node_feature_dim, [128, 16], |
|
|
216 |
beta=1e-3, batch_norm=True) |
|
|
217 |
task = LinkPrediction(model) |
|
|
218 |
|
|
|
219 |
optimizer = torch.optim.Adam(task.parameters(), lr=1e-2) |
|
|
220 |
solver = core.Engine(task, train_set, valid_set, test_set, optimizer, gpus=[0], |
|
|
221 |
batch_size=len(train_set)) |
|
|
222 |
solver.train(num_epoch=200) |
|
|
223 |
solver.evaluate("valid") |
|
|
224 |
|
|
|
225 |
The result may look like |
|
|
226 |
|
|
|
227 |
.. code:: bash |
|
|
228 |
|
|
|
229 |
AUROC: 0.898589 |