a | b/loss/BCE.py | ||
---|---|---|---|
1 | import torch.nn as nn |
||
2 | import torch |
||
3 | |||
4 | loss_function = torch.nn.BCELoss() |
||
5 | |||
6 | def BCE_loss(input, target): |
||
7 | # input = input.cuda() |
||
8 | # target = target.cuda() |
||
9 | return loss_function(input, target.float().cuda()) |