Diff of /model/network/triplet.py [000000] .. [40f229]

Switch to unified view

a b/model/network/triplet.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
5
6
class TripletLoss(nn.Module):
7
    def __init__(self, batch_size, hard_or_full, margin):
8
        super(TripletLoss, self).__init__()
9
        self.batch_size = batch_size
10
        self.margin = margin
11
12
    def forward(self, feature, label):
13
        # feature: [n, m, d], label: [n, m]
14
        n, m, d = feature.size()
15
        hp_mask = (label.unsqueeze(1) == label.unsqueeze(2)).byte().view(-1)
16
        hn_mask = (label.unsqueeze(1) != label.unsqueeze(2)).byte().view(-1)
17
18
        dist = self.batch_dist(feature)
19
        mean_dist = dist.mean(1).mean(1)
20
        dist = dist.view(-1)
21
        # hard
22
        hard_hp_dist = torch.max(torch.masked_select(dist, hp_mask).view(n, m, -1), 2)[0]
23
        hard_hn_dist = torch.min(torch.masked_select(dist, hn_mask).view(n, m, -1), 2)[0]
24
        hard_loss_metric = F.relu(self.margin + hard_hp_dist - hard_hn_dist).view(n, -1)
25
26
        hard_loss_metric_mean = torch.mean(hard_loss_metric, 1)
27
28
        # non-zero full
29
        full_hp_dist = torch.masked_select(dist, hp_mask).view(n, m, -1, 1)
30
        full_hn_dist = torch.masked_select(dist, hn_mask).view(n, m, 1, -1)
31
        full_loss_metric = F.relu(self.margin + full_hp_dist - full_hn_dist).view(n, -1)
32
33
        full_loss_metric_sum = full_loss_metric.sum(1)
34
        full_loss_num = (full_loss_metric != 0).sum(1).float()
35
36
        full_loss_metric_mean = full_loss_metric_sum / full_loss_num
37
        full_loss_metric_mean[full_loss_num == 0] = 0
38
39
        return full_loss_metric_mean, hard_loss_metric_mean, mean_dist, full_loss_num
40
41
    def batch_dist(self, x):
42
        x2 = torch.sum(x ** 2, 2)
43
        dist = x2.unsqueeze(2) + x2.unsqueeze(2).transpose(1, 2) - 2 * torch.matmul(x, x.transpose(1, 2))
44
        dist = torch.sqrt(F.relu(dist))
45
        return dist