a b/docproduct/loss.py
1
import tensorflow as tf
2
3
def qa_pair_loss(y_true, y_pred):
4
    y_true = tf.eye(tf.shape(y_pred)[0])*2-1
5
    q_embedding, a_embedding = tf.unstack(y_pred, axis=1)
6
    q_embedding = q_embedding / \
7
        tf.norm(q_embedding, axis=-1, keepdims=True)
8
    a_embedding = a_embedding / \
9
        tf.norm(a_embedding, axis=-1, keepdims=True)
10
    similarity_matrix = tf.matmul(
11
        q_embedding, a_embedding, transpose_b=True)
12
    return tf.reduce_mean(tf.norm(y_true - similarity_matrix, axis=-1))
13
14
15
def qa_pair_cross_entropy_loss(y_true, y_pred):
16
    y_true = tf.eye(tf.shape(y_pred)[0])
17
    q_embedding, a_embedding = tf.unstack(y_pred, axis=1)
18
    similarity_matrix = tf.matmul(
19
        a=q_embedding, b=a_embedding, transpose_b=True)
20
    similarity_matrix_softmaxed = tf.nn.softmax(similarity_matrix)
21
    return tf.keras.losses.categorical_crossentropy(y_true, similarity_matrix_softmaxed, from_logits=False)