--- a +++ b/loss/focal.py @@ -0,0 +1,44 @@ +import torch +from .utils import * +from torch import nn +import torch.nn.functional as F +from torch.autograd import Variable + + +class FocalLoss(nn.Module): + """ + Focal loss for binary classification + """ + def __init__(self, gamma=2, alpha=0.25): + super(FocalLoss, self).__init__() + self.gamma = gamma + self.alpha = alpha + + def forward(self, input, target): + # input:size is M*2. M is the batch number + # target:size is M. + pt = torch.softmax(input, dim=1) + p = pt[:, 1] + loss = -self.alpha * (1 - p)**self.gamma * (target * torch.log(p)) -\ + (1 - self.alpha) * p**self.gamma * ((1 - target) * torch.log(1 - p)) + return loss.mean() + + +class FocalLoss2d(nn.Module): + def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=255): + super(FocalLoss2d, self).__init__() + self.alpha = alpha + self.gamma = gamma + self.ignore_index = ignore_index + self.size_average = size_average + + def forward(self, inputs, targets): + inputs = to_float_and_cuda(inputs) + targets = to_long_and_cuda(targets) + ce_loss = F.cross_entropy(inputs, targets,reduction='none', ignore_index=self.ignore_index) + pt = torch.exp(-ce_loss) + focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss + if self.size_average: + return focal_loss.mean() + else: + return focal_loss.sum() \ No newline at end of file