Download this file
1 2 3 4 5 6 7 8 9 10 11 12
import torch import torch.nn.functional as F def bilinear(s, r, t): return torch.sum(s * r * t, dim = 1) def trans(s, r, t): return -torch.norm(s + r - t, dim = 1) def dot(s, t): return torch.sum(s * t, dim = 1)