a b/opengait/modeling/loss_aggregator.py
1
"""The loss aggregator."""
2
3
import torch
4
import torch.nn as nn
5
from . import losses
6
from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module
7
from utils import Odict
8
from utils import get_msg_mgr
9
10
11
class LossAggregator(nn.Module):
12
    """The loss aggregator.
13
14
    This class is used to aggregate the losses.
15
    For example, if you have two losses, one is triplet loss, the other is cross entropy loss,
16
    you can aggregate them as follows:
17
    loss_num = tripley_loss + cross_entropy_loss 
18
19
    Attributes:
20
        losses: A dict of losses.
21
    """
22
    def __init__(self, loss_cfg) -> None:
23
        """
24
        Initialize the loss aggregator.
25
26
        LossAggregator can be indexed like a regular Python dictionary, 
27
        but modules it contains are properly registered, and will be visible by all Module methods.
28
        All parameters registered in losses can be accessed by the method 'self.parameters()',
29
        thus they can be trained properly.
30
        
31
        Args:
32
            loss_cfg: Config of losses. List for multiple losses.
33
        """
34
        super().__init__()
35
        self.losses = nn.ModuleDict({loss_cfg['log_prefix']: self._build_loss_(loss_cfg)} if is_dict(loss_cfg) \
36
            else {cfg['log_prefix']: self._build_loss_(cfg) for cfg in loss_cfg})
37
38
    def _build_loss_(self, loss_cfg):
39
        """Build the losses from loss_cfg.
40
41
        Args:
42
            loss_cfg: Config of loss.
43
        """
44
        Loss = get_attr_from([losses], loss_cfg['type'])
45
        valid_loss_arg = get_valid_args(
46
            Loss, loss_cfg, ['type', 'gather_and_scale'])
47
        loss = get_ddp_module(Loss(**valid_loss_arg).cuda())
48
        return loss
49
50
    def forward(self, training_feats):
51
        """Compute the sum of all losses.
52
53
        The input is a dict of features. The key is the name of loss and the value is the feature and label. If the key not in 
54
        built losses and the value is torch.Tensor, then it is the computed loss to be added loss_sum.
55
56
        Args:
57
            training_feats: A dict of features. The same as the output["training_feat"] of the model.
58
        """
59
        loss_sum = .0
60
        loss_info = Odict()
61
62
        for k, v in training_feats.items():
63
            if k in self.losses:
64
                loss_func = self.losses[k]
65
                loss, info = loss_func(**v)
66
                for name, value in info.items():
67
                    loss_info['scalar/%s/%s' % (k, name)] = value
68
                loss = loss.mean() * loss_func.loss_term_weight
69
                loss_sum += loss
70
71
            else:
72
                if isinstance(v, dict):
73
                    raise ValueError(
74
                        "The key %s in -Trainng-Feat- should be stated in your loss_cfg as log_prefix."%k
75
                    )
76
                elif is_tensor(v):
77
                    _ = v.mean()
78
                    loss_info['scalar/%s' % k] = _
79
                    loss_sum += _
80
                    get_msg_mgr().log_debug(
81
                        "Please check whether %s needed in training." % k)
82
                else:
83
                    raise ValueError(
84
                        "Error type for -Trainng-Feat-, supported: A feature dict or loss tensor.")
85
86
        return loss_sum, loss_info