Diff of /rvseg/loss.py [000000] .. [6673ef]

Switch to unified view

a b/rvseg/loss.py
1
#!/usr/bin/env python
2
3
from __future__ import division, print_function
4
5
from keras import backend as K
6
7
8
def soft_sorensen_dice(y_true, y_pred, axis=None, smooth=1):
9
    intersection = K.sum(y_true * y_pred, axis=axis)
10
    area_true = K.sum(y_true, axis=axis)
11
    area_pred = K.sum(y_pred, axis=axis)
12
    return (2 * intersection + smooth) / (area_true + area_pred + smooth)
13
    
14
def hard_sorensen_dice(y_true, y_pred, axis=None, smooth=1):
15
    y_true_int = K.round(y_true)
16
    y_pred_int = K.round(y_pred)
17
    return soft_sorensen_dice(y_true_int, y_pred_int, axis, smooth)
18
19
sorensen_dice = hard_sorensen_dice
20
21
def sorensen_dice_loss(y_true, y_pred, weights):
22
    # Input tensors have shape (batch_size, height, width, classes)
23
    # User must input list of weights with length equal to number of classes
24
    #
25
    # Ex: for simple binary classification, with the 0th mask
26
    # corresponding to the background and the 1st mask corresponding
27
    # to the object of interest, we set weights = [0, 1]
28
    batch_dice_coefs = soft_sorensen_dice(y_true, y_pred, axis=[1, 2])
29
    dice_coefs = K.mean(batch_dice_coefs, axis=0)
30
    w = K.constant(weights) / sum(weights)
31
    return 1 - K.sum(w * dice_coefs)
32
33
def soft_jaccard(y_true, y_pred, axis=None, smooth=1):
34
    intersection = K.sum(y_true * y_pred, axis=axis)
35
    area_true = K.sum(y_true, axis=axis)
36
    area_pred = K.sum(y_pred, axis=axis)
37
    union = area_true + area_pred - intersection
38
    return (intersection + smooth) / (union + smooth)
39
40
def hard_jaccard(y_true, y_pred, axis=None, smooth=1):
41
    y_true_int = K.round(y_true)
42
    y_pred_int = K.round(y_pred)
43
    return soft_jaccard(y_true_int, y_pred_int, axis, smooth)
44
45
jaccard = hard_jaccard
46
47
def jaccard_loss(y_true, y_pred, weights):
48
    batch_jaccard_coefs = soft_jaccard(y_true, y_pred, axis=[1, 2])
49
    jaccard_coefs = K.mean(batch_jaccard_coefs, axis=0)
50
    w = K.constant(weights) / sum(weights)
51
    return 1 - K.sum(w * jaccard_coefs)
52
53
def weighted_categorical_crossentropy(y_true, y_pred, weights, epsilon=1e-8):
54
    ndim = K.ndim(y_pred)
55
    ncategory = K.int_shape(y_pred)[-1]
56
    # scale predictions so class probabilities of each pixel sum to 1
57
    y_pred /= K.sum(y_pred, axis=(ndim-1), keepdims=True)
58
    y_pred = K.clip(y_pred, epsilon, 1-epsilon)
59
    w = K.constant(weights) * (ncategory / sum(weights))
60
    # first, average over all axis except classes
61
    cross_entropies = -K.mean(y_true * K.log(y_pred), axis=tuple(range(ndim-1)))
62
    return K.sum(w * cross_entropies)