a b/tests/tuning/test_tuning.py
1
import datetime
2
import os
3
from unittest.mock import Mock, patch
4
5
import pytest
6
7
try:
8
    import optuna
9
except ImportError:
10
    optuna = None
11
12
if optuna is None:
13
    pytest.skip("optuna not installed", allow_module_level=True)
14
15
from confit import Config
16
17
from edsnlp.tune import (
18
    compute_importances,
19
    compute_n_trials,
20
    compute_remaining_n_trials_possible,
21
    compute_time_per_trial,
22
    is_plotly_install,
23
    load_config,
24
    optimize,
25
    process_results,
26
    tune,
27
)
28
29
30
def build_trial(number, value, params, datetime_start, datetime_complete):
31
    trial = Mock(spec=optuna.trial.FrozenTrial)
32
    trial.number = number
33
    trial.value = value
34
    trial.values = [value]
35
    trial.params = params
36
    trial.distributions = {
37
        "param1": optuna.distributions.FloatDistribution(
38
            high=0.3,
39
            log=False,
40
            low=0.0,
41
            step=0.05,
42
        ),
43
        "param2": optuna.distributions.FloatDistribution(
44
            high=0.3,
45
            log=False,
46
            low=0.0,
47
            step=0.05,
48
        ),
49
    }
50
    trial.datetime_start = datetime_start
51
    trial.datetime_complete = datetime_complete
52
    trial.state = optuna.trial.TrialState.COMPLETE
53
    trial.system_attrs = {}
54
    trial.user_attrs = {}
55
    return trial
56
57
58
@pytest.fixture
59
def study():
60
    study = Mock(spec=optuna.study.Study)
61
    study.study_name = "mock_study"
62
    study._is_multi_objective.return_value = False
63
64
    trials = []
65
    trial_0 = build_trial(
66
        number=0,
67
        value=0.9,
68
        params={"param1": 0.15, "param2": 0.3},
69
        datetime_start=datetime.datetime(2025, 1, 1, 12, 0, 0),
70
        datetime_complete=datetime.datetime(2025, 1, 1, 12, 5, 0),
71
    )
72
    trials.append(trial_0)
73
74
    trial_1 = build_trial(
75
        number=1,
76
        value=0.75,
77
        params={"param1": 0.05, "param2": 0.2},
78
        datetime_start=datetime.datetime(2025, 1, 1, 12, 5, 0),
79
        datetime_complete=datetime.datetime(2025, 1, 1, 12, 10, 0),
80
    )
81
    trials.append(trial_1)
82
83
    trial_2 = build_trial(
84
        number=2,
85
        value=0.99,
86
        params={"param1": 0.3, "param2": 0.25},
87
        datetime_start=datetime.datetime(2025, 1, 1, 12, 10, 0),
88
        datetime_complete=datetime.datetime(2025, 1, 1, 12, 15, 0),
89
    )
90
    trials.append(trial_2)
91
92
    study.trials = trials
93
    study.get_trials.return_value = trials
94
    study.best_trial = trials[2]
95
    return study
96
97
98
@pytest.mark.parametrize("ema", [True, False])
99
def test_compute_time_per_trial_with_ema(study, ema):
100
    result = compute_time_per_trial(study, ema=ema, alpha=0.1)
101
    assert result == pytest.approx(300.00)
102
103
104
@pytest.mark.parametrize(
105
    "gpu_hours, time_per_trial, expected_n_trials, raises_exception",
106
    [
107
        (1, 120, 30, False),
108
        (0.5, 3600, None, True),
109
    ],
110
)
111
def test_compute_n_trials(
112
    gpu_hours, time_per_trial, expected_n_trials, raises_exception
113
):
114
    if raises_exception:
115
        with pytest.raises(ValueError):
116
            compute_n_trials(gpu_hours, time_per_trial)
117
    else:
118
        result = compute_n_trials(gpu_hours, time_per_trial)
119
        assert result == expected_n_trials
120
121
122
def test_compute_importances(study):
123
    importance = compute_importances(study)
124
    assert importance == {"param2": 0.5239814153755754, "param1": 0.4760185846244246}
125
126
127
@pytest.mark.parametrize("viz", [True, False])
128
@pytest.mark.parametrize(
129
    "config_path", ["tests/tuning/config.yml", "tests/tuning/config.cfg"]
130
)
131
def test_process_results(study, tmpdir, viz, config_path):
132
    output_dir = tmpdir.mkdir("output")
133
    config = {
134
        "train": {
135
            "param1": None,
136
        },
137
        ".lr": {
138
            "param2": 0.01,
139
        },
140
    }
141
    hyperparameters = {
142
        "train.param1": {
143
            "type": "int",
144
            "alias": "param1",
145
            "low": 2,
146
            "high": 8,
147
            "step": 2,
148
        },
149
    }
150
    best_params, importances = process_results(
151
        study, output_dir, viz, config, config_path, hyperparameters
152
    )
153
154
    assert isinstance(best_params, dict)
155
    assert isinstance(importances, dict)
156
157
    results_file = os.path.join(output_dir, "results_summary.txt")
158
    assert os.path.exists(results_file)
159
160
    with open(results_file, "r") as f:
161
        content = f.read()
162
        assert "Study Summary" in content
163
        assert "Best trial" in content
164
        assert "Value" in content
165
        assert "Params" in content
166
        assert "Importances" in content
167
168
    if config_path.endswith("yml") or config_path.endswith("yaml"):
169
        config_file = os.path.join(output_dir, "config.yml")
170
    else:
171
        config_file = os.path.join(output_dir, "config.cfg")
172
    assert os.path.exists(config_file), f"Expected file {config_file} not found"
173
174
    with open(config_file, "r", encoding="utf-8") as f:
175
        content = f.read()
176
    assert (
177
        "# My usefull comment" in content
178
    ), f"Expected comment not found in {config_file}"
179
180
    if viz:
181
        optimization_history_file = os.path.join(
182
            output_dir, "optimization_history.html"
183
        )
184
        assert os.path.exists(
185
            optimization_history_file
186
        ), f"Expected file {optimization_history_file} not found"
187
188
        parallel_coord_file = os.path.join(output_dir, "parallel_coordinate.html")
189
        assert os.path.exists(
190
            parallel_coord_file
191
        ), f"Expected file {parallel_coord_file} not found"
192
193
        contour_file = os.path.join(output_dir, "contour.html")
194
        assert os.path.exists(contour_file), f"Expected file {contour_file} not found"
195
196
        edf_file = os.path.join(output_dir, "edf.html")
197
        assert os.path.exists(edf_file), f"Expected file {edf_file} not found"
198
199
        timeline_file = os.path.join(output_dir, "timeline.html")
200
        assert os.path.exists(timeline_file), f"Expected file {timeline_file} not found"
201
202
203
def test_compute_remaining_n_trials_possible(study):
204
    gpu_hours = 0.5
205
    remaining_trials = compute_remaining_n_trials_possible(study, gpu_hours)
206
    assert remaining_trials == 3
207
208
209
@patch("edsnlp.tune.objective_with_param")
210
@patch("optuna.study.Study.optimize")
211
@pytest.mark.parametrize("has_study", [True, False])
212
def test_optimize(mock_objective_with_param, mock_optimize_study, has_study, study):
213
    mock_objective_with_param.return_value = 0.9
214
    metric = ("ner", "micro", "f")
215
    checkpoint_dir = "./checkpoint"
216
217
    if has_study:
218
219
        def pass_fn(obj, n_trials, callbacks):
220
            pass
221
222
        study.optimize = pass_fn
223
        study = optimize(
224
            "config_path",
225
            tuned_parameters={},
226
            n_trials=1,
227
            metric=metric,
228
            checkpoint_dir=checkpoint_dir,
229
            study=study,
230
        )
231
        assert isinstance(study, Mock)
232
        assert len(study.trials) == 3
233
234
    else:
235
        study = optimize(
236
            "config_path",
237
            tuned_parameters={},
238
            n_trials=1,
239
            metric=metric,
240
            checkpoint_dir=checkpoint_dir,
241
            study=None,
242
        )
243
        assert isinstance(study, optuna.study.Study)
244
        assert len(study.trials) == 0
245
246
247
@patch("edsnlp.tune.optimize")
248
@patch("edsnlp.tune.process_results")
249
@patch("edsnlp.tune.load_config")
250
@patch("edsnlp.tune.compute_n_trials")
251
@patch("edsnlp.tune.update_config")
252
@pytest.mark.parametrize("n_trials", [10, None])
253
@pytest.mark.parametrize("two_phase_tuning", [False, True])
254
def test_tune(
255
    mock_update_config,
256
    mock_compute_n_trials,
257
    mock_load_config,
258
    mock_process_results,
259
    mock_optimize,
260
    study,
261
    n_trials,
262
    two_phase_tuning,
263
):
264
    mock_load_config.return_value = {"train": {}, "scorer": {}, "val_data": {}}
265
    mock_update_config.return_value = None, {"train": {}, "scorer": {}, "val_data": {}}
266
    mock_optimize.return_value = study
267
    mock_process_results.return_value = ({}, {})
268
    mock_compute_n_trials.return_value = 10
269
    config_meta = {"config_path": ["fake_path"]}
270
    hyperparameters = {
271
        "param1": {"type": "float", "low": 0.0, "high": 1.0},
272
        "param2": {"type": "float", "low": 0.0, "high": 1.0},
273
    }
274
    output_dir = "output_dir"
275
    checkpoint_dir = "checkpoint_dir"
276
    gpu_hours = 0.25
277
    seed = 42
278
279
    tune(
280
        config_meta=config_meta,
281
        hyperparameters=hyperparameters,
282
        output_dir=output_dir,
283
        checkpoint_dir=checkpoint_dir,
284
        gpu_hours=gpu_hours,
285
        n_trials=n_trials,
286
        two_phase_tuning=two_phase_tuning,
287
        seed=seed,
288
    )
289
290
    mock_load_config.assert_called_once()
291
292
    if two_phase_tuning:
293
        if n_trials is None:
294
            assert mock_compute_n_trials.call_count == 2  # 1 at begining + 1 at end
295
            assert mock_optimize.call_count == 3  # 1 at begining + 2 for tuning
296
        else:
297
            mock_compute_n_trials.assert_not_called()
298
            assert mock_optimize.call_count == 2  # 2 for tuning
299
300
        assert mock_process_results.call_count == 2  # one for each phase
301
302
    else:
303
        if n_trials is None:
304
            assert mock_compute_n_trials.call_count == 2  # 1 at begining + 1 at end
305
            assert (
306
                mock_optimize.call_count == 3
307
            )  # 1 at begining + 1 for tuning + 1 at end
308
        else:
309
            mock_compute_n_trials.assert_not_called()
310
            assert mock_optimize.call_count == 1  # 1 for tuning
311
312
        mock_process_results.assert_called_once()
313
314
315
@patch("importlib.util.find_spec")
316
def test_plotly(mock_importlib_util_find_spec):
317
    mock_importlib_util_find_spec.return_value = None
318
    assert not is_plotly_install()
319
320
321
def test_load_config(tmpdir):
322
    cfg = """\
323
    "a":
324
        "aa": 1
325
    "b": 2
326
    "c": "test"
327
    """
328
    config_dir = tmpdir.mkdir("configs")
329
    config_path = os.path.join(config_dir, "config.yml")
330
    Config.from_yaml_str(cfg).to_disk(config_path)
331
    config = load_config(config_path)
332
    assert isinstance(config, Config)
333
    with pytest.raises(FileNotFoundError):
334
        load_config("wrong_path")