|
a |
|
b/opengait/modeling/loss_aggregator.py |
|
|
1 |
"""The loss aggregator.""" |
|
|
2 |
|
|
|
3 |
import torch |
|
|
4 |
import torch.nn as nn |
|
|
5 |
from . import losses |
|
|
6 |
from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module |
|
|
7 |
from utils import Odict |
|
|
8 |
from utils import get_msg_mgr |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
class LossAggregator(nn.Module): |
|
|
12 |
"""The loss aggregator. |
|
|
13 |
|
|
|
14 |
This class is used to aggregate the losses. |
|
|
15 |
For example, if you have two losses, one is triplet loss, the other is cross entropy loss, |
|
|
16 |
you can aggregate them as follows: |
|
|
17 |
loss_num = tripley_loss + cross_entropy_loss |
|
|
18 |
|
|
|
19 |
Attributes: |
|
|
20 |
losses: A dict of losses. |
|
|
21 |
""" |
|
|
22 |
def __init__(self, loss_cfg) -> None: |
|
|
23 |
""" |
|
|
24 |
Initialize the loss aggregator. |
|
|
25 |
|
|
|
26 |
LossAggregator can be indexed like a regular Python dictionary, |
|
|
27 |
but modules it contains are properly registered, and will be visible by all Module methods. |
|
|
28 |
All parameters registered in losses can be accessed by the method 'self.parameters()', |
|
|
29 |
thus they can be trained properly. |
|
|
30 |
|
|
|
31 |
Args: |
|
|
32 |
loss_cfg: Config of losses. List for multiple losses. |
|
|
33 |
""" |
|
|
34 |
super().__init__() |
|
|
35 |
self.losses = nn.ModuleDict({loss_cfg['log_prefix']: self._build_loss_(loss_cfg)} if is_dict(loss_cfg) \ |
|
|
36 |
else {cfg['log_prefix']: self._build_loss_(cfg) for cfg in loss_cfg}) |
|
|
37 |
|
|
|
38 |
def _build_loss_(self, loss_cfg): |
|
|
39 |
"""Build the losses from loss_cfg. |
|
|
40 |
|
|
|
41 |
Args: |
|
|
42 |
loss_cfg: Config of loss. |
|
|
43 |
""" |
|
|
44 |
Loss = get_attr_from([losses], loss_cfg['type']) |
|
|
45 |
valid_loss_arg = get_valid_args( |
|
|
46 |
Loss, loss_cfg, ['type', 'gather_and_scale']) |
|
|
47 |
loss = get_ddp_module(Loss(**valid_loss_arg).cuda()) |
|
|
48 |
return loss |
|
|
49 |
|
|
|
50 |
def forward(self, training_feats): |
|
|
51 |
"""Compute the sum of all losses. |
|
|
52 |
|
|
|
53 |
The input is a dict of features. The key is the name of loss and the value is the feature and label. If the key not in |
|
|
54 |
built losses and the value is torch.Tensor, then it is the computed loss to be added loss_sum. |
|
|
55 |
|
|
|
56 |
Args: |
|
|
57 |
training_feats: A dict of features. The same as the output["training_feat"] of the model. |
|
|
58 |
""" |
|
|
59 |
loss_sum = .0 |
|
|
60 |
loss_info = Odict() |
|
|
61 |
|
|
|
62 |
for k, v in training_feats.items(): |
|
|
63 |
if k in self.losses: |
|
|
64 |
loss_func = self.losses[k] |
|
|
65 |
loss, info = loss_func(**v) |
|
|
66 |
for name, value in info.items(): |
|
|
67 |
loss_info['scalar/%s/%s' % (k, name)] = value |
|
|
68 |
loss = loss.mean() * loss_func.loss_term_weight |
|
|
69 |
loss_sum += loss |
|
|
70 |
|
|
|
71 |
else: |
|
|
72 |
if isinstance(v, dict): |
|
|
73 |
raise ValueError( |
|
|
74 |
"The key %s in -Trainng-Feat- should be stated in your loss_cfg as log_prefix."%k |
|
|
75 |
) |
|
|
76 |
elif is_tensor(v): |
|
|
77 |
_ = v.mean() |
|
|
78 |
loss_info['scalar/%s' % k] = _ |
|
|
79 |
loss_sum += _ |
|
|
80 |
get_msg_mgr().log_debug( |
|
|
81 |
"Please check whether %s needed in training." % k) |
|
|
82 |
else: |
|
|
83 |
raise ValueError( |
|
|
84 |
"Error type for -Trainng-Feat-, supported: A feature dict or loss tensor.") |
|
|
85 |
|
|
|
86 |
return loss_sum, loss_info |