a b/tests/tuning/test_update_config.py
1
import pytest
2
3
try:
4
    import optuna
5
except ImportError:
6
    optuna = None
7
8
if optuna is None:
9
    pytest.skip("optuna not installed", allow_module_level=True)
10
11
12
from edsnlp.tune import update_config
13
14
15
@pytest.fixture
16
def minimal_config():
17
    return {
18
        "train": {
19
            "layers": None,
20
        },
21
        ".lr": {
22
            "learning_rate": None,
23
        },
24
    }
25
26
27
@pytest.fixture
28
def hyperparameters():
29
    return {
30
        "train.layers": {
31
            "type": "int",
32
            "low": 2,
33
            "high": 8,
34
            "step": 2,
35
        },
36
        "'.lr'.learning_rate": {
37
            "alias": "learning_rate",
38
            "type": "float",
39
            "low": 0.001,
40
            "high": 0.1,
41
            "log": True,
42
        },
43
        "train.batch_size": {
44
            "alias": "batch_size",
45
            "type": "categorical",
46
            "choices": [32, 64, 128],
47
        },
48
    }
49
50
51
@pytest.fixture
52
def hyperparameters_with_invalid_type():
53
    return {
54
        "train.optimizer": {
55
            "type": "string",
56
            "choices": ["adam", "sgd"],
57
        }
58
    }
59
60
61
@pytest.fixture
62
def hyperparameters_with_invalid_path():
63
    return {
64
        "model.layers": {
65
            "type": "int",
66
            "low": 2,
67
            "high": 8,
68
            "step": 2,
69
        },
70
    }
71
72
73
@pytest.fixture
74
def trial():
75
    study = optuna.create_study(direction="maximize")
76
    trial = study.ask()
77
    return trial
78
79
80
def test_update_config_with_values(minimal_config, hyperparameters):
81
    values = {"learning_rate": 0.05, "train.layers": 6, "batch_size": 64}
82
    _, updated_config = update_config(minimal_config, hyperparameters, values=values)
83
84
    assert updated_config[".lr"]["learning_rate"] == values["learning_rate"]
85
    assert updated_config["train"]["layers"] == values["train.layers"]
86
    assert updated_config["train"]["batch_size"] == values["batch_size"]
87
88
89
def test_update_config_with_trial(minimal_config, hyperparameters, trial):
90
    _, updated_config = update_config(minimal_config, hyperparameters, trial=trial)
91
92
    learning_rate = updated_config[".lr"]["learning_rate"]
93
    layers = updated_config["train"]["layers"]
94
    batch_size = updated_config["train"]["batch_size"]
95
96
    assert (
97
        hyperparameters["'.lr'.learning_rate"]["low"]
98
        <= learning_rate
99
        <= hyperparameters["'.lr'.learning_rate"]["high"]
100
    )
101
    assert (
102
        hyperparameters["train.layers"]["low"]
103
        <= layers
104
        <= hyperparameters["train.layers"]["high"]
105
    )
106
    assert layers % hyperparameters["train.layers"]["step"] == 0
107
    assert batch_size in hyperparameters["train.batch_size"]["choices"]
108
109
110
def test_update_config_raises_error_on_unknown_parameter_type(
111
    minimal_config, hyperparameters_with_invalid_type, trial
112
):
113
    with pytest.raises(
114
        ValueError,
115
        match="Unknown parameter type 'string' for hyperparameter 'train.optimizer'.",
116
    ):
117
        update_config(minimal_config, hyperparameters_with_invalid_type, trial=trial)
118
119
120
def test_update_config_raises_error_on_wrong_path(
121
    minimal_config, hyperparameters_with_invalid_path, trial
122
):
123
    with pytest.raises(KeyError, match="Path 'model' not found in config."):
124
        update_config(minimal_config, hyperparameters_with_invalid_path, trial=trial)