Switch to unified view

a b/src/loss_functions/losses.py
1
import torch
2
import torch.nn as nn
3
4
5
class AsymmetricLoss(nn.Module):
6
    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
7
        super(AsymmetricLoss, self).__init__()
8
9
        self.gamma_neg = gamma_neg
10
        self.gamma_pos = gamma_pos
11
        self.clip = clip
12
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
13
        self.eps = eps
14
15
    def forward(self, x, y):
16
        """"
17
        Parameters
18
        ----------
19
        x: input logits
20
        y: targets (multi-label binarized vector)
21
        """
22
23
        # Calculating Probabilities
24
        x_sigmoid = torch.sigmoid(x)
25
        xs_pos = x_sigmoid
26
        xs_neg = 1 - x_sigmoid
27
28
        # Asymmetric Clipping
29
        if self.clip is not None and self.clip > 0:
30
            xs_neg = (xs_neg + self.clip).clamp(max=1)
31
32
        # Basic CE calculation
33
        los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
34
        los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
35
        loss = los_pos + los_neg
36
37
        # Asymmetric Focusing
38
        if self.gamma_neg > 0 or self.gamma_pos > 0:
39
            if self.disable_torch_grad_focal_loss:
40
                torch.set_grad_enabled(False)
41
            pt0 = xs_pos * y
42
            pt1 = xs_neg * (1 - y)  # pt = p if t > 0 else 1-p
43
            pt = pt0 + pt1
44
            one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
45
            one_sided_w = torch.pow(1 - pt, one_sided_gamma)
46
            if self.disable_torch_grad_focal_loss:
47
                torch.set_grad_enabled(True)
48
            loss *= one_sided_w
49
50
        return -loss.sum()
51
52
53
class AsymmetricLossOptimized(nn.Module):
54
    ''' Notice - optimized version, minimizes memory allocation and gpu uploading,
55
    favors inplace operations'''
56
57
    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
58
        super(AsymmetricLossOptimized, self).__init__()
59
60
        self.gamma_neg = gamma_neg
61
        self.gamma_pos = gamma_pos
62
        self.clip = clip
63
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
64
        self.eps = eps
65
66
        # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
67
        self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None
68
69
    def forward(self, x, y):
70
        """"
71
        Parameters
72
        ----------
73
        x: input logits
74
        y: targets (multi-label binarized vector)
75
        """
76
77
        self.targets = y
78
        self.anti_targets = 1 - y
79
80
        # Calculating Probabilities
81
        self.xs_pos = torch.sigmoid(x)
82
        self.xs_neg = 1.0 - self.xs_pos
83
84
        # Asymmetric Clipping
85
        if self.clip is not None and self.clip > 0:
86
            self.xs_neg.add_(self.clip).clamp_(max=1)
87
88
        # Basic CE calculation
89
        self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
90
        self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))
91
92
        # Asymmetric Focusing
93
        if self.gamma_neg > 0 or self.gamma_pos > 0:
94
            if self.disable_torch_grad_focal_loss:
95
                torch.set_grad_enabled(False)
96
            self.xs_pos = self.xs_pos * self.targets
97
            self.xs_neg = self.xs_neg * self.anti_targets
98
            self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
99
                                          self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
100
            if self.disable_torch_grad_focal_loss:
101
                torch.set_grad_enabled(True)
102
            self.loss *= self.asymmetric_w
103
104
        return -self.loss.sum()
105
106
107
class ASLSingleLabel(nn.Module):
108
    '''
109
    This loss is intended for single-label classification problems
110
    '''
111
    def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, reduction='mean'):
112
        super(ASLSingleLabel, self).__init__()
113
114
        self.eps = eps
115
        self.logsoftmax = nn.LogSoftmax(dim=-1)
116
        self.targets_classes = []
117
        self.gamma_pos = gamma_pos
118
        self.gamma_neg = gamma_neg
119
        self.reduction = reduction
120
121
    def forward(self, inputs, target):
122
        '''
123
        "input" dimensions: - (batch_size,number_classes)
124
        "target" dimensions: - (batch_size)
125
        '''
126
        num_classes = inputs.size()[-1]
127
        log_preds = self.logsoftmax(inputs)
128
        self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)
129
130
        # ASL weights
131
        targets = self.targets_classes
132
        anti_targets = 1 - targets
133
        xs_pos = torch.exp(log_preds)
134
        xs_neg = 1 - xs_pos
135
        xs_pos = xs_pos * targets
136
        xs_neg = xs_neg * anti_targets
137
        asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
138
                                 self.gamma_pos * targets + self.gamma_neg * anti_targets)
139
        log_preds = log_preds * asymmetric_w
140
141
        if self.eps > 0:  # label smoothing
142
            self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes)
143
144
        # loss calculation
145
        loss = - self.targets_classes.mul(log_preds)
146
147
        loss = loss.sum(dim=-1)
148
        if self.reduction == 'mean':
149
            loss = loss.mean()
150
151
        return loss