Diff of /loss/focal.py [000000] .. [f77492]

Switch to side-by-side view

--- a
+++ b/loss/focal.py
@@ -0,0 +1,44 @@
+import torch
+from .utils import *
+from torch import nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+
+
+class FocalLoss(nn.Module):
+    """
+    Focal loss for binary classification
+    """
+    def __init__(self, gamma=2, alpha=0.25):
+        super(FocalLoss, self).__init__()
+        self.gamma = gamma
+        self.alpha = alpha
+
+    def forward(self, input, target):
+        # input:size is M*2. M is the batch number
+        # target:size is M.
+        pt = torch.softmax(input, dim=1)
+        p = pt[:, 1]
+        loss = -self.alpha * (1 - p)**self.gamma * (target * torch.log(p)) -\
+            (1 - self.alpha) * p**self.gamma * ((1 - target) * torch.log(1 - p))
+        return loss.mean()
+
+
+class FocalLoss2d(nn.Module):
+    def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=255):
+        super(FocalLoss2d, self).__init__()
+        self.alpha = alpha
+        self.gamma = gamma
+        self.ignore_index = ignore_index
+        self.size_average = size_average
+
+    def forward(self, inputs, targets):
+        inputs = to_float_and_cuda(inputs)
+        targets = to_long_and_cuda(targets)
+        ce_loss = F.cross_entropy(inputs, targets,reduction='none', ignore_index=self.ignore_index)
+        pt = torch.exp(-ce_loss)
+        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
+        if self.size_average:
+            return focal_loss.mean()
+        else:
+            return focal_loss.sum()
\ No newline at end of file