--- a +++ b/util.py @@ -0,0 +1,84 @@ +"""Contains custom loss, dice coefficient, and optimizer classes.""" +import tensorflow as tf + + +class DiceVAELoss(object): + """Implements custom dice-VAE loss.""" + def __init__(self, + name='custom_loss', + data_format='channels_last', + **kwargs): + self.axis = (0, 1, 2, 3) if data_format == 'channels_last' else (0, 2, 3, 4) + + def __call__(self, x, y, y_pred, y_vae, z_mean, z_logvar, sample_weight=None): + l2_loss = tf.reduce_mean((x - y_vae) ** 2) + kld_loss = tf.reduce_mean(z_mean ** 2 + tf.math.exp(z_logvar) - z_logvar - 1.0) + + # Calculate dice loss. + intersection = tf.reduce_sum(y_pred * y, axis=self.axis) + pred = tf.reduce_sum(y_pred ** 2, axis=self.axis) + true = tf.reduce_sum(y ** 2, axis=self.axis) + + dice_loss = tf.reduce_mean(1.0 - (2.0 * intersection + 1.0) / (pred + true + 1.0)) + + return dice_loss + 0.1*l2_loss + 0.1*kld_loss + + +class DiceCoefficient(object): + """Implements dice coefficient for binary classification.""" + def __init__(self, + name='dice_coefficient', + data_format='channels_last'): + self.name = name + self.data_format = data_format + + def __call__(self, y_true, y_pred): + dice_axes = (0, 1, 2) if self.data_format == 'channels_last' else (0, 2, 3, 4) + onehot_axis = -1 if self.data_format == 'channels_last' else 1 + + # Mask out values that correspond to values < 0.5. + mask = tf.reduce_max(y_pred, axis=onehot_axis, keepdims=True) + mask = tf.cast(mask > 0.5, tf.float32) + + # Create one-hot encoding of predictions. + out_ch = y_pred.shape[onehot_axis] + y_pred = tf.argmax(y_pred, axis=onehot_axis, output_type=tf.int32) + y_pred = tf.one_hot(y_pred, out_ch, axis=onehot_axis, dtype=tf.float32) + y_pred *= mask + + # Compute dice score. + intersection = tf.reduce_sum(y_pred * y_true, axis=dice_axes) + pred = tf.reduce_sum(y_pred, axis=dice_axes) + true = tf.reduce_sum(y_true, axis=dice_axes) + + macroavg = tf.reduce_mean((2.0 * intersection + 1.0) / (pred + true + 1.0)) + microavg = tf.reduce_sum(y_pred * y_true) / (tf.reduce_sum(y_pred) + tf.reduce_sum(y_true)) + + return macroavg, microavg + + +class ScheduledOptim(tf.keras.optimizers.Adam): + """Adam optimizer that allows for scheduling every epoch.""" + def __init__(self, + learning_rate=1e-4, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-7, + amsgrad=False, + name='Adam', + n_epochs=300, + **kwargs): + super(ScheduledOptim, self).__init__( + learning_rate=learning_rate, + beta_1=beta_1, + beta_2=beta_2, + epsilon=epsilon, + amsgrad=amsgrad, + name=name, + **kwargs) + self.init_lr = tf.constant(learning_rate, dtype=tf.float32) + self.n_epochs = float(n_epochs) + + def __call__(self, epoch): + new_lr = self.init_lr * ((1.0 - epoch / self.n_epochs) ** 0.9) + self._set_hyper('learning_rate', new_lr)