--- a +++ b/adpkd_segmentation/utils/losses.py @@ -0,0 +1,380 @@ +"""Loss utilities and definitions""" + +# %% +import torch +import torch.nn as nn +import torch.nn.functional as F + +from adpkd_segmentation.data.data_utils import ( + KIDNEY_PIXELS, + STUDY_TKV, + VOXEL_VOLUME, +) + + +# %% +def binarize_thresholds(pred, thresholds=[0.5]): + """ + Args: + pred: model pred tensor with shape b x c x (X x Y) + thresholds: list of floats i.e. [0.6,0.5,0.4] + + Returns: + float tensor: binary values + """ + + C = len(thresholds) + thresholds = torch.tensor(thresholds) + thresholds = thresholds.reshape(1, C, 1, 1) + thresholds.expand_as(pred) + thresholds = thresholds.to(pred.device) + res = pred > thresholds + + return res.float() + + +# %% +def binarize_argmax(pred): + """ + Args: + pred: model pred tensor with shape b x c x (X x Y) + + Returns: + float tensor: binary values + """ + max_c = torch.argmax(pred, 1) # argmax across C axis + num_classes = pred.shape[1] + encoded = torch.nn.functional.one_hot(max_c, num_classes) + encoded = encoded.permute([0, 3, 1, 2]) + + return encoded.float() + + +class SigmoidBinarize: + def __init__(self, thresholds): + self.thresholds = thresholds + + def __call__(self, pred): + # Expects (N, C, H, W) format + return binarize_thresholds(torch.sigmoid(pred), self.thresholds) + + +class SigmoidForwardBinarize: + def __init__(self, thresholds): + self.thresholds = thresholds + + def __call__(self, pred): + # Expects (N, C, H, W) format + soft = torch.sigmoid(pred) + hard = binarize_thresholds(soft, self.thresholds) + return hard.detach() + soft - soft.detach() + + +class SoftmaxBinarize: + def __call__(self, pred): + # Expects (N, C, H, W) format + return binarize_argmax(pred) + + +class SoftmaxForwardBinarize: + def __call__(self, pred): + # Expects (N, C, H, W) format + soft = F.softmax(pred, dim=1) + hard = binarize_argmax(soft) + return hard.detach() + soft - soft.detach() + + +class StandardizeModels: + def __init__(self, ignore_channel=2): + # used for backgoround in 3 channel setups + self.ignore_channel = 2 + + def __call__(self, binary_mask): + # N, C, H, W mask + num_channels = binary_mask.shape[1] + if num_channels == 1: + return binary_mask + elif num_channels == 2: + return torch.sum(binary_mask, dim=1, keepdim=True) + elif num_channels == 3: + sum_all = torch.sum(binary_mask, dim=1) + sum_all = sum_all - binary_mask[:, self.ignore_channel, ...] + sum_all = sum_all.unsqueeze(1) + return sum_all + else: + raise ValueError( + "Unsupported number of channels: {}".format(num_channels) + ) + + +class Dice(nn.Module): + """Dice metric/loss. + + Supports different Dice variants. + """ + + def __init__( + self, + pred_process=None, + epsilon=1e-8, + power=2, + dim=(2, 3), + standardize_func=None, + use_as_loss=True, + ): + super().__init__() + self.pred_process = pred_process + self.epsilon = epsilon + self.power = power + self.dim = dim + self.standardize_func = standardize_func + self.use_as_loss = use_as_loss + + def __call__(self, pred, target): + + if self.pred_process is not None: + pred = self.pred_process(pred) + + if self.standardize_func is not None: + pred = self.standardize_func(pred) + target = self.standardize_func(target) + + intersection = torch.sum(pred * target, dim=self.dim) + set_add = torch.sum( + pred ** self.power + target ** self.power, dim=self.dim + ) + score = (2 * intersection + self.epsilon) / (set_add + self.epsilon) + score = score.mean() + if self.use_as_loss: + return 1 - score + return score + + +class PredictionEntropy(nn.Module): + """ + Calculates average entropy of the predicted soft mask. + + Doesn't depend on ground truth mask. + """ + + def __init__(self, pred_process, epsilon=1e-8, standardize_func=None): + super().__init__() + self.pred_process = pred_process + self.epsilon = epsilon + self.standardize_func = standardize_func + + def __call__(self, pred, target): + pred = self.pred_process(pred) + if self.standardize_func is not None: + pred = self.standardize_func(pred) + target = self.standardize_func(target) + + entropy = -pred * torch.log(pred + self.epsilon) + return entropy.mean() + + +class KidneyPixelMAPE(nn.Module): + """ + Calculates the absolute percentage error for predicted kidney pixel counts + + (label kidney pixel count - predicted k.p. count) / (label k.p. count) + + By default, kidney pixel summation is done for each image separately, and + averaged over the entire batch. + + Depending on the `pred_process` function, + predicted kidney pixel count can be soft or hard. + """ + + def __init__( + self, pred_process, epsilon=1.0, dim=(2, 3), standardize_func=None + ): + super().__init__() + self.pred_process = pred_process + self.epsilon = epsilon + self.dim = dim + self.standardize_func = standardize_func + + def __call__(self, pred, target): + + pred = self.pred_process(pred) + if self.standardize_func is not None: + pred = self.standardize_func(pred) + target = self.standardize_func(target) + + target_count = target.sum(dim=self.dim).detach() + pred_count = pred.sum(dim=self.dim) + + kp_batch_MAPE = torch.abs( + (target_count - pred_count) / (target_count + self.epsilon) + ).mean() + + return kp_batch_MAPE + + +class KidneyPixelMSLE(nn.Module): + """ + Mean square error for the log of kidney pixel counts. + + MSE of ln(label kidney pixel count) - ln(predicted k.p. count) + + By default, pixels are counted separetely for each image, with final + averaging across all images + + Depending on the `pred_process` function, + predicted kidney pixel count can be soft or hard. + """ + + def __init__( + self, pred_process, epsilon=1.0, dim=(2, 3), standardize_func=None + ): + super().__init__() + self.pred_process = pred_process + self.epsilon = epsilon + self.dim = dim + self.standardize_func = standardize_func + + def __call__(self, pred, target): + pred = self.pred_process(pred) + if self.standardize_func is not None: + pred = self.standardize_func(pred) + target = self.standardize_func(target) + + target_count = target.sum(dim=self.dim).detach() + pred_count = pred.sum(dim=self.dim) + + sle = ( + torch.log(target_count + self.epsilon) + - torch.log(pred_count + self.epsilon) + ) ** 2 + msle = torch.mean(sle) + return msle + + +class WeightedLosses(nn.Module): + def __init__(self, criterions, weights, requires_extra_dict=None): + super().__init__() + self.criterions = criterions + self.weights = weights + self.requires_extra_dict = requires_extra_dict + if requires_extra_dict is None: + self.requires_extra_dict = [False for c in self.criterions] + + def __call__(self, pred, target, extra_dict=None): + losses = [] + for c, w, e in zip( + self.criterions, self.weights, self.requires_extra_dict + ): + loss = c(pred, target, extra_dict) if e else c(pred, target) + losses.append(loss * w) + return torch.sum(torch.stack(losses)) + + +class DynamicBalanceLosses(nn.Module): + def __init__( + self, criterions, epsilon=1e-6, weights=None, requires_extra_dict=None + ): + self.criterions = criterions + self.epsilon = epsilon + self.requires_extra_dict = requires_extra_dict + self.weights = weights + if weights is None: + self.weights = [1.0] * len(self.criterions) + self.weights = torch.tensor(self.weights) + if requires_extra_dict is None: + self.requires_extra_dict = [False for c in self.criterions] + + def __call__(self, pred, target, extra_dict=None): + # first, scale losses such that + # L_1 * s_1 = L_2 * s_2 = ... L_n * s_n = + # L_1 * L_2 * ... * L_n + # e.g. s_2 = L_1 * L_3 * ... * L_n + # calculate scaling factors dynamically + partial_losses = [] + for c, e in zip(self.criterions, self.requires_extra_dict): + loss = c(pred, target, extra_dict) if e else c(pred, target) + partial_losses.append(loss) + partial_losses = torch.stack(partial_losses) + self.epsilon + # no backprop through dynamic scaling factors + detached = partial_losses.detach() + prod = torch.prod(detached) + # divide the total product by the vector of loss values + # to get scaling factors such as e.g. s_2 = L_1 * L_3 * ... * L_n + scales = prod / detached + # final weighting by external weights + self.weights = self.weights.to(scales.device) + scales = scales * self.weights + normalization = torch.sum(scales) + + loss = (partial_losses * scales).sum() / normalization + return loss + + +class ErrorLogTKVRelative(nn.Module): + def __init__(self, pred_process, epsilon=1.0, standardize_func=None): + super().__init__() + self.pred_process = pred_process + self.epsilon = epsilon + self.standardize_func = standardize_func + + def __call__(self, pred, target, extra_dict): + pred = self.pred_process(pred) + if self.standardize_func is not None: + pred = self.standardize_func(pred) + target = self.standardize_func(target) + + intersection = torch.sum(pred * target, dim=(1, 2, 3)) + error = ( + torch.sum(pred ** 2, dim=(1, 2, 3)) + + torch.sum(target ** 2, dim=(1, 2, 3)) + - 2 * intersection + ) + + # augmentation correction for original kidney pixel count + # also, convert to VOXEL VOLUME + scale = (extra_dict[KIDNEY_PIXELS] + self.epsilon) / ( + torch.sum(target, dim=(1, 2, 3)) + self.epsilon + ) + scaled_vol_error = scale * error * extra_dict[VOXEL_VOLUME] + # error more important if kidneys are smaller + # but for the same kidney volume, error on any slice + # matters equally + # use log due to different orders of magnitudes + weight = 1 / (torch.log(extra_dict[STUDY_TKV]) + self.epsilon) + log_error = (scaled_vol_error * weight).mean() + + return log_error + + +class BiasReductionLoss(nn.Module): + def __init__( + self, pred_process, standardize_func=None, w1=0.5, w2=0.5, epsilon=1e-8 + ): + super().__init__() + self.pred_process = pred_process + self.standardize_func = standardize_func + self.w1 = w1 + self.w2 = w2 + self.epsilon = epsilon + + def __call__(self, pred, target): + pred = self.pred_process(pred) + if self.standardize_func is not None: + pred = self.standardize_func(pred) + target = self.standardize_func(target) + + intersection = torch.sum(pred * target, dim=(1, 2, 3)) + # count what's missing from the target area + missing = target.sum(dim=(1, 2, 3)) - intersection + # count all extra predictions outside the target area + false_pos = torch.sum(pred * (1 - target), dim=(1, 2, 3)) + + # both losses should go to zero, but they should also be the same + loss = ( + self.w1 * (missing ** 2 + false_pos ** 2) + + self.w2 * (missing - false_pos) ** 2 + ) + # sqrt is not differentiable at zero + loss = (loss.mean() + self.epsilon) ** 0.5 + + return loss