# ruff:noqa:E402
import pytest
try:
import torch.nn
except ImportError:
torch = None
if torch is None:
pytest.skip("torch not installed", allow_module_level=True)
pytest.importorskip("rich")
import shutil
from typing import (
Optional,
Sequence,
Union,
)
import pytest
import spacy.tokenizer
import torch.nn
from confit import Config
from confit.utils.random import set_seed
from spacy.tokens import Doc, Span
from edsnlp.core.registries import registry
from edsnlp.data.converters import AttributesMappingArg, get_current_tokenizer
from edsnlp.metrics.dep_parsing import DependencyParsingMetric
from edsnlp.training.optimizer import LinearSchedule, ScheduledOptimizer
from edsnlp.training.trainer import GenericScorer, train
from edsnlp.utils.span_getters import SpanSetterArg, set_spans
@registry.factory.register("myproject.custom_dict2doc", spacy_compatible=False)
class CustomSampleGenerator:
def __init__(
self,
*,
tokenizer: Optional[spacy.tokenizer.Tokenizer] = None,
name: str = "myproject.custom_dict2doc",
span_setter: SpanSetterArg = {"ents": True, "*": True},
bool_attributes: Union[str, Sequence[str]] = [],
span_attributes: Optional[AttributesMappingArg] = None,
):
self.tokenizer = tokenizer
self.name = name
self.span_setter = span_setter
self.bool_attributes = bool_attributes
self.span_attributes = span_attributes
def __call__(self, obj):
tok = get_current_tokenizer() if self.tokenizer is None else self.tokenizer
doc = tok(obj["note_text"] or "")
doc._.note_id = obj.get("note_id", obj.get("__FILENAME__"))
doc._.note_datetime = obj.get("note_datetime")
spans = []
if self.span_attributes is not None:
for dst in self.span_attributes.values():
if not Span.has_extension(dst):
Span.set_extension(dst, default=None)
for ent in obj.get("entities") or ():
ent = dict(ent)
span = doc.char_span(
ent.pop("start"),
ent.pop("end"),
label=ent.pop("label"),
alignment_mode="expand",
)
for label, value in ent.items():
new_name = (
self.span_attributes.get(label, None)
if self.span_attributes is not None
else label
)
if self.span_attributes is None and not Span.has_extension(new_name):
Span.set_extension(new_name, default=None)
if new_name:
span._.set(new_name, value)
spans.append(span)
set_spans(doc, spans, span_setter=self.span_setter)
for attr in self.bool_attributes:
for span in spans:
if span._.get(attr) is None:
span._.set(attr, False)
return doc
def test_ner_qualif_train_diff_bert(run_in_test_dir, tmp_path):
set_seed(42)
config = Config.from_disk("ner_qlf_diff_bert_config.yml")
shutil.rmtree(tmp_path, ignore_errors=True)
kwargs = Config.resolve(config["train"], registry=registry, root=config)
nlp = train(**kwargs, output_dir=tmp_path, cpu=True)
scorer = GenericScorer(**kwargs["scorer"])
val_data = kwargs["val_data"]
last_scores = scorer(nlp, val_data)
# Check empty doc
nlp("")
assert last_scores["ner"]["micro"]["f"] > 0.4
assert last_scores["qual"]["micro"]["f"] > 0.4
def test_ner_qualif_train_same_bert(run_in_test_dir, tmp_path):
set_seed(42)
config = Config.from_disk("ner_qlf_same_bert_config.yml")
shutil.rmtree(tmp_path, ignore_errors=True)
kwargs = Config.resolve(config["train"], registry=registry, root=config)
nlp = train(**kwargs, output_dir=tmp_path, cpu=True)
scorer = GenericScorer(**kwargs["scorer"])
val_data = kwargs["val_data"]
last_scores = scorer(nlp, val_data)
# Check empty doc
nlp("")
assert last_scores["ner"]["micro"]["f"] > 0.4
assert last_scores["qual"]["micro"]["f"] > 0.4
def test_qualif_train(run_in_test_dir, tmp_path):
set_seed(42)
config = Config.from_disk("qlf_config.yml")
shutil.rmtree(tmp_path, ignore_errors=True)
kwargs = Config.resolve(config["train"], registry=registry, root=config)
nlp = train(**kwargs, output_dir=tmp_path, cpu=True)
scorer = GenericScorer(**kwargs["scorer"])
val_data = kwargs["val_data"]
last_scores = scorer(nlp, val_data)
# Check empty doc
nlp("")
assert last_scores["qual"]["micro"]["f"] >= 0.4
def test_dep_parser_train(run_in_test_dir, tmp_path):
set_seed(42)
config = Config.from_disk("dep_parser_config.yml")
shutil.rmtree(tmp_path, ignore_errors=True)
kwargs = Config.resolve(config["train"], registry=registry, root=config)
nlp = train(**kwargs, output_dir=tmp_path, cpu=True)
scorer = GenericScorer(**kwargs["scorer"])
val_data = list(kwargs["val_data"])
last_scores = scorer(nlp, val_data)
scorer_bis = GenericScorer(parser=DependencyParsingMetric(filter_expr="False"))
# Just to test what happens if the scores indicate 2 roots
val_data_bis = [Doc.from_docs([val_data[0], val_data[0]])]
nlp.pipes.parser.decoding_mode = "mst"
last_scores_bis = scorer_bis(nlp, val_data_bis)
assert last_scores_bis["parser"]["uas"] == 0.0
# Check empty doc
nlp("")
assert last_scores["dep"]["las"] >= 0.4
def test_optimizer():
net = torch.nn.Linear(10, 10)
optim = ScheduledOptimizer(
torch.optim.AdamW,
module=net,
total_steps=10,
groups={
".*": {
"lr": 9e-4,
"schedules": LinearSchedule(
warmup_rate=0.1,
start_value=0,
),
}
},
)
for param in net.parameters():
assert "exp_avg" not in optim.optim.state[param]
optim.initialize()
for param in net.parameters():
assert "exp_avg" in optim.optim.state[param]
lr_values = [optim.optim.param_groups[0]["lr"]]
for i in range(10):
optim.step()
lr_values.append(optim.optim.param_groups[0]["lr"])
# close enough
assert lr_values == pytest.approx(
[
0.0,
0.0009,
0.0008,
0.0007,
0.0006,
0.0005,
0.0004,
0.0003,
0.0002,
0.0001,
0.0,
]
)