a b/opengait/modeling/losses/base.py
1
from ctypes import ArgumentError
2
import torch.nn as nn
3
import torch
4
from utils import Odict
5
import functools
6
from utils import ddp_all_gather
7
8
9
def gather_and_scale_wrapper(func):
10
    """Internal wrapper: gather the input from multple cards to one card, and scale the loss by the number of cards.
11
    """
12
13
    @functools.wraps(func)
14
    def inner(*args, **kwds):
15
        try:
16
17
            for k, v in kwds.items():
18
                kwds[k] = ddp_all_gather(v)
19
20
            loss, loss_info = func(*args, **kwds)
21
            loss *= torch.distributed.get_world_size()
22
            return loss, loss_info
23
        except:
24
            raise ArgumentError
25
    return inner
26
27
28
class BaseLoss(nn.Module):
29
    """
30
    Base class for all losses.
31
32
    Your loss should also subclass this class.
33
    """
34
35
    def __init__(self, loss_term_weight=1.0):
36
        """
37
        Initialize the base class.
38
39
        Args:
40
            loss_term_weight: the weight of the loss term.
41
        """
42
        super(BaseLoss, self).__init__()
43
        self.loss_term_weight = loss_term_weight
44
        self.info = Odict()
45
46
    def forward(self, logits, labels):
47
        """
48
        The default forward function.
49
50
        This function should be overridden by the subclass. 
51
52
        Args:
53
            logits: the logits of the model.
54
            labels: the labels of the data.
55
56
        Returns:
57
            tuple of loss and info.
58
        """
59
        return .0, self.info