--- a +++ b/rvseg/loss.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python + +from __future__ import division, print_function + +from keras import backend as K + + +def soft_sorensen_dice(y_true, y_pred, axis=None, smooth=1): + intersection = K.sum(y_true * y_pred, axis=axis) + area_true = K.sum(y_true, axis=axis) + area_pred = K.sum(y_pred, axis=axis) + return (2 * intersection + smooth) / (area_true + area_pred + smooth) + +def hard_sorensen_dice(y_true, y_pred, axis=None, smooth=1): + y_true_int = K.round(y_true) + y_pred_int = K.round(y_pred) + return soft_sorensen_dice(y_true_int, y_pred_int, axis, smooth) + +sorensen_dice = hard_sorensen_dice + +def sorensen_dice_loss(y_true, y_pred, weights): + # Input tensors have shape (batch_size, height, width, classes) + # User must input list of weights with length equal to number of classes + # + # Ex: for simple binary classification, with the 0th mask + # corresponding to the background and the 1st mask corresponding + # to the object of interest, we set weights = [0, 1] + batch_dice_coefs = soft_sorensen_dice(y_true, y_pred, axis=[1, 2]) + dice_coefs = K.mean(batch_dice_coefs, axis=0) + w = K.constant(weights) / sum(weights) + return 1 - K.sum(w * dice_coefs) + +def soft_jaccard(y_true, y_pred, axis=None, smooth=1): + intersection = K.sum(y_true * y_pred, axis=axis) + area_true = K.sum(y_true, axis=axis) + area_pred = K.sum(y_pred, axis=axis) + union = area_true + area_pred - intersection + return (intersection + smooth) / (union + smooth) + +def hard_jaccard(y_true, y_pred, axis=None, smooth=1): + y_true_int = K.round(y_true) + y_pred_int = K.round(y_pred) + return soft_jaccard(y_true_int, y_pred_int, axis, smooth) + +jaccard = hard_jaccard + +def jaccard_loss(y_true, y_pred, weights): + batch_jaccard_coefs = soft_jaccard(y_true, y_pred, axis=[1, 2]) + jaccard_coefs = K.mean(batch_jaccard_coefs, axis=0) + w = K.constant(weights) / sum(weights) + return 1 - K.sum(w * jaccard_coefs) + +def weighted_categorical_crossentropy(y_true, y_pred, weights, epsilon=1e-8): + ndim = K.ndim(y_pred) + ncategory = K.int_shape(y_pred)[-1] + # scale predictions so class probabilities of each pixel sum to 1 + y_pred /= K.sum(y_pred, axis=(ndim-1), keepdims=True) + y_pred = K.clip(y_pred, epsilon, 1-epsilon) + w = K.constant(weights) * (ncategory / sum(weights)) + # first, average over all axis except classes + cross_entropies = -K.mean(y_true * K.log(y_pred), axis=tuple(range(ndim-1))) + return K.sum(w * cross_entropies)