Diff of /util.py [000000] .. [408896]

Switch to unified view

a b/util.py
1
"""Contains custom loss, dice coefficient, and optimizer classes."""
2
import tensorflow as tf
3
4
5
class DiceVAELoss(object):
6
    """Implements custom dice-VAE loss."""
7
    def __init__(self,
8
                 name='custom_loss',
9
                 data_format='channels_last',
10
                 **kwargs):
11
        self.axis = (0, 1, 2, 3) if data_format == 'channels_last' else (0, 2, 3, 4)
12
13
    def __call__(self, x, y, y_pred, y_vae, z_mean, z_logvar, sample_weight=None):
14
        l2_loss = tf.reduce_mean((x - y_vae) ** 2)
15
        kld_loss = tf.reduce_mean(z_mean ** 2 + tf.math.exp(z_logvar) - z_logvar - 1.0)
16
17
        # Calculate dice loss.
18
        intersection = tf.reduce_sum(y_pred * y, axis=self.axis)
19
        pred = tf.reduce_sum(y_pred ** 2, axis=self.axis)
20
        true = tf.reduce_sum(y ** 2, axis=self.axis)
21
22
        dice_loss = tf.reduce_mean(1.0 - (2.0 * intersection + 1.0) / (pred + true + 1.0))
23
24
        return dice_loss + 0.1*l2_loss + 0.1*kld_loss
25
26
27
class DiceCoefficient(object):
28
    """Implements dice coefficient for binary classification."""
29
    def __init__(self,
30
                 name='dice_coefficient',
31
                 data_format='channels_last'):
32
        self.name = name
33
        self.data_format = data_format
34
35
    def __call__(self, y_true, y_pred):
36
        dice_axes = (0, 1, 2) if self.data_format == 'channels_last' else (0, 2, 3, 4)
37
        onehot_axis = -1 if self.data_format == 'channels_last' else 1
38
39
        # Mask out values that correspond to values < 0.5.
40
        mask = tf.reduce_max(y_pred, axis=onehot_axis, keepdims=True)
41
        mask = tf.cast(mask > 0.5, tf.float32)
42
43
        # Create one-hot encoding of predictions.
44
        out_ch = y_pred.shape[onehot_axis]
45
        y_pred = tf.argmax(y_pred, axis=onehot_axis, output_type=tf.int32)
46
        y_pred = tf.one_hot(y_pred, out_ch, axis=onehot_axis, dtype=tf.float32)
47
        y_pred *= mask
48
49
        # Compute dice score.
50
        intersection = tf.reduce_sum(y_pred * y_true, axis=dice_axes)
51
        pred = tf.reduce_sum(y_pred, axis=dice_axes)
52
        true = tf.reduce_sum(y_true, axis=dice_axes)
53
54
        macroavg = tf.reduce_mean((2.0 * intersection + 1.0) / (pred + true + 1.0))
55
        microavg = tf.reduce_sum(y_pred * y_true) / (tf.reduce_sum(y_pred) + tf.reduce_sum(y_true))
56
57
        return macroavg, microavg
58
59
60
class ScheduledOptim(tf.keras.optimizers.Adam):
61
    """Adam optimizer that allows for scheduling every epoch."""
62
    def __init__(self,
63
                 learning_rate=1e-4,
64
                 beta_1=0.9,
65
                 beta_2=0.999,
66
                 epsilon=1e-7,
67
                 amsgrad=False,
68
                 name='Adam',
69
                 n_epochs=300,
70
                 **kwargs):
71
        super(ScheduledOptim, self).__init__(
72
                                        learning_rate=learning_rate,
73
                                        beta_1=beta_1,
74
                                        beta_2=beta_2,
75
                                        epsilon=epsilon,
76
                                        amsgrad=amsgrad,
77
                                        name=name,
78
                                        **kwargs)
79
        self.init_lr = tf.constant(learning_rate, dtype=tf.float32)
80
        self.n_epochs = float(n_epochs)
81
82
    def __call__(self, epoch):
83
        new_lr = self.init_lr * ((1.0 - epoch / self.n_epochs) ** 0.9)
84
        self._set_hyper('learning_rate', new_lr)