|
a |
|
b/tests/training/test_train.py |
|
|
1 |
# ruff:noqa:E402 |
|
|
2 |
|
|
|
3 |
import pytest |
|
|
4 |
|
|
|
5 |
try: |
|
|
6 |
import torch.nn |
|
|
7 |
except ImportError: |
|
|
8 |
torch = None |
|
|
9 |
|
|
|
10 |
if torch is None: |
|
|
11 |
pytest.skip("torch not installed", allow_module_level=True) |
|
|
12 |
pytest.importorskip("rich") |
|
|
13 |
|
|
|
14 |
import shutil |
|
|
15 |
from typing import ( |
|
|
16 |
Optional, |
|
|
17 |
Sequence, |
|
|
18 |
Union, |
|
|
19 |
) |
|
|
20 |
|
|
|
21 |
import pytest |
|
|
22 |
import spacy.tokenizer |
|
|
23 |
import torch.nn |
|
|
24 |
from confit import Config |
|
|
25 |
from confit.utils.random import set_seed |
|
|
26 |
from spacy.tokens import Doc, Span |
|
|
27 |
|
|
|
28 |
from edsnlp.core.registries import registry |
|
|
29 |
from edsnlp.data.converters import AttributesMappingArg, get_current_tokenizer |
|
|
30 |
from edsnlp.metrics.dep_parsing import DependencyParsingMetric |
|
|
31 |
from edsnlp.training.optimizer import LinearSchedule, ScheduledOptimizer |
|
|
32 |
from edsnlp.training.trainer import GenericScorer, train |
|
|
33 |
from edsnlp.utils.span_getters import SpanSetterArg, set_spans |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
@registry.factory.register("myproject.custom_dict2doc", spacy_compatible=False) |
|
|
37 |
class CustomSampleGenerator: |
|
|
38 |
def __init__( |
|
|
39 |
self, |
|
|
40 |
*, |
|
|
41 |
tokenizer: Optional[spacy.tokenizer.Tokenizer] = None, |
|
|
42 |
name: str = "myproject.custom_dict2doc", |
|
|
43 |
span_setter: SpanSetterArg = {"ents": True, "*": True}, |
|
|
44 |
bool_attributes: Union[str, Sequence[str]] = [], |
|
|
45 |
span_attributes: Optional[AttributesMappingArg] = None, |
|
|
46 |
): |
|
|
47 |
self.tokenizer = tokenizer |
|
|
48 |
self.name = name |
|
|
49 |
self.span_setter = span_setter |
|
|
50 |
self.bool_attributes = bool_attributes |
|
|
51 |
self.span_attributes = span_attributes |
|
|
52 |
|
|
|
53 |
def __call__(self, obj): |
|
|
54 |
tok = get_current_tokenizer() if self.tokenizer is None else self.tokenizer |
|
|
55 |
doc = tok(obj["note_text"] or "") |
|
|
56 |
doc._.note_id = obj.get("note_id", obj.get("__FILENAME__")) |
|
|
57 |
doc._.note_datetime = obj.get("note_datetime") |
|
|
58 |
|
|
|
59 |
spans = [] |
|
|
60 |
|
|
|
61 |
if self.span_attributes is not None: |
|
|
62 |
for dst in self.span_attributes.values(): |
|
|
63 |
if not Span.has_extension(dst): |
|
|
64 |
Span.set_extension(dst, default=None) |
|
|
65 |
|
|
|
66 |
for ent in obj.get("entities") or (): |
|
|
67 |
ent = dict(ent) |
|
|
68 |
span = doc.char_span( |
|
|
69 |
ent.pop("start"), |
|
|
70 |
ent.pop("end"), |
|
|
71 |
label=ent.pop("label"), |
|
|
72 |
alignment_mode="expand", |
|
|
73 |
) |
|
|
74 |
for label, value in ent.items(): |
|
|
75 |
new_name = ( |
|
|
76 |
self.span_attributes.get(label, None) |
|
|
77 |
if self.span_attributes is not None |
|
|
78 |
else label |
|
|
79 |
) |
|
|
80 |
if self.span_attributes is None and not Span.has_extension(new_name): |
|
|
81 |
Span.set_extension(new_name, default=None) |
|
|
82 |
|
|
|
83 |
if new_name: |
|
|
84 |
span._.set(new_name, value) |
|
|
85 |
spans.append(span) |
|
|
86 |
|
|
|
87 |
set_spans(doc, spans, span_setter=self.span_setter) |
|
|
88 |
for attr in self.bool_attributes: |
|
|
89 |
for span in spans: |
|
|
90 |
if span._.get(attr) is None: |
|
|
91 |
span._.set(attr, False) |
|
|
92 |
return doc |
|
|
93 |
|
|
|
94 |
|
|
|
95 |
def test_ner_qualif_train_diff_bert(run_in_test_dir, tmp_path): |
|
|
96 |
set_seed(42) |
|
|
97 |
config = Config.from_disk("ner_qlf_diff_bert_config.yml") |
|
|
98 |
shutil.rmtree(tmp_path, ignore_errors=True) |
|
|
99 |
kwargs = Config.resolve(config["train"], registry=registry, root=config) |
|
|
100 |
nlp = train(**kwargs, output_dir=tmp_path, cpu=True) |
|
|
101 |
scorer = GenericScorer(**kwargs["scorer"]) |
|
|
102 |
val_data = kwargs["val_data"] |
|
|
103 |
last_scores = scorer(nlp, val_data) |
|
|
104 |
|
|
|
105 |
# Check empty doc |
|
|
106 |
nlp("") |
|
|
107 |
|
|
|
108 |
assert last_scores["ner"]["micro"]["f"] > 0.4 |
|
|
109 |
assert last_scores["qual"]["micro"]["f"] > 0.4 |
|
|
110 |
|
|
|
111 |
|
|
|
112 |
def test_ner_qualif_train_same_bert(run_in_test_dir, tmp_path): |
|
|
113 |
set_seed(42) |
|
|
114 |
config = Config.from_disk("ner_qlf_same_bert_config.yml") |
|
|
115 |
shutil.rmtree(tmp_path, ignore_errors=True) |
|
|
116 |
kwargs = Config.resolve(config["train"], registry=registry, root=config) |
|
|
117 |
nlp = train(**kwargs, output_dir=tmp_path, cpu=True) |
|
|
118 |
scorer = GenericScorer(**kwargs["scorer"]) |
|
|
119 |
val_data = kwargs["val_data"] |
|
|
120 |
last_scores = scorer(nlp, val_data) |
|
|
121 |
|
|
|
122 |
# Check empty doc |
|
|
123 |
nlp("") |
|
|
124 |
|
|
|
125 |
assert last_scores["ner"]["micro"]["f"] > 0.4 |
|
|
126 |
assert last_scores["qual"]["micro"]["f"] > 0.4 |
|
|
127 |
|
|
|
128 |
|
|
|
129 |
def test_qualif_train(run_in_test_dir, tmp_path): |
|
|
130 |
set_seed(42) |
|
|
131 |
config = Config.from_disk("qlf_config.yml") |
|
|
132 |
shutil.rmtree(tmp_path, ignore_errors=True) |
|
|
133 |
kwargs = Config.resolve(config["train"], registry=registry, root=config) |
|
|
134 |
nlp = train(**kwargs, output_dir=tmp_path, cpu=True) |
|
|
135 |
scorer = GenericScorer(**kwargs["scorer"]) |
|
|
136 |
val_data = kwargs["val_data"] |
|
|
137 |
last_scores = scorer(nlp, val_data) |
|
|
138 |
|
|
|
139 |
# Check empty doc |
|
|
140 |
nlp("") |
|
|
141 |
|
|
|
142 |
assert last_scores["qual"]["micro"]["f"] >= 0.4 |
|
|
143 |
|
|
|
144 |
|
|
|
145 |
def test_dep_parser_train(run_in_test_dir, tmp_path): |
|
|
146 |
set_seed(42) |
|
|
147 |
config = Config.from_disk("dep_parser_config.yml") |
|
|
148 |
shutil.rmtree(tmp_path, ignore_errors=True) |
|
|
149 |
kwargs = Config.resolve(config["train"], registry=registry, root=config) |
|
|
150 |
nlp = train(**kwargs, output_dir=tmp_path, cpu=True) |
|
|
151 |
scorer = GenericScorer(**kwargs["scorer"]) |
|
|
152 |
val_data = list(kwargs["val_data"]) |
|
|
153 |
last_scores = scorer(nlp, val_data) |
|
|
154 |
|
|
|
155 |
scorer_bis = GenericScorer(parser=DependencyParsingMetric(filter_expr="False")) |
|
|
156 |
# Just to test what happens if the scores indicate 2 roots |
|
|
157 |
val_data_bis = [Doc.from_docs([val_data[0], val_data[0]])] |
|
|
158 |
nlp.pipes.parser.decoding_mode = "mst" |
|
|
159 |
last_scores_bis = scorer_bis(nlp, val_data_bis) |
|
|
160 |
assert last_scores_bis["parser"]["uas"] == 0.0 |
|
|
161 |
|
|
|
162 |
# Check empty doc |
|
|
163 |
nlp("") |
|
|
164 |
|
|
|
165 |
assert last_scores["dep"]["las"] >= 0.4 |
|
|
166 |
|
|
|
167 |
|
|
|
168 |
def test_optimizer(): |
|
|
169 |
net = torch.nn.Linear(10, 10) |
|
|
170 |
optim = ScheduledOptimizer( |
|
|
171 |
torch.optim.AdamW, |
|
|
172 |
module=net, |
|
|
173 |
total_steps=10, |
|
|
174 |
groups={ |
|
|
175 |
".*": { |
|
|
176 |
"lr": 9e-4, |
|
|
177 |
"schedules": LinearSchedule( |
|
|
178 |
warmup_rate=0.1, |
|
|
179 |
start_value=0, |
|
|
180 |
), |
|
|
181 |
} |
|
|
182 |
}, |
|
|
183 |
) |
|
|
184 |
for param in net.parameters(): |
|
|
185 |
assert "exp_avg" not in optim.optim.state[param] |
|
|
186 |
optim.initialize() |
|
|
187 |
for param in net.parameters(): |
|
|
188 |
assert "exp_avg" in optim.optim.state[param] |
|
|
189 |
lr_values = [optim.optim.param_groups[0]["lr"]] |
|
|
190 |
for i in range(10): |
|
|
191 |
optim.step() |
|
|
192 |
lr_values.append(optim.optim.param_groups[0]["lr"]) |
|
|
193 |
|
|
|
194 |
# close enough |
|
|
195 |
assert lr_values == pytest.approx( |
|
|
196 |
[ |
|
|
197 |
0.0, |
|
|
198 |
0.0009, |
|
|
199 |
0.0008, |
|
|
200 |
0.0007, |
|
|
201 |
0.0006, |
|
|
202 |
0.0005, |
|
|
203 |
0.0004, |
|
|
204 |
0.0003, |
|
|
205 |
0.0002, |
|
|
206 |
0.0001, |
|
|
207 |
0.0, |
|
|
208 |
] |
|
|
209 |
) |