[1180c1]: / llava / eval / eval_metrics / utils.py

Download this file

41 lines (35 with data), 1.4 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from collections import defaultdict
import re
import math
def brevity_penalty(candidate, references):
c = len(candidate)
ref_lens = (len(reference) for reference in references)
r = min(ref_lens, key=lambda ref_len: (abs(ref_len - c), ref_len))
if c > r:
return 1
else:
return math.exp(1 - r / c)
def modified_precision(candidate, references, n):
max_frequency = defaultdict(int)
min_frequency = defaultdict(int)
candidate_words = split_sentence(candidate, n)
for reference in references:
reference_words = split_sentence(reference, n)
for word in candidate_words:
max_frequency[word] = max(max_frequency[word], reference_words[word])
for word in candidate_words:
min_frequency[word] = min(max_frequency[word], candidate_words[word])
P = sum(min_frequency.values()) / sum(candidate_words.values())
return P
def split_sentence(sentence, n):
words = defaultdict(int)
# tmp_sentence = re.sub("[^a-zA-Z ]", "", sentence)
tmp_sentence = sentence
tmp_sentence = tmp_sentence.lower()
tmp_sentence = tmp_sentence.strip().split()
length = len(tmp_sentence)
for i in range(length - n + 1):
tmp_words = " ".join(tmp_sentence[i: i + n])
if tmp_words:
words[tmp_words] += 1
return words