|
a |
|
b/torchdrug/tasks/reasoning.py |
|
|
1 |
import torch |
|
|
2 |
from torch import nn |
|
|
3 |
from torch.nn import functional as F |
|
|
4 |
from torch.utils import data as torch_data |
|
|
5 |
|
|
|
6 |
from torchdrug import core, tasks |
|
|
7 |
from torchdrug.layers import functional |
|
|
8 |
from torchdrug.core import Registry as R |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
@R.register("tasks.KnowledgeGraphCompletion") |
|
|
12 |
class KnowledgeGraphCompletion(tasks.Task, core.Configurable): |
|
|
13 |
""" |
|
|
14 |
Knowledge graph completion task. |
|
|
15 |
|
|
|
16 |
This class provides routines for the family of knowledge graph embedding models. |
|
|
17 |
|
|
|
18 |
Parameters: |
|
|
19 |
model (nn.Module): knowledge graph completion model |
|
|
20 |
criterion (str, list or dict, optional): training criterion(s). For dict, the keys are criterions and the values |
|
|
21 |
are the corresponding weights. Available criterions are ``bce``, ``ce`` and ``ranking``. |
|
|
22 |
metric (str or list of str, optional): metric(s). Available metrics are ``mr``, ``mrr`` and ``hits@K``. |
|
|
23 |
num_negative (int, optional): number of negative samples per positive sample |
|
|
24 |
margin (float, optional): margin in ranking criterion |
|
|
25 |
adversarial_temperature (float, optional): temperature for self-adversarial negative sampling. |
|
|
26 |
Set ``0`` to disable self-adversarial negative sampling. |
|
|
27 |
strict_negative (bool, optional): use strict negative sampling or not |
|
|
28 |
fact_ratio (float, optional): split the training set into facts and labels. |
|
|
29 |
Set ``None`` to use the whole training set as both facts and labels. |
|
|
30 |
sample_weight (bool, optional): whether to down-weight triplets from entities of large degrees |
|
|
31 |
filtered_ranking (bool, optional): use filtered or unfiltered ranking for evaluation |
|
|
32 |
full_batch_eval (bool, optional): whether to feed test negative samples by full batch or mini batch. |
|
|
33 |
Full batch speeds up evaluation significantly, but may cause OOM problems for some models and datasets. |
|
|
34 |
""" |
|
|
35 |
_option_members = {"criterion", "metric"} |
|
|
36 |
|
|
|
37 |
def __init__(self, model, criterion="bce", metric=("mr", "mrr", "hits@1", "hits@3", "hits@10"), |
|
|
38 |
num_negative=128, margin=6, adversarial_temperature=0, strict_negative=True, fact_ratio=None, |
|
|
39 |
sample_weight=True, filtered_ranking=True, full_batch_eval=False): |
|
|
40 |
super(KnowledgeGraphCompletion, self).__init__() |
|
|
41 |
self.model = model |
|
|
42 |
self.criterion = criterion |
|
|
43 |
self.metric = metric |
|
|
44 |
self.num_negative = num_negative |
|
|
45 |
self.margin = margin |
|
|
46 |
self.adversarial_temperature = adversarial_temperature |
|
|
47 |
self.strict_negative = strict_negative |
|
|
48 |
self.fact_ratio = fact_ratio |
|
|
49 |
self.sample_weight = sample_weight |
|
|
50 |
self.filtered_ranking = filtered_ranking |
|
|
51 |
self.full_batch_eval = full_batch_eval |
|
|
52 |
|
|
|
53 |
def preprocess(self, train_set, valid_set, test_set): |
|
|
54 |
if isinstance(train_set, torch_data.Subset): |
|
|
55 |
dataset = train_set.dataset |
|
|
56 |
else: |
|
|
57 |
dataset = train_set |
|
|
58 |
self.num_entity = dataset.num_entity |
|
|
59 |
self.num_relation = dataset.num_relation |
|
|
60 |
self.register_buffer("graph", dataset.graph) |
|
|
61 |
fact_mask = torch.ones(len(dataset), dtype=torch.bool) |
|
|
62 |
fact_mask[valid_set.indices] = 0 |
|
|
63 |
fact_mask[test_set.indices] = 0 |
|
|
64 |
if self.fact_ratio: |
|
|
65 |
length = int(len(train_set) * self.fact_ratio) |
|
|
66 |
index = torch.randperm(len(train_set))[length:] |
|
|
67 |
train_indices = torch.tensor(train_set.indices) |
|
|
68 |
fact_mask[train_indices[index]] = 0 |
|
|
69 |
train_set = torch_data.Subset(train_set, index) |
|
|
70 |
self.register_buffer("fact_graph", dataset.graph.edge_mask(fact_mask)) |
|
|
71 |
|
|
|
72 |
if self.sample_weight: |
|
|
73 |
degree_hr = torch.zeros(self.num_entity, self.num_relation, dtype=torch.long) |
|
|
74 |
degree_tr = torch.zeros(self.num_entity, self.num_relation, dtype=torch.long) |
|
|
75 |
for h, t, r in train_set: |
|
|
76 |
degree_hr[h, r] += 1 |
|
|
77 |
degree_tr[t, r] += 1 |
|
|
78 |
self.register_buffer("degree_hr", degree_hr) |
|
|
79 |
self.register_buffer("degree_tr", degree_tr) |
|
|
80 |
|
|
|
81 |
return train_set, valid_set, test_set |
|
|
82 |
|
|
|
83 |
def forward(self, batch, all_loss=None, metric=None): |
|
|
84 |
"""""" |
|
|
85 |
all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) |
|
|
86 |
metric = {} |
|
|
87 |
|
|
|
88 |
pred = self.predict(batch, all_loss, metric) |
|
|
89 |
pos_h_index, pos_t_index, pos_r_index = batch.t() |
|
|
90 |
|
|
|
91 |
for criterion, weight in self.criterion.items(): |
|
|
92 |
if criterion == "bce": |
|
|
93 |
target = torch.zeros_like(pred) |
|
|
94 |
target[:, 0] = 1 |
|
|
95 |
loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none") |
|
|
96 |
|
|
|
97 |
neg_weight = torch.ones_like(pred) |
|
|
98 |
if self.adversarial_temperature > 0: |
|
|
99 |
with torch.no_grad(): |
|
|
100 |
neg_weight[:, 1:] = F.softmax(pred[:, 1:] / self.adversarial_temperature, dim=-1) |
|
|
101 |
else: |
|
|
102 |
neg_weight[:, 1:] = 1 / self.num_negative |
|
|
103 |
loss = (loss * neg_weight).sum(dim=-1) / neg_weight.sum(dim=-1) |
|
|
104 |
elif criterion == "ce": |
|
|
105 |
target = torch.zeros(len(pred), dtype=torch.long, device=self.device) |
|
|
106 |
loss = F.cross_entropy(pred, target, reduction="none") |
|
|
107 |
elif criterion == "ranking": |
|
|
108 |
positive = pred[:, :1] |
|
|
109 |
negative = pred[:, 1:] |
|
|
110 |
target = torch.ones_like(negative) |
|
|
111 |
loss = F.margin_ranking_loss(positive, negative, target, margin=self.margin) |
|
|
112 |
else: |
|
|
113 |
raise ValueError("Unknown criterion `%s`" % criterion) |
|
|
114 |
|
|
|
115 |
if self.sample_weight: |
|
|
116 |
sample_weight = self.degree_hr[pos_h_index, pos_r_index] * self.degree_tr[pos_t_index, pos_r_index] |
|
|
117 |
sample_weight = 1 / sample_weight.float().sqrt() |
|
|
118 |
loss = (loss * sample_weight).sum() / sample_weight.sum() |
|
|
119 |
else: |
|
|
120 |
loss = loss.mean() |
|
|
121 |
|
|
|
122 |
name = tasks._get_criterion_name(criterion) |
|
|
123 |
metric[name] = loss |
|
|
124 |
all_loss += loss * weight |
|
|
125 |
|
|
|
126 |
return all_loss, metric |
|
|
127 |
|
|
|
128 |
def predict(self, batch, all_loss=None, metric=None): |
|
|
129 |
pos_h_index, pos_t_index, pos_r_index = batch.t() |
|
|
130 |
batch_size = len(batch) |
|
|
131 |
|
|
|
132 |
if all_loss is None: |
|
|
133 |
# test |
|
|
134 |
all_index = torch.arange(self.num_entity, device=self.device) |
|
|
135 |
t_preds = [] |
|
|
136 |
h_preds = [] |
|
|
137 |
num_negative = self.num_entity if self.full_batch_eval else self.num_negative |
|
|
138 |
for neg_index in all_index.split(num_negative): |
|
|
139 |
r_index = pos_r_index.unsqueeze(-1).expand(-1, len(neg_index)) |
|
|
140 |
h_index, t_index = torch.meshgrid(pos_h_index, neg_index) |
|
|
141 |
t_pred = self.model(self.fact_graph, h_index, t_index, r_index, all_loss=all_loss, metric=metric) |
|
|
142 |
t_preds.append(t_pred) |
|
|
143 |
t_pred = torch.cat(t_preds, dim=-1) |
|
|
144 |
for neg_index in all_index.split(num_negative): |
|
|
145 |
r_index = pos_r_index.unsqueeze(-1).expand(-1, len(neg_index)) |
|
|
146 |
t_index, h_index = torch.meshgrid(pos_t_index, neg_index) |
|
|
147 |
h_pred = self.model(self.fact_graph, h_index, t_index, r_index, all_loss=all_loss, metric=metric) |
|
|
148 |
h_preds.append(h_pred) |
|
|
149 |
h_pred = torch.cat(h_preds, dim=-1) |
|
|
150 |
pred = torch.stack([t_pred, h_pred], dim=1) |
|
|
151 |
# in case of GPU OOM |
|
|
152 |
pred = pred.cpu() |
|
|
153 |
else: |
|
|
154 |
# train |
|
|
155 |
if self.strict_negative: |
|
|
156 |
neg_index = self._strict_negative(pos_h_index, pos_t_index, pos_r_index) |
|
|
157 |
else: |
|
|
158 |
neg_index = torch.randint(self.num_entity, (batch_size, self.num_negative), device=self.device) |
|
|
159 |
h_index = pos_h_index.unsqueeze(-1).repeat(1, self.num_negative + 1) |
|
|
160 |
t_index = pos_t_index.unsqueeze(-1).repeat(1, self.num_negative + 1) |
|
|
161 |
r_index = pos_r_index.unsqueeze(-1).repeat(1, self.num_negative + 1) |
|
|
162 |
t_index[:batch_size // 2, 1:] = neg_index[:batch_size // 2] |
|
|
163 |
h_index[batch_size // 2:, 1:] = neg_index[batch_size // 2:] |
|
|
164 |
pred = self.model(self.fact_graph, h_index, t_index, r_index, all_loss=all_loss, metric=metric) |
|
|
165 |
|
|
|
166 |
return pred |
|
|
167 |
|
|
|
168 |
def target(self, batch): |
|
|
169 |
# test target |
|
|
170 |
batch_size = len(batch) |
|
|
171 |
pos_h_index, pos_t_index, pos_r_index = batch.t() |
|
|
172 |
any = -torch.ones_like(pos_h_index) |
|
|
173 |
|
|
|
174 |
pattern = torch.stack([pos_h_index, any, pos_r_index], dim=-1) |
|
|
175 |
edge_index, num_t_truth = self.graph.match(pattern) |
|
|
176 |
t_truth_index = self.graph.edge_list[edge_index, 1] |
|
|
177 |
pos_index = torch.repeat_interleave(num_t_truth) |
|
|
178 |
t_mask = torch.ones(batch_size, self.num_entity, dtype=torch.bool, device=self.device) |
|
|
179 |
t_mask[pos_index, t_truth_index] = 0 |
|
|
180 |
|
|
|
181 |
pattern = torch.stack([any, pos_t_index, pos_r_index], dim=-1) |
|
|
182 |
edge_index, num_h_truth = self.graph.match(pattern) |
|
|
183 |
h_truth_index = self.graph.edge_list[edge_index, 0] |
|
|
184 |
pos_index = torch.repeat_interleave(num_h_truth) |
|
|
185 |
h_mask = torch.ones(batch_size, self.num_entity, dtype=torch.bool, device=self.device) |
|
|
186 |
h_mask[pos_index, h_truth_index] = 0 |
|
|
187 |
|
|
|
188 |
mask = torch.stack([t_mask, h_mask], dim=1) |
|
|
189 |
target = torch.stack([pos_t_index, pos_h_index], dim=1) |
|
|
190 |
|
|
|
191 |
# in case of GPU OOM |
|
|
192 |
return mask.cpu(), target.cpu() |
|
|
193 |
|
|
|
194 |
def evaluate(self, pred, target): |
|
|
195 |
mask, target = target |
|
|
196 |
|
|
|
197 |
pos_pred = pred.gather(-1, target.unsqueeze(-1)) |
|
|
198 |
if self.filtered_ranking: |
|
|
199 |
ranking = torch.sum((pos_pred <= pred) & mask, dim=-1) + 1 |
|
|
200 |
else: |
|
|
201 |
ranking = torch.sum(pos_pred <= pred, dim=-1) + 1 |
|
|
202 |
|
|
|
203 |
metric = {} |
|
|
204 |
for _metric in self.metric: |
|
|
205 |
if _metric == "mr": |
|
|
206 |
score = ranking.float().mean() |
|
|
207 |
elif _metric == "mrr": |
|
|
208 |
score = (1 / ranking.float()).mean() |
|
|
209 |
elif _metric.startswith("hits@"): |
|
|
210 |
threshold = int(_metric[5:]) |
|
|
211 |
score = (ranking <= threshold).float().mean() |
|
|
212 |
else: |
|
|
213 |
raise ValueError("Unknown metric `%s`" % _metric) |
|
|
214 |
|
|
|
215 |
name = tasks._get_metric_name(_metric) |
|
|
216 |
metric[name] = score |
|
|
217 |
|
|
|
218 |
return metric |
|
|
219 |
|
|
|
220 |
def visualize(self, batch): |
|
|
221 |
h_index, t_index, r_index = batch.t() |
|
|
222 |
return self.model.visualize(self.fact_graph, h_index, t_index, r_index) |
|
|
223 |
|
|
|
224 |
@torch.no_grad() |
|
|
225 |
def _strict_negative(self, pos_h_index, pos_t_index, pos_r_index): |
|
|
226 |
batch_size = len(pos_h_index) |
|
|
227 |
any = -torch.ones_like(pos_h_index) |
|
|
228 |
|
|
|
229 |
pattern = torch.stack([pos_h_index, any, pos_r_index], dim=-1) |
|
|
230 |
pattern = pattern[:batch_size // 2] |
|
|
231 |
edge_index, num_t_truth = self.fact_graph.match(pattern) |
|
|
232 |
t_truth_index = self.fact_graph.edge_list[edge_index, 1] |
|
|
233 |
pos_index = torch.repeat_interleave(num_t_truth) |
|
|
234 |
t_mask = torch.ones(len(pattern), self.num_entity, dtype=torch.bool, device=self.device) |
|
|
235 |
t_mask[pos_index, t_truth_index] = 0 |
|
|
236 |
neg_t_candidate = t_mask.nonzero()[:, 1] |
|
|
237 |
num_t_candidate = t_mask.sum(dim=-1) |
|
|
238 |
neg_t_index = functional.variadic_sample(neg_t_candidate, num_t_candidate, self.num_negative) |
|
|
239 |
|
|
|
240 |
pattern = torch.stack([any, pos_t_index, pos_r_index], dim=-1) |
|
|
241 |
pattern = pattern[batch_size // 2:] |
|
|
242 |
edge_index, num_h_truth = self.fact_graph.match(pattern) |
|
|
243 |
h_truth_index = self.fact_graph.edge_list[edge_index, 0] |
|
|
244 |
pos_index = torch.repeat_interleave(num_h_truth) |
|
|
245 |
h_mask = torch.ones(len(pattern), self.num_entity, dtype=torch.bool, device=self.device) |
|
|
246 |
h_mask[pos_index, h_truth_index] = 0 |
|
|
247 |
neg_h_candidate = h_mask.nonzero()[:, 1] |
|
|
248 |
num_h_candidate = h_mask.sum(dim=-1) |
|
|
249 |
neg_h_index = functional.variadic_sample(neg_h_candidate, num_h_candidate, self.num_negative) |
|
|
250 |
|
|
|
251 |
neg_index = torch.cat([neg_t_index, neg_h_index]) |
|
|
252 |
|
|
|
253 |
return neg_index |