[8adc28]: / code / benchmark_networks / loss_function.py

Download this file

20 lines (14 with data), 730 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import tensorflow as tf
################################# loss functions ##########################################################
def denoise_loss_mse(denoise, clean):
loss = tf.losses.mean_squared_error(denoise, clean)
return tf.reduce_mean(loss)
def denoise_loss_rmse(denoise, clean): #tmse
loss = tf.losses.mean_squared_error(denoise, clean)
#loss2 = tf.losses.mean_squared_error(noise, clean)
return tf.math.sqrt(tf.reduce_mean(loss))
def denoise_loss_rrmset(denoise, clean): #tmse
rmse1 = denoise_loss_rmse(denoise, clean)
rmse2 = denoise_loss_rmse(clean, tf.zeros(clean.shape[0], tf.float64))
#loss2 = tf.losses.mean_squared_error(noise, clean)
return rmse1/rmse2