[db6163]: / shepherd / decoders.py

Download this file

13 lines (8 with data), 228 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
import torch
import torch.nn.functional as F
def bilinear(s, r, t):
return torch.sum(s * r * t, dim = 1)
def trans(s, r, t):
return -torch.norm(s + r - t, dim = 1)
def dot(s, t):
return torch.sum(s * t, dim = 1)