[903821]: / utils / losses.py

Download this file

104 lines (88 with data), 3.7 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
import torch
import torch.nn as nn
import torch.nn.functional as F
def dice_loss(score, target):
target = target.float()
smooth = 1e-5
intersect = torch.sum(score * target)
y_sum = torch.sum(target * target)
z_sum = torch.sum(score * score)
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
loss = 1 - loss
return loss
def Binary_dice_loss(predictive, target, ep=1e-8):
intersection = 2 * torch.sum(predictive * target) + ep
union = torch.sum(predictive) + torch.sum(target) + ep
loss = 1 - intersection / union
return loss
def kl_loss(inputs, targets, ep=1e-8):
kl_loss=nn.KLDivLoss(reduction='mean')
consist_loss = kl_loss(torch.log(inputs+ep), targets)
return consist_loss
def soft_ce_loss(inputs, target, ep=1e-8):
logprobs = torch.log(inputs+ep)
return torch.mean(-(target[:,0,...]*logprobs[:,0,...]+target[:,1,...]*logprobs[:,1,...]))
def softmax_kl_loss(input_logits, target_logits, sigmoid=False):
"""Takes softmax on both sides and returns KL divergence
Note:
- Returns the sum over all examples. Divide by the batch size afterwards
if you want the mean.
- Sends gradients to inputs but not the targets.
"""
assert input_logits.size() == target_logits.size()
if sigmoid:
input_log_softmax = torch.log(torch.sigmoid(input_logits))
target_softmax = torch.sigmoid(target_logits)
else:
input_log_softmax = F.log_softmax(input_logits, dim=1)
target_softmax = F.softmax(target_logits, dim=1)
# return F.kl_div(input_log_softmax, target_softmax)
kl_div = F.kl_div(input_log_softmax, target_softmax, reduction='mean')
# mean_kl_div = torch.mean(0.2*kl_div[:,0,...]+0.8*kl_div[:,1,...])
return kl_div
def softmax_mse_loss(input_logits, target_logits):
"""Takes softmax on both sides and returns MSE loss
Note:
- Returns the sum over all examples. Divide by the batch size afterwards
if you want the mean.
- Sends gradients to inputs but not the targets.
"""
assert input_logits.size() == target_logits.size()
input_softmax = F.softmax(input_logits, dim=1)
target_softmax = F.softmax(target_logits, dim=1)
mse_loss = F.mse_loss(input_softmax,target_softmax)
return mse_loss
def mse_loss(input1, input2):
return torch.mean((input1 - input2)**2)
class DiceLoss(nn.Module):
def __init__(self, n_classes):
super(DiceLoss, self).__init__()
self.n_classes = n_classes
def _one_hot_encoder(self, input_tensor):
tensor_list = []
for i in range(self.n_classes):
temp_prob = input_tensor == i * torch.ones_like(input_tensor)
tensor_list.append(temp_prob)
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()
def _dice_loss(self, score, target):
target = target.float()
smooth = 1e-10
intersection = torch.sum(score * target)
union = torch.sum(score * score) + torch.sum(target * target) + smooth
loss = 1 - intersection / union
return loss
def forward(self, inputs, target, weight=None, softmax=False):
if softmax:
inputs = torch.softmax(inputs, dim=1)
target = self._one_hot_encoder(target)
if weight is None:
weight = [1] * self.n_classes
assert inputs.size() == target.size(), 'predict & target shape do not match'
class_wise_dice = []
loss = 0.0
for i in range(0, self.n_classes):
dice = self._dice_loss(inputs[:, i], target[:, i])
class_wise_dice.append(1.0 - dice.item())
loss += dice * weight[i]
return loss / self.n_classes