Switch to side-by-side view

--- a
+++ b/opengait/modeling/loss_aggregator.py
@@ -0,0 +1,86 @@
+"""The loss aggregator."""
+
+import torch
+import torch.nn as nn
+from . import losses
+from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module
+from utils import Odict
+from utils import get_msg_mgr
+
+
+class LossAggregator(nn.Module):
+    """The loss aggregator.
+
+    This class is used to aggregate the losses.
+    For example, if you have two losses, one is triplet loss, the other is cross entropy loss,
+    you can aggregate them as follows:
+    loss_num = tripley_loss + cross_entropy_loss 
+
+    Attributes:
+        losses: A dict of losses.
+    """
+    def __init__(self, loss_cfg) -> None:
+        """
+        Initialize the loss aggregator.
+
+        LossAggregator can be indexed like a regular Python dictionary, 
+        but modules it contains are properly registered, and will be visible by all Module methods.
+        All parameters registered in losses can be accessed by the method 'self.parameters()',
+        thus they can be trained properly.
+        
+        Args:
+            loss_cfg: Config of losses. List for multiple losses.
+        """
+        super().__init__()
+        self.losses = nn.ModuleDict({loss_cfg['log_prefix']: self._build_loss_(loss_cfg)} if is_dict(loss_cfg) \
+            else {cfg['log_prefix']: self._build_loss_(cfg) for cfg in loss_cfg})
+
+    def _build_loss_(self, loss_cfg):
+        """Build the losses from loss_cfg.
+
+        Args:
+            loss_cfg: Config of loss.
+        """
+        Loss = get_attr_from([losses], loss_cfg['type'])
+        valid_loss_arg = get_valid_args(
+            Loss, loss_cfg, ['type', 'gather_and_scale'])
+        loss = get_ddp_module(Loss(**valid_loss_arg).cuda())
+        return loss
+
+    def forward(self, training_feats):
+        """Compute the sum of all losses.
+
+        The input is a dict of features. The key is the name of loss and the value is the feature and label. If the key not in 
+        built losses and the value is torch.Tensor, then it is the computed loss to be added loss_sum.
+
+        Args:
+            training_feats: A dict of features. The same as the output["training_feat"] of the model.
+        """
+        loss_sum = .0
+        loss_info = Odict()
+
+        for k, v in training_feats.items():
+            if k in self.losses:
+                loss_func = self.losses[k]
+                loss, info = loss_func(**v)
+                for name, value in info.items():
+                    loss_info['scalar/%s/%s' % (k, name)] = value
+                loss = loss.mean() * loss_func.loss_term_weight
+                loss_sum += loss
+
+            else:
+                if isinstance(v, dict):
+                    raise ValueError(
+                        "The key %s in -Trainng-Feat- should be stated in your loss_cfg as log_prefix."%k
+                    )
+                elif is_tensor(v):
+                    _ = v.mean()
+                    loss_info['scalar/%s' % k] = _
+                    loss_sum += _
+                    get_msg_mgr().log_debug(
+                        "Please check whether %s needed in training." % k)
+                else:
+                    raise ValueError(
+                        "Error type for -Trainng-Feat-, supported: A feature dict or loss tensor.")
+
+        return loss_sum, loss_info