[fd9ef4]: / opengait / modeling / losses / triplet.py

Download this file

69 lines (56 with data), 2.4 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.nn.functional as F
from .base import BaseLoss, gather_and_scale_wrapper
class TripletLoss(BaseLoss):
def __init__(self, margin, loss_term_weight=1.0):
super(TripletLoss, self).__init__(loss_term_weight)
self.margin = margin
@gather_and_scale_wrapper
def forward(self, embeddings, labels):
# embeddings: [n, c, p], label: [n]
embeddings = embeddings.permute(
2, 0, 1).contiguous().float() # [n, c, p] -> [p, n, c]
ref_embed, ref_label = embeddings, labels
dist = self.ComputeDistance(embeddings, ref_embed) # [p, n1, n2]
mean_dist = dist.mean((1, 2)) # [p]
ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist)
dist_diff = (ap_dist - an_dist).view(dist.size(0), -1)
loss = F.relu(dist_diff + self.margin)
hard_loss = torch.max(loss, -1)[0]
loss_avg, loss_num = self.AvgNonZeroReducer(loss)
self.info.update({
'loss': loss_avg.detach().clone(),
'hard_loss': hard_loss.detach().clone(),
'loss_num': loss_num.detach().clone(),
'mean_dist': mean_dist.detach().clone()})
return loss_avg, self.info
def AvgNonZeroReducer(self, loss):
eps = 1.0e-9
loss_sum = loss.sum(-1)
loss_num = (loss != 0).sum(-1).float()
loss_avg = loss_sum / (loss_num + eps)
loss_avg[loss_num == 0] = 0
return loss_avg, loss_num
def ComputeDistance(self, x, y):
"""
x: [p, n_x, c]
y: [p, n_y, c]
"""
x2 = torch.sum(x ** 2, -1).unsqueeze(2) # [p, n_x, 1]
y2 = torch.sum(y ** 2, -1).unsqueeze(1) # [p, 1, n_y]
inner = x.matmul(y.transpose(1, 2)) # [p, n_x, n_y]
dist = x2 + y2 - 2 * inner
dist = torch.sqrt(F.relu(dist)) # [p, n_x, n_y]
return dist
def Convert2Triplets(self, row_labels, clo_label, dist):
"""
row_labels: tensor with size [n_r]
clo_label : tensor with size [n_c]
"""
matches = (row_labels.unsqueeze(1) ==
clo_label.unsqueeze(0)).bool() # [n_r, n_c]
diffenc = torch.logical_not(matches) # [n_r, n_c]
p, n, _ = dist.size()
ap_dist = dist[:, matches].view(p, n, -1, 1)
an_dist = dist[:, diffenc].view(p, n, 1, -1)
return ap_dist, an_dist