a b/edsnlp/pipes/misc/split/split.py
1
import random
2
import re
3
from typing import Iterable, Optional
4
5
from spacy.tokens import Doc, Span
6
7
import edsnlp
8
from edsnlp import Pipeline
9
10
EMPTY = object()
11
12
13
def make_shifter(start, end, new_doc):
14
    cache = {}
15
16
    def rec(obj):
17
        if isinstance(obj, Span):
18
            if obj in cache:
19
                return cache[obj]
20
            if obj.end > start and obj.start < end:
21
                res = Span(
22
                    new_doc,
23
                    max(0, obj.start - start),
24
                    min(obj.end - start, end - start),
25
                    obj.label,
26
                )
27
            else:
28
                res = EMPTY
29
            cache[obj] = res
30
        elif isinstance(obj, (list, tuple, set)):
31
            res = type(obj)(
32
                filter(
33
                    lambda x: x is not EMPTY,
34
                    (rec(span) for span in obj),
35
                )
36
            )
37
        elif isinstance(obj, dict):
38
            res = {}
39
            for k, v in obj.items():
40
                new_v = rec(v)
41
                if new_v is not EMPTY:
42
                    res[k] = new_v
43
        else:
44
            res = obj
45
        return res
46
47
    return rec
48
49
50
def subset_doc(doc: Doc, start: int, end: int) -> Doc:
51
    """
52
    Subset a doc given a start and end index.
53
54
    Parameters
55
    ----------
56
    doc: Doc
57
        The doc to subset
58
    start: int
59
        The start index
60
    end: int
61
        The end index
62
63
    Returns
64
    -------
65
    Doc
66
    """
67
    new_doc = doc[start:end].as_doc()
68
69
    shifter = make_shifter(start, end, new_doc)
70
71
    char_beg = doc[start].idx if start < len(doc) else 0
72
    char_end = doc[end - 1].idx + len(doc[end - 1].text)
73
    for k, val in list(doc.user_data.items()):
74
        new_value = shifter(val)
75
        if k[0] == "._." and new_value is not EMPTY:
76
            new_doc.user_data[
77
                (
78
                    k[0],
79
                    k[1],
80
                    None if k[2] is None else max(0, k[2] - char_beg),
81
                    None if k[3] is None else min(k[3] - char_beg, char_end - char_beg),
82
                )
83
            ] = new_value
84
85
    for name, group in doc.spans.items():
86
        new_doc.spans[name] = shifter(list(group))
87
88
    return new_doc
89
90
91
@edsnlp.registry.factory.register("eds.split", spacy_compatible=False)
92
class Split:
93
    def __init__(
94
        self,
95
        nlp: Optional[Pipeline] = None,
96
        name: str = "split",
97
        *,
98
        max_length: int = 0,
99
        regex: Optional[str] = "\n{2,}",
100
        filter_expr: Optional[str] = None,
101
        randomize: float = 0.0,
102
    ):
103
        """
104
        The `eds.split` component splits a document into multiple documents
105
        based on a regex pattern or a maximum length.
106
107
        !!! warning "Not for pipelines"
108
109
            This component is not meant to be used in a pipeline, but rather
110
            as a preprocessing step when dealing with a stream of documents
111
            as in the example below.
112
113
        Examples
114
        --------
115
116
        ```python
117
        import edsnlp, edsnlp.pipes as eds
118
119
        # Create the stream
120
        stream = edsnlp.data.from_iterable(
121
            ["Sentence 1\\n\\nThis is another longer sentence more than 5 words"]
122
        )
123
124
        # Convert texts into docs
125
        stream = stream.map_pipeline(edsnlp.blank("eds"))
126
127
        # Apply the split component
128
        stream = stream.map(eds.split(max_length=5, regex="\\n{2,}"))
129
130
        print(" | ".join(doc.text.strip() for doc in stream))
131
        # Out: Sentence 1 | This is another longer sentence | more than 5 words
132
        ```
133
134
        Parameters
135
        ----------
136
        nlp: Optional[Pipeline]
137
            The pipeline
138
        name: str
139
            The component name
140
        max_length: int
141
            The maximum length of the produced documents.
142
            If 0, the document will not be split based on length.
143
        regex: Optional[str]
144
            The regex pattern to split the document on
145
        filter_expr: Optional[str]
146
            An optional filter expression to filter the produced documents
147
        randomize: float
148
            The randomization factor to split the documents, to avoid
149
            producing documents that are all `max_length` tokens long
150
            (0 means all documents will have the maximum possible length
151
            while 1 will produce documents with a length varying between
152
            0 and `max_length` uniformly)
153
        """
154
        self.max_length = max_length
155
        self.regex = re.compile(regex) if regex else None
156
        self.filter_fn = eval(f"lambda doc:{filter_expr}") if filter_expr else None
157
        self.randomize = randomize
158
159
    def __call__(self, doc: Doc) -> Iterable[Doc]:
160
        for sub_doc in self.split_doc(doc):
161
            if sub_doc.text.strip():
162
                if not self.filter_fn or self.filter_fn(sub_doc):
163
                    yield sub_doc
164
165
    def split_doc(
166
        self,
167
        doc: Doc,
168
    ) -> Iterable[Doc]:
169
        """
170
        Split a doc into multiple docs of max_length tokens.
171
172
        Parameters
173
        ----------
174
        doc: Doc
175
            The doc to split
176
177
        Returns
178
        -------
179
        Iterable[Doc]
180
        """
181
        max_length = self.max_length
182
183
        if max_length <= 0 and self.regex is None:
184
            yield doc
185
            return
186
187
        start = 0
188
        end = 0
189
        # Split doc into segments between the regex matches
190
        matches = (
191
            [
192
                next(
193
                    m.end(g)
194
                    for g in range(self.regex.groups + 1)
195
                    if m.end(g) is not None
196
                )
197
                for m in self.regex.finditer(doc.text)
198
            ]
199
            if self.regex
200
            else []
201
        ) + [len(doc.text)]
202
        word_ends = doc.to_array("IDX") + doc.to_array("LENGTH")
203
        segments_end = word_ends.searchsorted([m for m in matches], side="right")
204
205
        for end in segments_end:
206
            # If the sentence adds too many tokens
207
            if end - start > max_length > 0:
208
                # But the current buffer too large
209
                while end - start > max_length:
210
                    subset_end = (
211
                        start + int(max_length * (random.random() ** self.randomize))
212
                        if self.randomize
213
                        else start + max_length
214
                    )
215
                    yield subset_doc(doc, start, subset_end)
216
                    start = subset_end
217
                yield subset_doc(doc, start, end)
218
                start = end
219
220
            if end > start:
221
                yield subset_doc(doc, start, end)
222
                start = end
223
224
        yield subset_doc(doc, start, end)