a b/tests/training/test_train.py
1
# ruff:noqa:E402
2
3
import pytest
4
5
try:
6
    import torch.nn
7
except ImportError:
8
    torch = None
9
10
if torch is None:
11
    pytest.skip("torch not installed", allow_module_level=True)
12
pytest.importorskip("rich")
13
14
import shutil
15
from typing import (
16
    Optional,
17
    Sequence,
18
    Union,
19
)
20
21
import pytest
22
import spacy.tokenizer
23
import torch.nn
24
from confit import Config
25
from confit.utils.random import set_seed
26
from spacy.tokens import Doc, Span
27
28
from edsnlp.core.registries import registry
29
from edsnlp.data.converters import AttributesMappingArg, get_current_tokenizer
30
from edsnlp.metrics.dep_parsing import DependencyParsingMetric
31
from edsnlp.training.optimizer import LinearSchedule, ScheduledOptimizer
32
from edsnlp.training.trainer import GenericScorer, train
33
from edsnlp.utils.span_getters import SpanSetterArg, set_spans
34
35
36
@registry.factory.register("myproject.custom_dict2doc", spacy_compatible=False)
37
class CustomSampleGenerator:
38
    def __init__(
39
        self,
40
        *,
41
        tokenizer: Optional[spacy.tokenizer.Tokenizer] = None,
42
        name: str = "myproject.custom_dict2doc",
43
        span_setter: SpanSetterArg = {"ents": True, "*": True},
44
        bool_attributes: Union[str, Sequence[str]] = [],
45
        span_attributes: Optional[AttributesMappingArg] = None,
46
    ):
47
        self.tokenizer = tokenizer
48
        self.name = name
49
        self.span_setter = span_setter
50
        self.bool_attributes = bool_attributes
51
        self.span_attributes = span_attributes
52
53
    def __call__(self, obj):
54
        tok = get_current_tokenizer() if self.tokenizer is None else self.tokenizer
55
        doc = tok(obj["note_text"] or "")
56
        doc._.note_id = obj.get("note_id", obj.get("__FILENAME__"))
57
        doc._.note_datetime = obj.get("note_datetime")
58
59
        spans = []
60
61
        if self.span_attributes is not None:
62
            for dst in self.span_attributes.values():
63
                if not Span.has_extension(dst):
64
                    Span.set_extension(dst, default=None)
65
66
        for ent in obj.get("entities") or ():
67
            ent = dict(ent)
68
            span = doc.char_span(
69
                ent.pop("start"),
70
                ent.pop("end"),
71
                label=ent.pop("label"),
72
                alignment_mode="expand",
73
            )
74
            for label, value in ent.items():
75
                new_name = (
76
                    self.span_attributes.get(label, None)
77
                    if self.span_attributes is not None
78
                    else label
79
                )
80
                if self.span_attributes is None and not Span.has_extension(new_name):
81
                    Span.set_extension(new_name, default=None)
82
83
                if new_name:
84
                    span._.set(new_name, value)
85
            spans.append(span)
86
87
        set_spans(doc, spans, span_setter=self.span_setter)
88
        for attr in self.bool_attributes:
89
            for span in spans:
90
                if span._.get(attr) is None:
91
                    span._.set(attr, False)
92
        return doc
93
94
95
def test_ner_qualif_train_diff_bert(run_in_test_dir, tmp_path):
96
    set_seed(42)
97
    config = Config.from_disk("ner_qlf_diff_bert_config.yml")
98
    shutil.rmtree(tmp_path, ignore_errors=True)
99
    kwargs = Config.resolve(config["train"], registry=registry, root=config)
100
    nlp = train(**kwargs, output_dir=tmp_path, cpu=True)
101
    scorer = GenericScorer(**kwargs["scorer"])
102
    val_data = kwargs["val_data"]
103
    last_scores = scorer(nlp, val_data)
104
105
    # Check empty doc
106
    nlp("")
107
108
    assert last_scores["ner"]["micro"]["f"] > 0.4
109
    assert last_scores["qual"]["micro"]["f"] > 0.4
110
111
112
def test_ner_qualif_train_same_bert(run_in_test_dir, tmp_path):
113
    set_seed(42)
114
    config = Config.from_disk("ner_qlf_same_bert_config.yml")
115
    shutil.rmtree(tmp_path, ignore_errors=True)
116
    kwargs = Config.resolve(config["train"], registry=registry, root=config)
117
    nlp = train(**kwargs, output_dir=tmp_path, cpu=True)
118
    scorer = GenericScorer(**kwargs["scorer"])
119
    val_data = kwargs["val_data"]
120
    last_scores = scorer(nlp, val_data)
121
122
    # Check empty doc
123
    nlp("")
124
125
    assert last_scores["ner"]["micro"]["f"] > 0.4
126
    assert last_scores["qual"]["micro"]["f"] > 0.4
127
128
129
def test_qualif_train(run_in_test_dir, tmp_path):
130
    set_seed(42)
131
    config = Config.from_disk("qlf_config.yml")
132
    shutil.rmtree(tmp_path, ignore_errors=True)
133
    kwargs = Config.resolve(config["train"], registry=registry, root=config)
134
    nlp = train(**kwargs, output_dir=tmp_path, cpu=True)
135
    scorer = GenericScorer(**kwargs["scorer"])
136
    val_data = kwargs["val_data"]
137
    last_scores = scorer(nlp, val_data)
138
139
    # Check empty doc
140
    nlp("")
141
142
    assert last_scores["qual"]["micro"]["f"] >= 0.4
143
144
145
def test_dep_parser_train(run_in_test_dir, tmp_path):
146
    set_seed(42)
147
    config = Config.from_disk("dep_parser_config.yml")
148
    shutil.rmtree(tmp_path, ignore_errors=True)
149
    kwargs = Config.resolve(config["train"], registry=registry, root=config)
150
    nlp = train(**kwargs, output_dir=tmp_path, cpu=True)
151
    scorer = GenericScorer(**kwargs["scorer"])
152
    val_data = list(kwargs["val_data"])
153
    last_scores = scorer(nlp, val_data)
154
155
    scorer_bis = GenericScorer(parser=DependencyParsingMetric(filter_expr="False"))
156
    # Just to test what happens if the scores indicate 2 roots
157
    val_data_bis = [Doc.from_docs([val_data[0], val_data[0]])]
158
    nlp.pipes.parser.decoding_mode = "mst"
159
    last_scores_bis = scorer_bis(nlp, val_data_bis)
160
    assert last_scores_bis["parser"]["uas"] == 0.0
161
162
    # Check empty doc
163
    nlp("")
164
165
    assert last_scores["dep"]["las"] >= 0.4
166
167
168
def test_optimizer():
169
    net = torch.nn.Linear(10, 10)
170
    optim = ScheduledOptimizer(
171
        torch.optim.AdamW,
172
        module=net,
173
        total_steps=10,
174
        groups={
175
            ".*": {
176
                "lr": 9e-4,
177
                "schedules": LinearSchedule(
178
                    warmup_rate=0.1,
179
                    start_value=0,
180
                ),
181
            }
182
        },
183
    )
184
    for param in net.parameters():
185
        assert "exp_avg" not in optim.optim.state[param]
186
    optim.initialize()
187
    for param in net.parameters():
188
        assert "exp_avg" in optim.optim.state[param]
189
    lr_values = [optim.optim.param_groups[0]["lr"]]
190
    for i in range(10):
191
        optim.step()
192
        lr_values.append(optim.optim.param_groups[0]["lr"])
193
194
    # close enough
195
    assert lr_values == pytest.approx(
196
        [
197
            0.0,
198
            0.0009,
199
            0.0008,
200
            0.0007,
201
            0.0006,
202
            0.0005,
203
            0.0004,
204
            0.0003,
205
            0.0002,
206
            0.0001,
207
            0.0,
208
        ]
209
    )