Switch to side-by-side view

--- a
+++ b/opengait/modeling/losses/base.py
@@ -0,0 +1,59 @@
+from ctypes import ArgumentError
+import torch.nn as nn
+import torch
+from utils import Odict
+import functools
+from utils import ddp_all_gather
+
+
+def gather_and_scale_wrapper(func):
+    """Internal wrapper: gather the input from multple cards to one card, and scale the loss by the number of cards.
+    """
+
+    @functools.wraps(func)
+    def inner(*args, **kwds):
+        try:
+
+            for k, v in kwds.items():
+                kwds[k] = ddp_all_gather(v)
+
+            loss, loss_info = func(*args, **kwds)
+            loss *= torch.distributed.get_world_size()
+            return loss, loss_info
+        except:
+            raise ArgumentError
+    return inner
+
+
+class BaseLoss(nn.Module):
+    """
+    Base class for all losses.
+
+    Your loss should also subclass this class.
+    """
+
+    def __init__(self, loss_term_weight=1.0):
+        """
+        Initialize the base class.
+
+        Args:
+            loss_term_weight: the weight of the loss term.
+        """
+        super(BaseLoss, self).__init__()
+        self.loss_term_weight = loss_term_weight
+        self.info = Odict()
+
+    def forward(self, logits, labels):
+        """
+        The default forward function.
+
+        This function should be overridden by the subclass. 
+
+        Args:
+            logits: the logits of the model.
+            labels: the labels of the data.
+
+        Returns:
+            tuple of loss and info.
+        """
+        return .0, self.info