--- a
+++ b/pretrain/trans.py
@@ -0,0 +1,14 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+class TransE(nn.Module):
+    def __init__(self, margin=1.0):
+        super(TransE, self).__init__()
+        self.margin = margin
+
+    def forward(self, cui_0, cui_1, cui_2, re):
+        pos = cui_0 + re - cui_1
+        neg = cui_0 + re - cui_2
+        return torch.mean(F.relu(self.margin + torch.norm(pos, p=2, dim=1) - torch.norm(neg, p=2, dim=1)))