[98e649]: / libs / losses / create_losses.py

Download this file

26 lines (19 with data), 840 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch.nn as nn
from libs.losses.df_loss import EuclideanLossWithOHEM
from libs.losses.mag_angle_loss import EuclideanAngleLossWithOHEM
from libs.losses.surface_loss import SurfaceLoss
class Total_loss():
def __init__(self, boundary=False):
self.df_loss = EuclideanAngleLossWithOHEM()
self.boundary = boundary
if boundary:
self.boundary_loss = SurfaceLoss(idc=[1,2,3])
def __call__(self, net_logit, dist_maps, df_out, gts_df, gts):
df_loss = self.df_loss(df_out, gts_df, gts[:, None, ...])
if self.boundary:
net_prob = nn.functional.softmax(net_logit, dim=1)
b_loss = self.boundary_loss(net_prob, dist_maps, gts)
else:
b_loss = torch.tensor([0.], device=net_logit.device)
return df_loss, b_loss