a b/loss/BCE.py
1
import torch.nn as nn
2
import torch
3
4
loss_function = torch.nn.BCELoss()
5
6
def BCE_loss(input, target):
7
    # input = input.cuda()
8
    # target = target.cuda()
9
    return loss_function(input, target.float().cuda())