[c3444c]: / pretrain / trans.py

Download this file

15 lines (11 with data), 433 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
import torch
from torch import nn
import torch.nn.functional as F
class TransE(nn.Module):
def __init__(self, margin=1.0):
super(TransE, self).__init__()
self.margin = margin
def forward(self, cui_0, cui_1, cui_2, re):
pos = cui_0 + re - cui_1
neg = cui_0 + re - cui_2
return torch.mean(F.relu(self.margin + torch.norm(pos, p=2, dim=1) - torch.norm(neg, p=2, dim=1)))