[66de0a]: / opengait / modeling / losses / base.py

Download this file

60 lines (45 with data), 1.4 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from ctypes import ArgumentError
import torch.nn as nn
import torch
from utils import Odict
import functools
from utils import ddp_all_gather
def gather_and_scale_wrapper(func):
"""Internal wrapper: gather the input from multple cards to one card, and scale the loss by the number of cards.
"""
@functools.wraps(func)
def inner(*args, **kwds):
try:
for k, v in kwds.items():
kwds[k] = ddp_all_gather(v)
loss, loss_info = func(*args, **kwds)
loss *= torch.distributed.get_world_size()
return loss, loss_info
except:
raise ArgumentError
return inner
class BaseLoss(nn.Module):
"""
Base class for all losses.
Your loss should also subclass this class.
"""
def __init__(self, loss_term_weight=1.0):
"""
Initialize the base class.
Args:
loss_term_weight: the weight of the loss term.
"""
super(BaseLoss, self).__init__()
self.loss_term_weight = loss_term_weight
self.info = Odict()
def forward(self, logits, labels):
"""
The default forward function.
This function should be overridden by the subclass.
Args:
logits: the logits of the model.
labels: the labels of the data.
Returns:
tuple of loss and info.
"""
return .0, self.info