[66de0a]: / opengait / modeling / losses / bce.py

Download this file

42 lines (33 with data), 1.2 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from .base import BaseLoss
from evaluation import mean_iou
class BinaryCrossEntropyLoss(BaseLoss):
def __init__(self, loss_term_weight=1.0, eps=1.0e-9):
super(BinaryCrossEntropyLoss, self).__init__(loss_term_weight)
self.eps = eps
def forward(self, logits, labels):
"""
logits: [n, 1, h, w]
labels: [n, 1, h, w]
"""
# predts = torch.sigmoid(logits.float())
labels = labels.float()
logits = logits.float()
loss = - (labels * torch.log(logits + self.eps) +
(1 - labels) * torch.log(1. - logits + self.eps))
n = loss.size(0)
loss = loss.view(n, -1)
mean_loss = loss.mean()
hard_loss = loss.max()
miou = mean_iou((logits > 0.5).float(), labels)
self.info.update({
'loss': mean_loss.detach().clone(),
'hard_loss': hard_loss.detach().clone(),
'miou': miou.detach().clone()})
return mean_loss, self.info
if __name__ == "__main__":
loss_func = BinaryCrossEntropyLoss()
ipts = torch.randn(1, 1, 128, 64)
tags = (torch.randn(1, 1, 128, 64) > 0.).float()
loss = loss_func(ipts, tags)
print(loss)