Diff of /src/losses.py [000000] .. [95f789]

Switch to side-by-side view

--- a
+++ b/src/losses.py
@@ -0,0 +1,19 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+
+class LogLoss(nn.BCEWithLogitsLoss):
+    def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
+
+        if weight is None:
+            pass
+        else:
+            weight = torch.tensor(weight, requires_grad=False, dtype=torch.float32).cuda()
+
+        super(LogLoss, self).__init__(
+            weight=weight,
+            size_average=size_average,
+            reduce=reduce,
+            reduction=reduction
+        )