--- a +++ b/libs/losses/create_losses.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn + +from libs.losses.df_loss import EuclideanLossWithOHEM +from libs.losses.mag_angle_loss import EuclideanAngleLossWithOHEM +from libs.losses.surface_loss import SurfaceLoss + +class Total_loss(): + def __init__(self, boundary=False): + self.df_loss = EuclideanAngleLossWithOHEM() + self.boundary = boundary + if boundary: + self.boundary_loss = SurfaceLoss(idc=[1,2,3]) + + def __call__(self, net_logit, dist_maps, df_out, gts_df, gts): + df_loss = self.df_loss(df_out, gts_df, gts[:, None, ...]) + + if self.boundary: + net_prob = nn.functional.softmax(net_logit, dim=1) + b_loss = self.boundary_loss(net_prob, dist_maps, gts) + else: + b_loss = torch.tensor([0.], device=net_logit.device) + + return df_loss, b_loss +