Switch to side-by-side view

--- a
+++ b/libs/losses/create_losses.py
@@ -0,0 +1,25 @@
+import torch
+import torch.nn as nn
+
+from libs.losses.df_loss import EuclideanLossWithOHEM
+from libs.losses.mag_angle_loss import EuclideanAngleLossWithOHEM
+from libs.losses.surface_loss import SurfaceLoss
+
+class Total_loss():
+    def __init__(self, boundary=False):
+        self.df_loss = EuclideanAngleLossWithOHEM()
+        self.boundary = boundary
+        if boundary:
+            self.boundary_loss = SurfaceLoss(idc=[1,2,3])
+        
+    def __call__(self, net_logit, dist_maps, df_out, gts_df, gts):
+        df_loss = self.df_loss(df_out, gts_df, gts[:, None, ...])
+
+        if self.boundary:
+            net_prob = nn.functional.softmax(net_logit, dim=1)
+            b_loss = self.boundary_loss(net_prob, dist_maps, gts)
+        else:
+            b_loss = torch.tensor([0.], device=net_logit.device)
+
+        return df_loss, b_loss
+