a | b/src/scheduler.py | ||
---|---|---|---|
1 | import torch.optim as optim |
||
2 | |||
3 | def get_scheduler(conf): |
||
4 | '''Setup the learning rate scheduler''' |
||
5 | scheduler = conf['scheduler']['name'] |
||
6 | if scheduler == 'StepLR': |
||
7 | optimizer = conf['optimizer'] |
||
8 | step_size = conf['scheduler']['step_size'] |
||
9 | gamma = conf['scheduler']['gamma'] |
||
10 | scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, |
||
11 | step_size=step_size, |
||
12 | gamma=gamma) |
||
13 | else: |
||
14 | print('Scheduler {scheduler} not suported.') |
||
15 | exit() |
||
16 | return scheduler |