Switch to unified view

a b/summarization/rouge_git/rouge.py
1
# -*- coding: utf-8 -*-
2
from __future__ import absolute_import
3
import six
4
# import rouge_git.rouge_score as rouge_score
5
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__)))) # Didn't do anything
6
import rouge_score# as rouge_score
7
import io
8
import os
9
10
class FilesRouge:
11
    def __init__(self, *args, **kwargs):
12
        """See the `Rouge` class for args
13
        """
14
        self.rouge = Rouge(*args, **kwargs)
15
16
    def _check_files(self, hyp_path, ref_path):
17
        assert(os.path.isfile(hyp_path))
18
        assert(os.path.isfile(ref_path))
19
20
        def line_count(path):
21
            count = 0
22
            with open(path, "rb") as f:
23
                for line in f:
24
                    count += 1
25
            return count
26
27
        hyp_lc = line_count(hyp_path)
28
        ref_lc = line_count(ref_path)
29
        assert(hyp_lc == ref_lc)
30
31
    def get_scores(self, hyp_path, ref_path, avg=False, ignore_empty=False):
32
        """Calculate ROUGE scores between each pair of
33
        lines (hyp_file[i], ref_file[i]).
34
        Args:
35
          * hyp_path: hypothesis file path
36
          * ref_path: references file path
37
          * avg (False): whether to get an average scores or a list
38
        """
39
        self._check_files(hyp_path, ref_path)
40
41
        with io.open(hyp_path, encoding="utf-8", mode="r") as hyp_file:
42
            hyps = [line[:-1] for line in hyp_file]
43
44
        with io.open(ref_path, encoding="utf-8", mode="r") as ref_file:
45
            refs = [line[:-1] for line in ref_file]
46
47
        return self.rouge.get_scores(hyps, refs, avg=avg,
48
                                     ignore_empty=ignore_empty)
49
50
51
class Rouge:
52
    DEFAULT_METRICS = ["rouge-1", "rouge-2", "rouge-l"]
53
    AVAILABLE_METRICS = {
54
        "rouge-1": lambda hyp, ref, **k: rouge_score.rouge_n(hyp, ref, 1, **k),
55
        "rouge-2": lambda hyp, ref, **k: rouge_score.rouge_n(hyp, ref, 2, **k),
56
        "rouge-l": lambda hyp, ref, **k:
57
            rouge_score.rouge_l_summary_level(hyp, ref, **k),
58
    }
59
    DEFAULT_STATS = ["f", "p", "r"]
60
    AVAILABLE_STATS = ["f", "p", "r"]
61
62
    def __init__(self, metrics=None, stats=None, return_lengths=False,
63
                 raw_results=False, exclusive=False):
64
        self.return_lengths = return_lengths
65
        self.raw_results = raw_results
66
        self.exclusive = exclusive
67
68
        if metrics is not None:
69
            self.metrics = [m.lower() for m in metrics]
70
71
            for m in self.metrics:
72
                if m not in Rouge.AVAILABLE_METRICS:
73
                    raise ValueError("Unknown metric '%s'" % m)
74
        else:
75
            self.metrics = Rouge.DEFAULT_METRICS
76
77
        if self.raw_results:
78
            self.stats = ["hyp", "ref", "overlap"]
79
        else:
80
            if stats is not None:
81
                self.stats = [s.lower() for s in stats]
82
83
                for s in self.stats:
84
                    if s not in Rouge.AVAILABLE_STATS:
85
                        raise ValueError("Unknown stat '%s'" % s)
86
            else:
87
                self.stats = Rouge.DEFAULT_STATS
88
89
    def get_scores(self, hyps, refs, avg=False, ignore_empty=False):
90
        if isinstance(hyps, six.string_types):
91
            hyps, refs = [hyps], [refs]
92
93
        if ignore_empty:
94
            # Filter out hyps of 0 length
95
            hyps_and_refs = zip(hyps, refs)
96
            hyps_and_refs = [_ for _ in hyps_and_refs
97
                             if len(_[0]) > 0
98
                             and len(_[1]) > 0]
99
            hyps, refs = zip(*hyps_and_refs)
100
101
        assert(isinstance(hyps, type(refs)))
102
        assert(len(hyps) == len(refs))
103
104
        if not avg:
105
            return self._get_scores(hyps, refs)
106
        return self._get_avg_scores(hyps, refs)
107
108
    def _get_scores(self, hyps, refs):
109
        scores = []
110
        for hyp, ref in zip(hyps, refs):
111
            sen_score = {}
112
113
            hyp = [" ".join(_.split()) for _ in hyp.split(".") if len(_) > 0]
114
            ref = [" ".join(_.split()) for _ in ref.split(".") if len(_) > 0]
115
116
            for m in self.metrics:
117
                fn = Rouge.AVAILABLE_METRICS[m]
118
                sc = fn(
119
                    hyp,
120
                    ref,
121
                    raw_results=self.raw_results,
122
                    exclusive=self.exclusive)
123
                sen_score[m] = {s: sc[s] for s in self.stats}
124
125
            if self.return_lengths:
126
                lengths = {
127
                    "hyp": len(" ".join(hyp).split()),
128
                    "ref": len(" ".join(ref).split())
129
                }
130
                sen_score["lengths"] = lengths
131
            scores.append(sen_score)
132
        return scores
133
134
    def _get_avg_scores(self, hyps, refs):
135
        scores = {m: {s: 0 for s in self.stats} for m in self.metrics}
136
        if self.return_lengths:
137
            scores["lengths"] = {"hyp": 0, "ref": 0}
138
139
        count = 0
140
        for (hyp, ref) in zip(hyps, refs):
141
            hyp = [" ".join(_.split()) for _ in hyp.split(".") if len(_) > 0]
142
            ref = [" ".join(_.split()) for _ in ref.split(".") if len(_) > 0]
143
144
            for m in self.metrics:
145
                fn = Rouge.AVAILABLE_METRICS[m]
146
                sc = fn(hyp, ref, exclusive=self.exclusive)
147
                scores[m] = {s: scores[m][s] + sc[s] for s in self.stats}
148
149
            if self.return_lengths:
150
                scores["lengths"]["hyp"] += len(" ".join(hyp).split())
151
                scores["lengths"]["ref"] += len(" ".join(ref).split())
152
153
            count += 1
154
        avg_scores = {
155
            m: {s: scores[m][s] / count for s in self.stats}
156
            for m in self.metrics
157
        }
158
159
        if self.return_lengths:
160
            avg_scores["lengths"] = {
161
                k: scores["lengths"][k] / count
162
                for k in ["hyp", "ref"]
163
            }
164
165
        return avg_scores