a b/opengait/modeling/losses/triplet.py
1
import torch
2
import torch.nn.functional as F
3
4
from .base import BaseLoss, gather_and_scale_wrapper
5
6
7
class TripletLoss(BaseLoss):
8
    def __init__(self, margin, loss_term_weight=1.0):
9
        super(TripletLoss, self).__init__(loss_term_weight)
10
        self.margin = margin
11
12
    @gather_and_scale_wrapper
13
    def forward(self, embeddings, labels):
14
        # embeddings: [n, c, p], label: [n]
15
        embeddings = embeddings.permute(
16
            2, 0, 1).contiguous().float()  # [n, c, p] -> [p, n, c]
17
18
        ref_embed, ref_label = embeddings, labels
19
        dist = self.ComputeDistance(embeddings, ref_embed)  # [p, n1, n2]
20
        mean_dist = dist.mean((1, 2))  # [p]
21
        ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist)
22
        dist_diff = (ap_dist - an_dist).view(dist.size(0), -1)
23
        loss = F.relu(dist_diff + self.margin)
24
25
        hard_loss = torch.max(loss, -1)[0]
26
        loss_avg, loss_num = self.AvgNonZeroReducer(loss)
27
28
        self.info.update({
29
            'loss': loss_avg.detach().clone(),
30
            'hard_loss': hard_loss.detach().clone(),
31
            'loss_num': loss_num.detach().clone(),
32
            'mean_dist': mean_dist.detach().clone()})
33
34
        return loss_avg, self.info
35
36
    def AvgNonZeroReducer(self, loss):
37
        eps = 1.0e-9
38
        loss_sum = loss.sum(-1)
39
        loss_num = (loss != 0).sum(-1).float()
40
41
        loss_avg = loss_sum / (loss_num + eps)
42
        loss_avg[loss_num == 0] = 0
43
        return loss_avg, loss_num
44
45
    def ComputeDistance(self, x, y):
46
        """
47
            x: [p, n_x, c]
48
            y: [p, n_y, c]
49
        """
50
        x2 = torch.sum(x ** 2, -1).unsqueeze(2)  # [p, n_x, 1]
51
        y2 = torch.sum(y ** 2, -1).unsqueeze(1)  # [p, 1, n_y]
52
        inner = x.matmul(y.transpose(1, 2))  # [p, n_x, n_y]
53
        dist = x2 + y2 - 2 * inner
54
        dist = torch.sqrt(F.relu(dist))  # [p, n_x, n_y]
55
        return dist
56
57
    def Convert2Triplets(self, row_labels, clo_label, dist):
58
        """
59
            row_labels: tensor with size [n_r]
60
            clo_label : tensor with size [n_c]
61
        """
62
        matches = (row_labels.unsqueeze(1) ==
63
                   clo_label.unsqueeze(0)).bool()  # [n_r, n_c]
64
        diffenc = torch.logical_not(matches)  # [n_r, n_c]
65
        p, n, _ = dist.size()
66
        ap_dist = dist[:, matches].view(p, n, -1, 1)
67
        an_dist = dist[:, diffenc].view(p, n, 1, -1)
68
        return ap_dist, an_dist