Diff of /util.py [000000] .. [408896]

Switch to side-by-side view

--- a
+++ b/util.py
@@ -0,0 +1,84 @@
+"""Contains custom loss, dice coefficient, and optimizer classes."""
+import tensorflow as tf
+
+
+class DiceVAELoss(object):
+    """Implements custom dice-VAE loss."""
+    def __init__(self,
+                 name='custom_loss',
+                 data_format='channels_last',
+                 **kwargs):
+        self.axis = (0, 1, 2, 3) if data_format == 'channels_last' else (0, 2, 3, 4)
+
+    def __call__(self, x, y, y_pred, y_vae, z_mean, z_logvar, sample_weight=None):
+        l2_loss = tf.reduce_mean((x - y_vae) ** 2)
+        kld_loss = tf.reduce_mean(z_mean ** 2 + tf.math.exp(z_logvar) - z_logvar - 1.0)
+
+        # Calculate dice loss.
+        intersection = tf.reduce_sum(y_pred * y, axis=self.axis)
+        pred = tf.reduce_sum(y_pred ** 2, axis=self.axis)
+        true = tf.reduce_sum(y ** 2, axis=self.axis)
+
+        dice_loss = tf.reduce_mean(1.0 - (2.0 * intersection + 1.0) / (pred + true + 1.0))
+
+        return dice_loss + 0.1*l2_loss + 0.1*kld_loss
+
+
+class DiceCoefficient(object):
+    """Implements dice coefficient for binary classification."""
+    def __init__(self,
+                 name='dice_coefficient',
+                 data_format='channels_last'):
+        self.name = name
+        self.data_format = data_format
+
+    def __call__(self, y_true, y_pred):
+        dice_axes = (0, 1, 2) if self.data_format == 'channels_last' else (0, 2, 3, 4)
+        onehot_axis = -1 if self.data_format == 'channels_last' else 1
+
+        # Mask out values that correspond to values < 0.5.
+        mask = tf.reduce_max(y_pred, axis=onehot_axis, keepdims=True)
+        mask = tf.cast(mask > 0.5, tf.float32)
+
+        # Create one-hot encoding of predictions.
+        out_ch = y_pred.shape[onehot_axis]
+        y_pred = tf.argmax(y_pred, axis=onehot_axis, output_type=tf.int32)
+        y_pred = tf.one_hot(y_pred, out_ch, axis=onehot_axis, dtype=tf.float32)
+        y_pred *= mask
+
+        # Compute dice score.
+        intersection = tf.reduce_sum(y_pred * y_true, axis=dice_axes)
+        pred = tf.reduce_sum(y_pred, axis=dice_axes)
+        true = tf.reduce_sum(y_true, axis=dice_axes)
+
+        macroavg = tf.reduce_mean((2.0 * intersection + 1.0) / (pred + true + 1.0))
+        microavg = tf.reduce_sum(y_pred * y_true) / (tf.reduce_sum(y_pred) + tf.reduce_sum(y_true))
+
+        return macroavg, microavg
+
+
+class ScheduledOptim(tf.keras.optimizers.Adam):
+    """Adam optimizer that allows for scheduling every epoch."""
+    def __init__(self,
+                 learning_rate=1e-4,
+                 beta_1=0.9,
+                 beta_2=0.999,
+                 epsilon=1e-7,
+                 amsgrad=False,
+                 name='Adam',
+                 n_epochs=300,
+                 **kwargs):
+        super(ScheduledOptim, self).__init__(
+                                        learning_rate=learning_rate,
+                                        beta_1=beta_1,
+                                        beta_2=beta_2,
+                                        epsilon=epsilon,
+                                        amsgrad=amsgrad,
+                                        name=name,
+                                        **kwargs)
+        self.init_lr = tf.constant(learning_rate, dtype=tf.float32)
+        self.n_epochs = float(n_epochs)
+
+    def __call__(self, epoch):
+        new_lr = self.init_lr * ((1.0 - epoch / self.n_epochs) ** 0.9)
+        self._set_hyper('learning_rate', new_lr)