Switch to unified view

a b/src/LFBNet/losses/losses.py
1
"""" Script to compute different loss functions in Keras based on tensorflow.
2
3
This script compute dice loss, binary cross entropy loss, focal loss, and their combinations.
4
It also computes hard and soft dice metric as well as loss.
5
6
"""
7
8
# Import libraries
9
from numpy.random import seed
10
from typing import List, Tuple
11
from keras import backend as K
12
import tensorflow as tf
13
from numpy import ndarray
14
15
# seed random number generator
16
seed(1)
17
18
19
class LossMetric:
20
    """ compute loss and metrics
21
22
    Attributes:
23
        y_true: the reference value,
24
        y_predicted: the predicted value to compare with y_true.
25
26
    Returns:
27
        Returns the loss or metric.
28
29
    """
30
    def __init__(self, y_true: List[float] = None, y_predicted: List[float] = None):
31
        self.y_true = y_true
32
        self.y_predicted = y_predicted
33
34
    @staticmethod
35
    def dice_metric(y_true: ndarray = None, y_predicted: ndarray = None, soft_dice: bool = False,
36
            threshold_value: float = 0.5, smooth=1) -> float:
37
        """compute the dice coefficient between the reference and target
38
        Threshold dice similarity coefficient
39
40
        Args:
41
            y_true: reference target.
42
            y_predicted:  predicted target by the model.
43
            soft_dice: apply soft dice or not.
44
            threshold_value:  thresholding value for soft-dice application.
45
            smooth: avoid division by zero values.
46
47
        Returns:
48
            Returns dice similarity coefficient, with threshold predicted values
49
50
        """
51
        y_true = K.flatten(y_true)
52
        y_predicted = K.flatten(y_predicted)
53
        # prevent from log(0)
54
        y_true = K.clip(y_true, 10e-8, 1. - 10e-8)
55
        y_predicted = K.clip(y_predicted, 10e-8, 1. - 10e-8)
56
57
        # soft dice
58
        if soft_dice:
59
            y_predicted = K.cast(K.greater(y_predicted, threshold_value), dtype='float32')
60
61
        intersection = K.sum(y_true * y_predicted)
62
63
        # smooth: avoid by zero division
64
        dice = (2. * intersection + smooth) / (K.sum(y_true) + K.sum(y_predicted) + smooth)
65
66
        return dice
67
68
    def dice_loss(self, y_true: ndarray, y_predicted: ndarray) -> float:
69
        """ Compute the dice loss
70
71
        Args:
72
            y_true: reference target.
73
            y_predicted:  predicted target by the model.
74
75
        Returns:
76
            Returns dice loss.
77
78
        """
79
        return 1 - self.dice_metric(y_true, y_predicted)
80
81
    @staticmethod
82
    def binary_cross_entropy_loss(y_true: ndarray = None, y_predicted: ndarray = None) -> float:
83
        """ compute the binary cross entropy loss
84
85
        Args:
86
            y_true: reference target.
87
            y_predicted:  predicted target by the model.
88
89
        Returns:
90
            Returns binary cross entropy between the target and predicted value.
91
92
        """
93
        # prevent from log(0)
94
        y_true = K.clip(y_true, 10e-8, 1. - 10e-8)
95
        y_predicted = K.clip(y_predicted, 10e-8, 1. - 10e-8)
96
97
        return K.binary_crossentropy(y_true, y_predicted)
98
99
    @staticmethod
100
    def binary_focal_loss(y_true: ndarray = None, y_predicted: ndarray = None, gamma: int = 2,
101
            alpha: float = .25) -> float:
102
        """ computes the focal loss
103
104
        Args:
105
            y_true: reference target.
106
            y_predicted: predicted target by the model.
107
            gamma: constant value
108
            alpha: constant value
109
110
        Returns:
111
            Returns focal loss.
112
113
        """
114
        y_true = K.cast(y_true, dtype='float32')
115
        # Define epsilon so that the back-propagation will not result in NaN for 0 divisor case
116
        epsilon = K.epsilon()
117
        # Add the epsilon to prediction value
118
        y_predicted = y_predicted + epsilon
119
        # Clip the prediction value
120
        y_predicted = K.clip(y_predicted, epsilon, 1.0 - epsilon)
121
122
        # Calculate p_t
123
        p_t = tf.where(K.equal(y_true, 1), y_predicted, 1 - y_predicted)
124
        # Calculate alpha_t
125
        alpha_factor = K.ones_like(y_true) * alpha
126
        alpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1 - alpha_factor)
127
        # Calculate cross entropy
128
        cross_entropy = -K.log(p_t)
129
        weight = alpha_t * K.pow((1 - p_t), gamma)
130
        # Calculate focal loss f
131
        loss = weight * cross_entropy
132
        # Sum the losses in mini_batch
133
        loss = K.mean(K.sum(loss, axis=1))
134
        return loss
135
136
    @staticmethod
137
    def focal_loss(y_true: ndarray = None, y_predicted: ndarray = None, gamma: int = 2, alpha: float = .25) -> float:
138
        """ computes the focal loss
139
140
        Adapted from: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook
141
        """
142
143
        y_true = K.flatten(y_true)
144
        y_predicted = K.flatten(y_predicted)
145
146
        bce = K.binary_crossentropy(y_true, y_predicted)
147
        bce_exp = K.exp(-bce)
148
        focal_loss = K.mean(alpha * K.pow((1 - bce_exp), gamma) * bce)
149
150
        return focal_loss
151
152
    def dice_plus_binary_cross_entropy_loss(self, y_true, y_predicted):
153
        """ compute the average of the sum of dice and binary cross entropy loss.
154
155
        Args:
156
            y_true: reference target.
157
            y_predicted:  predicted target by the model.
158
159
        Returns:
160
            Returns the average of the sum of dice and binary cross entropy losses.
161
162
        """
163
        loss = 0.5 * (self.dice_loss(y_true, y_predicted) + self.binary_cross_entropy_loss(y_true=y_true,
164
            y_predicted=y_predicted))
165
166
        return loss
167
168
    def dice_plus_focal_loss(self, y_true: ndarray, y_predicted: ndarray) -> float:
169
        """ compute the sum of the dice and focal loss
170
171
        Args:
172
            y_true: reference target.
173
            y_predicted:  predicted target by the model.
174
175
        Returns:
176
            Returns the sum of the dice and focal loss.
177
178
        """
179
        return self.dice_loss(y_true, y_predicted) + self.binary_focal_loss(y_true, y_predicted)
180
181
    @staticmethod
182
    def iou_loss(y_true, y_predicted, smooth=1e-8):
183
        """ compute the intersection over union loss.
184
185
        Args:
186
            y_true: reference target.
187
            y_predicted:  predicted target by the model.
188
            smooth: avoid division by zero.
189
190
        Returns:
191
            Returns intersection over union loss.
192
193
        """
194
        y_true = K.flatten(y_true)
195
        y_predicted = K.flatten(y_predicted)
196
197
        intersection = K.sum(K.dot(y_true, y_predicted))
198
        total = K.sum(y_true) + K.sum(y_predicted)
199
200
        union = total - intersection
201
202
        iou = (intersection + smooth) / (union + smooth)
203
204
        return iou