|
a |
|
b/summarization/rouge_git/rouge.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
from __future__ import absolute_import |
|
|
3 |
import six |
|
|
4 |
# import rouge_git.rouge_score as rouge_score |
|
|
5 |
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__)))) # Didn't do anything |
|
|
6 |
import rouge_score# as rouge_score |
|
|
7 |
import io |
|
|
8 |
import os |
|
|
9 |
|
|
|
10 |
class FilesRouge: |
|
|
11 |
def __init__(self, *args, **kwargs): |
|
|
12 |
"""See the `Rouge` class for args |
|
|
13 |
""" |
|
|
14 |
self.rouge = Rouge(*args, **kwargs) |
|
|
15 |
|
|
|
16 |
def _check_files(self, hyp_path, ref_path): |
|
|
17 |
assert(os.path.isfile(hyp_path)) |
|
|
18 |
assert(os.path.isfile(ref_path)) |
|
|
19 |
|
|
|
20 |
def line_count(path): |
|
|
21 |
count = 0 |
|
|
22 |
with open(path, "rb") as f: |
|
|
23 |
for line in f: |
|
|
24 |
count += 1 |
|
|
25 |
return count |
|
|
26 |
|
|
|
27 |
hyp_lc = line_count(hyp_path) |
|
|
28 |
ref_lc = line_count(ref_path) |
|
|
29 |
assert(hyp_lc == ref_lc) |
|
|
30 |
|
|
|
31 |
def get_scores(self, hyp_path, ref_path, avg=False, ignore_empty=False): |
|
|
32 |
"""Calculate ROUGE scores between each pair of |
|
|
33 |
lines (hyp_file[i], ref_file[i]). |
|
|
34 |
Args: |
|
|
35 |
* hyp_path: hypothesis file path |
|
|
36 |
* ref_path: references file path |
|
|
37 |
* avg (False): whether to get an average scores or a list |
|
|
38 |
""" |
|
|
39 |
self._check_files(hyp_path, ref_path) |
|
|
40 |
|
|
|
41 |
with io.open(hyp_path, encoding="utf-8", mode="r") as hyp_file: |
|
|
42 |
hyps = [line[:-1] for line in hyp_file] |
|
|
43 |
|
|
|
44 |
with io.open(ref_path, encoding="utf-8", mode="r") as ref_file: |
|
|
45 |
refs = [line[:-1] for line in ref_file] |
|
|
46 |
|
|
|
47 |
return self.rouge.get_scores(hyps, refs, avg=avg, |
|
|
48 |
ignore_empty=ignore_empty) |
|
|
49 |
|
|
|
50 |
|
|
|
51 |
class Rouge: |
|
|
52 |
DEFAULT_METRICS = ["rouge-1", "rouge-2", "rouge-l"] |
|
|
53 |
AVAILABLE_METRICS = { |
|
|
54 |
"rouge-1": lambda hyp, ref, **k: rouge_score.rouge_n(hyp, ref, 1, **k), |
|
|
55 |
"rouge-2": lambda hyp, ref, **k: rouge_score.rouge_n(hyp, ref, 2, **k), |
|
|
56 |
"rouge-l": lambda hyp, ref, **k: |
|
|
57 |
rouge_score.rouge_l_summary_level(hyp, ref, **k), |
|
|
58 |
} |
|
|
59 |
DEFAULT_STATS = ["f", "p", "r"] |
|
|
60 |
AVAILABLE_STATS = ["f", "p", "r"] |
|
|
61 |
|
|
|
62 |
def __init__(self, metrics=None, stats=None, return_lengths=False, |
|
|
63 |
raw_results=False, exclusive=False): |
|
|
64 |
self.return_lengths = return_lengths |
|
|
65 |
self.raw_results = raw_results |
|
|
66 |
self.exclusive = exclusive |
|
|
67 |
|
|
|
68 |
if metrics is not None: |
|
|
69 |
self.metrics = [m.lower() for m in metrics] |
|
|
70 |
|
|
|
71 |
for m in self.metrics: |
|
|
72 |
if m not in Rouge.AVAILABLE_METRICS: |
|
|
73 |
raise ValueError("Unknown metric '%s'" % m) |
|
|
74 |
else: |
|
|
75 |
self.metrics = Rouge.DEFAULT_METRICS |
|
|
76 |
|
|
|
77 |
if self.raw_results: |
|
|
78 |
self.stats = ["hyp", "ref", "overlap"] |
|
|
79 |
else: |
|
|
80 |
if stats is not None: |
|
|
81 |
self.stats = [s.lower() for s in stats] |
|
|
82 |
|
|
|
83 |
for s in self.stats: |
|
|
84 |
if s not in Rouge.AVAILABLE_STATS: |
|
|
85 |
raise ValueError("Unknown stat '%s'" % s) |
|
|
86 |
else: |
|
|
87 |
self.stats = Rouge.DEFAULT_STATS |
|
|
88 |
|
|
|
89 |
def get_scores(self, hyps, refs, avg=False, ignore_empty=False): |
|
|
90 |
if isinstance(hyps, six.string_types): |
|
|
91 |
hyps, refs = [hyps], [refs] |
|
|
92 |
|
|
|
93 |
if ignore_empty: |
|
|
94 |
# Filter out hyps of 0 length |
|
|
95 |
hyps_and_refs = zip(hyps, refs) |
|
|
96 |
hyps_and_refs = [_ for _ in hyps_and_refs |
|
|
97 |
if len(_[0]) > 0 |
|
|
98 |
and len(_[1]) > 0] |
|
|
99 |
hyps, refs = zip(*hyps_and_refs) |
|
|
100 |
|
|
|
101 |
assert(isinstance(hyps, type(refs))) |
|
|
102 |
assert(len(hyps) == len(refs)) |
|
|
103 |
|
|
|
104 |
if not avg: |
|
|
105 |
return self._get_scores(hyps, refs) |
|
|
106 |
return self._get_avg_scores(hyps, refs) |
|
|
107 |
|
|
|
108 |
def _get_scores(self, hyps, refs): |
|
|
109 |
scores = [] |
|
|
110 |
for hyp, ref in zip(hyps, refs): |
|
|
111 |
sen_score = {} |
|
|
112 |
|
|
|
113 |
hyp = [" ".join(_.split()) for _ in hyp.split(".") if len(_) > 0] |
|
|
114 |
ref = [" ".join(_.split()) for _ in ref.split(".") if len(_) > 0] |
|
|
115 |
|
|
|
116 |
for m in self.metrics: |
|
|
117 |
fn = Rouge.AVAILABLE_METRICS[m] |
|
|
118 |
sc = fn( |
|
|
119 |
hyp, |
|
|
120 |
ref, |
|
|
121 |
raw_results=self.raw_results, |
|
|
122 |
exclusive=self.exclusive) |
|
|
123 |
sen_score[m] = {s: sc[s] for s in self.stats} |
|
|
124 |
|
|
|
125 |
if self.return_lengths: |
|
|
126 |
lengths = { |
|
|
127 |
"hyp": len(" ".join(hyp).split()), |
|
|
128 |
"ref": len(" ".join(ref).split()) |
|
|
129 |
} |
|
|
130 |
sen_score["lengths"] = lengths |
|
|
131 |
scores.append(sen_score) |
|
|
132 |
return scores |
|
|
133 |
|
|
|
134 |
def _get_avg_scores(self, hyps, refs): |
|
|
135 |
scores = {m: {s: 0 for s in self.stats} for m in self.metrics} |
|
|
136 |
if self.return_lengths: |
|
|
137 |
scores["lengths"] = {"hyp": 0, "ref": 0} |
|
|
138 |
|
|
|
139 |
count = 0 |
|
|
140 |
for (hyp, ref) in zip(hyps, refs): |
|
|
141 |
hyp = [" ".join(_.split()) for _ in hyp.split(".") if len(_) > 0] |
|
|
142 |
ref = [" ".join(_.split()) for _ in ref.split(".") if len(_) > 0] |
|
|
143 |
|
|
|
144 |
for m in self.metrics: |
|
|
145 |
fn = Rouge.AVAILABLE_METRICS[m] |
|
|
146 |
sc = fn(hyp, ref, exclusive=self.exclusive) |
|
|
147 |
scores[m] = {s: scores[m][s] + sc[s] for s in self.stats} |
|
|
148 |
|
|
|
149 |
if self.return_lengths: |
|
|
150 |
scores["lengths"]["hyp"] += len(" ".join(hyp).split()) |
|
|
151 |
scores["lengths"]["ref"] += len(" ".join(ref).split()) |
|
|
152 |
|
|
|
153 |
count += 1 |
|
|
154 |
avg_scores = { |
|
|
155 |
m: {s: scores[m][s] / count for s in self.stats} |
|
|
156 |
for m in self.metrics |
|
|
157 |
} |
|
|
158 |
|
|
|
159 |
if self.return_lengths: |
|
|
160 |
avg_scores["lengths"] = { |
|
|
161 |
k: scores["lengths"][k] / count |
|
|
162 |
for k in ["hyp", "ref"] |
|
|
163 |
} |
|
|
164 |
|
|
|
165 |
return avg_scores |