--- a +++ b/docproduct/metrics.py @@ -0,0 +1,12 @@ +import tensorflow as tf + + +def qa_pair_batch_accuracy(y_true, y_pred): + y_true = tf.eye(tf.shape(y_pred)[0]) + q_embedding, a_embedding = tf.unstack(y_pred, axis=1) + similarity_matrix = tf.matmul( + q_embedding, a_embedding, transpose_b=True) + y_true = tf.argmax(y_true, axis=1) + y_pred = tf.argmax(similarity_matrix, axis=1) + acc = tf.reduce_mean(tf.cast(tf.equal(y_pred, y_true), tf.float32)) + return acc