Diff of /losses.py [000000] .. [277df6]

Switch to unified view

a b/losses.py
1
import tensorflow as tf
2
import numpy as np
3
from tensorflow.keras import losses
4
5
# Add perceptual loss
6
7
def disc_hinge(dis_real, dis_fake):
8
    real_loss = -1.0 * tf.reduce_mean( tf.minimum(0.0, -1.0 + dis_real) )
9
    fake_loss = -1.0 * tf.reduce_mean( tf.minimum(0.0, -1.0 - dis_fake) )
10
    return (real_loss + fake_loss)/2.0
11
12
def gen_hinge(dis_fake):
13
    fake_loss = -1.0 * tf.reduce_mean( dis_fake )
14
    return fake_loss
15
16
def disc_loss(dis_real, dis_fake, dis_wrong=None):
17
    # real = tf.ones_like(dis_real)
18
    # fake = tf.zeros_like(dis_fake)
19
    real = tf.convert_to_tensor(np.random.randint(low=7, high=12, size=dis_real.shape)/10.0)
20
    fake = tf.convert_to_tensor(np.random.randint(low=0, high=3, size=dis_real.shape)/10.0)
21
    real_loss  = losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(real, dis_real)
22
    fake_loss  = losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(fake, dis_fake)
23
    #wrong_loss = losses.BinaryCrossentropy()(fake, dis_wrong)
24
    #total_loss = (real_loss + fake_loss + wrong_loss)/3.0
25
    # total_loss = tf.reduce_mean(real_loss**2 + fake_loss**2)
26
    total_loss = (real_loss + fake_loss)/2.0
27
    return total_loss
28
29
def gen_loss(dis_fake):
30
    real = tf.ones_like(dis_fake)
31
    fake_loss = losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(real, dis_fake)
32
    return fake_loss
33
34
def critic_loss(D_real, D_fake):
35
    # real = -tf.ones_like(D_real)
36
    # fake =  tf.ones_like(D_fake)
37
    # return tf.reduce_mean(D_real*real) + tf.reduce_mean(D_fake*fake)
38
    return  (tf.reduce_mean(D_fake) - tf.reduce_mean(D_real))
39
40
# def gen_loss(D_fake, real_img, fake_img):
41
#   # fake =  tf.ones_like(D_fake)
42
#   # return tf.reduce_mean(D_fake*fake)
43
#   # real_img = tf.clip_by_value(255.0*(real_img*0.5+0.5), 0.0, 255.0)
44
#   # fake_img = tf.clip_by_value(255.0*(fake_img*0.5+0.5), 0.0, 255.0)
45
#   return -tf.reduce_mean(D_fake) #+ 0.6 * tf.keras.losses.MeanSquaredError()(real_img, fake_img)
46
47
def wgan_gp_loss(D_real, D_fake, Y, Y_cap, model, batch_size):
48
    dloss = (tf.reduce_mean(D_fake) - tf.reduce_mean(D_real))
49
    lam   = 10
50
    eps   = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=0, maxval=1)
51
    x_cap = eps * Y + (1-eps) * Y_cap
52
    with tf.GradientTape() as gptape:
53
        gptape.watch(x_cap)
54
        out = model.critic(x_cap, training=True)
55
    grad  = gptape.gradient(out, [x_cap])[0]
56
        # Fetching only x-gradient
57
    # grad_norm = tf.norm(grad, ord='euclidean', axis=1)
58
    # grad_pen  = tf.reduce_mean(tf.square(grad_norm - 1))
59
    grad_norm = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1, 2, 3]))
60
    grad_pen  = tf.reduce_mean((grad_norm - 1.0) ** 2)
61
    dloss     = dloss + lam * grad_pen
62
    return dloss