--- a +++ b/pretrain/trans.py @@ -0,0 +1,14 @@ +import torch +from torch import nn +import torch.nn.functional as F + + +class TransE(nn.Module): + def __init__(self, margin=1.0): + super(TransE, self).__init__() + self.margin = margin + + def forward(self, cui_0, cui_1, cui_2, re): + pos = cui_0 + re - cui_1 + neg = cui_0 + re - cui_2 + return torch.mean(F.relu(self.margin + torch.norm(pos, p=2, dim=1) - torch.norm(neg, p=2, dim=1)))