|
a |
|
b/util.py |
|
|
1 |
"""Contains custom loss, dice coefficient, and optimizer classes.""" |
|
|
2 |
import tensorflow as tf |
|
|
3 |
|
|
|
4 |
|
|
|
5 |
class DiceVAELoss(object): |
|
|
6 |
"""Implements custom dice-VAE loss.""" |
|
|
7 |
def __init__(self, |
|
|
8 |
name='custom_loss', |
|
|
9 |
data_format='channels_last', |
|
|
10 |
**kwargs): |
|
|
11 |
self.axis = (0, 1, 2, 3) if data_format == 'channels_last' else (0, 2, 3, 4) |
|
|
12 |
|
|
|
13 |
def __call__(self, x, y, y_pred, y_vae, z_mean, z_logvar, sample_weight=None): |
|
|
14 |
l2_loss = tf.reduce_mean((x - y_vae) ** 2) |
|
|
15 |
kld_loss = tf.reduce_mean(z_mean ** 2 + tf.math.exp(z_logvar) - z_logvar - 1.0) |
|
|
16 |
|
|
|
17 |
# Calculate dice loss. |
|
|
18 |
intersection = tf.reduce_sum(y_pred * y, axis=self.axis) |
|
|
19 |
pred = tf.reduce_sum(y_pred ** 2, axis=self.axis) |
|
|
20 |
true = tf.reduce_sum(y ** 2, axis=self.axis) |
|
|
21 |
|
|
|
22 |
dice_loss = tf.reduce_mean(1.0 - (2.0 * intersection + 1.0) / (pred + true + 1.0)) |
|
|
23 |
|
|
|
24 |
return dice_loss + 0.1*l2_loss + 0.1*kld_loss |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
class DiceCoefficient(object): |
|
|
28 |
"""Implements dice coefficient for binary classification.""" |
|
|
29 |
def __init__(self, |
|
|
30 |
name='dice_coefficient', |
|
|
31 |
data_format='channels_last'): |
|
|
32 |
self.name = name |
|
|
33 |
self.data_format = data_format |
|
|
34 |
|
|
|
35 |
def __call__(self, y_true, y_pred): |
|
|
36 |
dice_axes = (0, 1, 2) if self.data_format == 'channels_last' else (0, 2, 3, 4) |
|
|
37 |
onehot_axis = -1 if self.data_format == 'channels_last' else 1 |
|
|
38 |
|
|
|
39 |
# Mask out values that correspond to values < 0.5. |
|
|
40 |
mask = tf.reduce_max(y_pred, axis=onehot_axis, keepdims=True) |
|
|
41 |
mask = tf.cast(mask > 0.5, tf.float32) |
|
|
42 |
|
|
|
43 |
# Create one-hot encoding of predictions. |
|
|
44 |
out_ch = y_pred.shape[onehot_axis] |
|
|
45 |
y_pred = tf.argmax(y_pred, axis=onehot_axis, output_type=tf.int32) |
|
|
46 |
y_pred = tf.one_hot(y_pred, out_ch, axis=onehot_axis, dtype=tf.float32) |
|
|
47 |
y_pred *= mask |
|
|
48 |
|
|
|
49 |
# Compute dice score. |
|
|
50 |
intersection = tf.reduce_sum(y_pred * y_true, axis=dice_axes) |
|
|
51 |
pred = tf.reduce_sum(y_pred, axis=dice_axes) |
|
|
52 |
true = tf.reduce_sum(y_true, axis=dice_axes) |
|
|
53 |
|
|
|
54 |
macroavg = tf.reduce_mean((2.0 * intersection + 1.0) / (pred + true + 1.0)) |
|
|
55 |
microavg = tf.reduce_sum(y_pred * y_true) / (tf.reduce_sum(y_pred) + tf.reduce_sum(y_true)) |
|
|
56 |
|
|
|
57 |
return macroavg, microavg |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
class ScheduledOptim(tf.keras.optimizers.Adam): |
|
|
61 |
"""Adam optimizer that allows for scheduling every epoch.""" |
|
|
62 |
def __init__(self, |
|
|
63 |
learning_rate=1e-4, |
|
|
64 |
beta_1=0.9, |
|
|
65 |
beta_2=0.999, |
|
|
66 |
epsilon=1e-7, |
|
|
67 |
amsgrad=False, |
|
|
68 |
name='Adam', |
|
|
69 |
n_epochs=300, |
|
|
70 |
**kwargs): |
|
|
71 |
super(ScheduledOptim, self).__init__( |
|
|
72 |
learning_rate=learning_rate, |
|
|
73 |
beta_1=beta_1, |
|
|
74 |
beta_2=beta_2, |
|
|
75 |
epsilon=epsilon, |
|
|
76 |
amsgrad=amsgrad, |
|
|
77 |
name=name, |
|
|
78 |
**kwargs) |
|
|
79 |
self.init_lr = tf.constant(learning_rate, dtype=tf.float32) |
|
|
80 |
self.n_epochs = float(n_epochs) |
|
|
81 |
|
|
|
82 |
def __call__(self, epoch): |
|
|
83 |
new_lr = self.init_lr * ((1.0 - epoch / self.n_epochs) ** 0.9) |
|
|
84 |
self._set_hyper('learning_rate', new_lr) |