|
a |
|
b/model/network/triplet.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import torch.nn.functional as F |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
class TripletLoss(nn.Module): |
|
|
7 |
def __init__(self, batch_size, hard_or_full, margin): |
|
|
8 |
super(TripletLoss, self).__init__() |
|
|
9 |
self.batch_size = batch_size |
|
|
10 |
self.margin = margin |
|
|
11 |
|
|
|
12 |
def forward(self, feature, label): |
|
|
13 |
# feature: [n, m, d], label: [n, m] |
|
|
14 |
n, m, d = feature.size() |
|
|
15 |
hp_mask = (label.unsqueeze(1) == label.unsqueeze(2)).byte().view(-1) |
|
|
16 |
hn_mask = (label.unsqueeze(1) != label.unsqueeze(2)).byte().view(-1) |
|
|
17 |
|
|
|
18 |
dist = self.batch_dist(feature) |
|
|
19 |
mean_dist = dist.mean(1).mean(1) |
|
|
20 |
dist = dist.view(-1) |
|
|
21 |
# hard |
|
|
22 |
hard_hp_dist = torch.max(torch.masked_select(dist, hp_mask).view(n, m, -1), 2)[0] |
|
|
23 |
hard_hn_dist = torch.min(torch.masked_select(dist, hn_mask).view(n, m, -1), 2)[0] |
|
|
24 |
hard_loss_metric = F.relu(self.margin + hard_hp_dist - hard_hn_dist).view(n, -1) |
|
|
25 |
|
|
|
26 |
hard_loss_metric_mean = torch.mean(hard_loss_metric, 1) |
|
|
27 |
|
|
|
28 |
# non-zero full |
|
|
29 |
full_hp_dist = torch.masked_select(dist, hp_mask).view(n, m, -1, 1) |
|
|
30 |
full_hn_dist = torch.masked_select(dist, hn_mask).view(n, m, 1, -1) |
|
|
31 |
full_loss_metric = F.relu(self.margin + full_hp_dist - full_hn_dist).view(n, -1) |
|
|
32 |
|
|
|
33 |
full_loss_metric_sum = full_loss_metric.sum(1) |
|
|
34 |
full_loss_num = (full_loss_metric != 0).sum(1).float() |
|
|
35 |
|
|
|
36 |
full_loss_metric_mean = full_loss_metric_sum / full_loss_num |
|
|
37 |
full_loss_metric_mean[full_loss_num == 0] = 0 |
|
|
38 |
|
|
|
39 |
return full_loss_metric_mean, hard_loss_metric_mean, mean_dist, full_loss_num |
|
|
40 |
|
|
|
41 |
def batch_dist(self, x): |
|
|
42 |
x2 = torch.sum(x ** 2, 2) |
|
|
43 |
dist = x2.unsqueeze(2) + x2.unsqueeze(2).transpose(1, 2) - 2 * torch.matmul(x, x.transpose(1, 2)) |
|
|
44 |
dist = torch.sqrt(F.relu(dist)) |
|
|
45 |
return dist |