--- a
+++ b/tests/training/test_optimizer.py
@@ -0,0 +1,174 @@
+# ruff:noqa:E402
+import pytest
+
+try:
+    import torch.nn
+except ImportError:
+    torch = None
+
+if torch is None:
+    pytest.skip("torch not installed", allow_module_level=True)
+pytest.importorskip("rich")
+
+from edsnlp.training.optimizer import LinearSchedule, ScheduledOptimizer
+
+
+class Net(torch.nn.Module):
+    def __init__(self):
+        super(Net, self).__init__()
+        self.fc1 = torch.nn.Linear(10, 1)
+        self.fc2 = torch.nn.Linear(1, 1)
+
+    def forward(self, x):
+        return self.fc(x)
+
+
+@pytest.fixture(scope="module")
+def net():
+    net = Net()
+    return net
+
+
+@pytest.mark.parametrize(
+    "groups",
+    [
+        # Old schedule API
+        {
+            "fc1[.].*": {
+                "lr": 0.1,
+                "weight_decay": 0.01,
+                "schedules": [
+                    {
+                        "@schedules": "linear",
+                        "start_value": 0.0,
+                        "warmup_rate": 0.2,
+                    },
+                ],
+            },
+            "fc2[.]bias": False,
+            "": {
+                "lr": 0.0001,
+                "weight_decay": 0.0,
+            },
+        },
+        # New schedule API
+        {
+            "fc1[.].*": {
+                "lr": {
+                    "@schedules": "linear",
+                    "start_value": 0.0,
+                    "max_value": 0.1,
+                    "warmup_rate": 0.2,
+                },
+                "weight_decay": 0.01,
+            },
+            "fc2[.]bias": False,
+            "": {
+                "lr": 0.0001,
+                "weight_decay": 0.0,
+            },
+        },
+    ],
+)
+def test_old_parameter_selection(net, groups):
+    optim = ScheduledOptimizer(
+        optim="adamw",
+        module=net,
+        groups=groups,
+        total_steps=10,
+    )
+    assert len(optim.state) == 0
+    optim.initialize()
+    assert all([p in optim.state for p in net.fc1.parameters()])
+    optim.state = optim.state
+
+    fc1_group = optim.param_groups[1]
+    assert fc1_group["lr"] == pytest.approx(0.0)
+    assert fc1_group["weight_decay"] == pytest.approx(0.01)
+    assert set(fc1_group["params"]) == {net.fc1.weight, net.fc1.bias}
+
+    fc2_group = optim.param_groups[0]
+    assert fc2_group["lr"] == pytest.approx(0.0001)
+    assert set(fc2_group["params"]) == {net.fc2.weight}
+
+    lr_values = [fc1_group["lr"]]
+
+    for i in range(10):
+        optim.step()
+        lr_values.append(fc1_group["lr"])
+
+    assert lr_values == pytest.approx(
+        [
+            0.0,
+            0.05,
+            0.1,
+            0.0875,
+            0.075,
+            0.0625,
+            0.05,
+            0.0375,
+            0.025,
+            0.0125,
+            0.0,
+        ]
+    )
+
+
+def test_serialization(net):
+    optim = ScheduledOptimizer(
+        optim="adamw",
+        module=net,
+        groups={
+            "fc1[.].*": {
+                "lr": 0.1,
+                "weight_decay": 0.01,
+                "schedules": LinearSchedule(start_value=0.0, warmup_rate=0.2),
+            },
+            "fc2[.]bias": False,
+            "": {
+                "lr": 0.0001,
+                "weight_decay": 0.0,
+            },
+        },
+        total_steps=10,
+    )
+    optim.initialize()
+    optim.param_groups = optim.param_groups
+
+    state_dict = None
+    for i in range(10):
+        if i == 5:
+            state_dict = optim.state_dict()
+        optim.step()
+
+    assert optim.param_groups[-1]["lr"] == pytest.approx(0.0)
+    optim.load_state_dict(state_dict)
+    assert optim.param_groups[-1]["lr"] == pytest.approx(0.0625)
+
+    optim.reset()
+
+
+def test_repr(net):
+    optim = ScheduledOptimizer(
+        optim="adamw",
+        module=net,
+        groups={
+            "fc1[.].*": {
+                "lr": 0.1,
+                "weight_decay": 0.01,
+                "schedules": [
+                    LinearSchedule(start_value=0.0, warmup_rate=0.2),
+                    LinearSchedule(path="weight_decay"),
+                ],
+            },
+            "fc2[.]bias": False,
+            ".*": {
+                "lr": 0.0001,
+                "weight_decay": 0.0,
+            },
+        },
+        total_steps=10,
+    )
+    optim.initialize()
+
+    assert "ScheduledOptimizer[AdamW]" in repr(optim)