a | b/shepherd/decoders.py | ||
---|---|---|---|
1 | import torch |
||
2 | import torch.nn.functional as F |
||
3 | |||
4 | |||
5 | def bilinear(s, r, t): |
||
6 | return torch.sum(s * r * t, dim = 1) |
||
7 | |||
8 | def trans(s, r, t): |
||
9 | return -torch.norm(s + r - t, dim = 1) |
||
10 | |||
11 | def dot(s, t): |
||
12 | return torch.sum(s * t, dim = 1) |