|
a |
|
b/opengait/modeling/losses/base.py |
|
|
1 |
from ctypes import ArgumentError |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import torch |
|
|
4 |
from utils import Odict |
|
|
5 |
import functools |
|
|
6 |
from utils import ddp_all_gather |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
def gather_and_scale_wrapper(func): |
|
|
10 |
"""Internal wrapper: gather the input from multple cards to one card, and scale the loss by the number of cards. |
|
|
11 |
""" |
|
|
12 |
|
|
|
13 |
@functools.wraps(func) |
|
|
14 |
def inner(*args, **kwds): |
|
|
15 |
try: |
|
|
16 |
|
|
|
17 |
for k, v in kwds.items(): |
|
|
18 |
kwds[k] = ddp_all_gather(v) |
|
|
19 |
|
|
|
20 |
loss, loss_info = func(*args, **kwds) |
|
|
21 |
loss *= torch.distributed.get_world_size() |
|
|
22 |
return loss, loss_info |
|
|
23 |
except: |
|
|
24 |
raise ArgumentError |
|
|
25 |
return inner |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
class BaseLoss(nn.Module): |
|
|
29 |
""" |
|
|
30 |
Base class for all losses. |
|
|
31 |
|
|
|
32 |
Your loss should also subclass this class. |
|
|
33 |
""" |
|
|
34 |
|
|
|
35 |
def __init__(self, loss_term_weight=1.0): |
|
|
36 |
""" |
|
|
37 |
Initialize the base class. |
|
|
38 |
|
|
|
39 |
Args: |
|
|
40 |
loss_term_weight: the weight of the loss term. |
|
|
41 |
""" |
|
|
42 |
super(BaseLoss, self).__init__() |
|
|
43 |
self.loss_term_weight = loss_term_weight |
|
|
44 |
self.info = Odict() |
|
|
45 |
|
|
|
46 |
def forward(self, logits, labels): |
|
|
47 |
""" |
|
|
48 |
The default forward function. |
|
|
49 |
|
|
|
50 |
This function should be overridden by the subclass. |
|
|
51 |
|
|
|
52 |
Args: |
|
|
53 |
logits: the logits of the model. |
|
|
54 |
labels: the labels of the data. |
|
|
55 |
|
|
|
56 |
Returns: |
|
|
57 |
tuple of loss and info. |
|
|
58 |
""" |
|
|
59 |
return .0, self.info |