|
a |
|
b/src/LFBNet/losses/losses.py |
|
|
1 |
"""" Script to compute different loss functions in Keras based on tensorflow. |
|
|
2 |
|
|
|
3 |
This script compute dice loss, binary cross entropy loss, focal loss, and their combinations. |
|
|
4 |
It also computes hard and soft dice metric as well as loss. |
|
|
5 |
|
|
|
6 |
""" |
|
|
7 |
|
|
|
8 |
# Import libraries |
|
|
9 |
from numpy.random import seed |
|
|
10 |
from typing import List, Tuple |
|
|
11 |
from keras import backend as K |
|
|
12 |
import tensorflow as tf |
|
|
13 |
from numpy import ndarray |
|
|
14 |
|
|
|
15 |
# seed random number generator |
|
|
16 |
seed(1) |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
class LossMetric: |
|
|
20 |
""" compute loss and metrics |
|
|
21 |
|
|
|
22 |
Attributes: |
|
|
23 |
y_true: the reference value, |
|
|
24 |
y_predicted: the predicted value to compare with y_true. |
|
|
25 |
|
|
|
26 |
Returns: |
|
|
27 |
Returns the loss or metric. |
|
|
28 |
|
|
|
29 |
""" |
|
|
30 |
def __init__(self, y_true: List[float] = None, y_predicted: List[float] = None): |
|
|
31 |
self.y_true = y_true |
|
|
32 |
self.y_predicted = y_predicted |
|
|
33 |
|
|
|
34 |
@staticmethod |
|
|
35 |
def dice_metric(y_true: ndarray = None, y_predicted: ndarray = None, soft_dice: bool = False, |
|
|
36 |
threshold_value: float = 0.5, smooth=1) -> float: |
|
|
37 |
"""compute the dice coefficient between the reference and target |
|
|
38 |
Threshold dice similarity coefficient |
|
|
39 |
|
|
|
40 |
Args: |
|
|
41 |
y_true: reference target. |
|
|
42 |
y_predicted: predicted target by the model. |
|
|
43 |
soft_dice: apply soft dice or not. |
|
|
44 |
threshold_value: thresholding value for soft-dice application. |
|
|
45 |
smooth: avoid division by zero values. |
|
|
46 |
|
|
|
47 |
Returns: |
|
|
48 |
Returns dice similarity coefficient, with threshold predicted values |
|
|
49 |
|
|
|
50 |
""" |
|
|
51 |
y_true = K.flatten(y_true) |
|
|
52 |
y_predicted = K.flatten(y_predicted) |
|
|
53 |
# prevent from log(0) |
|
|
54 |
y_true = K.clip(y_true, 10e-8, 1. - 10e-8) |
|
|
55 |
y_predicted = K.clip(y_predicted, 10e-8, 1. - 10e-8) |
|
|
56 |
|
|
|
57 |
# soft dice |
|
|
58 |
if soft_dice: |
|
|
59 |
y_predicted = K.cast(K.greater(y_predicted, threshold_value), dtype='float32') |
|
|
60 |
|
|
|
61 |
intersection = K.sum(y_true * y_predicted) |
|
|
62 |
|
|
|
63 |
# smooth: avoid by zero division |
|
|
64 |
dice = (2. * intersection + smooth) / (K.sum(y_true) + K.sum(y_predicted) + smooth) |
|
|
65 |
|
|
|
66 |
return dice |
|
|
67 |
|
|
|
68 |
def dice_loss(self, y_true: ndarray, y_predicted: ndarray) -> float: |
|
|
69 |
""" Compute the dice loss |
|
|
70 |
|
|
|
71 |
Args: |
|
|
72 |
y_true: reference target. |
|
|
73 |
y_predicted: predicted target by the model. |
|
|
74 |
|
|
|
75 |
Returns: |
|
|
76 |
Returns dice loss. |
|
|
77 |
|
|
|
78 |
""" |
|
|
79 |
return 1 - self.dice_metric(y_true, y_predicted) |
|
|
80 |
|
|
|
81 |
@staticmethod |
|
|
82 |
def binary_cross_entropy_loss(y_true: ndarray = None, y_predicted: ndarray = None) -> float: |
|
|
83 |
""" compute the binary cross entropy loss |
|
|
84 |
|
|
|
85 |
Args: |
|
|
86 |
y_true: reference target. |
|
|
87 |
y_predicted: predicted target by the model. |
|
|
88 |
|
|
|
89 |
Returns: |
|
|
90 |
Returns binary cross entropy between the target and predicted value. |
|
|
91 |
|
|
|
92 |
""" |
|
|
93 |
# prevent from log(0) |
|
|
94 |
y_true = K.clip(y_true, 10e-8, 1. - 10e-8) |
|
|
95 |
y_predicted = K.clip(y_predicted, 10e-8, 1. - 10e-8) |
|
|
96 |
|
|
|
97 |
return K.binary_crossentropy(y_true, y_predicted) |
|
|
98 |
|
|
|
99 |
@staticmethod |
|
|
100 |
def binary_focal_loss(y_true: ndarray = None, y_predicted: ndarray = None, gamma: int = 2, |
|
|
101 |
alpha: float = .25) -> float: |
|
|
102 |
""" computes the focal loss |
|
|
103 |
|
|
|
104 |
Args: |
|
|
105 |
y_true: reference target. |
|
|
106 |
y_predicted: predicted target by the model. |
|
|
107 |
gamma: constant value |
|
|
108 |
alpha: constant value |
|
|
109 |
|
|
|
110 |
Returns: |
|
|
111 |
Returns focal loss. |
|
|
112 |
|
|
|
113 |
""" |
|
|
114 |
y_true = K.cast(y_true, dtype='float32') |
|
|
115 |
# Define epsilon so that the back-propagation will not result in NaN for 0 divisor case |
|
|
116 |
epsilon = K.epsilon() |
|
|
117 |
# Add the epsilon to prediction value |
|
|
118 |
y_predicted = y_predicted + epsilon |
|
|
119 |
# Clip the prediction value |
|
|
120 |
y_predicted = K.clip(y_predicted, epsilon, 1.0 - epsilon) |
|
|
121 |
|
|
|
122 |
# Calculate p_t |
|
|
123 |
p_t = tf.where(K.equal(y_true, 1), y_predicted, 1 - y_predicted) |
|
|
124 |
# Calculate alpha_t |
|
|
125 |
alpha_factor = K.ones_like(y_true) * alpha |
|
|
126 |
alpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1 - alpha_factor) |
|
|
127 |
# Calculate cross entropy |
|
|
128 |
cross_entropy = -K.log(p_t) |
|
|
129 |
weight = alpha_t * K.pow((1 - p_t), gamma) |
|
|
130 |
# Calculate focal loss f |
|
|
131 |
loss = weight * cross_entropy |
|
|
132 |
# Sum the losses in mini_batch |
|
|
133 |
loss = K.mean(K.sum(loss, axis=1)) |
|
|
134 |
return loss |
|
|
135 |
|
|
|
136 |
@staticmethod |
|
|
137 |
def focal_loss(y_true: ndarray = None, y_predicted: ndarray = None, gamma: int = 2, alpha: float = .25) -> float: |
|
|
138 |
""" computes the focal loss |
|
|
139 |
|
|
|
140 |
Adapted from: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook |
|
|
141 |
""" |
|
|
142 |
|
|
|
143 |
y_true = K.flatten(y_true) |
|
|
144 |
y_predicted = K.flatten(y_predicted) |
|
|
145 |
|
|
|
146 |
bce = K.binary_crossentropy(y_true, y_predicted) |
|
|
147 |
bce_exp = K.exp(-bce) |
|
|
148 |
focal_loss = K.mean(alpha * K.pow((1 - bce_exp), gamma) * bce) |
|
|
149 |
|
|
|
150 |
return focal_loss |
|
|
151 |
|
|
|
152 |
def dice_plus_binary_cross_entropy_loss(self, y_true, y_predicted): |
|
|
153 |
""" compute the average of the sum of dice and binary cross entropy loss. |
|
|
154 |
|
|
|
155 |
Args: |
|
|
156 |
y_true: reference target. |
|
|
157 |
y_predicted: predicted target by the model. |
|
|
158 |
|
|
|
159 |
Returns: |
|
|
160 |
Returns the average of the sum of dice and binary cross entropy losses. |
|
|
161 |
|
|
|
162 |
""" |
|
|
163 |
loss = 0.5 * (self.dice_loss(y_true, y_predicted) + self.binary_cross_entropy_loss(y_true=y_true, |
|
|
164 |
y_predicted=y_predicted)) |
|
|
165 |
|
|
|
166 |
return loss |
|
|
167 |
|
|
|
168 |
def dice_plus_focal_loss(self, y_true: ndarray, y_predicted: ndarray) -> float: |
|
|
169 |
""" compute the sum of the dice and focal loss |
|
|
170 |
|
|
|
171 |
Args: |
|
|
172 |
y_true: reference target. |
|
|
173 |
y_predicted: predicted target by the model. |
|
|
174 |
|
|
|
175 |
Returns: |
|
|
176 |
Returns the sum of the dice and focal loss. |
|
|
177 |
|
|
|
178 |
""" |
|
|
179 |
return self.dice_loss(y_true, y_predicted) + self.binary_focal_loss(y_true, y_predicted) |
|
|
180 |
|
|
|
181 |
@staticmethod |
|
|
182 |
def iou_loss(y_true, y_predicted, smooth=1e-8): |
|
|
183 |
""" compute the intersection over union loss. |
|
|
184 |
|
|
|
185 |
Args: |
|
|
186 |
y_true: reference target. |
|
|
187 |
y_predicted: predicted target by the model. |
|
|
188 |
smooth: avoid division by zero. |
|
|
189 |
|
|
|
190 |
Returns: |
|
|
191 |
Returns intersection over union loss. |
|
|
192 |
|
|
|
193 |
""" |
|
|
194 |
y_true = K.flatten(y_true) |
|
|
195 |
y_predicted = K.flatten(y_predicted) |
|
|
196 |
|
|
|
197 |
intersection = K.sum(K.dot(y_true, y_predicted)) |
|
|
198 |
total = K.sum(y_true) + K.sum(y_predicted) |
|
|
199 |
|
|
|
200 |
union = total - intersection |
|
|
201 |
|
|
|
202 |
iou = (intersection + smooth) / (union + smooth) |
|
|
203 |
|
|
|
204 |
return iou |