Diff of /shepherd/decoders.py [000000] .. [db6163]

Switch to unified view

a b/shepherd/decoders.py
1
import torch
2
import torch.nn.functional as F
3
4
5
def bilinear(s, r, t):
6
    return torch.sum(s * r * t, dim = 1)
7
8
def trans(s, r, t):
9
    return -torch.norm(s + r - t, dim = 1)
10
11
def dot(s, t):
12
    return torch.sum(s * t, dim = 1)