Diff of /src/losses.py [000000] .. [95f789]

Switch to unified view

a b/src/losses.py
1
import numpy as np
2
import torch
3
import torch.nn as nn
4
5
6
class LogLoss(nn.BCEWithLogitsLoss):
7
    def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
8
9
        if weight is None:
10
            pass
11
        else:
12
            weight = torch.tensor(weight, requires_grad=False, dtype=torch.float32).cuda()
13
14
        super(LogLoss, self).__init__(
15
            weight=weight,
16
            size_average=size_average,
17
            reduce=reduce,
18
            reduction=reduction
19
        )