Diff of /loss/focal.py [000000] .. [f77492]

Switch to unified view

a b/loss/focal.py
1
import torch
2
from .utils import *
3
from torch import nn
4
import torch.nn.functional as F
5
from torch.autograd import Variable
6
7
8
class FocalLoss(nn.Module):
9
    """
10
    Focal loss for binary classification
11
    """
12
    def __init__(self, gamma=2, alpha=0.25):
13
        super(FocalLoss, self).__init__()
14
        self.gamma = gamma
15
        self.alpha = alpha
16
17
    def forward(self, input, target):
18
        # input:size is M*2. M is the batch number
19
        # target:size is M.
20
        pt = torch.softmax(input, dim=1)
21
        p = pt[:, 1]
22
        loss = -self.alpha * (1 - p)**self.gamma * (target * torch.log(p)) -\
23
            (1 - self.alpha) * p**self.gamma * ((1 - target) * torch.log(1 - p))
24
        return loss.mean()
25
26
27
class FocalLoss2d(nn.Module):
28
    def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=255):
29
        super(FocalLoss2d, self).__init__()
30
        self.alpha = alpha
31
        self.gamma = gamma
32
        self.ignore_index = ignore_index
33
        self.size_average = size_average
34
35
    def forward(self, inputs, targets):
36
        inputs = to_float_and_cuda(inputs)
37
        targets = to_long_and_cuda(targets)
38
        ce_loss = F.cross_entropy(inputs, targets,reduction='none', ignore_index=self.ignore_index)
39
        pt = torch.exp(-ce_loss)
40
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
41
        if self.size_average:
42
            return focal_loss.mean()
43
        else:
44
            return focal_loss.sum()