|
a |
|
b/edsnlp/pipes/misc/split/split.py |
|
|
1 |
import random |
|
|
2 |
import re |
|
|
3 |
from typing import Iterable, Optional |
|
|
4 |
|
|
|
5 |
from spacy.tokens import Doc, Span |
|
|
6 |
|
|
|
7 |
import edsnlp |
|
|
8 |
from edsnlp import Pipeline |
|
|
9 |
|
|
|
10 |
EMPTY = object() |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
def make_shifter(start, end, new_doc): |
|
|
14 |
cache = {} |
|
|
15 |
|
|
|
16 |
def rec(obj): |
|
|
17 |
if isinstance(obj, Span): |
|
|
18 |
if obj in cache: |
|
|
19 |
return cache[obj] |
|
|
20 |
if obj.end > start and obj.start < end: |
|
|
21 |
res = Span( |
|
|
22 |
new_doc, |
|
|
23 |
max(0, obj.start - start), |
|
|
24 |
min(obj.end - start, end - start), |
|
|
25 |
obj.label, |
|
|
26 |
) |
|
|
27 |
else: |
|
|
28 |
res = EMPTY |
|
|
29 |
cache[obj] = res |
|
|
30 |
elif isinstance(obj, (list, tuple, set)): |
|
|
31 |
res = type(obj)( |
|
|
32 |
filter( |
|
|
33 |
lambda x: x is not EMPTY, |
|
|
34 |
(rec(span) for span in obj), |
|
|
35 |
) |
|
|
36 |
) |
|
|
37 |
elif isinstance(obj, dict): |
|
|
38 |
res = {} |
|
|
39 |
for k, v in obj.items(): |
|
|
40 |
new_v = rec(v) |
|
|
41 |
if new_v is not EMPTY: |
|
|
42 |
res[k] = new_v |
|
|
43 |
else: |
|
|
44 |
res = obj |
|
|
45 |
return res |
|
|
46 |
|
|
|
47 |
return rec |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
def subset_doc(doc: Doc, start: int, end: int) -> Doc: |
|
|
51 |
""" |
|
|
52 |
Subset a doc given a start and end index. |
|
|
53 |
|
|
|
54 |
Parameters |
|
|
55 |
---------- |
|
|
56 |
doc: Doc |
|
|
57 |
The doc to subset |
|
|
58 |
start: int |
|
|
59 |
The start index |
|
|
60 |
end: int |
|
|
61 |
The end index |
|
|
62 |
|
|
|
63 |
Returns |
|
|
64 |
------- |
|
|
65 |
Doc |
|
|
66 |
""" |
|
|
67 |
new_doc = doc[start:end].as_doc() |
|
|
68 |
|
|
|
69 |
shifter = make_shifter(start, end, new_doc) |
|
|
70 |
|
|
|
71 |
char_beg = doc[start].idx if start < len(doc) else 0 |
|
|
72 |
char_end = doc[end - 1].idx + len(doc[end - 1].text) |
|
|
73 |
for k, val in list(doc.user_data.items()): |
|
|
74 |
new_value = shifter(val) |
|
|
75 |
if k[0] == "._." and new_value is not EMPTY: |
|
|
76 |
new_doc.user_data[ |
|
|
77 |
( |
|
|
78 |
k[0], |
|
|
79 |
k[1], |
|
|
80 |
None if k[2] is None else max(0, k[2] - char_beg), |
|
|
81 |
None if k[3] is None else min(k[3] - char_beg, char_end - char_beg), |
|
|
82 |
) |
|
|
83 |
] = new_value |
|
|
84 |
|
|
|
85 |
for name, group in doc.spans.items(): |
|
|
86 |
new_doc.spans[name] = shifter(list(group)) |
|
|
87 |
|
|
|
88 |
return new_doc |
|
|
89 |
|
|
|
90 |
|
|
|
91 |
@edsnlp.registry.factory.register("eds.split", spacy_compatible=False) |
|
|
92 |
class Split: |
|
|
93 |
def __init__( |
|
|
94 |
self, |
|
|
95 |
nlp: Optional[Pipeline] = None, |
|
|
96 |
name: str = "split", |
|
|
97 |
*, |
|
|
98 |
max_length: int = 0, |
|
|
99 |
regex: Optional[str] = "\n{2,}", |
|
|
100 |
filter_expr: Optional[str] = None, |
|
|
101 |
randomize: float = 0.0, |
|
|
102 |
): |
|
|
103 |
""" |
|
|
104 |
The `eds.split` component splits a document into multiple documents |
|
|
105 |
based on a regex pattern or a maximum length. |
|
|
106 |
|
|
|
107 |
!!! warning "Not for pipelines" |
|
|
108 |
|
|
|
109 |
This component is not meant to be used in a pipeline, but rather |
|
|
110 |
as a preprocessing step when dealing with a stream of documents |
|
|
111 |
as in the example below. |
|
|
112 |
|
|
|
113 |
Examples |
|
|
114 |
-------- |
|
|
115 |
|
|
|
116 |
```python |
|
|
117 |
import edsnlp, edsnlp.pipes as eds |
|
|
118 |
|
|
|
119 |
# Create the stream |
|
|
120 |
stream = edsnlp.data.from_iterable( |
|
|
121 |
["Sentence 1\\n\\nThis is another longer sentence more than 5 words"] |
|
|
122 |
) |
|
|
123 |
|
|
|
124 |
# Convert texts into docs |
|
|
125 |
stream = stream.map_pipeline(edsnlp.blank("eds")) |
|
|
126 |
|
|
|
127 |
# Apply the split component |
|
|
128 |
stream = stream.map(eds.split(max_length=5, regex="\\n{2,}")) |
|
|
129 |
|
|
|
130 |
print(" | ".join(doc.text.strip() for doc in stream)) |
|
|
131 |
# Out: Sentence 1 | This is another longer sentence | more than 5 words |
|
|
132 |
``` |
|
|
133 |
|
|
|
134 |
Parameters |
|
|
135 |
---------- |
|
|
136 |
nlp: Optional[Pipeline] |
|
|
137 |
The pipeline |
|
|
138 |
name: str |
|
|
139 |
The component name |
|
|
140 |
max_length: int |
|
|
141 |
The maximum length of the produced documents. |
|
|
142 |
If 0, the document will not be split based on length. |
|
|
143 |
regex: Optional[str] |
|
|
144 |
The regex pattern to split the document on |
|
|
145 |
filter_expr: Optional[str] |
|
|
146 |
An optional filter expression to filter the produced documents |
|
|
147 |
randomize: float |
|
|
148 |
The randomization factor to split the documents, to avoid |
|
|
149 |
producing documents that are all `max_length` tokens long |
|
|
150 |
(0 means all documents will have the maximum possible length |
|
|
151 |
while 1 will produce documents with a length varying between |
|
|
152 |
0 and `max_length` uniformly) |
|
|
153 |
""" |
|
|
154 |
self.max_length = max_length |
|
|
155 |
self.regex = re.compile(regex) if regex else None |
|
|
156 |
self.filter_fn = eval(f"lambda doc:{filter_expr}") if filter_expr else None |
|
|
157 |
self.randomize = randomize |
|
|
158 |
|
|
|
159 |
def __call__(self, doc: Doc) -> Iterable[Doc]: |
|
|
160 |
for sub_doc in self.split_doc(doc): |
|
|
161 |
if sub_doc.text.strip(): |
|
|
162 |
if not self.filter_fn or self.filter_fn(sub_doc): |
|
|
163 |
yield sub_doc |
|
|
164 |
|
|
|
165 |
def split_doc( |
|
|
166 |
self, |
|
|
167 |
doc: Doc, |
|
|
168 |
) -> Iterable[Doc]: |
|
|
169 |
""" |
|
|
170 |
Split a doc into multiple docs of max_length tokens. |
|
|
171 |
|
|
|
172 |
Parameters |
|
|
173 |
---------- |
|
|
174 |
doc: Doc |
|
|
175 |
The doc to split |
|
|
176 |
|
|
|
177 |
Returns |
|
|
178 |
------- |
|
|
179 |
Iterable[Doc] |
|
|
180 |
""" |
|
|
181 |
max_length = self.max_length |
|
|
182 |
|
|
|
183 |
if max_length <= 0 and self.regex is None: |
|
|
184 |
yield doc |
|
|
185 |
return |
|
|
186 |
|
|
|
187 |
start = 0 |
|
|
188 |
end = 0 |
|
|
189 |
# Split doc into segments between the regex matches |
|
|
190 |
matches = ( |
|
|
191 |
[ |
|
|
192 |
next( |
|
|
193 |
m.end(g) |
|
|
194 |
for g in range(self.regex.groups + 1) |
|
|
195 |
if m.end(g) is not None |
|
|
196 |
) |
|
|
197 |
for m in self.regex.finditer(doc.text) |
|
|
198 |
] |
|
|
199 |
if self.regex |
|
|
200 |
else [] |
|
|
201 |
) + [len(doc.text)] |
|
|
202 |
word_ends = doc.to_array("IDX") + doc.to_array("LENGTH") |
|
|
203 |
segments_end = word_ends.searchsorted([m for m in matches], side="right") |
|
|
204 |
|
|
|
205 |
for end in segments_end: |
|
|
206 |
# If the sentence adds too many tokens |
|
|
207 |
if end - start > max_length > 0: |
|
|
208 |
# But the current buffer too large |
|
|
209 |
while end - start > max_length: |
|
|
210 |
subset_end = ( |
|
|
211 |
start + int(max_length * (random.random() ** self.randomize)) |
|
|
212 |
if self.randomize |
|
|
213 |
else start + max_length |
|
|
214 |
) |
|
|
215 |
yield subset_doc(doc, start, subset_end) |
|
|
216 |
start = subset_end |
|
|
217 |
yield subset_doc(doc, start, end) |
|
|
218 |
start = end |
|
|
219 |
|
|
|
220 |
if end > start: |
|
|
221 |
yield subset_doc(doc, start, end) |
|
|
222 |
start = end |
|
|
223 |
|
|
|
224 |
yield subset_doc(doc, start, end) |