--- a +++ b/loss/BCE.py @@ -0,0 +1,9 @@ +import torch.nn as nn +import torch + +loss_function = torch.nn.BCELoss() + +def BCE_loss(input, target): + # input = input.cuda() + # target = target.cuda() + return loss_function(input, target.float().cuda()) \ No newline at end of file