|
a |
|
b/loss/focal.py |
|
|
1 |
import torch |
|
|
2 |
from .utils import * |
|
|
3 |
from torch import nn |
|
|
4 |
import torch.nn.functional as F |
|
|
5 |
from torch.autograd import Variable |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
class FocalLoss(nn.Module): |
|
|
9 |
""" |
|
|
10 |
Focal loss for binary classification |
|
|
11 |
""" |
|
|
12 |
def __init__(self, gamma=2, alpha=0.25): |
|
|
13 |
super(FocalLoss, self).__init__() |
|
|
14 |
self.gamma = gamma |
|
|
15 |
self.alpha = alpha |
|
|
16 |
|
|
|
17 |
def forward(self, input, target): |
|
|
18 |
# input:size is M*2. M is the batch number |
|
|
19 |
# target:size is M. |
|
|
20 |
pt = torch.softmax(input, dim=1) |
|
|
21 |
p = pt[:, 1] |
|
|
22 |
loss = -self.alpha * (1 - p)**self.gamma * (target * torch.log(p)) -\ |
|
|
23 |
(1 - self.alpha) * p**self.gamma * ((1 - target) * torch.log(1 - p)) |
|
|
24 |
return loss.mean() |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
class FocalLoss2d(nn.Module): |
|
|
28 |
def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=255): |
|
|
29 |
super(FocalLoss2d, self).__init__() |
|
|
30 |
self.alpha = alpha |
|
|
31 |
self.gamma = gamma |
|
|
32 |
self.ignore_index = ignore_index |
|
|
33 |
self.size_average = size_average |
|
|
34 |
|
|
|
35 |
def forward(self, inputs, targets): |
|
|
36 |
inputs = to_float_and_cuda(inputs) |
|
|
37 |
targets = to_long_and_cuda(targets) |
|
|
38 |
ce_loss = F.cross_entropy(inputs, targets,reduction='none', ignore_index=self.ignore_index) |
|
|
39 |
pt = torch.exp(-ce_loss) |
|
|
40 |
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss |
|
|
41 |
if self.size_average: |
|
|
42 |
return focal_loss.mean() |
|
|
43 |
else: |
|
|
44 |
return focal_loss.sum() |