[cf6a9e]: / loss / focal.py

Download this file

44 lines (38 with data), 1.5 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from .utils import *
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
class FocalLoss(nn.Module):
"""
Focal loss for binary classification
"""
def __init__(self, gamma=2, alpha=0.25):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, input, target):
# input:size is M*2. M is the batch number
# target:size is M.
pt = torch.softmax(input, dim=1)
p = pt[:, 1]
loss = -self.alpha * (1 - p)**self.gamma * (target * torch.log(p)) -\
(1 - self.alpha) * p**self.gamma * ((1 - target) * torch.log(1 - p))
return loss.mean()
class FocalLoss2d(nn.Module):
def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=255):
super(FocalLoss2d, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.ignore_index = ignore_index
self.size_average = size_average
def forward(self, inputs, targets):
inputs = to_float_and_cuda(inputs)
targets = to_long_and_cuda(targets)
ce_loss = F.cross_entropy(inputs, targets,reduction='none', ignore_index=self.ignore_index)
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
if self.size_average:
return focal_loss.mean()
else:
return focal_loss.sum()