--- a +++ b/pretrain/loss.py @@ -0,0 +1,108 @@ +import torch +import torch.nn as nn + + +class AMSoftmax(nn.Module): + def __init__(self, + in_feats, + n_classes=10, + m=0.35, + s=30): + super(AMSoftmax, self).__init__() + self.m = m + self.s = s + self.in_feats = in_feats + self.W = torch.nn.Parameter(torch.randn(in_feats, n_classes), requires_grad=True) + self.ce = nn.CrossEntropyLoss() + nn.init.xavier_normal_(self.W, gain=1) + + def forward(self, x, label): + #print(x.shape, lb.shape, self.in_feats) + #assert x.size()[0] == label.size()[0] + #assert x.size()[1] == self.in_feats + x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12) + x_norm = torch.div(x, x_norm) + w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12) + w_norm = torch.div(self.W, w_norm) + costh = torch.mm(x_norm, w_norm) + # print(x_norm.shape, w_norm.shape, costh.shape) + lb_view = label.view(-1, 1).to(x.device) + delt_costh = torch.zeros(costh.size()).to(x.device).scatter_(1, lb_view, self.m) + costh_m = costh - delt_costh + costh_m_s = self.s * costh_m + loss = self.ce(costh_m_s, label) + return loss, costh_m_s + + def predict(self, x): + x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12) + x_norm = torch.div(x, x_norm) + w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12) + w_norm = torch.div(self.W, w_norm) + costh = torch.mm(x_norm, w_norm) + return costh + +class MultiSimilarityLoss(nn.Module): + def __init__(self): + super(MultiSimilarityLoss, self).__init__() + self.thresh = 0.5 + self.margin = 0.1 + + self.scale_pos = 2.0 + self.scale_neg = 50.0 + + def forward(self, feats, labels): + #assert feats.size(0) == labels.size(0), \ + # f"feats.size(0): {feats.size(0)} is not equal to labels.size(0): {labels.size(0)}" + batch_size = feats.size(0) + + # Feature normalize + x_norm = torch.norm(feats, p=2, dim=1, keepdim=True).clamp(min=1e-12) + x_norm = torch.div(feats, x_norm) + + sim_mat = torch.matmul(x_norm, torch.t(x_norm)) + + epsilon = 1e-5 + loss = [] + + #unique_label, inverse_indices = torch.unique_consecutive(labels, return_inverse=True) + + for i in range(batch_size): + pos_pair_ = sim_mat[i][labels == labels[i]] + pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon] + neg_pair_ = sim_mat[i][labels != labels[i]] + + #print(pos_pair_) + #print(neg_pair_) + + if len(neg_pair_) >= 1: + pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)] + if len(pos_pair) >= 1: + pos_loss = 1.0 / self.scale_pos * torch.log( + 1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh)))) + loss.append(pos_loss) + + if len(pos_pair_) >= 1: + neg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)] + if len(neg_pair) >= 1: + neg_loss = 1.0 / self.scale_neg * torch.log( + 1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh)))) + loss.append(neg_loss) + + #print(labels, len(loss)) + if len(loss) == 0: + return torch.zeros([], requires_grad=True).to(feats.device) + + loss = sum(loss) / batch_size + return loss + +if __name__ == '__main__': + criteria = AMSoftmax(20, 5) + a = torch.randn(10, 20) + lb = torch.randint(0, 5, (10, ), dtype=torch.long) + loss = criteria(a, lb) + loss.backward() + + print(loss.detach().numpy()) + print(list(criteria.parameters())[0].shape) + print(type(next(criteria.parameters()))) + print(lb)