Diff of /pretrain/trans.py [000000] .. [c3444c]

Switch to unified view

a b/pretrain/trans.py
1
import torch
2
from torch import nn
3
import torch.nn.functional as F
4
5
6
class TransE(nn.Module):
7
    def __init__(self, margin=1.0):
8
        super(TransE, self).__init__()
9
        self.margin = margin
10
11
    def forward(self, cui_0, cui_1, cui_2, re):
12
        pos = cui_0 + re - cui_1
13
        neg = cui_0 + re - cui_2
14
        return torch.mean(F.relu(self.margin + torch.norm(pos, p=2, dim=1) - torch.norm(neg, p=2, dim=1)))