[6969be]: / rocaseg / components / losses.py

Download this file

50 lines (40 with data), 1.3 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import logging
from torch import nn
logging.basicConfig()
logger = logging.getLogger('losses')
logger.setLevel(logging.DEBUG)
class CrossEntropyLoss(nn.Module):
def __init__(self, num_classes, batch_avg=True, batch_weight=None,
class_avg=True, class_weight=None, **kwargs):
"""
Parameters
----------
batch_avg:
Whether to average over the batch dimension.
batch_weight:
Batch samples importance coefficients.
class_avg:
Whether to average over the class dimension.
class_weight:
Classes importance coefficients.
"""
super().__init__()
self.num_classes = num_classes
self.batch_avg = batch_avg
self.class_avg = class_avg
self.batch_weight = batch_weight
self.class_weight = class_weight
logger.warning('Redundant loss function arguments:\n{}'
.format(repr(kwargs)))
self.ce = nn.CrossEntropyLoss(weight=class_weight)
def forward(self, input_, target, **kwargs):
"""
Parameters
----------
input_: (b, ch, d0, d1) tensor
target: (b, d0, d1) tensor
Returns
-------
out: float tensor
"""
return self.ce(input_, target)