a b/edsnlp/pipes/base.py
1
import abc
2
import inspect
3
import warnings
4
from operator import attrgetter
5
from typing import (
6
    Iterable,
7
    List,
8
    Optional,
9
    Tuple,
10
)
11
12
from spacy.tokens import Doc, Span
13
14
from edsnlp.core import PipelineProtocol
15
from edsnlp.core.registries import CurriedFactory
16
from edsnlp.utils.span_getters import (
17
    SpanGetter,  # noqa: F401
18
    SpanGetterArg,  # noqa: F401
19
    SpanSetter,
20
    SpanSetterArg,
21
    get_spans,  # noqa: F401
22
    set_spans,
23
    validate_span_getter,  # noqa: F401
24
    validate_span_setter,
25
)
26
27
28
def value_getter(span: Span):
29
    key = span._._get_key("value")
30
    if key in span.doc.user_data:
31
        return span.doc.user_data[key]
32
    return span._.get(span.label_) if span._.has(span.label_) else None
33
34
35
class BaseComponentMeta(abc.ABCMeta):
36
    def __init__(cls, name, bases, dct):
37
        super().__init__(name, bases, dct)
38
39
        sig = inspect.signature(cls.__init__)
40
        sig = sig.replace(parameters=tuple(sig.parameters.values())[1:])
41
        cls.__signature__ = sig
42
43
    def __call__(cls, nlp=inspect.Signature.empty, *args, **kwargs):
44
        # If this component is missing the nlp argument, we curry it with the
45
        # provided arguments and return a CurriedFactory object.
46
        sig = inspect.signature(cls.__init__)
47
        try:
48
            bound = sig.bind_partial(None, nlp, *args, **kwargs)
49
            bound.arguments.pop("self", None)
50
            if (
51
                "nlp" in sig.parameters
52
                and sig.parameters["nlp"].default is sig.empty
53
                and bound.arguments.get("nlp", sig.empty) is sig.empty
54
            ):
55
                return CurriedFactory(cls, bound.arguments)
56
            if nlp is inspect.Signature.empty:
57
                bound.arguments.pop("nlp", None)
58
        except TypeError:  # pragma: no cover
59
            if nlp is inspect.Signature.empty:
60
                super().__call__(*args, **kwargs)
61
            else:
62
                super().__call__(nlp, *args, **kwargs)
63
        return super().__call__(**bound.arguments)
64
65
66
class BaseComponent(abc.ABC, metaclass=BaseComponentMeta):
67
    """
68
    The `BaseComponent` adds a `set_extensions` method,
69
    called at the creation of the object.
70
71
    It helps decouple the initialisation of the pipeline from
72
    the creation of extensions, and is particularly usefull when
73
    distributing EDSNLP on a cluster, since the serialisation mechanism
74
    imposes that the extensions be reset.
75
    """
76
77
    def __init__(
78
        self,
79
        nlp: Optional[PipelineProtocol] = None,
80
        name: Optional[str] = None,
81
        *args,
82
        **kwargs,
83
    ):
84
        super().__init__(*args, **kwargs)
85
        self.name = name
86
        self.set_extensions()
87
88
    def set_extensions(self):
89
        """
90
        Set `Doc`, `Span` and `Token` extensions.
91
        """
92
        if Span.has_extension("value"):
93
            if Span.get_extension("value")[2] is not value_getter:
94
                warnings.warn(
95
                    "A Span extension 'value' already exists with a different getter. "
96
                    "Keeping the existing extension, but some components of edsnlp may "
97
                    "not work as expected."
98
                )
99
            return
100
        Span.set_extension(
101
            "value",
102
            getter=value_getter,
103
        )
104
105
    def get_spans(self, doc: Doc):  # noqa: F811
106
        """
107
        Returns sorted spans of interest according to the
108
        possible value of `on_ents_only`.
109
        Includes `doc.ents` by default, and adds eventual SpanGroups.
110
        """
111
        ents = list(doc.ents) + list(doc.spans.get("discarded", []))
112
113
        on_ents_only = getattr(self, "on_ents_only", None)
114
115
        if isinstance(on_ents_only, str):
116
            on_ents_only = [on_ents_only]
117
        if isinstance(on_ents_only, (set, list)):
118
            for spankey in set(on_ents_only) & set(doc.spans.keys()):
119
                ents.extend(doc.spans.get(spankey, []))
120
121
        return sorted(list(set(ents)), key=(attrgetter("start", "end")))
122
123
    def _boundaries(
124
        self, doc: Doc, terminations: Optional[List[Span]] = None
125
    ) -> List[Tuple[int, int]]:
126
        """
127
        Create sub sentences based sentences and terminations found in text.
128
129
        Parameters
130
        ----------
131
        doc:
132
            spaCy Doc object
133
        terminations:
134
            List of tuples with (match_id, start, end)
135
136
        Returns
137
        -------
138
        boundaries:
139
            List of tuples with (start, end) of spans
140
        """
141
142
        if terminations is None:
143
            terminations = []
144
145
        sent_starts = [sent.start for sent in doc.sents]
146
        termination_starts = [t.start for t in terminations]
147
148
        starts = sent_starts + termination_starts + [len(doc)]
149
150
        # Remove duplicates
151
        starts = list(set(starts))
152
153
        # Sort starts
154
        starts.sort()
155
156
        boundaries = [(start, end) for start, end in zip(starts[:-1], starts[1:])]
157
158
        return boundaries
159
160
    def __setstate__(self, state):
161
        self.__dict__.update(state)
162
        self.set_extensions()
163
164
165
class BaseNERComponent(BaseComponent, abc.ABC):
166
    span_setter: SpanSetter
167
168
    def __init__(
169
        self,
170
        nlp: PipelineProtocol = None,
171
        name: str = None,
172
        *args,
173
        span_setter: SpanSetterArg,
174
        **kwargs,
175
    ):
176
        super().__init__(nlp, name, *args, **kwargs)
177
        self.span_setter: SpanSetter = validate_span_setter(span_setter)  # type: ignore
178
179
    def set_spans(self, doc, matches):
180
        return set_spans(doc, matches, self.span_setter)
181
182
183
class BaseSpanAttributeClassifierComponent(BaseComponent, abc.ABC):
184
    span_getter: SpanGetter
185
    attributes: Iterable[str]
186
187
    def __init__(
188
        self,
189
        nlp: PipelineProtocol = None,
190
        name: str = None,
191
        *args,
192
        span_getter: SpanGetterArg,
193
        **kwargs,
194
    ):
195
        super().__init__(nlp, name, *args, **kwargs)
196
        self.span_getter: SpanGetter = validate_span_getter(span_getter)  # type: ignore
197
198
    # For backwards compatibility
199
    @property
200
    def qualifiers(self):  # pragma: no cover
201
        return self.attributes
202
203
    @qualifiers.setter
204
    def qualifiers(self, value):  # pragma: no cover
205
        self.attributes = value