--- a +++ b/doc/source/notes/model.rst @@ -0,0 +1,229 @@ +Customize Models & Tasks +======================== + +TorchDrug provides many popular model architectures for graph representation +learning. However, you may still find yourself in need of some more customized +architectures. + +Here we illustrate the steps for writing customized models based on the example +of `variational graph auto encoder`_ (VGAE). VGAE learns latent node representations +with a graph convolutional network (GCN) encoder and an inner product decoder. +They are jointly trained with a reconstruction loss and evaluated on the link +prediction task. + +.. _variational graph auto encoder: https://arxiv.org/pdf/1611.07308.pdf + +As a convention, we separate representation models and task-specific designs for +better reusability. + +Node Representation Model +------------------------- + +In VGAE, the node representation model is a variational graph convolutional network +(VGCN). This can be implemented via standard graph convolution layers, plus a +variational regularization loss. We define our model as a subclass of `nn.Module` +and :class:`core.Configurable <torchdrug.core.Configurable>`. + +.. code:: python + + import torch + from torch import nn + from torch.nn import functional as F + from torch.utils import data as torch_data + + from torchdrug import core, layers, datasets, metrics + from torchdrug.core import Registry as R + + @R.register("models.VGCN") + class VariationalGraphConvolutionalNetwork(nn.Module, core.Configurable): + + def __init__(self, input_dim, hidden_dims, beta=0, batch_norm=False, + activation="relu"): + super(VariationalGraphConvolutionalNetwork, self).__init__() + self.input_dim = input_dim + self.output_dim = hidden_dims[-1] + self.dims = [input_dim] + list(hidden_dims) + self.beta = beta + + self.layers = nn.ModuleList() + for i in range(len(self.dims) - 2): + self.layers.append( + layers.GraphConv(self.dims[i], self.dims[i + 1], None, + batch_norm, activation) + ) + self.layers.append( + layers.GraphConv(self.dims[-2], self.dims[-1] * 2, None, False, None) + ) + +The definition is similar to most other ``torch`` models, except two points. +First, the decoration line ``@R.register("models.VGCN")`` registers the model in +the library with the name ``models.VGCN``. This enables the model to be dumped +into string format and reconstructed later. Second, ``self.input_dim`` and +``self.output_dim`` are set to inform other models that connect to it. + +Then we implement the forward function. The forward function takes 4 arguments, +graph(s), node input feature(s), the global loss and the global metric. The advanatage +of these global variables is that they enable implementation of losses in a +distributed, module-centric manner. + +We compute the variational regularization loss, and add it to the global loss and the +global metric. + +.. code:: + + def reparameterize(self, mu, log_sigma): + if self.training: + z = mu + torch.rand_like(mu) * log_sigma.exp() + else: + z = mu + return z + + def forward(self, graph, input, all_loss=None, metric=None): + x = input + for layer in self.layers: + x = layer(graph, x) + mu, log_sigma = x.chunk(2, dim=-1) + node_feature = self.reparameterize(mu, log_sigma) + + if all_loss is not None and self.beta > 0: + loss = 0.5 * (mu ** 2 + log_sigma.exp() ** 2 - 2 * log_sigma - 1) + loss = loss.sum(dim=-1).mean() + all_loss += loss * self.beta + metric["variational regularization loss"] = loss + + return { + "node_feature": node_feature + } + +Here we explicitly return a dict to indicate the type of our representations. The +dict may also contain other representations, such edge representations or graph +representations. + +Link Prediction Task +-------------------- + +Here we show how to implement the link prediction task for VGAE. + +Generally, a task in TorchDrug contains 4 functions, ``predict()``, ``target()``, +``forward`` and ``evaluate()``. Such interfaces empower us to seamlessly switch +between different devices, such as CPUs, GPUs or even the distributed setting. + +Among the above functions, ``predict()`` and ``target()`` compute the prediction and +the ground truth for a batch respectively. ``forward()`` compute the training loss, +while ``evaluate()`` compute the evaluation metrics. + +Optionally, one can also implement ``preprocess()`` function, which performs +arbitrary operations based on the dataset. + +In the case of VGAE, we first compute the undirected training graph in +``preprocess()``. In ``predict()``, we perform negative sampling, and predict +the logits for both positive and negative edges. In ``target()``, we return +the ground truth label for edges. ``evaluate()`` computes the area under ROC curve +for the predictions. + +.. code:: python + + @R.register("tasks.LinkPrediction") + class LinkPrediction(tasks.Task, core.Configurable): + + def __init__(self, model): + super(LinkPrediction, self).__init__() + self.model = model + + def preprocess(self, train_set, valid_set, test_set): + dataset = train_set.dataset + graph = dataset.graph + train_graph = dataset.graph.edge_mask(train_set.indices) + + # flip the edges to make the graph undirected + edge_list = train_graph.edge_list.repeat(2, 1) + edge_list[train_graph.num_edge:, :2] = edge_list[train_graph.num_edge:, :2] \ + .flip(1) + index = torch.arange(train_graph.num_edge, device=self.device) \ + .repeat(2, 1).t().flatten() + data_dict, meta_dict = train_graph.data_mask(edge_index=index) + train_graph = type(train_graph)( + edge_list, edge_weight=train_graph.edge_weight[index], + num_node=train_graph.num_node, num_edge=train_graph.num_edge * 2, + meta_dict=meta_dict, **data_dict + ) + + self.register_buffer("train_graph", train_graph) + self.num_node = dataset.num_node + + def forward(self, batch): + all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) + metric = {} + + pred = self.predict(batch, all_loss, metric) + target = self.target(batch) + metric.update(self.evaluate(pred, target)) + + loss = F.binary_cross_entropy_with_logits(pred, target) + metric["bce loss"] = loss + + all_loss += loss + + return all_loss, metric + + def predict(self, batch, all_loss=None, metric=None): + neg_batch = torch.randint(self.num_node, batch.shape, device=self.device) + batch = torch.cat([batch, neg_batch]) + node_in, node_out = batch.t() + + output = self.model(self.train_graph, self.train_graph.node_feature.float(), + all_loss, metric) + node_feature = output["node_feature"] + pred = torch.einsum("bd, bd -> b", + node_feature[node_in], node_feature[node_out]) + return pred + + def target(self, batch): + batch_size = len(batch) + target = torch.zeros(batch_size * 2, device=self.device) + target[:batch_size] = 1 + return target + + def evaluate(self, pred, target): + roc = metrics.area_under_roc(pred, target) + return { + "AUROC": roc + } + +Put Them Together +----------------- + +Let's put all the ingredients together. Since the original Cora is a node +classification dataset, we apply a wrapper to make it compatible with link +prediction. + +.. code:: python + + class CoraLinkPrediction(datasets.Cora): + + def __getitem__(self, index): + return self.graph.edge_list[index] + + def __len__(self): + return self.graph.num_edge + + dataset = CoraLinkPrediction("~/node-datasets/") + lengths = [int(0.8 * len(dataset)), int(0.1 * len(dataset))] + lengths += [len(dataset) - sum(lengths)] + train_set, valid_set, test_set = torch_data.random_split(dataset, lengths) + + model = VariationalGraphConvolutionalNetwork(dataset.node_feature_dim, [128, 16], + beta=1e-3, batch_norm=True) + task = LinkPrediction(model) + + optimizer = torch.optim.Adam(task.parameters(), lr=1e-2) + solver = core.Engine(task, train_set, valid_set, test_set, optimizer, gpus=[0], + batch_size=len(train_set)) + solver.train(num_epoch=200) + solver.evaluate("valid") + +The result may look like + +.. code:: bash + + AUROC: 0.898589 \ No newline at end of file