a | b/src/losses.py | ||
---|---|---|---|
1 | import numpy as np |
||
2 | import torch |
||
3 | import torch.nn as nn |
||
4 | |||
5 | |||
6 | class LogLoss(nn.BCEWithLogitsLoss): |
||
7 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'): |
||
8 | |||
9 | if weight is None: |
||
10 | pass |
||
11 | else: |
||
12 | weight = torch.tensor(weight, requires_grad=False, dtype=torch.float32).cuda() |
||
13 | |||
14 | super(LogLoss, self).__init__( |
||
15 | weight=weight, |
||
16 | size_average=size_average, |
||
17 | reduce=reduce, |
||
18 | reduction=reduction |
||
19 | ) |