--- a +++ b/utils/losses.py @@ -0,0 +1,103 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +def dice_loss(score, target): + target = target.float() + smooth = 1e-5 + intersect = torch.sum(score * target) + y_sum = torch.sum(target * target) + z_sum = torch.sum(score * score) + loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) + loss = 1 - loss + return loss + +def Binary_dice_loss(predictive, target, ep=1e-8): + intersection = 2 * torch.sum(predictive * target) + ep + union = torch.sum(predictive) + torch.sum(target) + ep + loss = 1 - intersection / union + return loss + +def kl_loss(inputs, targets, ep=1e-8): + kl_loss=nn.KLDivLoss(reduction='mean') + consist_loss = kl_loss(torch.log(inputs+ep), targets) + return consist_loss + +def soft_ce_loss(inputs, target, ep=1e-8): + logprobs = torch.log(inputs+ep) + return torch.mean(-(target[:,0,...]*logprobs[:,0,...]+target[:,1,...]*logprobs[:,1,...])) + +def softmax_kl_loss(input_logits, target_logits, sigmoid=False): + """Takes softmax on both sides and returns KL divergence + + Note: + - Returns the sum over all examples. Divide by the batch size afterwards + if you want the mean. + - Sends gradients to inputs but not the targets. + """ + assert input_logits.size() == target_logits.size() + if sigmoid: + input_log_softmax = torch.log(torch.sigmoid(input_logits)) + target_softmax = torch.sigmoid(target_logits) + else: + input_log_softmax = F.log_softmax(input_logits, dim=1) + target_softmax = F.softmax(target_logits, dim=1) + + # return F.kl_div(input_log_softmax, target_softmax) + kl_div = F.kl_div(input_log_softmax, target_softmax, reduction='mean') + # mean_kl_div = torch.mean(0.2*kl_div[:,0,...]+0.8*kl_div[:,1,...]) + return kl_div + +def softmax_mse_loss(input_logits, target_logits): + """Takes softmax on both sides and returns MSE loss + + Note: + - Returns the sum over all examples. Divide by the batch size afterwards + if you want the mean. + - Sends gradients to inputs but not the targets. + """ + assert input_logits.size() == target_logits.size() + input_softmax = F.softmax(input_logits, dim=1) + target_softmax = F.softmax(target_logits, dim=1) + + mse_loss = F.mse_loss(input_softmax,target_softmax) + return mse_loss + +def mse_loss(input1, input2): + return torch.mean((input1 - input2)**2) + +class DiceLoss(nn.Module): + def __init__(self, n_classes): + super(DiceLoss, self).__init__() + self.n_classes = n_classes + + def _one_hot_encoder(self, input_tensor): + tensor_list = [] + for i in range(self.n_classes): + temp_prob = input_tensor == i * torch.ones_like(input_tensor) + tensor_list.append(temp_prob) + output_tensor = torch.cat(tensor_list, dim=1) + return output_tensor.float() + + def _dice_loss(self, score, target): + target = target.float() + smooth = 1e-10 + intersection = torch.sum(score * target) + union = torch.sum(score * score) + torch.sum(target * target) + smooth + loss = 1 - intersection / union + return loss + + def forward(self, inputs, target, weight=None, softmax=False): + if softmax: + inputs = torch.softmax(inputs, dim=1) + target = self._one_hot_encoder(target) + if weight is None: + weight = [1] * self.n_classes + assert inputs.size() == target.size(), 'predict & target shape do not match' + class_wise_dice = [] + loss = 0.0 + for i in range(0, self.n_classes): + dice = self._dice_loss(inputs[:, i], target[:, i]) + class_wise_dice.append(1.0 - dice.item()) + loss += dice * weight[i] + return loss / self.n_classes