--- a
+++ b/model/network/triplet.py
@@ -0,0 +1,45 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class TripletLoss(nn.Module):
+    def __init__(self, batch_size, hard_or_full, margin):
+        super(TripletLoss, self).__init__()
+        self.batch_size = batch_size
+        self.margin = margin
+
+    def forward(self, feature, label):
+        # feature: [n, m, d], label: [n, m]
+        n, m, d = feature.size()
+        hp_mask = (label.unsqueeze(1) == label.unsqueeze(2)).byte().view(-1)
+        hn_mask = (label.unsqueeze(1) != label.unsqueeze(2)).byte().view(-1)
+
+        dist = self.batch_dist(feature)
+        mean_dist = dist.mean(1).mean(1)
+        dist = dist.view(-1)
+        # hard
+        hard_hp_dist = torch.max(torch.masked_select(dist, hp_mask).view(n, m, -1), 2)[0]
+        hard_hn_dist = torch.min(torch.masked_select(dist, hn_mask).view(n, m, -1), 2)[0]
+        hard_loss_metric = F.relu(self.margin + hard_hp_dist - hard_hn_dist).view(n, -1)
+
+        hard_loss_metric_mean = torch.mean(hard_loss_metric, 1)
+
+        # non-zero full
+        full_hp_dist = torch.masked_select(dist, hp_mask).view(n, m, -1, 1)
+        full_hn_dist = torch.masked_select(dist, hn_mask).view(n, m, 1, -1)
+        full_loss_metric = F.relu(self.margin + full_hp_dist - full_hn_dist).view(n, -1)
+
+        full_loss_metric_sum = full_loss_metric.sum(1)
+        full_loss_num = (full_loss_metric != 0).sum(1).float()
+
+        full_loss_metric_mean = full_loss_metric_sum / full_loss_num
+        full_loss_metric_mean[full_loss_num == 0] = 0
+
+        return full_loss_metric_mean, hard_loss_metric_mean, mean_dist, full_loss_num
+
+    def batch_dist(self, x):
+        x2 = torch.sum(x ** 2, 2)
+        dist = x2.unsqueeze(2) + x2.unsqueeze(2).transpose(1, 2) - 2 * torch.matmul(x, x.transpose(1, 2))
+        dist = torch.sqrt(F.relu(dist))
+        return dist