a | b/pretrain/trans.py | ||
---|---|---|---|
1 | import torch |
||
2 | from torch import nn |
||
3 | import torch.nn.functional as F |
||
4 | |||
5 | |||
6 | class TransE(nn.Module): |
||
7 | def __init__(self, margin=1.0): |
||
8 | super(TransE, self).__init__() |
||
9 | self.margin = margin |
||
10 | |||
11 | def forward(self, cui_0, cui_1, cui_2, re): |
||
12 | pos = cui_0 + re - cui_1 |
||
13 | neg = cui_0 + re - cui_2 |
||
14 | return torch.mean(F.relu(self.margin + torch.norm(pos, p=2, dim=1) - torch.norm(neg, p=2, dim=1))) |