Switch to side-by-side view

--- a
+++ b/tests/training/test_train.py
@@ -0,0 +1,209 @@
+# ruff:noqa:E402
+
+import pytest
+
+try:
+    import torch.nn
+except ImportError:
+    torch = None
+
+if torch is None:
+    pytest.skip("torch not installed", allow_module_level=True)
+pytest.importorskip("rich")
+
+import shutil
+from typing import (
+    Optional,
+    Sequence,
+    Union,
+)
+
+import pytest
+import spacy.tokenizer
+import torch.nn
+from confit import Config
+from confit.utils.random import set_seed
+from spacy.tokens import Doc, Span
+
+from edsnlp.core.registries import registry
+from edsnlp.data.converters import AttributesMappingArg, get_current_tokenizer
+from edsnlp.metrics.dep_parsing import DependencyParsingMetric
+from edsnlp.training.optimizer import LinearSchedule, ScheduledOptimizer
+from edsnlp.training.trainer import GenericScorer, train
+from edsnlp.utils.span_getters import SpanSetterArg, set_spans
+
+
+@registry.factory.register("myproject.custom_dict2doc", spacy_compatible=False)
+class CustomSampleGenerator:
+    def __init__(
+        self,
+        *,
+        tokenizer: Optional[spacy.tokenizer.Tokenizer] = None,
+        name: str = "myproject.custom_dict2doc",
+        span_setter: SpanSetterArg = {"ents": True, "*": True},
+        bool_attributes: Union[str, Sequence[str]] = [],
+        span_attributes: Optional[AttributesMappingArg] = None,
+    ):
+        self.tokenizer = tokenizer
+        self.name = name
+        self.span_setter = span_setter
+        self.bool_attributes = bool_attributes
+        self.span_attributes = span_attributes
+
+    def __call__(self, obj):
+        tok = get_current_tokenizer() if self.tokenizer is None else self.tokenizer
+        doc = tok(obj["note_text"] or "")
+        doc._.note_id = obj.get("note_id", obj.get("__FILENAME__"))
+        doc._.note_datetime = obj.get("note_datetime")
+
+        spans = []
+
+        if self.span_attributes is not None:
+            for dst in self.span_attributes.values():
+                if not Span.has_extension(dst):
+                    Span.set_extension(dst, default=None)
+
+        for ent in obj.get("entities") or ():
+            ent = dict(ent)
+            span = doc.char_span(
+                ent.pop("start"),
+                ent.pop("end"),
+                label=ent.pop("label"),
+                alignment_mode="expand",
+            )
+            for label, value in ent.items():
+                new_name = (
+                    self.span_attributes.get(label, None)
+                    if self.span_attributes is not None
+                    else label
+                )
+                if self.span_attributes is None and not Span.has_extension(new_name):
+                    Span.set_extension(new_name, default=None)
+
+                if new_name:
+                    span._.set(new_name, value)
+            spans.append(span)
+
+        set_spans(doc, spans, span_setter=self.span_setter)
+        for attr in self.bool_attributes:
+            for span in spans:
+                if span._.get(attr) is None:
+                    span._.set(attr, False)
+        return doc
+
+
+def test_ner_qualif_train_diff_bert(run_in_test_dir, tmp_path):
+    set_seed(42)
+    config = Config.from_disk("ner_qlf_diff_bert_config.yml")
+    shutil.rmtree(tmp_path, ignore_errors=True)
+    kwargs = Config.resolve(config["train"], registry=registry, root=config)
+    nlp = train(**kwargs, output_dir=tmp_path, cpu=True)
+    scorer = GenericScorer(**kwargs["scorer"])
+    val_data = kwargs["val_data"]
+    last_scores = scorer(nlp, val_data)
+
+    # Check empty doc
+    nlp("")
+
+    assert last_scores["ner"]["micro"]["f"] > 0.4
+    assert last_scores["qual"]["micro"]["f"] > 0.4
+
+
+def test_ner_qualif_train_same_bert(run_in_test_dir, tmp_path):
+    set_seed(42)
+    config = Config.from_disk("ner_qlf_same_bert_config.yml")
+    shutil.rmtree(tmp_path, ignore_errors=True)
+    kwargs = Config.resolve(config["train"], registry=registry, root=config)
+    nlp = train(**kwargs, output_dir=tmp_path, cpu=True)
+    scorer = GenericScorer(**kwargs["scorer"])
+    val_data = kwargs["val_data"]
+    last_scores = scorer(nlp, val_data)
+
+    # Check empty doc
+    nlp("")
+
+    assert last_scores["ner"]["micro"]["f"] > 0.4
+    assert last_scores["qual"]["micro"]["f"] > 0.4
+
+
+def test_qualif_train(run_in_test_dir, tmp_path):
+    set_seed(42)
+    config = Config.from_disk("qlf_config.yml")
+    shutil.rmtree(tmp_path, ignore_errors=True)
+    kwargs = Config.resolve(config["train"], registry=registry, root=config)
+    nlp = train(**kwargs, output_dir=tmp_path, cpu=True)
+    scorer = GenericScorer(**kwargs["scorer"])
+    val_data = kwargs["val_data"]
+    last_scores = scorer(nlp, val_data)
+
+    # Check empty doc
+    nlp("")
+
+    assert last_scores["qual"]["micro"]["f"] >= 0.4
+
+
+def test_dep_parser_train(run_in_test_dir, tmp_path):
+    set_seed(42)
+    config = Config.from_disk("dep_parser_config.yml")
+    shutil.rmtree(tmp_path, ignore_errors=True)
+    kwargs = Config.resolve(config["train"], registry=registry, root=config)
+    nlp = train(**kwargs, output_dir=tmp_path, cpu=True)
+    scorer = GenericScorer(**kwargs["scorer"])
+    val_data = list(kwargs["val_data"])
+    last_scores = scorer(nlp, val_data)
+
+    scorer_bis = GenericScorer(parser=DependencyParsingMetric(filter_expr="False"))
+    # Just to test what happens if the scores indicate 2 roots
+    val_data_bis = [Doc.from_docs([val_data[0], val_data[0]])]
+    nlp.pipes.parser.decoding_mode = "mst"
+    last_scores_bis = scorer_bis(nlp, val_data_bis)
+    assert last_scores_bis["parser"]["uas"] == 0.0
+
+    # Check empty doc
+    nlp("")
+
+    assert last_scores["dep"]["las"] >= 0.4
+
+
+def test_optimizer():
+    net = torch.nn.Linear(10, 10)
+    optim = ScheduledOptimizer(
+        torch.optim.AdamW,
+        module=net,
+        total_steps=10,
+        groups={
+            ".*": {
+                "lr": 9e-4,
+                "schedules": LinearSchedule(
+                    warmup_rate=0.1,
+                    start_value=0,
+                ),
+            }
+        },
+    )
+    for param in net.parameters():
+        assert "exp_avg" not in optim.optim.state[param]
+    optim.initialize()
+    for param in net.parameters():
+        assert "exp_avg" in optim.optim.state[param]
+    lr_values = [optim.optim.param_groups[0]["lr"]]
+    for i in range(10):
+        optim.step()
+        lr_values.append(optim.optim.param_groups[0]["lr"])
+
+    # close enough
+    assert lr_values == pytest.approx(
+        [
+            0.0,
+            0.0009,
+            0.0008,
+            0.0007,
+            0.0006,
+            0.0005,
+            0.0004,
+            0.0003,
+            0.0002,
+            0.0001,
+            0.0,
+        ]
+    )