--- a +++ b/docproduct/loss.py @@ -0,0 +1,21 @@ +import tensorflow as tf + +def qa_pair_loss(y_true, y_pred): + y_true = tf.eye(tf.shape(y_pred)[0])*2-1 + q_embedding, a_embedding = tf.unstack(y_pred, axis=1) + q_embedding = q_embedding / \ + tf.norm(q_embedding, axis=-1, keepdims=True) + a_embedding = a_embedding / \ + tf.norm(a_embedding, axis=-1, keepdims=True) + similarity_matrix = tf.matmul( + q_embedding, a_embedding, transpose_b=True) + return tf.reduce_mean(tf.norm(y_true - similarity_matrix, axis=-1)) + + +def qa_pair_cross_entropy_loss(y_true, y_pred): + y_true = tf.eye(tf.shape(y_pred)[0]) + q_embedding, a_embedding = tf.unstack(y_pred, axis=1) + similarity_matrix = tf.matmul( + a=q_embedding, b=a_embedding, transpose_b=True) + similarity_matrix_softmaxed = tf.nn.softmax(similarity_matrix) + return tf.keras.losses.categorical_crossentropy(y_true, similarity_matrix_softmaxed, from_logits=False)