a b/code/losses.py
1
import torch
2
import numpy as np
3
4
def log_barrier(z, t=5):
5
6
    # Only one value
7
    if z.shape[0] == 1:
8
9
        if z <= - 1 / t ** 2:
10
            log_barrier_loss = - torch.log(-z) / t
11
        else:
12
            log_barrier_loss = t * z + -np.log(1 / (t ** 2)) / t + 1 / t
13
14
    # Constrain over multiple values
15
    else:
16
        log_barrier_loss = torch.tensor(0).cuda().float()
17
        for i in np.arange(0, z.shape[0]):
18
            zi = z[i, 0]
19
            if zi <= - 1 / t ** 2:
20
                log_barrier_loss += - torch.log(-zi) / t
21
            else:
22
                log_barrier_loss += t * zi + -np.log(1 / (t ** 2)) / t + 1 / t
23
24
    return log_barrier_loss
25
26
27
class SupConLoss(torch.nn.Module):
28
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
29
    It also supports the unsupervised contrastive loss in SimCLR"""
30
    def __init__(self, temperature=0.07, contrast_mode='all',
31
                 base_temperature=0.07):
32
        super(SupConLoss, self).__init__()
33
        self.temperature = temperature
34
        self.contrast_mode = contrast_mode
35
        self.base_temperature = base_temperature
36
37
    def forward(self, features, labels=None, mask=None):
38
        """Compute loss for model. If both `labels` and `mask` are None,
39
        it degenerates to SimCLR unsupervised loss:
40
        https://arxiv.org/pdf/2002.05709.pdf
41
        Args:
42
            features: hidden vector of shape [bsz, n_views, ...].
43
            labels: ground truth of shape [bsz].
44
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
45
                has the same class as sample i. Can be asymmetric.
46
        Returns:
47
            A loss scalar.
48
        """
49
        device = (torch.device('cuda')
50
                  if features.is_cuda
51
                  else torch.device('cpu'))
52
53
        if len(features.shape) < 3:
54
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
55
                             'at least 3 dimensions are required')
56
        if len(features.shape) > 3:
57
            features = features.view(features.shape[0], features.shape[1], -1)
58
59
        batch_size = features.shape[0]
60
        if labels is not None and mask is not None:
61
            raise ValueError('Cannot define both `labels` and `mask`')
62
        elif labels is None and mask is None:
63
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
64
        elif labels is not None:
65
            labels = labels.contiguous().view(-1, 1)
66
            if labels.shape[0] != batch_size:
67
                raise ValueError('Num of labels does not match num of features')
68
            mask = torch.eq(labels, labels.T).float().to(device)
69
        else:
70
            mask = mask.float().to(device)
71
72
        contrast_count = features.shape[1]
73
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
74
        if self.contrast_mode == 'one':
75
            anchor_feature = features[:, 0]
76
            anchor_count = 1
77
        elif self.contrast_mode == 'all':
78
            anchor_feature = contrast_feature
79
            anchor_count = contrast_count
80
        else:
81
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
82
83
        # compute logits
84
        anchor_dot_contrast = torch.div(
85
            torch.matmul(anchor_feature, contrast_feature.T),
86
            self.temperature)
87
        # for numerical stability
88
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
89
        logits = anchor_dot_contrast - logits_max.detach()
90
91
        # tile mask
92
        mask = mask.repeat(anchor_count, contrast_count)
93
        # mask-out self-contrast cases
94
        logits_mask = torch.scatter(
95
            torch.ones_like(mask),
96
            1,
97
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
98
            0
99
        )
100
        mask = mask * logits_mask
101
102
        # compute log_prob
103
        exp_logits = torch.exp(logits) * logits_mask
104
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-3)
105
106
        # compute mean of log-likelihood over positive
107
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1)+1e-3)
108
109
        # loss
110
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
111
        loss = loss.view(anchor_count, batch_size).mean()
112
113
        return loss
114
115
116
class SupervisedContrastiveLoss(torch.nn.Module):
117
    def __init__(self, temperature=0.1):
118
        super(SupervisedContrastiveLoss, self).__init__()
119
        self.temperature = temperature
120
121
    def forward(self, feature_vectors, labels):
122
        # Normalize feature vectors
123
        feature_vectors_normalized = torch.nn.functional.normalize(feature_vectors, p=2, dim=1)
124
        # Compute logits
125
        logits = torch.div(
126
            torch.matmul(
127
                feature_vectors_normalized, torch.transpose(feature_vectors_normalized, 0, 1)
128
            ),
129
            self.temperature,
130
        )
131
        return losses.NTXentLoss(temperature=0.07)(logits, torch.squeeze(labels))