|
a |
|
b/pretrain/loss.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
|
|
|
4 |
|
|
|
5 |
class AMSoftmax(nn.Module): |
|
|
6 |
def __init__(self, |
|
|
7 |
in_feats, |
|
|
8 |
n_classes=10, |
|
|
9 |
m=0.35, |
|
|
10 |
s=30): |
|
|
11 |
super(AMSoftmax, self).__init__() |
|
|
12 |
self.m = m |
|
|
13 |
self.s = s |
|
|
14 |
self.in_feats = in_feats |
|
|
15 |
self.W = torch.nn.Parameter(torch.randn(in_feats, n_classes), requires_grad=True) |
|
|
16 |
self.ce = nn.CrossEntropyLoss() |
|
|
17 |
nn.init.xavier_normal_(self.W, gain=1) |
|
|
18 |
|
|
|
19 |
def forward(self, x, label): |
|
|
20 |
#print(x.shape, lb.shape, self.in_feats) |
|
|
21 |
#assert x.size()[0] == label.size()[0] |
|
|
22 |
#assert x.size()[1] == self.in_feats |
|
|
23 |
x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12) |
|
|
24 |
x_norm = torch.div(x, x_norm) |
|
|
25 |
w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12) |
|
|
26 |
w_norm = torch.div(self.W, w_norm) |
|
|
27 |
costh = torch.mm(x_norm, w_norm) |
|
|
28 |
# print(x_norm.shape, w_norm.shape, costh.shape) |
|
|
29 |
lb_view = label.view(-1, 1).to(x.device) |
|
|
30 |
delt_costh = torch.zeros(costh.size()).to(x.device).scatter_(1, lb_view, self.m) |
|
|
31 |
costh_m = costh - delt_costh |
|
|
32 |
costh_m_s = self.s * costh_m |
|
|
33 |
loss = self.ce(costh_m_s, label) |
|
|
34 |
return loss, costh_m_s |
|
|
35 |
|
|
|
36 |
def predict(self, x): |
|
|
37 |
x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12) |
|
|
38 |
x_norm = torch.div(x, x_norm) |
|
|
39 |
w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12) |
|
|
40 |
w_norm = torch.div(self.W, w_norm) |
|
|
41 |
costh = torch.mm(x_norm, w_norm) |
|
|
42 |
return costh |
|
|
43 |
|
|
|
44 |
class MultiSimilarityLoss(nn.Module): |
|
|
45 |
def __init__(self): |
|
|
46 |
super(MultiSimilarityLoss, self).__init__() |
|
|
47 |
self.thresh = 0.5 |
|
|
48 |
self.margin = 0.1 |
|
|
49 |
|
|
|
50 |
self.scale_pos = 2.0 |
|
|
51 |
self.scale_neg = 50.0 |
|
|
52 |
|
|
|
53 |
def forward(self, feats, labels): |
|
|
54 |
#assert feats.size(0) == labels.size(0), \ |
|
|
55 |
# f"feats.size(0): {feats.size(0)} is not equal to labels.size(0): {labels.size(0)}" |
|
|
56 |
batch_size = feats.size(0) |
|
|
57 |
|
|
|
58 |
# Feature normalize |
|
|
59 |
x_norm = torch.norm(feats, p=2, dim=1, keepdim=True).clamp(min=1e-12) |
|
|
60 |
x_norm = torch.div(feats, x_norm) |
|
|
61 |
|
|
|
62 |
sim_mat = torch.matmul(x_norm, torch.t(x_norm)) |
|
|
63 |
|
|
|
64 |
epsilon = 1e-5 |
|
|
65 |
loss = [] |
|
|
66 |
|
|
|
67 |
#unique_label, inverse_indices = torch.unique_consecutive(labels, return_inverse=True) |
|
|
68 |
|
|
|
69 |
for i in range(batch_size): |
|
|
70 |
pos_pair_ = sim_mat[i][labels == labels[i]] |
|
|
71 |
pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon] |
|
|
72 |
neg_pair_ = sim_mat[i][labels != labels[i]] |
|
|
73 |
|
|
|
74 |
#print(pos_pair_) |
|
|
75 |
#print(neg_pair_) |
|
|
76 |
|
|
|
77 |
if len(neg_pair_) >= 1: |
|
|
78 |
pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)] |
|
|
79 |
if len(pos_pair) >= 1: |
|
|
80 |
pos_loss = 1.0 / self.scale_pos * torch.log( |
|
|
81 |
1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh)))) |
|
|
82 |
loss.append(pos_loss) |
|
|
83 |
|
|
|
84 |
if len(pos_pair_) >= 1: |
|
|
85 |
neg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)] |
|
|
86 |
if len(neg_pair) >= 1: |
|
|
87 |
neg_loss = 1.0 / self.scale_neg * torch.log( |
|
|
88 |
1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh)))) |
|
|
89 |
loss.append(neg_loss) |
|
|
90 |
|
|
|
91 |
#print(labels, len(loss)) |
|
|
92 |
if len(loss) == 0: |
|
|
93 |
return torch.zeros([], requires_grad=True).to(feats.device) |
|
|
94 |
|
|
|
95 |
loss = sum(loss) / batch_size |
|
|
96 |
return loss |
|
|
97 |
|
|
|
98 |
if __name__ == '__main__': |
|
|
99 |
criteria = AMSoftmax(20, 5) |
|
|
100 |
a = torch.randn(10, 20) |
|
|
101 |
lb = torch.randint(0, 5, (10, ), dtype=torch.long) |
|
|
102 |
loss = criteria(a, lb) |
|
|
103 |
loss.backward() |
|
|
104 |
|
|
|
105 |
print(loss.detach().numpy()) |
|
|
106 |
print(list(criteria.parameters())[0].shape) |
|
|
107 |
print(type(next(criteria.parameters()))) |
|
|
108 |
print(lb) |