import tensorflow as tf
################################# loss functions ##########################################################
def denoise_loss_mse(denoise, clean):
loss = tf.losses.mean_squared_error(denoise, clean)
return tf.reduce_mean(loss)
def denoise_loss_rmse(denoise, clean): #tmse
loss = tf.losses.mean_squared_error(denoise, clean)
#loss2 = tf.losses.mean_squared_error(noise, clean)
return tf.math.sqrt(tf.reduce_mean(loss))
def denoise_loss_rrmset(denoise, clean): #tmse
rmse1 = denoise_loss_rmse(denoise, clean)
rmse2 = denoise_loss_rmse(clean, tf.zeros(clean.shape[0], tf.float64))
#loss2 = tf.losses.mean_squared_error(noise, clean)
return rmse1/rmse2