[95f789]: / src / losses.py

Download this file

20 lines (15 with data), 505 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
import numpy as np
import torch
import torch.nn as nn
class LogLoss(nn.BCEWithLogitsLoss):
def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
if weight is None:
pass
else:
weight = torch.tensor(weight, requires_grad=False, dtype=torch.float32).cuda()
super(LogLoss, self).__init__(
weight=weight,
size_average=size_average,
reduce=reduce,
reduction=reduction
)