[cf6a9e]: / loss / BCE.py

Download this file

9 lines (7 with data), 211 Bytes

1
2
3
4
5
6
7
8
9
import torch.nn as nn
import torch
loss_function = torch.nn.BCELoss()
def BCE_loss(input, target):
# input = input.cuda()
# target = target.cuda()
return loss_function(input, target.float().cuda())