Switch to side-by-side view

--- a
+++ b/adpkd_segmentation/utils/criterions.py
@@ -0,0 +1,31 @@
+"""Dictionary of criterions"""
+
+
+class LossesMetrics:
+    """ Produces function to generate dict of keys: losses/metrics for batch"""
+
+    def __init__(self, criterions_dict, requires_extra_info=None):
+        """
+        Args:
+            criterions_dict {dict} -- key (str) : criterion_losses
+            requires_extra_info {list or None}, keys in `criterions_dict`
+                noting losses which require extra information
+        """
+        self.criterions_dict = criterions_dict
+        if requires_extra_info is None:
+            self.requires_extra_info = set()
+        else:
+            self.requires_extra_info = set(requires_extra_info)
+
+    def __call__(self):
+        def losses_dict(y_hat, y, extra_dict=None):
+            res = {}
+            for c_name, criterion in self.criterions_dict.items():
+                # optional info for some criterions
+                if c_name in self.requires_extra_info:
+                    res[c_name] = criterion(y_hat, y, extra_dict)
+                else:
+                    res[c_name] = criterion(y_hat, y)
+            return res
+
+        return losses_dict