|
a |
|
b/pathflowai/schedulers.py |
|
|
1 |
""" |
|
|
2 |
schedulers.py |
|
|
3 |
======================= |
|
|
4 |
Modulates the learning rate during the training process. |
|
|
5 |
""" |
|
|
6 |
import torch |
|
|
7 |
import math |
|
|
8 |
from torch.optim.lr_scheduler import ExponentialLR |
|
|
9 |
|
|
|
10 |
class CosineAnnealingWithRestartsLR(torch.optim.lr_scheduler._LRScheduler): |
|
|
11 |
r"""Set the learning rate of each parameter group using a cosine annealing |
|
|
12 |
schedule, where :math:`\eta_{max}` is set to the initial lr and |
|
|
13 |
:math:`T_{cur}` is the number of epochs since the last restart in SGDR: |
|
|
14 |
.. math:: |
|
|
15 |
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + |
|
|
16 |
\cos(\frac{T_{cur}}{T_{max}}\pi)) |
|
|
17 |
When last_epoch=-1, sets initial lr as lr. |
|
|
18 |
It has been proposed in |
|
|
19 |
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. This implements |
|
|
20 |
the cosine annealing part of SGDR, the restarts and number of iterations multiplier. |
|
|
21 |
Args: |
|
|
22 |
optimizer (Optimizer): Wrapped optimizer. |
|
|
23 |
T_max (int): Maximum number of iterations. |
|
|
24 |
T_mult (float): Multiply T_max by this number after each restart. Default: 1. |
|
|
25 |
eta_min (float): Minimum learning rate. Default: 0. |
|
|
26 |
last_epoch (int): The index of last epoch. Default: -1. |
|
|
27 |
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts: |
|
|
28 |
https://arxiv.org/abs/1608.03983 |
|
|
29 |
""" |
|
|
30 |
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, T_mult=1., alpha_decay=1.0): |
|
|
31 |
self.T_max = T_max |
|
|
32 |
self.T_mult = T_mult |
|
|
33 |
self.restart_every = T_max |
|
|
34 |
self.eta_min = eta_min |
|
|
35 |
self.restarts = 0 |
|
|
36 |
self.restarted_at = 0 |
|
|
37 |
self.alpha = alpha_decay |
|
|
38 |
super().__init__(optimizer, last_epoch) |
|
|
39 |
|
|
|
40 |
def restart(self): |
|
|
41 |
self.restarts += 1 |
|
|
42 |
self.restart_every = int(round(self.restart_every * self.T_mult)) |
|
|
43 |
self.restarted_at = self.last_epoch |
|
|
44 |
|
|
|
45 |
def cosine(self, base_lr): |
|
|
46 |
return self.eta_min + self.alpha**self.restarts * (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.step_n / self.restart_every)) / 2 |
|
|
47 |
|
|
|
48 |
@property |
|
|
49 |
def step_n(self): |
|
|
50 |
return self.last_epoch - self.restarted_at |
|
|
51 |
|
|
|
52 |
def get_lr(self): |
|
|
53 |
if self.step_n >= self.restart_every: |
|
|
54 |
self.restart() |
|
|
55 |
return [self.cosine(base_lr) for base_lr in self.base_lrs] |
|
|
56 |
|
|
|
57 |
class Scheduler: |
|
|
58 |
"""Scheduler class that modulates learning rate of torch optimizers over epochs. |
|
|
59 |
|
|
|
60 |
Parameters |
|
|
61 |
---------- |
|
|
62 |
optimizer : type |
|
|
63 |
torch.Optimizer object |
|
|
64 |
opts : type |
|
|
65 |
Options of setting the learning rate scheduler, see default. |
|
|
66 |
|
|
|
67 |
Attributes |
|
|
68 |
---------- |
|
|
69 |
schedulers : type |
|
|
70 |
Different types of schedulers to choose from. |
|
|
71 |
scheduler_step_fn : type |
|
|
72 |
How scheduler updates learning rate. |
|
|
73 |
initial_lr : type |
|
|
74 |
Initial set learning rate. |
|
|
75 |
scheduler_choice : type |
|
|
76 |
What scheduler type was chosen. |
|
|
77 |
scheduler : type |
|
|
78 |
Scheduler object chosen that will more directly update optimizer LR. |
|
|
79 |
|
|
|
80 |
""" |
|
|
81 |
def __init__(self, optimizer=None, opts=dict(scheduler='null',lr_scheduler_decay=0.5,T_max=10,eta_min=5e-8,T_mult=2)): |
|
|
82 |
self.schedulers = {'exp':(lambda optimizer: ExponentialLR(optimizer, opts["lr_scheduler_decay"])), |
|
|
83 |
'null':(lambda optimizer: None), |
|
|
84 |
'warm_restarts':(lambda optimizer: CosineAnnealingWithRestartsLR(optimizer, T_max=opts['T_max'], eta_min=opts['eta_min'], last_epoch=-1, T_mult=opts['T_mult']))} |
|
|
85 |
self.scheduler_step_fn = {'exp':(lambda scheduler: scheduler.step()), |
|
|
86 |
'warm_restarts':(lambda scheduler: scheduler.step()), |
|
|
87 |
'null':(lambda scheduler: None)} |
|
|
88 |
self.initial_lr = optimizer.param_groups[0]['lr'] |
|
|
89 |
self.scheduler_choice = opts['scheduler'] |
|
|
90 |
self.scheduler = self.schedulers[self.scheduler_choice](optimizer) if optimizer is not None else None |
|
|
91 |
|
|
|
92 |
def step(self): |
|
|
93 |
"""Update optimizer learning rate""" |
|
|
94 |
self.scheduler_step_fn[self.scheduler_choice](self.scheduler) |
|
|
95 |
|
|
|
96 |
def get_lr(self): |
|
|
97 |
"""Return current learning rate. |
|
|
98 |
|
|
|
99 |
Returns |
|
|
100 |
------- |
|
|
101 |
float |
|
|
102 |
Current learning rate. |
|
|
103 |
|
|
|
104 |
""" |
|
|
105 |
lr = (self.initial_lr if self.scheduler_choice == 'null' else self.scheduler.optimizer.param_groups[0]['lr']) |
|
|
106 |
return lr |