a b/coderpp/train/loss.py
1
import torch
2
import torch.nn as nn
3
4
5
class AMSoftmax(nn.Module):
6
    def __init__(self,
7
                 in_feats,
8
                 n_classes=10,
9
                 m=0.35,
10
                 s=30):
11
        super(AMSoftmax, self).__init__()
12
        self.m = m
13
        self.s = s
14
        self.in_feats = in_feats
15
        self.W = torch.nn.Parameter(torch.randn(in_feats, n_classes), requires_grad=True)
16
        self.ce = nn.CrossEntropyLoss()
17
        nn.init.xavier_normal_(self.W, gain=1)
18
19
    def forward(self, x, label):
20
        #print(x.shape, lb.shape, self.in_feats)
21
        #assert x.size()[0] == label.size()[0]
22
        #assert x.size()[1] == self.in_feats
23
        x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12)
24
        x_norm = torch.div(x, x_norm)
25
        w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12)
26
        w_norm = torch.div(self.W, w_norm)
27
        costh = torch.mm(x_norm, w_norm)
28
        # print(x_norm.shape, w_norm.shape, costh.shape)
29
        lb_view = label.view(-1, 1).to(x.device)
30
        delt_costh = torch.zeros(costh.size()).to(x.device).scatter_(1, lb_view, self.m)
31
        costh_m = costh - delt_costh
32
        costh_m_s = self.s * costh_m
33
        loss = self.ce(costh_m_s, label)
34
        return loss, costh_m_s
35
36
    def predict(self, x):
37
        x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12)
38
        x_norm = torch.div(x, x_norm)
39
        w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12)
40
        w_norm = torch.div(self.W, w_norm)
41
        costh = torch.mm(x_norm, w_norm)
42
        return costh
43
44
class MultiSimilarityLoss(nn.Module):
45
    def __init__(self):
46
        super(MultiSimilarityLoss, self).__init__()
47
        self.thresh = 0.5
48
        self.margin = 0.1
49
50
        self.scale_pos = 2.0
51
        self.scale_neg = 50.0
52
53
    def forward(self, feats, labels):
54
        #assert feats.size(0) == labels.size(0), \
55
        #    f"feats.size(0): {feats.size(0)} is not equal to labels.size(0): {labels.size(0)}"
56
        batch_size = feats.size(0)
57
58
        # Feature normalize
59
        x_norm = torch.norm(feats, p=2, dim=1, keepdim=True).clamp(min=1e-12)
60
        x_norm = torch.div(feats, x_norm)
61
62
        sim_mat = torch.matmul(x_norm, torch.t(x_norm))
63
64
        epsilon = 1e-5
65
        loss = []
66
67
        #unique_label, inverse_indices = torch.unique_consecutive(labels, return_inverse=True)
68
69
        for i in range(batch_size):
70
            pos_pair_ = sim_mat[i][labels == labels[i]]
71
            pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon]
72
            neg_pair_ = sim_mat[i][labels != labels[i]]
73
74
            #print(pos_pair_)
75
            #print(neg_pair_)
76
           
77
            if len(neg_pair_) >= 1:
78
                pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)]
79
                if len(pos_pair) >= 1:
80
                    pos_loss = 1.0 / self.scale_pos * torch.log(
81
                        1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh))))
82
                    loss.append(pos_loss)
83
84
            if len(pos_pair_) >= 1:
85
                neg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)]
86
                if len(neg_pair) >= 1:
87
                    neg_loss = 1.0 / self.scale_neg * torch.log(
88
                        1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh))))
89
                    loss.append(neg_loss)
90
91
        #print(labels, len(loss))
92
        if len(loss) == 0:
93
            return torch.zeros([], requires_grad=True).to(feats.device)
94
95
        loss = sum(loss) / batch_size
96
        return loss
97
98
if __name__ == '__main__':
99
    criteria = AMSoftmax(20, 5)
100
    a = torch.randn(10, 20)
101
    lb = torch.randint(0, 5, (10, ), dtype=torch.long)
102
    loss = criteria(a, lb)
103
    loss.backward()
104
105
    print(loss.detach().numpy())
106
    print(list(criteria.parameters())[0].shape)
107
    print(type(next(criteria.parameters())))
108
    print(lb)