|
a |
|
b/edsnlp/pipes/base.py |
|
|
1 |
import abc |
|
|
2 |
import inspect |
|
|
3 |
import warnings |
|
|
4 |
from operator import attrgetter |
|
|
5 |
from typing import ( |
|
|
6 |
Iterable, |
|
|
7 |
List, |
|
|
8 |
Optional, |
|
|
9 |
Tuple, |
|
|
10 |
) |
|
|
11 |
|
|
|
12 |
from spacy.tokens import Doc, Span |
|
|
13 |
|
|
|
14 |
from edsnlp.core import PipelineProtocol |
|
|
15 |
from edsnlp.core.registries import CurriedFactory |
|
|
16 |
from edsnlp.utils.span_getters import ( |
|
|
17 |
SpanGetter, # noqa: F401 |
|
|
18 |
SpanGetterArg, # noqa: F401 |
|
|
19 |
SpanSetter, |
|
|
20 |
SpanSetterArg, |
|
|
21 |
get_spans, # noqa: F401 |
|
|
22 |
set_spans, |
|
|
23 |
validate_span_getter, # noqa: F401 |
|
|
24 |
validate_span_setter, |
|
|
25 |
) |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
def value_getter(span: Span): |
|
|
29 |
key = span._._get_key("value") |
|
|
30 |
if key in span.doc.user_data: |
|
|
31 |
return span.doc.user_data[key] |
|
|
32 |
return span._.get(span.label_) if span._.has(span.label_) else None |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
class BaseComponentMeta(abc.ABCMeta): |
|
|
36 |
def __init__(cls, name, bases, dct): |
|
|
37 |
super().__init__(name, bases, dct) |
|
|
38 |
|
|
|
39 |
sig = inspect.signature(cls.__init__) |
|
|
40 |
sig = sig.replace(parameters=tuple(sig.parameters.values())[1:]) |
|
|
41 |
cls.__signature__ = sig |
|
|
42 |
|
|
|
43 |
def __call__(cls, nlp=inspect.Signature.empty, *args, **kwargs): |
|
|
44 |
# If this component is missing the nlp argument, we curry it with the |
|
|
45 |
# provided arguments and return a CurriedFactory object. |
|
|
46 |
sig = inspect.signature(cls.__init__) |
|
|
47 |
try: |
|
|
48 |
bound = sig.bind_partial(None, nlp, *args, **kwargs) |
|
|
49 |
bound.arguments.pop("self", None) |
|
|
50 |
if ( |
|
|
51 |
"nlp" in sig.parameters |
|
|
52 |
and sig.parameters["nlp"].default is sig.empty |
|
|
53 |
and bound.arguments.get("nlp", sig.empty) is sig.empty |
|
|
54 |
): |
|
|
55 |
return CurriedFactory(cls, bound.arguments) |
|
|
56 |
if nlp is inspect.Signature.empty: |
|
|
57 |
bound.arguments.pop("nlp", None) |
|
|
58 |
except TypeError: # pragma: no cover |
|
|
59 |
if nlp is inspect.Signature.empty: |
|
|
60 |
super().__call__(*args, **kwargs) |
|
|
61 |
else: |
|
|
62 |
super().__call__(nlp, *args, **kwargs) |
|
|
63 |
return super().__call__(**bound.arguments) |
|
|
64 |
|
|
|
65 |
|
|
|
66 |
class BaseComponent(abc.ABC, metaclass=BaseComponentMeta): |
|
|
67 |
""" |
|
|
68 |
The `BaseComponent` adds a `set_extensions` method, |
|
|
69 |
called at the creation of the object. |
|
|
70 |
|
|
|
71 |
It helps decouple the initialisation of the pipeline from |
|
|
72 |
the creation of extensions, and is particularly usefull when |
|
|
73 |
distributing EDSNLP on a cluster, since the serialisation mechanism |
|
|
74 |
imposes that the extensions be reset. |
|
|
75 |
""" |
|
|
76 |
|
|
|
77 |
def __init__( |
|
|
78 |
self, |
|
|
79 |
nlp: Optional[PipelineProtocol] = None, |
|
|
80 |
name: Optional[str] = None, |
|
|
81 |
*args, |
|
|
82 |
**kwargs, |
|
|
83 |
): |
|
|
84 |
super().__init__(*args, **kwargs) |
|
|
85 |
self.name = name |
|
|
86 |
self.set_extensions() |
|
|
87 |
|
|
|
88 |
def set_extensions(self): |
|
|
89 |
""" |
|
|
90 |
Set `Doc`, `Span` and `Token` extensions. |
|
|
91 |
""" |
|
|
92 |
if Span.has_extension("value"): |
|
|
93 |
if Span.get_extension("value")[2] is not value_getter: |
|
|
94 |
warnings.warn( |
|
|
95 |
"A Span extension 'value' already exists with a different getter. " |
|
|
96 |
"Keeping the existing extension, but some components of edsnlp may " |
|
|
97 |
"not work as expected." |
|
|
98 |
) |
|
|
99 |
return |
|
|
100 |
Span.set_extension( |
|
|
101 |
"value", |
|
|
102 |
getter=value_getter, |
|
|
103 |
) |
|
|
104 |
|
|
|
105 |
def get_spans(self, doc: Doc): # noqa: F811 |
|
|
106 |
""" |
|
|
107 |
Returns sorted spans of interest according to the |
|
|
108 |
possible value of `on_ents_only`. |
|
|
109 |
Includes `doc.ents` by default, and adds eventual SpanGroups. |
|
|
110 |
""" |
|
|
111 |
ents = list(doc.ents) + list(doc.spans.get("discarded", [])) |
|
|
112 |
|
|
|
113 |
on_ents_only = getattr(self, "on_ents_only", None) |
|
|
114 |
|
|
|
115 |
if isinstance(on_ents_only, str): |
|
|
116 |
on_ents_only = [on_ents_only] |
|
|
117 |
if isinstance(on_ents_only, (set, list)): |
|
|
118 |
for spankey in set(on_ents_only) & set(doc.spans.keys()): |
|
|
119 |
ents.extend(doc.spans.get(spankey, [])) |
|
|
120 |
|
|
|
121 |
return sorted(list(set(ents)), key=(attrgetter("start", "end"))) |
|
|
122 |
|
|
|
123 |
def _boundaries( |
|
|
124 |
self, doc: Doc, terminations: Optional[List[Span]] = None |
|
|
125 |
) -> List[Tuple[int, int]]: |
|
|
126 |
""" |
|
|
127 |
Create sub sentences based sentences and terminations found in text. |
|
|
128 |
|
|
|
129 |
Parameters |
|
|
130 |
---------- |
|
|
131 |
doc: |
|
|
132 |
spaCy Doc object |
|
|
133 |
terminations: |
|
|
134 |
List of tuples with (match_id, start, end) |
|
|
135 |
|
|
|
136 |
Returns |
|
|
137 |
------- |
|
|
138 |
boundaries: |
|
|
139 |
List of tuples with (start, end) of spans |
|
|
140 |
""" |
|
|
141 |
|
|
|
142 |
if terminations is None: |
|
|
143 |
terminations = [] |
|
|
144 |
|
|
|
145 |
sent_starts = [sent.start for sent in doc.sents] |
|
|
146 |
termination_starts = [t.start for t in terminations] |
|
|
147 |
|
|
|
148 |
starts = sent_starts + termination_starts + [len(doc)] |
|
|
149 |
|
|
|
150 |
# Remove duplicates |
|
|
151 |
starts = list(set(starts)) |
|
|
152 |
|
|
|
153 |
# Sort starts |
|
|
154 |
starts.sort() |
|
|
155 |
|
|
|
156 |
boundaries = [(start, end) for start, end in zip(starts[:-1], starts[1:])] |
|
|
157 |
|
|
|
158 |
return boundaries |
|
|
159 |
|
|
|
160 |
def __setstate__(self, state): |
|
|
161 |
self.__dict__.update(state) |
|
|
162 |
self.set_extensions() |
|
|
163 |
|
|
|
164 |
|
|
|
165 |
class BaseNERComponent(BaseComponent, abc.ABC): |
|
|
166 |
span_setter: SpanSetter |
|
|
167 |
|
|
|
168 |
def __init__( |
|
|
169 |
self, |
|
|
170 |
nlp: PipelineProtocol = None, |
|
|
171 |
name: str = None, |
|
|
172 |
*args, |
|
|
173 |
span_setter: SpanSetterArg, |
|
|
174 |
**kwargs, |
|
|
175 |
): |
|
|
176 |
super().__init__(nlp, name, *args, **kwargs) |
|
|
177 |
self.span_setter: SpanSetter = validate_span_setter(span_setter) # type: ignore |
|
|
178 |
|
|
|
179 |
def set_spans(self, doc, matches): |
|
|
180 |
return set_spans(doc, matches, self.span_setter) |
|
|
181 |
|
|
|
182 |
|
|
|
183 |
class BaseSpanAttributeClassifierComponent(BaseComponent, abc.ABC): |
|
|
184 |
span_getter: SpanGetter |
|
|
185 |
attributes: Iterable[str] |
|
|
186 |
|
|
|
187 |
def __init__( |
|
|
188 |
self, |
|
|
189 |
nlp: PipelineProtocol = None, |
|
|
190 |
name: str = None, |
|
|
191 |
*args, |
|
|
192 |
span_getter: SpanGetterArg, |
|
|
193 |
**kwargs, |
|
|
194 |
): |
|
|
195 |
super().__init__(nlp, name, *args, **kwargs) |
|
|
196 |
self.span_getter: SpanGetter = validate_span_getter(span_getter) # type: ignore |
|
|
197 |
|
|
|
198 |
# For backwards compatibility |
|
|
199 |
@property |
|
|
200 |
def qualifiers(self): # pragma: no cover |
|
|
201 |
return self.attributes |
|
|
202 |
|
|
|
203 |
@qualifiers.setter |
|
|
204 |
def qualifiers(self, value): # pragma: no cover |
|
|
205 |
self.attributes = value |