--- a +++ b/torchdrug/data/graph.py @@ -0,0 +1,1851 @@ +import math +import warnings +from functools import reduce +from collections import defaultdict + +import networkx as nx + +from matplotlib import pyplot as plt +import torch +from torch_scatter import scatter_add, scatter_min + +from torchdrug import core, utils +from torchdrug.data import Dictionary +from torchdrug.utils import pretty + +plt.switch_backend("agg") + + +class Graph(core._MetaContainer): + r""" + Basic container for sparse graphs. + + To batch graphs with variadic sizes, use :meth:`data.Graph.pack <torchdrug.data.Graph.pack>`. + This will return a PackedGraph object with the following block diagonal adjacency matrix. + + .. math:: + + \begin{bmatrix} + A_1 & \cdots & 0 \\ + \vdots & \ddots & \vdots \\ + 0 & \cdots & A_n + \end{bmatrix} + + where :math:`A_i` is the adjacency of :math:`i`-th graph. + + You may register dynamic attributes for each graph. + The registered attributes will be automatically processed during packing. + + .. warning:: + + This class doesn't enforce any order on the edges. + + Example:: + + >>> graph = data.Graph(torch.randint(10, (30, 2))) + >>> with graph.node(): + >>> graph.my_node_attr = torch.rand(10, 5, 5) + + Parameters: + edge_list (array_like, optional): list of edges of shape :math:`(|E|, 2)` or :math:`(|E|, 3)`. + Each tuple is (node_in, node_out) or (node_in, node_out, relation). + edge_weight (array_like, optional): edge weights of shape :math:`(|E|,)` + num_node (int, optional): number of nodes. + By default, it will be inferred from the largest id in `edge_list` + num_relation (int, optional): number of relations + node_feature (array_like, optional): node features of shape :math:`(|V|, ...)` + edge_feature (array_like, optional): edge features of shape :math:`(|E|, ...)` + graph_feature (array_like, optional): graph feature of any shape + """ + + _meta_types = {"node", "edge", "graph", "node reference", "edge reference", "graph reference"} + + def __init__(self, edge_list=None, edge_weight=None, num_node=None, num_relation=None, + node_feature=None, edge_feature=None, graph_feature=None, **kwargs): + super(Graph, self).__init__(**kwargs) + # edge_list: N * [h, t] or N * [h, t, r] + edge_list, num_edge = self._standarize_edge_list(edge_list, num_relation) + edge_weight = self._standarize_edge_weight(edge_weight, edge_list) + + num_node = self._standarize_num_node(num_node, edge_list) + num_relation = self._standarize_num_relation(num_relation, edge_list) + + self._edge_list = edge_list + self._edge_weight = edge_weight + self.num_node = num_node + self.num_edge = num_edge + self.num_relation = num_relation + + if node_feature is not None: + with self.node(): + self.node_feature = torch.as_tensor(node_feature, device=self.device) + if edge_feature is not None: + with self.edge(): + self.edge_feature = torch.as_tensor(edge_feature, device=self.device) + if graph_feature is not None: + with self.graph(): + self.graph_feature = torch.as_tensor(graph_feature, device=self.device) + + def node(self): + """ + Context manager for node attributes. + """ + return self.context("node") + + def edge(self): + """ + Context manager for edge attributes. + """ + return self.context("edge") + + def graph(self): + """ + Context manager for graph attributes. + """ + return self.context("graph") + + def node_reference(self): + """ + Context manager for node references. + """ + return self.context("node reference") + + def edge_reference(self): + """ + Context manager for edge references. + """ + return self.context("edge reference") + + def graph_reference(self): + """ + Context manager for graph references. + """ + return self.context("graph reference") + + def _check_attribute(self, key, value): + for type in self._meta_contexts: + if "reference" in type: + if value.dtype != torch.long: + raise TypeError("Tensors used as reference must be long tensors") + if type == "node": + if len(value) != self.num_node: + raise ValueError("Expect node attribute `%s` to have shape (%d, *), but found %s" % + (key, self.num_node, value.shape)) + elif type == "edge": + if len(value) != self.num_edge: + raise ValueError("Expect edge attribute `%s` to have shape (%d, *), but found %s" % + (key, self.num_edge, value.shape)) + elif type == "node reference": + is_valid = (value >= -1) & (value < self.num_node) + if not is_valid.all(): + error_value = value[~is_valid] + raise ValueError("Expect node reference in [-1, %d), but found %d" % + (self.num_node, error_value[0])) + elif type == "edge reference": + is_valid = (value >= -1) & (value < self.num_edge) + if not is_valid.all(): + error_value = value[~is_valid] + raise ValueError("Expect edge reference in [-1, %d), but found %d" % + (self.num_edge, error_value[0])) + elif type == "graph reference": + is_valid = (value >= -1) & (value < self.batch_size) + if not is_valid.all(): + error_value = value[~is_valid] + raise ValueError("Expect graph reference in [-1, %d), but found %d" % + (self.batch_size, error_value[0])) + + def __setattr__(self, key, value): + if hasattr(self, "meta_dict"): + self._check_attribute(key, value) + super(Graph, self).__setattr__(key, value) + + def _standarize_edge_list(self, edge_list, num_relation): + if edge_list is not None and len(edge_list): + if isinstance(edge_list, torch.Tensor) and edge_list.dtype != torch.long: + try: + edge_list = torch.LongTensor(edge_list) + except TypeError: + raise TypeError("Can't convert `edge_list` to torch.long") + else: + edge_list = torch.as_tensor(edge_list, dtype=torch.long) + else: + num_element = 2 if num_relation is None else 3 + if isinstance(edge_list, torch.Tensor): + device = edge_list.device + else: + device = "cpu" + edge_list = torch.zeros(0, num_element, dtype=torch.long, device=device) + if (edge_list < 0).any(): + raise ValueError("`edge_list` should only contain non-negative indexes") + num_edge = torch.tensor(len(edge_list), device=edge_list.device) + return edge_list, num_edge + + def _standarize_edge_weight(self, edge_weight, edge_list): + if edge_weight is not None: + edge_weight = torch.as_tensor(edge_weight, dtype=torch.float, device=edge_list.device) + if len(edge_list) != len(edge_weight): + raise ValueError("`edge_list` and `edge_weight` should be the same size, but found %d and %d" + % (len(edge_list), len(edge_weight))) + else: + edge_weight = torch.ones(len(edge_list), device=edge_list.device) + return edge_weight + + def _standarize_num_node(self, num_node, edge_list): + if num_node is None: + num_node = self._maybe_num_node(edge_list) + num_node = torch.as_tensor(num_node, device=edge_list.device) + if (edge_list[:, :2] >= num_node).any(): + raise ValueError("`num_node` is %d, but found node %d in `edge_list`" % (num_node, edge_list[:, :2].max())) + return num_node + + def _standarize_num_relation(self, num_relation, edge_list): + if num_relation is None and edge_list.shape[1] > 2: + num_relation = self._maybe_num_relation(edge_list) + if num_relation is not None: + num_relation = torch.as_tensor(num_relation, device=edge_list.device) + if edge_list.shape[1] <= 2: + raise ValueError("`num_relation` is provided, but the number of dims of `edge_list` is less than 3.") + elif (edge_list[:, 2] >= num_relation).any(): + raise ValueError("`num_relation` is %d, but found relation %d in `edge_list`" % (num_relation, edge_list[:, 2].max())) + return num_relation + + def _maybe_num_node(self, edge_list): + warnings.warn("_maybe_num_node() is used to determine the number of nodes. " + "This may underestimate the count if there are isolated nodes.") + if len(edge_list): + return edge_list[:, :2].max().item() + 1 + else: + return 0 + + def _maybe_num_relation(self, edge_list): + warnings.warn("_maybe_num_relation() is used to determine the number of relations. " + "This may underestimate the count if there are unseen relations.") + return edge_list[:, 2].max().item() + 1 + + def _standarize_index(self, index, count): + if isinstance(index, slice): + start = index.start or 0 + if start < 0: + start += count + stop = index.stop or count + if stop < 0: + stop += count + step = index.step or 1 + index = torch.arange(start, stop, step, device=self.device) + else: + index = torch.as_tensor(index, device=self.device) + if index.ndim == 0: + index = index.unsqueeze(0) + if index.dtype == torch.bool: + if index.shape != (count,): + raise IndexError("Invalid mask. Expect mask to have shape %s, but found %s" % + ((int(count),), tuple(index.shape))) + index = index.nonzero().squeeze(-1) + else: + index = index.long() + max_index = -1 if len(index) == 0 else index.max().item() + if max_index >= count: + raise IndexError("Invalid index. Expect index smaller than %d, but found %d" % (count, max_index)) + return index + + def _get_mapping(self, index, count): + index = self._standarize_index(index, count) + if (index.bincount() > 1).any(): + raise ValueError("Can't create mapping for duplicate index") + mapping = -torch.ones(count + 1, dtype=torch.long, device=self.device) + mapping[index] = torch.arange(len(index), device=self.device) + return mapping + + def _get_repeat_pack_offsets(self, num_xs, repeats): + new_num_xs = num_xs.repeat_interleave(repeats) + cum_repeats_shifted = repeats.cumsum(0) - repeats + new_num_xs[cum_repeats_shifted] -= num_xs + offsets = new_num_xs.cumsum(0) + return offsets + + @classmethod + def from_dense(cls, adjacency, node_feature=None, edge_feature=None): + """ + Create a sparse graph from a dense adjacency matrix. + For zero entries in the adjacency matrix, their edge features will be ignored. + + Parameters: + adjacency (array_like): adjacency matrix of shape :math:`(|V|, |V|)` or :math:`(|V|, |V|, |R|)` + node_feature (array_like): node features of shape :math:`(|V|, ...)` + edge_feature (array_like): edge features of shape :math:`(|V|, |V|, ...)` or :math:`(|V|, |V|, |R|, ...)` + """ + adjacency = torch.as_tensor(adjacency) + if adjacency.shape[0] != adjacency.shape[1]: + raise ValueError("`adjacency` should be a square matrix, but found %d and %d" % adjacency.shape[:2]) + + edge_list = adjacency.nonzero() + edge_weight = adjacency[tuple(edge_list.t())] + num_node = adjacency.shape[0] + num_relation = adjacency.shape[2] if adjacency.ndim > 2 else None + if edge_feature is not None: + edge_feature = torch.as_tensor(edge_feature) + edge_feature = edge_feature[tuple(edge_list.t())] + + return cls(edge_list, edge_weight, num_node, num_relation, node_feature, edge_feature) + + def connected_components(self): + """ + Split this graph into connected components. + + Returns: + (PackedGraph, LongTensor): connected components, number of connected components per graph + """ + node_in, node_out = self.edge_list.t()[:2] + range = torch.arange(self.num_node, device=self.device) + node_in, node_out = torch.cat([node_in, node_out, range]), torch.cat([node_out, node_in, range]) + + # find connected component + # O(|E|d), d is the diameter of the graph + min_neighbor = torch.arange(self.num_node, device=self.device) + last = torch.zeros_like(min_neighbor) + while not torch.equal(min_neighbor, last): + last = min_neighbor + min_neighbor = scatter_min(min_neighbor[node_out], node_in, dim_size=self.num_node)[0] + anchor = torch.unique(min_neighbor) + num_cc = self.node2graph[anchor].bincount(minlength=self.batch_size) + return self.split(min_neighbor), num_cc + + def split(self, node2graph): + """ + Split a graph into multiple disconnected graphs. + + Parameters: + node2graph (array_like): ID of the graph each node belongs to + + Returns: + PackedGraph + """ + node2graph = torch.as_tensor(node2graph, dtype=torch.long, device=self.device) + # coalesce arbitrary graph IDs to [0, n) + _, node2graph = torch.unique(node2graph, return_inverse=True) + num_graph = node2graph.max() + 1 + index = node2graph.argsort() + mapping = torch.zeros_like(index) + mapping[index] = torch.arange(len(index), device=self.device) + + node_in, node_out = self.edge_list.t()[:2] + edge_mask = node2graph[node_in] == node2graph[node_out] + edge2graph = node2graph[node_in] + edge_index = edge2graph.argsort() + edge_index = edge_index[edge_mask[edge_index]] + + prepend = torch.tensor([-1], device=self.device) + is_first_node = torch.diff(node2graph[index], prepend=prepend) > 0 + graph_index = self.node2graph[index[is_first_node]] + + edge_list = self.edge_list.clone() + edge_list[:, :2] = mapping[edge_list[:, :2]] + + num_nodes = node2graph.bincount(minlength=num_graph) + num_edges = edge2graph[edge_index].bincount(minlength=num_graph) + + num_cum_nodes = num_nodes.cumsum(0) + offsets = (num_cum_nodes - num_nodes)[edge2graph[edge_index]] + + data_dict, meta_dict = self.data_mask(index, edge_index, graph_index=graph_index, exclude="graph reference") + + return self.packed_type(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], num_nodes=num_nodes, + num_edges=num_edges, num_relation=self.num_relation, offsets=offsets, + meta_dict=meta_dict, **data_dict) + + @classmethod + def pack(cls, graphs): + """ + Pack a list of graphs into a PackedGraph object. + + Parameters: + graphs (list of Graph): list of graphs + + Returns: + PackedGraph + """ + edge_list = [] + edge_weight = [] + num_nodes = [] + num_edges = [] + num_relation = -1 + num_cum_node = 0 + num_cum_edge = 0 + num_graph = 0 + data_dict = defaultdict(list) + meta_dict = graphs[0].meta_dict + for graph in graphs: + edge_list.append(graph.edge_list) + edge_weight.append(graph.edge_weight) + num_nodes.append(graph.num_node) + num_edges.append(graph.num_edge) + for k, v in graph.data_dict.items(): + for type in meta_dict[k]: + if type == "graph": + v = v.unsqueeze(0) + elif type == "node reference": + v = v + num_cum_node + elif type == "edge reference": + v = v + num_cum_edge + elif type == "graph reference": + v = v + num_graph + data_dict[k].append(v) + if num_relation == -1: + num_relation = graph.num_relation + elif num_relation != graph.num_relation: + raise ValueError("Inconsistent `num_relation` in graphs. Expect %d but got %d." + % (num_relation, graph.num_relation)) + num_cum_node += graph.num_node + num_cum_edge += graph.num_edge + num_graph += 1 + + edge_list = torch.cat(edge_list) + edge_weight = torch.cat(edge_weight) + data_dict = {k: torch.cat(v) for k, v in data_dict.items()} + + return cls.packed_type(edge_list, edge_weight=edge_weight, num_nodes=num_nodes, num_edges=num_edges, + num_relation=num_relation, meta_dict=meta_dict, **data_dict) + + def repeat(self, count): + """ + Repeat this graph. + + Parameters: + count (int): number of repetitions + + Returns: + PackedGraph + """ + edge_list = self.edge_list.repeat(count, 1) + edge_weight = self.edge_weight.repeat(count) + num_nodes = [self.num_node] * count + num_edges = [self.num_edge] * count + num_relation = self.num_relation + + data_dict = {} + for k, v in self.data_dict.items(): + if "graph" in self.meta_dict[k]: + v = v.unsqueeze(0) + shape = [1] * v.ndim + shape[0] = count + length = len(v) + v = v.repeat(shape) + for type in self.meta_dict[k]: + if type == "node reference": + offsets = torch.arange(count, device=self.device) * self.num_node + v = v + offsets.repeat_interleave(length) + elif type == "edge reference": + offsets = torch.arange(count, device=self.device) * self.num_edge + v = v + offsets.repeat_interleave(length) + elif type == "graph reference": + offsets = torch.arange(count, device=self.device) + v = v + offsets.repeat_interleave(length) + data_dict[k] = v + + return self.packed_type(edge_list, edge_weight=edge_weight, num_nodes=num_nodes, num_edges=num_edges, + num_relation=num_relation, meta_dict=self.meta_dict, **data_dict) + + def get_edge(self, edge): + """ + Get the weight of of an edge. + + Parameters: + edge (array_like): index of shape :math:`(2,)` or :math:`(3,)` + + Returns: + Tensor: weight of the edge + """ + if len(edge) != self.edge_list.shape[1]: + raise ValueError("Incorrect edge index. Expect %d axes but got %d axes" + % (self.edge_list.shape[1], len(edge))) + + edge_index, num_match = self.match(edge) + return self.edge_weight[edge_index].sum() + + def match(self, pattern): + """ + Return all matched indexes for each pattern. Support patterns with ``-1`` as the wildcard. + + Parameters: + pattern (array_like): index of shape :math:`(N, 2)` or :math:`(N, 3)` + + Returns: + (LongTensor, LongTensor): matched indexes, number of matches per edge + + Examples:: + + >>> graph = data.Graph([[0, 1], [1, 0], [1, 2], [2, 1], [2, 0], [0, 2]]) + >>> index, num_match = graph.match([[0, -1], [1, 2]]) + >>> assert (index == torch.tensor([0, 5, 2])).all() + >>> assert (num_match == torch.tensor([2, 1])).all() + + """ + if len(pattern) == 0: + index = num_match = torch.zeros(0, dtype=torch.long, device=self.device) + return index, num_match + + if not hasattr(self, "edge_inverted_index"): + self.edge_inverted_index = {} + pattern = torch.as_tensor(pattern, dtype=torch.long, device=self.device) + if pattern.ndim == 1: + pattern = pattern.unsqueeze(0) + mask = pattern != -1 + scale = 2 ** torch.arange(pattern.shape[-1], device=self.device) + query_type = (mask * scale).sum(dim=-1) + query_index = query_type.argsort() + num_query = query_type.unique(return_counts=True)[1] + query_ends = num_query.cumsum(0) + query_starts = query_ends - num_query + mask_set = mask[query_index[query_starts]].tolist() + + type_ranges = [] + type_orders = [] + # get matched range for each query type + for i, mask in enumerate(mask_set): + query_type = tuple(mask) + type_index = query_index[query_starts[i]: query_ends[i]] + type_edge = pattern[type_index][:, mask] + if query_type not in self.edge_inverted_index: + self.edge_inverted_index[query_type] = self._build_edge_inverted_index(mask) + inverted_range, order = self.edge_inverted_index[query_type] + ranges = inverted_range.get(type_edge, default=0) + type_ranges.append(ranges) + type_orders.append(order) + ranges = torch.cat(type_ranges) + orders = torch.stack(type_orders) + types = torch.arange(len(mask_set), device=self.device) + types = types.repeat_interleave(num_query) + + # reorder matched ranges according to the query order + ranges = scatter_add(ranges, query_index, dim=0, dim_size=len(pattern)) + types = scatter_add(types, query_index, dim_size=len(pattern)) + # convert range to indexes + starts, ends = ranges.t() + num_match = ends - starts + offsets = num_match.cumsum(0) - num_match + types = types.repeat_interleave(num_match) + ranges = torch.arange(num_match.sum(), device=self.device) + ranges = ranges + (starts - offsets).repeat_interleave(num_match) + index = orders[types, ranges] + + return index, num_match + + def _build_edge_inverted_index(self, mask): + keys = self.edge_list[:, mask] + base = torch.tensor(self.shape, device=self.device) + base = base[mask] + max = reduce(int.__mul__, base.tolist()) + if max > torch.iinfo(torch.int64).max: + raise ValueError("Fail to build an inverted index table based on sorting. " + "The graph is too large.") + scale = base.cumprod(0) + scale = torch.div(scale[-1], scale, rounding_mode="floor") + key = (keys * scale).sum(dim=-1) + order = key.argsort() + num_keys = key.unique(return_counts=True)[1] + ends = num_keys.cumsum(0) + starts = ends - num_keys + ranges = torch.stack([starts, ends], dim=-1) + keys_set = keys[order[starts]] + inverted_range = Dictionary(keys_set, ranges) + return inverted_range, order + + def __getitem__(self, index): + # why do we check tuple? + # case 1: x[0, 1] is parsed as (0, 1) + # case 2: x[[0, 1]] is parsed as [0, 1] + if not isinstance(index, tuple): + index = (index,) + index = list(index) + + while len(index) < 2: + index.append(slice(None)) + if len(index) > 2: + raise ValueError("Graph has only 2 axis, but %d axis is indexed" % len(index)) + + if all([isinstance(axis_index, int) for axis_index in index]): + return self.get_edge(index) + + edge_list = self.edge_list.clone() + for i, axis_index in enumerate(index): + axis_index = self._standarize_index(axis_index, self.num_node) + mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device) + mapping[axis_index] = axis_index + edge_list[:, i] = mapping[edge_list[:, i]] + edge_index = (edge_list >= 0).all(dim=-1) + + return self.edge_mask(edge_index) + + def __len__(self): + return 1 + + @property + def batch_size(self): + """Batch size.""" + return 1 + + def subgraph(self, index): + """ + Return a subgraph based on the specified nodes. + Equivalent to :meth:`node_mask(index, compact=True) <node_mask>`. + + Parameters: + index (array_like): node index + + Returns: + Graph + + See also: + :meth:`Graph.node_mask` + """ + return self.node_mask(index, compact=True) + + def data_mask(self, node_index=None, edge_index=None, graph_index=None, include=None, exclude=None): + data_dict, meta_dict = self.data_by_meta(include, exclude) + node_mapping = None + edge_mapping = None + graph_mapping = None + for k, v in data_dict.items(): + for type in meta_dict[k]: + if type == "node" and node_index is not None: + v = v[node_index] + elif type == "edge" and edge_index is not None: + v = v[edge_index] + elif type == "graph" and graph_index is not None: + v = v.unsqueeze(0)[graph_index] + elif type == "node reference" and node_index is not None: + if node_mapping is None: + node_mapping = self._get_mapping(node_index, self.num_node) + v = node_mapping[v] + elif type == "edge reference" and edge_index is not None: + if edge_mapping is None: + edge_mapping = self._get_mapping(edge_index, self.num_edge) + v = edge_mapping[v] + elif type == "graph reference" and graph_index is not None: + if graph_mapping is None: + graph_mapping = self._get_mapping(graph_index, self.batch_size) + v = graph_mapping[v] + data_dict[k] = v + + return data_dict, meta_dict + + def node_mask(self, index, compact=False): + """ + Return a masked graph based on the specified nodes. + + This function can also be used to re-order the nodes. + + Parameters: + index (array_like): node index + compact (bool, optional): compact node ids or not + + Returns: + Graph + + Examples:: + + >>> graph = data.Graph.from_dense(torch.eye(3)) + >>> assert graph.node_mask([1, 2]).adjacency.shape == (3, 3) + >>> assert graph.node_mask([1, 2], compact=True).adjacency.shape == (2, 2) + + """ + index = self._standarize_index(index, self.num_node) + mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device) + if compact: + mapping[index] = torch.arange(len(index), device=self.device) + num_node = len(index) + else: + mapping[index] = index + num_node = self.num_node + + edge_list = self.edge_list.clone() + edge_list[:, :2] = mapping[edge_list[:, :2]] + edge_index = (edge_list[:, :2] >= 0).all(dim=-1) + + if compact: + data_dict, meta_dict = self.data_mask(index, edge_index) + else: + data_dict, meta_dict = self.data_mask(edge_index=edge_index) + + return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], num_node=num_node, + num_relation=self.num_relation, meta_dict=meta_dict, **data_dict) + + def compact(self): + """ + Remove isolated nodes and compact node ids. + + Returns: + Graph + """ + index = self.degree_out + self.degree_in > 0 + return self.subgraph(index) + + def edge_mask(self, index): + """ + Return a masked graph based on the specified edges. + + This function can also be used to re-order the edges. + + Parameters: + index (array_like): edge index + + Returns: + Graph + """ + index = self._standarize_index(index, self.num_edge) + data_dict, meta_dict = self.data_mask(edge_index=index) + + return type(self)(self.edge_list[index], edge_weight=self.edge_weight[index], num_node=self.num_node, + num_relation=self.num_relation, meta_dict=meta_dict, **data_dict) + + def line_graph(self): + """ + Construct a line graph of this graph. + The node feature of the line graph is inherited from the edge feature of the original graph. + + In the line graph, each node corresponds to an edge in the original graph. + For a pair of edges (a, b) and (b, c) that share the same intermediate node in the original graph, + there is a directed edge (a, b) -> (b, c) in the line graph. + + Returns: + Graph + """ + node_in, node_out = self.edge_list.t()[:2] + edge_index = torch.arange(self.num_edge, device=self.device) + edge_in = edge_index[node_out.argsort()] + edge_out = edge_index[node_in.argsort()] + + degree_in = node_in.bincount(minlength=self.num_node) + degree_out = node_out.bincount(minlength=self.num_node) + size = degree_out * degree_in + starts = (size.cumsum(0) - size).repeat_interleave(size) + range = torch.arange(size.sum(), device=self.device) + # each node u has degree_out[u] * degree_in[u] local edges + local_index = range - starts + local_inner_size = degree_in.repeat_interleave(size) + edge_in_offset = (degree_out.cumsum(0) - degree_out).repeat_interleave(size) + edge_out_offset = (degree_in.cumsum(0) - degree_in).repeat_interleave(size) + edge_in_index = torch.div(local_index, local_inner_size, rounding_mode="floor") + edge_in_offset + edge_out_index = local_index % local_inner_size + edge_out_offset + + edge_in = edge_in[edge_in_index] + edge_out = edge_out[edge_out_index] + edge_list = torch.stack([edge_in, edge_out], dim=-1) + node_feature = getattr(self, "edge_feature", None) + num_node = self.num_edge + num_edge = size.sum() + + return Graph(edge_list, num_node=num_node, num_edge=num_edge, node_feature=node_feature) + + def full(self): + """ + Return a fully connected graph over the nodes. + + Returns: + Graph + """ + index = torch.arange(self.num_node, device=self.device) + if self.num_relation: + edge_list = torch.meshgrid(index, index, torch.arange(self.num_relation, device=self.device)) + else: + edge_list = torch.meshgrid(index, index) + edge_list = torch.stack(edge_list).flatten(1) + edge_weight = torch.ones(len(edge_list)) + + data_dict, meta_dict = self.data_by_meta(exclude="edge") + + return type(self)(edge_list, edge_weight=edge_weight, num_node=self.num_node, num_relation=self.num_relation, + meta_dict=meta_dict, **data_dict) + + def directed(self, order=None): + """ + Mask the edges to create a directed graph. + Edges that go from a node index to a larger or equal node index will be kept. + + Parameters: + order (Tensor, optional): topological order of the nodes + """ + node_in, node_out = self.edge_list.t()[:2] + if order is not None: + edge_index = order[node_in] <= order[node_out] + else: + edge_index = node_in <= node_out + + return self.edge_mask(edge_index) + + def undirected(self, add_inverse=False): + """ + Flip all the edges to create an undirected graph. + + For knowledge graphs, the flipped edges can either have the original relation or an inverse relation. + The inverse relation for relation :math:`r` is defined as :math:`|R| + r`. + + Parameters: + add_inverse (bool, optional): whether to use inverse relations for flipped edges + """ + edge_list = self.edge_list.clone() + edge_list[:, :2] = edge_list[:, :2].flip(1) + num_relation = self.num_relation + if num_relation and add_inverse: + edge_list[:, 2] += num_relation + num_relation = num_relation * 2 + edge_list = torch.stack([self.edge_list, edge_list], dim=1).flatten(0, 1) + + index = torch.arange(self.num_edge, device=self.device).unsqueeze(-1).expand(-1, 2).flatten() + data_dict, meta_dict = self.data_mask(edge_index=index) + + return type(self)(edge_list, edge_weight=self.edge_weight[index], num_node=self.num_node, + num_relation=num_relation, meta_dict=meta_dict, **data_dict) + + @utils.cached_property + def adjacency(self): + """ + Adjacency matrix of this graph. + + If :attr:`num_relation` is specified, a sparse tensor of shape :math:`(|V|, |V|, num\_relation)` will be + returned. + Otherwise, a sparse tensor of shape :math:`(|V|, |V|)` will be returned. + """ + return utils.sparse_coo_tensor(self.edge_list.t(), self.edge_weight, self.shape) + + _tensor_names = ["edge_list", "edge_weight", "num_node", "num_relation", "edge_feature"] + + def to_tensors(self): + edge_feature = getattr(self, "edge_feature", torch.tensor(0, device=self.device)) + return self.edge_list, self.edge_weight, self.num_node, self.num_relation, edge_feature + + @classmethod + def from_tensors(cls, tensors): + edge_list, edge_weight, num_node, num_relation, edge_feature = tensors + if edge_feature.ndim == 0: + edge_feature = None + return cls(edge_list, edge_weight, num_node, num_relation, edge_feature=edge_feature) + + @property + def node2graph(self): + """Node id to graph id mapping.""" + return torch.zeros(self.num_node, dtype=torch.long, device=self.device) + + @property + def edge2graph(self): + """Edge id to graph id mapping.""" + return torch.zeros(self.num_edge, dtype=torch.long, device=self.device) + + @utils.cached_property + def degree_out(self): + """ + Weighted number of edges containing each node as output. + + Note this is the **in-degree** in graph theory. + """ + return scatter_add(self.edge_weight, self.edge_list[:, 1], dim_size=self.num_node) + + @utils.cached_property + def degree_in(self): + """ + Weighted number of edges containing each node as input. + + Note this is the **out-degree** in graph theory. + """ + return scatter_add(self.edge_weight, self.edge_list[:, 0], dim_size=self.num_node) + + @property + def edge_list(self): + """List of edges.""" + return self._edge_list + + @property + def edge_weight(self): + """Edge weights.""" + return self._edge_weight + + @property + def device(self): + """Device.""" + return self.edge_list.device + + @property + def requires_grad(self): + return self.edge_weight.requires_grad + + @property + def grad(self): + return self.edge_weight.grad + + @property + def data(self): + return self + + def requires_grad_(self): + self.edge_weight.requires_grad_() + return self + + def size(self, dim=None): + if self.num_relation: + size = torch.Size((self.num_node, self.num_node, self.num_relation)) + else: + size = torch.Size((self.num_node, self.num_node)) + if dim is None: + return size + return size[dim] + + @property + def shape(self): + return self.size() + + def copy_(self, src): + """ + Copy data from ``src`` into ``self`` and return ``self``. + + The ``src`` graph must have the same set of attributes as ``self``. + """ + self.edge_list.copy_(src.edge_list) + self.edge_weight.copy_(src.edge_weight) + self.num_node.copy_(src.num_node) + self.num_edge.copy_(src.num_edge) + if self.num_relation is not None: + self.num_relation.copy_(src.num_relation) + + keys = set(self.data_dict.keys()) + src_keys = set(src.data_dict.keys()) + if keys != src_keys: + raise RuntimeError("Attributes mismatch. Trying to assign attributes %s, " + "but current graph has attributes %s" % (src_keys, keys)) + for k, v in self.data_dict.items(): + v.copy_(src.data_dict[k]) + + return self + + def detach(self): + """ + Detach this graph. + """ + return type(self)(self.edge_list.detach(), edge_weight=self.edge_weight.detach(), + num_node=self.num_node, num_relation=self.num_relation, + meta_dict=self.meta_dict, **utils.detach(self.data_dict)) + + def clone(self): + """ + Clone this graph. + """ + return type(self)(self.edge_list.clone(), edge_weight=self.edge_weight.clone(), + num_node=self.num_node, num_relation=self.num_relation, + meta_dict=self.meta_dict, **utils.clone(self.data_dict)) + + def cuda(self, *args, **kwargs): + """ + Return a copy of this graph in CUDA memory. + + This is a non-op if the graph is already on the correct device. + """ + edge_list = self.edge_list.cuda(*args, **kwargs) + + if edge_list is self.edge_list: + return self + else: + return type(self)(edge_list, edge_weight=self.edge_weight, + num_node=self.num_node, num_relation=self.num_relation, + meta_dict=self.meta_dict, **utils.cuda(self.data_dict, *args, **kwargs)) + + def cpu(self): + """ + Return a copy of this graph in CPU memory. + + This is a non-op if the graph is already in CPU memory. + """ + edge_list = self.edge_list.cpu() + + if edge_list is self.edge_list: + return self + else: + return type(self)(edge_list, edge_weight=self.edge_weight, num_node=self.num_node, + num_relation=self.num_relation, meta_dict=self.meta_dict, **utils.cpu(self.data_dict)) + + def to(self, device, *args, **kwargs): + """ + Return a copy of this graph on the given device. + """ + device = torch.device(device) + if device.type == "cpu": + return self.cpu(*args, **kwargs) + else: + return self.cuda(device, *args, **kwargs) + + def __repr__(self): + fields = ["num_node=%d" % self.num_node, "num_edge=%d" % self.num_edge] + if self.num_relation is not None: + fields.append("num_relation=%d" % self.num_relation) + if self.device.type != "cpu": + fields.append("device='%s'" % self.device) + return "%s(%s)" % (self.__class__.__name__, ", ".join(fields)) + + def visualize(self, title=None, save_file=None, figure_size=(3, 3), ax=None, layout="spring"): + """ + Visualize this graph with matplotlib. + + Parameters: + title (str, optional): title for this graph + save_file (str, optional): ``png`` or ``pdf`` file to save visualization. + If not provided, show the figure in window. + figure_size (tuple of int, optional): width and height of the figure + ax (matplotlib.axes.Axes, optional): axis to plot the figure + layout (str, optional): graph layout + + See also: + `NetworkX graph layout`_ + + .. _NetworkX graph layout: + https://networkx.github.io/documentation/stable/reference/drawing.html#module-networkx.drawing.layout + """ + is_root = ax is None + if ax is None: + fig = plt.figure(figsize=figure_size) + if title is not None: + ax = plt.gca() + else: + ax = fig.add_axes([0, 0, 1, 1]) + if title is not None: + ax.set_title(title) + + edge_list = self.edge_list[:, :2].tolist() + G = nx.DiGraph(edge_list) + G.add_nodes_from(range(self.num_node)) + if hasattr(nx, "%s_layout" % layout): + func = getattr(nx, "%s_layout" % layout) + else: + raise ValueError("Unknown networkx layout `%s`" % layout) + if layout == "spring" or layout == "random": + pos = func(G, seed=0) + else: + pos = func(G) + nx.draw_networkx(G, pos, ax=ax) + if self.num_relation: + edge_labels = self.edge_list[:, 2].tolist() + edge_labels = {tuple(e): l for e, l in zip(edge_list, edge_labels)} + nx.draw_networkx_edge_labels(G, pos, edge_labels, ax=ax) + ax.set_frame_on(False) + + if is_root: + if save_file: + fig.savefig(save_file) + else: + fig.show() + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + return NotImplemented + + def __getstate__(self): + state = {} + cls = self.__class__ + for k, v in self.__dict__.items(): + # do not pickle property / cached property + if hasattr(cls, k) and isinstance(getattr(cls, k), property): + continue + state[k] = v + return state + + +class PackedGraph(Graph): + """ + Container for sparse graphs with variadic sizes. + + To create a PackedGraph from Graph objects + + >>> batch = data.Graph.pack(graphs) + + To retrieve Graph objects from a PackedGraph + + >>> graphs = batch.unpack() + + .. warning:: + + Edges of the same graph are guaranteed to be consecutive in the edge list. + However, this class doesn't enforce any order on the edges. + + Parameters: + edge_list (array_like, optional): list of edges of shape :math:`(|E|, 2)` or :math:`(|E|, 3)`. + Each tuple is (node_in, node_out) or (node_in, node_out, relation). + edge_weight (array_like, optional): edge weights of shape :math:`(|E|,)` + num_nodes (array_like, optional): number of nodes in each graph + By default, it will be inferred from the largest id in `edge_list` + num_edges (array_like, optional): number of edges in each graph + num_relation (int, optional): number of relations + node_feature (array_like, optional): node features of shape :math:`(|V|, ...)` + edge_feature (array_like, optional): edge features of shape :math:`(|E|, ...)` + offsets (array_like, optional): node id offsets of shape :math:`(|E|,)`. + If not provided, nodes in `edge_list` should be relative index, i.e., the index in each graph. + If provided, nodes in `edge_list` should be absolute index, i.e., the index in the packed graph. + """ + + unpacked_type = Graph + + def __init__(self, edge_list=None, edge_weight=None, num_nodes=None, num_edges=None, num_relation=None, + offsets=None, **kwargs): + edge_list, num_nodes, num_edges, num_cum_nodes, num_cum_edges, offsets = \ + self._get_cumulative(edge_list, num_nodes, num_edges, offsets) + + if offsets is None: + offsets = self._get_offsets(num_nodes, num_edges, num_cum_nodes) + edge_list = edge_list.clone() + edge_list[:, :2] += offsets.unsqueeze(-1) + + num_node = num_nodes.sum() + if (edge_list[:, :2] >= num_node).any(): + raise ValueError("Sum of `num_nodes` is %d, but found %d in `edge_list`" % + (num_node, edge_list[:, :2].max())) + + self._offsets = offsets + self.num_nodes = num_nodes + self.num_edges = num_edges + self.num_cum_nodes = num_cum_nodes + self.num_cum_edges = num_cum_edges + + super(PackedGraph, self).__init__(edge_list, edge_weight=edge_weight, num_node=num_node, + num_relation=num_relation, **kwargs) + + def _get_offsets(self, num_nodes=None, num_edges=None, num_cum_nodes=None, num_cum_edges=None): + if num_nodes is None: + prepend = torch.tensor([0], device=self.device) + num_nodes = torch.diff(num_cum_nodes, prepend=prepend) + if num_edges is None: + prepend = torch.tensor([0], device=self.device) + num_edges = torch.diff(num_cum_edges, prepend=prepend) + if num_cum_nodes is None: + num_cum_nodes = num_nodes.cumsum(0) + return (num_cum_nodes - num_nodes).repeat_interleave(num_edges) + + def merge(self, graph2graph): + """ + Merge multiple graphs into a single graph. + + Parameters: + graph2graph (array_like): ID of the new graph each graph belongs to + """ + graph2graph = torch.as_tensor(graph2graph, dtype=torch.long, device=self.device) + # coalesce arbitrary graph IDs to [0, n) + _, graph2graph = torch.unique(graph2graph, return_inverse=True) + + graph_key = graph2graph * self.batch_size + torch.arange(self.batch_size, device=self.device) + graph_index = graph_key.argsort() + graph = self.subbatch(graph_index) + graph2graph = graph2graph[graph_index] + + num_graph = graph2graph[-1] + 1 + num_nodes = scatter_add(graph.num_nodes, graph2graph, dim_size=num_graph) + num_edges = scatter_add(graph.num_edges, graph2graph, dim_size=num_graph) + offsets = self._get_offsets(num_nodes, num_edges) + + data_dict, meta_dict = graph.data_mask(exclude="graph") + + return type(self)(graph.edge_list, edge_weight=graph.edge_weight, num_nodes=num_nodes, + num_edges=num_edges, num_relation=graph.num_relation, offsets=offsets, + meta_dict=meta_dict, **data_dict) + + def unpack(self): + """ + Unpack this packed graph into a list of graphs. + + Returns: + list of Graph + """ + graphs = [] + for i in range(self.batch_size): + graphs.append(self.get_item(i)) + return graphs + + def __iter__(self): + self._iter_index = 0 + return self + + def __next__(self): + if self._iter_index < self.batch_size: + item = self[self._iter_index] + self._iter_index += 1 + return item + raise StopIteration + + def _check_attribute(self, key, value): + for type in self._meta_contexts: + if "reference" in type: + if value.dtype != torch.long: + raise TypeError("Tensors used as reference must be long tensors") + if type == "node": + if len(value) != self.num_node: + raise ValueError("Expect node attribute `%s` to have shape (%d, *), but found %s" % + (key, self.num_node, value.shape)) + elif type == "edge": + if len(value) != self.num_edge: + raise ValueError("Expect edge attribute `%s` to have shape (%d, *), but found %s" % + (key, self.num_edge, value.shape)) + elif type == "graph": + if len(value) != self.batch_size: + raise ValueError("Expect graph attribute `%s` to have shape (%d, *), but found %s" % + (key, self.batch_size, value.shape)) + elif type == "node reference": + is_valid = (value >= -1) & (value < self.num_node) + if not is_valid.all(): + error_value = value[~is_valid] + raise ValueError("Expect node reference in [-1, %d), but found %d" % + (self.num_node, error_value[0])) + elif type == "edge reference": + is_valid = (value >= -1) & (value < self.num_edge) + if not is_valid.all(): + error_value = value[~is_valid] + raise ValueError("Expect edge reference in [-1, %d), but found %d" % + (self.num_edge, error_value[0])) + elif type == "graph reference": + is_valid = (value >= -1) & (value < self.batch_size) + if not is_valid.all(): + error_value = value[~is_valid] + raise ValueError("Expect graph reference in [-1, %d), but found %d" % + (self.batch_size, error_value[0])) + + def unpack_data(self, data, type="auto"): + """ + Unpack node or edge data according to the packed graph. + + Parameters: + data (Tensor): data to unpack + type (str, optional): data type. Can be ``auto``, ``node``, or ``edge``. + + Returns: + list of Tensor + """ + if type == "auto": + if self.num_node == self.num_edge: + raise ValueError("Ambiguous type. Please specify either `node` or `edge`") + if len(data) == self.num_node: + type = "node" + elif len(data) == self.num_edge: + type = "edge" + else: + raise ValueError("Graph has %d nodes and %d edges, but data has %d entries" % + (self.num_node, self.num_edge, len(data))) + data_list = [] + if type == "node": + for i in range(self.batch_size): + data_list.append(data[self.num_cum_nodes[i] - self.num_nodes[i]: self.num_cum_nodes[i]]) + elif type == "edge": + for i in range(self.batch_size): + data_list.append(data[self.num_cum_edges[i] - self.num_edges[i]: self.num_cum_edges[i]]) + + return data_list + + def repeat(self, count): + """ + Repeat this packed graph. This function behaves similarly to `torch.Tensor.repeat`_. + + .. _torch.Tensor.repeat: + https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html + + Parameters: + count (int): number of repetitions + + Returns: + PackedGraph + """ + num_nodes = self.num_nodes.repeat(count) + num_edges = self.num_edges.repeat(count) + offsets = self._get_offsets(num_nodes, num_edges) + edge_list = self.edge_list.repeat(count, 1) + edge_list[:, :2] += (offsets - self._offsets.repeat(count)).unsqueeze(-1) + + data_dict = {} + for k, v in self.data_dict.items(): + shape = [1] * v.ndim + shape[0] = count + length = len(v) + v = v.repeat(shape) + for _type in self.meta_dict[k]: + if _type == "node reference": + pack_offsets = torch.arange(count, device=self.device) * self.num_node + v = v + pack_offsets.repeat_interleave(length) + elif _type == "edge reference": + pack_offsets = torch.arange(count, device=self.device) * self.num_edge + v = v + pack_offsets.repeat_interleave(length) + elif _type == "graph reference": + pack_offsets = torch.arange(count, device=self.device) * self.batch_size + v = v + pack_offsets.repeat_interleave(length) + data_dict[k] = v + + return type(self)(edge_list, edge_weight=self.edge_weight.repeat(count), + num_nodes=num_nodes, num_edges=num_edges, num_relation=self.num_relation, + offsets=offsets, meta_dict=self.meta_dict, **data_dict) + + def repeat_interleave(self, repeats): + """ + Repeat this packed graph. This function behaves similarly to `torch.repeat_interleave`_. + + .. _torch.repeat_interleave: + https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html + + Parameters: + repeats (Tensor or int): number of repetitions for each graph + + Returns: + PackedGraph + """ + repeats = torch.as_tensor(repeats, dtype=torch.long, device=self.device) + if repeats.numel() == 1: + repeats = repeats * torch.ones(self.batch_size, dtype=torch.long, device=self.device) + num_nodes = self.num_nodes.repeat_interleave(repeats) + num_edges = self.num_edges.repeat_interleave(repeats) + num_cum_nodes = num_nodes.cumsum(0) + num_cum_edges = num_edges.cumsum(0) + num_node = num_nodes.sum() + num_edge = num_edges.sum() + batch_size = repeats.sum() + num_graphs = torch.ones(batch_size, device=self.device) + + # special case 1: graphs[i] may have no node or no edge + # special case 2: repeats[i] may be 0 + cum_repeats_shifted = repeats.cumsum(0) - repeats + graph_mask = cum_repeats_shifted < batch_size + cum_repeats_shifted = cum_repeats_shifted[graph_mask] + + index = num_cum_nodes - num_nodes + index = torch.cat([index, index[cum_repeats_shifted]]) + value = torch.cat([-num_nodes, self.num_nodes[graph_mask]]) + mask = index < num_node + node_index = scatter_add(value[mask], index[mask], dim_size=num_node) + node_index = (node_index + 1).cumsum(0) - 1 + + index = num_cum_edges - num_edges + index = torch.cat([index, index[cum_repeats_shifted]]) + value = torch.cat([-num_edges, self.num_edges[graph_mask]]) + mask = index < num_edge + edge_index = scatter_add(value[mask], index[mask], dim_size=num_edge) + edge_index = (edge_index + 1).cumsum(0) - 1 + + graph_index = torch.repeat_interleave(repeats) + + offsets = self._get_offsets(num_nodes, num_edges) + edge_list = self.edge_list[edge_index] + edge_list[:, :2] += (offsets - self._offsets[edge_index]).unsqueeze(-1) + + node_offsets = None + edge_offsets = None + graph_offsets = None + data_dict = {} + for k, v in self.data_dict.items(): + num_xs = None + pack_offsets = None + for _type in self.meta_dict[k]: + if _type == "node": + v = v[node_index] + num_xs = num_nodes + elif _type == "edge": + v = v[edge_index] + num_xs = num_edges + elif _type == "graph": + v = v[graph_index] + num_xs = num_graphs + elif _type == "node reference": + if node_offsets is None: + node_offsets = self._get_repeat_pack_offsets(self.num_nodes, repeats) + pack_offsets = node_offsets + elif _type == "edge reference": + if edge_offsets is None: + edge_offsets = self._get_repeat_pack_offsets(self.num_edges, repeats) + pack_offsets = edge_offsets + elif _type == "graph reference": + if graph_offsets is None: + graph_offsets = self._get_repeat_pack_offsets(num_graphs, repeats) + pack_offsets = graph_offsets + # add offsets to make references point to indexes in their own graph + if num_xs is not None and pack_offsets is not None: + v = v + pack_offsets.repeat_interleave(num_xs) + data_dict[k] = v + + return type(self)(edge_list, edge_weight=self.edge_weight[edge_index], + num_nodes=num_nodes, num_edges=num_edges, num_relation=self.num_relation, + offsets=offsets, meta_dict=self.meta_dict, **data_dict) + + def get_item(self, index): + """ + Get the i-th graph from this packed graph. + + Parameters: + index (int): graph index + + Returns: + Graph + """ + node_index = torch.arange(self.num_cum_nodes[index] - self.num_nodes[index], self.num_cum_nodes[index], + device=self.device) + edge_index = torch.arange(self.num_cum_edges[index] - self.num_edges[index], self.num_cum_edges[index], + device=self.device) + graph_index = index + edge_list = self.edge_list[edge_index].clone() + edge_list[:, :2] -= self._offsets[edge_index].unsqueeze(-1) + data_dict, meta_dict = self.data_mask(node_index, edge_index, graph_index=graph_index) + + return self.unpacked_type(edge_list, edge_weight=self.edge_weight[edge_index], num_node=self.num_nodes[index], + num_relation=self.num_relation, meta_dict=meta_dict, **data_dict) + + def _get_cumulative(self, edge_list, num_nodes, num_edges, offsets): + if edge_list is None: + raise ValueError("`edge_list` should be provided") + if num_edges is None: + raise ValueError("`num_edges` should be provided") + + edge_list = torch.as_tensor(edge_list) + num_edges = torch.as_tensor(num_edges, device=edge_list.device) + num_edge = num_edges.sum() + if num_edge != len(edge_list): + raise ValueError("Sum of `num_edges` is %d, but found %d edges in `edge_list`" % (num_edge, len(edge_list))) + num_cum_edges = num_edges.cumsum(0) + + if offsets is None: + _edge_list = edge_list + else: + offsets = torch.as_tensor(offsets, device=edge_list.device) + _edge_list = edge_list.clone() + _edge_list[:, :2] -= offsets.unsqueeze(-1) + if num_nodes is None: + num_nodes = [] + for num_edge, num_cum_edge in zip(num_edges, num_cum_edges): + num_nodes.append(self._maybe_num_node(_edge_list[num_cum_edge - num_edge: num_cum_edge])) + num_nodes = torch.as_tensor(num_nodes, device=edge_list.device) + num_cum_nodes = num_nodes.cumsum(0) + + return edge_list, num_nodes, num_edges, num_cum_nodes, num_cum_edges, offsets + + def _get_num_xs(self, index, num_cum_xs): + x = torch.zeros(num_cum_xs[-1], dtype=torch.long, device=self.device) + x[index] = 1 + num_cum_indexes = x.cumsum(0) + num_cum_indexes = torch.cat([torch.zeros(1, dtype=torch.long, device=self.device), num_cum_indexes]) + new_num_cum_xs = num_cum_indexes[num_cum_xs] + prepend = torch.zeros(1, dtype=torch.long, device=self.device) + new_num_xs = torch.diff(new_num_cum_xs, prepend=prepend) + return new_num_xs + + def data_mask(self, node_index=None, edge_index=None, graph_index=None, include=None, exclude=None): + data_dict, meta_dict = self.data_by_meta(include, exclude) + node_mapping = None + edge_mapping = None + graph_mapping = None + for k, v in data_dict.items(): + for type in meta_dict[k]: + if type == "node" and node_index is not None: + v = v[node_index] + elif type == "edge" and edge_index is not None: + v = v[edge_index] + elif type == "graph" and graph_index is not None: + v = v[graph_index] + elif type == "node reference" and node_index is not None: + if node_mapping is None: + node_mapping = self._get_mapping(node_index, self.num_node) + v = node_mapping[v] + elif type == "edge reference" and edge_index is not None: + if edge_mapping is None: + edge_mapping = self._get_mapping(edge_index, self.num_edge) + v = edge_mapping[v] + elif type == "graph reference" and graph_index is not None: + if graph_mapping is None: + graph_mapping = self._get_mapping(graph_index, self.batch_size) + v = graph_mapping[v] + data_dict[k] = v + + return data_dict, meta_dict + + def __getitem__(self, index): + # why do we check tuple? + # case 1: x[0, 1] is parsed as (0, 1) + # case 2: x[[0, 1]] is parsed as [0, 1] + if not isinstance(index, tuple): + index = (index,) + + if isinstance(index[0], int): + item = self.get_item(index[0]) + if len(index) > 1: + item = item[index[1:]] + return item + if len(index) > 1: + raise ValueError("Complex indexing is not supported for PackedGraph") + + index = self._standarize_index(index[0], self.batch_size) + count = index.bincount(minlength=self.batch_size) + if self.batch_size > 0 and count.max() > 1: + graph = self.repeat_interleave(count) + index_order = index.argsort() + order = torch.zeros_like(index) + order[index_order] = torch.arange(len(index), dtype=torch.long, device=self.device) + return graph.subbatch(order) + + return self.subbatch(index) + + def __len__(self): + return len(self.num_nodes) + + def full(self): + """ + Return a pack of fully connected graphs. + + This is useful for computing node-pair-wise features. + The computation can be implemented as message passing over a fully connected graph. + + Returns: + PackedGraph + """ + # TODO: more efficient implementation? + graphs = self.unpack() + graphs = [graph.full() for graph in graphs] + return graphs[0].pack(graphs) + + @utils.cached_property + def node2graph(self): + """Node id to graph id mapping.""" + node2graph = torch.repeat_interleave(self.num_nodes) + return node2graph + + @utils.cached_property + def edge2graph(self): + """Edge id to graph id mapping.""" + edge2graph = torch.repeat_interleave(self.num_edges) + return edge2graph + + @property + def batch_size(self): + """Batch size.""" + return len(self.num_nodes) + + def node_mask(self, index, compact=False): + """ + Return a masked packed graph based on the specified nodes. + + Note the compact option is only applied to node ids but not graph ids. + To generate compact graph ids, use :meth:`subbatch`. + + Parameters: + index (array_like): node index + compact (bool, optional): compact node ids or not + + Returns: + PackedGraph + """ + index = self._standarize_index(index, self.num_node) + mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device) + if compact: + mapping[index] = torch.arange(len(index), device=self.device) + num_nodes = self._get_num_xs(index, self.num_cum_nodes) + offsets = self._get_offsets(num_nodes, self.num_edges) + else: + mapping[index] = index + num_nodes = self.num_nodes + offsets = self._offsets + + edge_list = self.edge_list.clone() + edge_list[:, :2] = mapping[edge_list[:, :2]] + edge_index = (edge_list[:, :2] >= 0).all(dim=-1) + num_edges = self._get_num_xs(edge_index, self.num_cum_edges) + + if compact: + data_dict, meta_dict = self.data_mask(index, edge_index) + else: + data_dict, meta_dict = self.data_mask(edge_index=edge_index) + + return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], num_nodes=num_nodes, + num_edges=num_edges, num_relation=self.num_relation, offsets=offsets[edge_index], + meta_dict=meta_dict, **data_dict) + + def edge_mask(self, index): + """ + Return a masked packed graph based on the specified edges. + + Parameters: + index (array_like): edge index + + Returns: + PackedGraph + """ + index = self._standarize_index(index, self.num_edge) + data_dict, meta_dict = self.data_mask(edge_index=index) + num_edges = self._get_num_xs(index, self.num_cum_edges) + + return type(self)(self.edge_list[index], edge_weight=self.edge_weight[index], num_nodes=self.num_nodes, + num_edges=num_edges, num_relation=self.num_relation, offsets=self._offsets[index], + meta_dict=meta_dict, **data_dict) + + def graph_mask(self, index, compact=False): + """ + Return a masked packed graph based on the specified graphs. + + This function can also be used to re-order the graphs. + + Parameters: + index (array_like): graph index + compact (bool, optional): compact graph ids or not + + Returns: + PackedGraph + """ + index = self._standarize_index(index, self.batch_size) + graph_mapping = -torch.ones(self.batch_size, dtype=torch.long, device=self.device) + graph_mapping[index] = torch.arange(len(index), device=self.device) + + node_index = graph_mapping[self.node2graph] >= 0 + node_index = self._standarize_index(node_index, self.num_node) + mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device) + if compact: + key = graph_mapping[self.node2graph[node_index]] * self.num_node + node_index + order = key.argsort() + node_index = node_index[order] + mapping[node_index] = torch.arange(len(node_index), device=self.device) + num_nodes = self.num_nodes[index] + else: + mapping[node_index] = node_index + num_nodes = torch.zeros_like(self.num_nodes) + num_nodes[index] = self.num_nodes[index] + + edge_list = self.edge_list.clone() + edge_list[:, :2] = mapping[edge_list[:, :2]] + edge_index = (edge_list[:, :2] >= 0).all(dim=-1) + edge_index = self._standarize_index(edge_index, self.num_edge) + if compact: + key = graph_mapping[self.edge2graph[edge_index]] * self.num_edge + edge_index + order = key.argsort() + edge_index = edge_index[order] + num_edges = self.num_edges[index] + else: + num_edges = torch.zeros_like(self.num_edges) + num_edges[index] = self.num_edges[index] + offsets = self._get_offsets(num_nodes, num_edges) + + if compact: + data_dict, meta_dict = self.data_mask(node_index, edge_index, graph_index=index) + else: + data_dict, meta_dict = self.data_mask(edge_index=edge_index) + + return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], num_nodes=num_nodes, + num_edges=num_edges, num_relation=self.num_relation, offsets=offsets, + meta_dict=meta_dict, **data_dict) + + def subbatch(self, index): + """ + Return a subbatch based on the specified graphs. + Equivalent to :meth:`graph_mask(index, compact=True) <graph_mask>`. + + Parameters: + index (array_like): graph index + + Returns: + PackedGraph + + See also: + :meth:`PackedGraph.graph_mask` + """ + return self.graph_mask(index, compact=True) + + def line_graph(self): + """ + Construct a packed line graph of this packed graph. + The node features of the line graphs are inherited from the edge features of the original graphs. + + In the line graph, each node corresponds to an edge in the original graph. + For a pair of edges (a, b) and (b, c) that share the same intermediate node in the original graph, + there is a directed edge (a, b) -> (b, c) in the line graph. + + Returns: + PackedGraph + """ + node_in, node_out = self.edge_list.t()[:2] + edge_index = torch.arange(self.num_edge, device=self.device) + edge_in = edge_index[node_out.argsort()] + edge_out = edge_index[node_in.argsort()] + + degree_in = node_in.bincount(minlength=self.num_node) + degree_out = node_out.bincount(minlength=self.num_node) + size = degree_out * degree_in + starts = (size.cumsum(0) - size).repeat_interleave(size) + range = torch.arange(size.sum(), device=self.device) + # each node u has degree_out[u] * degree_in[u] local edges + local_index = range - starts + local_inner_size = degree_in.repeat_interleave(size) + edge_in_offset = (degree_out.cumsum(0) - degree_out).repeat_interleave(size) + edge_out_offset = (degree_in.cumsum(0) - degree_in).repeat_interleave(size) + edge_in_index = torch.div(local_index, local_inner_size, rounding_mode="floor") + edge_in_offset + edge_out_index = local_index % local_inner_size + edge_out_offset + + edge_in = edge_in[edge_in_index] + edge_out = edge_out[edge_out_index] + edge_list = torch.stack([edge_in, edge_out], dim=-1) + node_feature = getattr(self, "edge_feature", None) + num_nodes = self.num_edges + num_edges = scatter_add(size, self.node2graph, dim=0, dim_size=self.batch_size) + offsets = self._get_offsets(num_nodes, num_edges) + + return PackedGraph(edge_list, num_nodes=num_nodes, num_edges=num_edges, offsets=offsets, + node_feature=node_feature) + + def undirected(self, add_inverse=False): + """ + Flip all the edges to create undirected graphs. + + For knowledge graphs, the flipped edges can either have the original relation or an inverse relation. + The inverse relation for relation :math:`r` is defined as :math:`|R| + r`. + + Parameters: + add_inverse (bool, optional): whether to use inverse relations for flipped edges + """ + edge_list = self.edge_list.clone() + edge_list[:, :2] = edge_list[:, :2].flip(1) + num_relation = self.num_relation + if num_relation and add_inverse: + edge_list[:, 2] += num_relation + num_relation = num_relation * 2 + edge_list = torch.stack([self.edge_list, edge_list], dim=1).flatten(0, 1) + offsets = self._offsets.unsqueeze(-1).expand(-1, 2).flatten() + + index = torch.arange(self.num_edge, device=self.device).unsqueeze(-1).expand(-1, 2).flatten() + data_dict, meta_dict = self.data_mask(edge_index=index, exclude="edge reference") + + return type(self)(edge_list, edge_weight=self.edge_weight[index], num_nodes=self.num_nodes, + num_edges=self.num_edges * 2, num_relation=num_relation, offsets=offsets, + meta_dict=meta_dict, **data_dict) + + def detach(self): + """ + Detach this packed graph. + """ + return type(self)(self.edge_list.detach(), edge_weight=self.edge_weight.detach(), + num_nodes=self.num_nodes, num_edges=self.num_edges, num_relation=self.num_relation, + offsets=self._offsets, meta_dict=self.meta_dict, **utils.detach(self.data_dict)) + + def clone(self): + """ + Clone this packed graph. + """ + return type(self)(self.edge_list.clone(), edge_weight=self.edge_weight.clone(), + num_nodes=self.num_nodes, num_edges=self.num_edges, num_relation=self.num_relation, + offsets=self._offsets, meta_dict=self.meta_dict, **utils.clone(self.data_dict)) + + def cuda(self, *args, **kwargs): + """ + Return a copy of this packed graph in CUDA memory. + + This is a non-op if the graph is already on the correct device. + """ + edge_list = self.edge_list.cuda(*args, **kwargs) + + if edge_list is self.edge_list: + return self + else: + return type(self)(edge_list, edge_weight=self.edge_weight, + num_nodes=self.num_nodes, num_edges=self.num_edges, num_relation=self.num_relation, + offsets=self._offsets, meta_dict=self.meta_dict, + **utils.cuda(self.data_dict, *args, **kwargs)) + + def cpu(self): + """ + Return a copy of this packed graph in CPU memory. + + This is a non-op if the graph is already in CPU memory. + """ + edge_list = self.edge_list.cpu() + + if edge_list is self.edge_list: + return self + else: + return type(self)(edge_list, edge_weight=self.edge_weight, + num_nodes=self.num_nodes, num_edges=self.num_edges, num_relation=self.num_relation, + offsets=self._offsets, meta_dict=self.meta_dict, **utils.cpu(self.data_dict)) + + def __repr__(self): + fields = ["batch_size=%d" % self.batch_size, + "num_nodes=%s" % pretty.long_array(self.num_nodes.tolist()), + "num_edges=%s" % pretty.long_array(self.num_edges.tolist())] + if self.num_relation is not None: + fields.append("num_relation=%d" % self.num_relation) + if self.device.type != "cpu": + fields.append("device='%s'" % self.device) + return "%s(%s)" % (self.__class__.__name__, ", ".join(fields)) + + def visualize(self, titles=None, save_file=None, figure_size=(3, 3), layout="spring", num_row=None, num_col=None): + """ + Visualize the packed graphs with matplotlib. + + Parameters: + titles (list of str, optional): title for each graph. Default is the ID of each graph. + save_file (str, optional): ``png`` or ``pdf`` file to save visualization. + If not provided, show the figure in window. + figure_size (tuple of int, optional): width and height of the figure + layout (str, optional): graph layout + num_row (int, optional): number of rows in the figure + num_col (int, optional): number of columns in the figure + + See also: + `NetworkX graph layout`_ + + .. _NetworkX graph layout: + https://networkx.github.io/documentation/stable/reference/drawing.html#module-networkx.drawing.layout + """ + if titles is None: + graph = self.get_item(0) + titles = ["%s %d" % (type(graph).__name__, i) for i in range(self.batch_size)] + if num_col is None: + if num_row is None: + num_col = math.ceil(self.batch_size ** 0.5) + else: + num_col = math.ceil(self.batch_size / num_row) + if num_row is None: + num_row = math.ceil(self.batch_size / num_col) + + figure_size = (num_col * figure_size[0], num_row * figure_size[1]) + fig = plt.figure(figsize=figure_size) + + for i in range(self.batch_size): + graph = self.get_item(i) + ax = fig.add_subplot(num_row, num_col, i + 1) + graph.visualize(title=titles[i], ax=ax, layout=layout) + # remove the space of axis labels + fig.tight_layout() + + if save_file: + fig.savefig(save_file) + else: + fig.show() + + +Graph.packed_type = PackedGraph + + +def cat(graphs): + for i, graph in enumerate(graphs): + if not isinstance(graph, PackedGraph): + graphs[i] = graph.pack([graph]) + + edge_list = torch.cat([graph.edge_list for graph in graphs]) + pack_num_nodes = torch.stack([graph.num_node for graph in graphs]) + pack_num_edges = torch.stack([graph.num_edge for graph in graphs]) + pack_num_cum_edges = pack_num_edges.cumsum(0) + graph_index = pack_num_cum_edges < len(edge_list) + pack_offsets = scatter_add(pack_num_nodes[graph_index], pack_num_cum_edges[graph_index], + dim_size=len(edge_list)) + pack_offsets = pack_offsets.cumsum(0) + + edge_list[:, :2] += pack_offsets.unsqueeze(-1) + offsets = torch.cat([graph._offsets for graph in graphs]) + pack_offsets + + edge_weight = torch.cat([graph.edge_weight for graph in graphs]) + num_nodes = torch.cat([graph.num_nodes for graph in graphs]) + num_edges = torch.cat([graph.num_edges for graph in graphs]) + num_relation = graphs[0].num_relation + assert all(graph.num_relation == num_relation for graph in graphs) + + # only keep attributes that exist in all graphs + # TODO: this interface is not safe. re-design the interface + keys = set(graphs[0].meta_dict.keys()) + for graph in graphs: + keys = keys.intersection(graph.meta_dict.keys()) + + meta_dict = {k: graphs[0].meta_dict[k] for k in keys} + data_dict = {} + for k in keys: + data_dict[k] = torch.cat([graph.data_dict[k] for graph in graphs]) + + return type(graphs[0])(edge_list, edge_weight=edge_weight, + num_nodes=num_nodes, num_edges=num_edges, num_relation=num_relation, offsets=offsets, + meta_dict=meta_dict, **data_dict) \ No newline at end of file