Switch to unified view

a b/torchdrug/tasks/reasoning.py
1
import torch
2
from torch import nn
3
from torch.nn import functional as F
4
from torch.utils import data as torch_data
5
6
from torchdrug import core, tasks
7
from torchdrug.layers import functional
8
from torchdrug.core import Registry as R
9
10
11
@R.register("tasks.KnowledgeGraphCompletion")
12
class KnowledgeGraphCompletion(tasks.Task, core.Configurable):
13
    """
14
    Knowledge graph completion task.
15
16
    This class provides routines for the family of knowledge graph embedding models.
17
18
    Parameters:
19
        model (nn.Module): knowledge graph completion model
20
        criterion (str, list or dict, optional): training criterion(s). For dict, the keys are criterions and the values
21
            are the corresponding weights. Available criterions are ``bce``, ``ce`` and ``ranking``.
22
        metric (str or list of str, optional): metric(s). Available metrics are ``mr``, ``mrr`` and ``hits@K``.
23
        num_negative (int, optional): number of negative samples per positive sample
24
        margin (float, optional): margin in ranking criterion
25
        adversarial_temperature (float, optional): temperature for self-adversarial negative sampling.
26
            Set ``0`` to disable self-adversarial negative sampling.
27
        strict_negative (bool, optional): use strict negative sampling or not
28
        fact_ratio (float, optional): split the training set into facts and labels.
29
            Set ``None`` to use the whole training set as both facts and labels.
30
        sample_weight (bool, optional): whether to down-weight triplets from entities of large degrees
31
        filtered_ranking (bool, optional): use filtered or unfiltered ranking for evaluation
32
        full_batch_eval (bool, optional): whether to feed test negative samples by full batch or mini batch.
33
            Full batch speeds up evaluation significantly, but may cause OOM problems for some models and datasets.
34
    """
35
    _option_members = {"criterion", "metric"}
36
37
    def __init__(self, model, criterion="bce", metric=("mr", "mrr", "hits@1", "hits@3", "hits@10"),
38
                 num_negative=128, margin=6, adversarial_temperature=0, strict_negative=True, fact_ratio=None,
39
                 sample_weight=True, filtered_ranking=True, full_batch_eval=False):
40
        super(KnowledgeGraphCompletion, self).__init__()
41
        self.model = model
42
        self.criterion = criterion
43
        self.metric = metric
44
        self.num_negative = num_negative
45
        self.margin = margin
46
        self.adversarial_temperature = adversarial_temperature
47
        self.strict_negative = strict_negative
48
        self.fact_ratio = fact_ratio
49
        self.sample_weight = sample_weight
50
        self.filtered_ranking = filtered_ranking
51
        self.full_batch_eval = full_batch_eval
52
53
    def preprocess(self, train_set, valid_set, test_set):
54
        if isinstance(train_set, torch_data.Subset):
55
            dataset = train_set.dataset
56
        else:
57
            dataset = train_set
58
        self.num_entity = dataset.num_entity
59
        self.num_relation = dataset.num_relation
60
        self.register_buffer("graph", dataset.graph)
61
        fact_mask = torch.ones(len(dataset), dtype=torch.bool)
62
        fact_mask[valid_set.indices] = 0
63
        fact_mask[test_set.indices] = 0
64
        if self.fact_ratio:
65
            length = int(len(train_set) * self.fact_ratio)
66
            index = torch.randperm(len(train_set))[length:]
67
            train_indices = torch.tensor(train_set.indices)
68
            fact_mask[train_indices[index]] = 0
69
            train_set = torch_data.Subset(train_set, index)
70
        self.register_buffer("fact_graph", dataset.graph.edge_mask(fact_mask))
71
72
        if self.sample_weight:
73
            degree_hr = torch.zeros(self.num_entity, self.num_relation, dtype=torch.long)
74
            degree_tr = torch.zeros(self.num_entity, self.num_relation, dtype=torch.long)
75
            for h, t, r in train_set:
76
                degree_hr[h, r] += 1
77
                degree_tr[t, r] += 1
78
            self.register_buffer("degree_hr", degree_hr)
79
            self.register_buffer("degree_tr", degree_tr)
80
81
        return train_set, valid_set, test_set
82
83
    def forward(self, batch, all_loss=None, metric=None):
84
        """"""
85
        all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
86
        metric = {}
87
88
        pred = self.predict(batch, all_loss, metric)
89
        pos_h_index, pos_t_index, pos_r_index = batch.t()
90
91
        for criterion, weight in self.criterion.items():
92
            if criterion == "bce":
93
                target = torch.zeros_like(pred)
94
                target[:, 0] = 1
95
                loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none")
96
97
                neg_weight = torch.ones_like(pred)
98
                if self.adversarial_temperature > 0:
99
                    with torch.no_grad():
100
                        neg_weight[:, 1:] = F.softmax(pred[:, 1:] / self.adversarial_temperature, dim=-1)
101
                else:
102
                    neg_weight[:, 1:] = 1 / self.num_negative
103
                loss = (loss * neg_weight).sum(dim=-1) / neg_weight.sum(dim=-1)
104
            elif criterion == "ce":
105
                target = torch.zeros(len(pred), dtype=torch.long, device=self.device)
106
                loss = F.cross_entropy(pred, target, reduction="none")
107
            elif criterion == "ranking":
108
                positive = pred[:, :1]
109
                negative = pred[:, 1:]
110
                target = torch.ones_like(negative)
111
                loss = F.margin_ranking_loss(positive, negative, target, margin=self.margin)
112
            else:
113
                raise ValueError("Unknown criterion `%s`" % criterion)
114
115
            if self.sample_weight:
116
                sample_weight = self.degree_hr[pos_h_index, pos_r_index] * self.degree_tr[pos_t_index, pos_r_index]
117
                sample_weight = 1 / sample_weight.float().sqrt()
118
                loss = (loss * sample_weight).sum() / sample_weight.sum()
119
            else:
120
                loss = loss.mean()
121
122
            name = tasks._get_criterion_name(criterion)
123
            metric[name] = loss
124
            all_loss += loss * weight
125
126
        return all_loss, metric
127
128
    def predict(self, batch, all_loss=None, metric=None):
129
        pos_h_index, pos_t_index, pos_r_index = batch.t()
130
        batch_size = len(batch)
131
132
        if all_loss is None:
133
            # test
134
            all_index = torch.arange(self.num_entity, device=self.device)
135
            t_preds = []
136
            h_preds = []
137
            num_negative = self.num_entity if self.full_batch_eval else self.num_negative
138
            for neg_index in all_index.split(num_negative):
139
                r_index = pos_r_index.unsqueeze(-1).expand(-1, len(neg_index))
140
                h_index, t_index = torch.meshgrid(pos_h_index, neg_index)
141
                t_pred = self.model(self.fact_graph, h_index, t_index, r_index, all_loss=all_loss, metric=metric)
142
                t_preds.append(t_pred)
143
            t_pred = torch.cat(t_preds, dim=-1)
144
            for neg_index in all_index.split(num_negative):
145
                r_index = pos_r_index.unsqueeze(-1).expand(-1, len(neg_index))
146
                t_index, h_index = torch.meshgrid(pos_t_index, neg_index)
147
                h_pred = self.model(self.fact_graph, h_index, t_index, r_index, all_loss=all_loss, metric=metric)
148
                h_preds.append(h_pred)
149
            h_pred = torch.cat(h_preds, dim=-1)
150
            pred = torch.stack([t_pred, h_pred], dim=1)
151
            # in case of GPU OOM
152
            pred = pred.cpu()
153
        else:
154
            # train
155
            if self.strict_negative:
156
                neg_index = self._strict_negative(pos_h_index, pos_t_index, pos_r_index)
157
            else:
158
                neg_index = torch.randint(self.num_entity, (batch_size, self.num_negative), device=self.device)
159
            h_index = pos_h_index.unsqueeze(-1).repeat(1, self.num_negative + 1)
160
            t_index = pos_t_index.unsqueeze(-1).repeat(1, self.num_negative + 1)
161
            r_index = pos_r_index.unsqueeze(-1).repeat(1, self.num_negative + 1)
162
            t_index[:batch_size // 2, 1:] = neg_index[:batch_size // 2]
163
            h_index[batch_size // 2:, 1:] = neg_index[batch_size // 2:]
164
            pred = self.model(self.fact_graph, h_index, t_index, r_index, all_loss=all_loss, metric=metric)
165
166
        return pred
167
168
    def target(self, batch):
169
        # test target
170
        batch_size = len(batch)
171
        pos_h_index, pos_t_index, pos_r_index = batch.t()
172
        any = -torch.ones_like(pos_h_index)
173
174
        pattern = torch.stack([pos_h_index, any, pos_r_index], dim=-1)
175
        edge_index, num_t_truth = self.graph.match(pattern)
176
        t_truth_index = self.graph.edge_list[edge_index, 1]
177
        pos_index = torch.repeat_interleave(num_t_truth)
178
        t_mask = torch.ones(batch_size, self.num_entity, dtype=torch.bool, device=self.device)
179
        t_mask[pos_index, t_truth_index] = 0
180
181
        pattern = torch.stack([any, pos_t_index, pos_r_index], dim=-1)
182
        edge_index, num_h_truth = self.graph.match(pattern)
183
        h_truth_index = self.graph.edge_list[edge_index, 0]
184
        pos_index = torch.repeat_interleave(num_h_truth)
185
        h_mask = torch.ones(batch_size, self.num_entity, dtype=torch.bool, device=self.device)
186
        h_mask[pos_index, h_truth_index] = 0
187
188
        mask = torch.stack([t_mask, h_mask], dim=1)
189
        target = torch.stack([pos_t_index, pos_h_index], dim=1)
190
191
        # in case of GPU OOM
192
        return mask.cpu(), target.cpu()
193
194
    def evaluate(self, pred, target):
195
        mask, target = target
196
197
        pos_pred = pred.gather(-1, target.unsqueeze(-1))
198
        if self.filtered_ranking:
199
            ranking = torch.sum((pos_pred <= pred) & mask, dim=-1) + 1
200
        else:
201
            ranking = torch.sum(pos_pred <= pred, dim=-1) + 1
202
203
        metric = {}
204
        for _metric in self.metric:
205
            if _metric == "mr":
206
                score = ranking.float().mean()
207
            elif _metric == "mrr":
208
                score = (1 / ranking.float()).mean()
209
            elif _metric.startswith("hits@"):
210
                threshold = int(_metric[5:])
211
                score = (ranking <= threshold).float().mean()
212
            else:
213
                raise ValueError("Unknown metric `%s`" % _metric)
214
215
            name = tasks._get_metric_name(_metric)
216
            metric[name] = score
217
218
        return metric
219
220
    def visualize(self, batch):
221
        h_index, t_index, r_index = batch.t()
222
        return self.model.visualize(self.fact_graph, h_index, t_index, r_index)
223
224
    @torch.no_grad()
225
    def _strict_negative(self, pos_h_index, pos_t_index, pos_r_index):
226
        batch_size = len(pos_h_index)
227
        any = -torch.ones_like(pos_h_index)
228
229
        pattern = torch.stack([pos_h_index, any, pos_r_index], dim=-1)
230
        pattern = pattern[:batch_size // 2]
231
        edge_index, num_t_truth = self.fact_graph.match(pattern)
232
        t_truth_index = self.fact_graph.edge_list[edge_index, 1]
233
        pos_index = torch.repeat_interleave(num_t_truth)
234
        t_mask = torch.ones(len(pattern), self.num_entity, dtype=torch.bool, device=self.device)
235
        t_mask[pos_index, t_truth_index] = 0
236
        neg_t_candidate = t_mask.nonzero()[:, 1]
237
        num_t_candidate = t_mask.sum(dim=-1)
238
        neg_t_index = functional.variadic_sample(neg_t_candidate, num_t_candidate, self.num_negative)
239
240
        pattern = torch.stack([any, pos_t_index, pos_r_index], dim=-1)
241
        pattern = pattern[batch_size // 2:]
242
        edge_index, num_h_truth = self.fact_graph.match(pattern)
243
        h_truth_index = self.fact_graph.edge_list[edge_index, 0]
244
        pos_index = torch.repeat_interleave(num_h_truth)
245
        h_mask = torch.ones(len(pattern), self.num_entity, dtype=torch.bool, device=self.device)
246
        h_mask[pos_index, h_truth_index] = 0
247
        neg_h_candidate = h_mask.nonzero()[:, 1]
248
        num_h_candidate = h_mask.sum(dim=-1)
249
        neg_h_index = functional.variadic_sample(neg_h_candidate, num_h_candidate, self.num_negative)
250
251
        neg_index = torch.cat([neg_t_index, neg_h_index])
252
253
        return neg_index