Diff of /shepherd/decoders.py [000000] .. [db6163]

Switch to side-by-side view

--- a
+++ b/shepherd/decoders.py
@@ -0,0 +1,12 @@
+import torch
+import torch.nn.functional as F
+
+
+def bilinear(s, r, t):
+    return torch.sum(s * r * t, dim = 1)
+
+def trans(s, r, t):
+    return -torch.norm(s + r - t, dim = 1)
+
+def dot(s, t):
+    return torch.sum(s * t, dim = 1)