|
a |
|
b/src/optimizers.py |
|
|
1 |
import math |
|
|
2 |
import torch |
|
|
3 |
from torch.optim.optimizer import Optimizer |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
class AdamW(Optimizer): |
|
|
7 |
r"""Implements AdamW algorithm. |
|
|
8 |
|
|
|
9 |
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. |
|
|
10 |
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. |
|
|
11 |
|
|
|
12 |
Arguments: |
|
|
13 |
params (iterable): iterable of parameters to optimize or dicts defining |
|
|
14 |
parameter groups |
|
|
15 |
lr (float, optional): learning rate (default: 1e-3) |
|
|
16 |
betas (Tuple[float, float], optional): coefficients used for computing |
|
|
17 |
running averages of gradient and its square (default: (0.9, 0.999)) |
|
|
18 |
eps (float, optional): term added to the denominator to improve |
|
|
19 |
numerical stability (default: 1e-8) |
|
|
20 |
weight_decay (float, optional): weight decay coefficient (default: 1e-2) |
|
|
21 |
amsgrad (boolean, optional): whether to use the AMSGrad variant of this |
|
|
22 |
algorithm from the paper `On the Convergence of Adam and Beyond`_ |
|
|
23 |
(default: False) |
|
|
24 |
|
|
|
25 |
.. _Adam\: A Method for Stochastic Optimization: |
|
|
26 |
https://arxiv.org/abs/1412.6980 |
|
|
27 |
.. _Decoupled Weight Decay Regularization: |
|
|
28 |
https://arxiv.org/abs/1711.05101 |
|
|
29 |
.. _On the Convergence of Adam and Beyond: |
|
|
30 |
https://openreview.net/forum?id=ryQu7f-RZ |
|
|
31 |
""" |
|
|
32 |
|
|
|
33 |
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, |
|
|
34 |
weight_decay=1e-2, amsgrad=False): |
|
|
35 |
if not 0.0 <= lr: |
|
|
36 |
raise ValueError("Invalid learning rate: {}".format(lr)) |
|
|
37 |
if not 0.0 <= eps: |
|
|
38 |
raise ValueError("Invalid epsilon value: {}".format(eps)) |
|
|
39 |
if not 0.0 <= betas[0] < 1.0: |
|
|
40 |
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) |
|
|
41 |
if not 0.0 <= betas[1] < 1.0: |
|
|
42 |
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) |
|
|
43 |
defaults = dict(lr=lr, betas=betas, eps=eps, |
|
|
44 |
weight_decay=weight_decay, amsgrad=amsgrad) |
|
|
45 |
super(AdamW, self).__init__(params, defaults) |
|
|
46 |
|
|
|
47 |
def __setstate__(self, state): |
|
|
48 |
super(AdamW, self).__setstate__(state) |
|
|
49 |
for group in self.param_groups: |
|
|
50 |
group.setdefault('amsgrad', False) |
|
|
51 |
|
|
|
52 |
def step(self, closure=None): |
|
|
53 |
"""Performs a single optimization step. |
|
|
54 |
|
|
|
55 |
Arguments: |
|
|
56 |
closure (callable, optional): A closure that reevaluates the model |
|
|
57 |
and returns the loss. |
|
|
58 |
""" |
|
|
59 |
loss = None |
|
|
60 |
if closure is not None: |
|
|
61 |
loss = closure() |
|
|
62 |
|
|
|
63 |
for group in self.param_groups: |
|
|
64 |
for p in group['params']: |
|
|
65 |
if p.grad is None: |
|
|
66 |
continue |
|
|
67 |
|
|
|
68 |
# Perform stepweight decay |
|
|
69 |
p.data.mul_(1 - group['lr'] * group['weight_decay']) |
|
|
70 |
|
|
|
71 |
# Perform optimization step |
|
|
72 |
grad = p.grad.data |
|
|
73 |
if grad.is_sparse: |
|
|
74 |
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') |
|
|
75 |
amsgrad = group['amsgrad'] |
|
|
76 |
|
|
|
77 |
state = self.state[p] |
|
|
78 |
|
|
|
79 |
# State initialization |
|
|
80 |
if len(state) == 0: |
|
|
81 |
state['step'] = 0 |
|
|
82 |
# Exponential moving average of gradient values |
|
|
83 |
state['exp_avg'] = torch.zeros_like(p.data) |
|
|
84 |
# Exponential moving average of squared gradient values |
|
|
85 |
state['exp_avg_sq'] = torch.zeros_like(p.data) |
|
|
86 |
if amsgrad: |
|
|
87 |
# Maintains max of all exp. moving avg. of sq. grad. values |
|
|
88 |
state['max_exp_avg_sq'] = torch.zeros_like(p.data) |
|
|
89 |
|
|
|
90 |
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] |
|
|
91 |
if amsgrad: |
|
|
92 |
max_exp_avg_sq = state['max_exp_avg_sq'] |
|
|
93 |
beta1, beta2 = group['betas'] |
|
|
94 |
|
|
|
95 |
state['step'] += 1 |
|
|
96 |
|
|
|
97 |
# Decay the first and second moment running average coefficient |
|
|
98 |
exp_avg.mul_(beta1).add_(1 - beta1, grad) |
|
|
99 |
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) |
|
|
100 |
if amsgrad: |
|
|
101 |
# Maintains the maximum of all 2nd moment running avg. till now |
|
|
102 |
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) |
|
|
103 |
# Use the max. for normalizing running avg. of gradient |
|
|
104 |
denom = max_exp_avg_sq.sqrt().add_(group['eps']) |
|
|
105 |
else: |
|
|
106 |
denom = exp_avg_sq.sqrt().add_(group['eps']) |
|
|
107 |
|
|
|
108 |
bias_correction1 = 1 - beta1 ** state['step'] |
|
|
109 |
bias_correction2 = 1 - beta2 ** state['step'] |
|
|
110 |
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 |
|
|
111 |
|
|
|
112 |
p.data.addcdiv_(-step_size, exp_avg, denom) |
|
|
113 |
|
|
|
114 |
return loss |
|
|
115 |
|
|
|
116 |
|
|
|
117 |
class Nadam(Optimizer): |
|
|
118 |
|
|
|
119 |
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, |
|
|
120 |
schedule_decay=0.004,amsgrad=False): |
|
|
121 |
if not 0.0 <= betas[0] < 1.0: |
|
|
122 |
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) |
|
|
123 |
if not 0.0 <= betas[1] < 1.0: |
|
|
124 |
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) |
|
|
125 |
defaults = dict(lr=lr, betas=betas, eps=eps, |
|
|
126 |
amsgrad=amsgrad,schedule_decay=schedule_decay) |
|
|
127 |
super(Nadam, self).__init__(params, defaults) |
|
|
128 |
|
|
|
129 |
def step(self, closure=None): |
|
|
130 |
loss = None |
|
|
131 |
if closure is not None: |
|
|
132 |
loss = closure() |
|
|
133 |
|
|
|
134 |
for group in self.param_groups: |
|
|
135 |
for p in group['params']: |
|
|
136 |
if p.grad is None: |
|
|
137 |
continue |
|
|
138 |
grad = p.grad.data |
|
|
139 |
if grad.is_sparse: |
|
|
140 |
raise RuntimeError('Nadam does not support sparse gradients, please consider SparseAdam instead') |
|
|
141 |
amsgrad = group['amsgrad'] |
|
|
142 |
|
|
|
143 |
state = self.state[p] |
|
|
144 |
|
|
|
145 |
# State initialization |
|
|
146 |
if len(state) == 0: |
|
|
147 |
state['step'] = 0 |
|
|
148 |
# Exponential moving average of gradient values |
|
|
149 |
state['exp_avg'] = torch.zeros_like(p.data) |
|
|
150 |
# Exponential moving average of squared gradient values |
|
|
151 |
state['exp_avg_sq'] = torch.zeros_like(p.data) |
|
|
152 |
|
|
|
153 |
state['m_schedule'] = 1 |
|
|
154 |
if amsgrad: |
|
|
155 |
# Maintains max of all exp. moving avg. of sq. grad. values |
|
|
156 |
state['max_exp_avg_sq'] = torch.zeros_like(p.data) |
|
|
157 |
|
|
|
158 |
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] |
|
|
159 |
if amsgrad: |
|
|
160 |
max_exp_avg_sq = state['max_exp_avg_sq'] |
|
|
161 |
beta1, beta2 = group['betas'] |
|
|
162 |
|
|
|
163 |
|
|
|
164 |
state['step'] += 1 |
|
|
165 |
momentum_cache_t = beta1 * ( |
|
|
166 |
1. - 0.5 * math.pow(0.96, state['step'] * group['schedule_decay'] )) |
|
|
167 |
momentum_cache_t_1 = beta1 * ( |
|
|
168 |
1. - 0.5 * math.pow(0.96, (state['step']+1) * group['schedule_decay'] )) |
|
|
169 |
state['m_schedule'] = state['m_schedule'] * momentum_cache_t |
|
|
170 |
|
|
|
171 |
exp_avg.mul_(beta1).add_(1 - beta1, grad) |
|
|
172 |
m_t_prime = exp_avg/(1 - state['m_schedule'] * momentum_cache_t_1) |
|
|
173 |
|
|
|
174 |
g_prime = grad.div(1 - state['m_schedule']) |
|
|
175 |
m_t_bar = (1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime |
|
|
176 |
|
|
|
177 |
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) |
|
|
178 |
if amsgrad: |
|
|
179 |
# Maintains the maximum of all 2nd moment running avg. till now |
|
|
180 |
torch.max(max_exp_avg_sq, exp_avg_sq , out=max_exp_avg_sq) |
|
|
181 |
# Use the max. for normalizing running avg. of gradient |
|
|
182 |
v_t_prime = max_exp_avg_sq/(1 - beta2 ** state['step']) |
|
|
183 |
else: |
|
|
184 |
v_t_prime = exp_avg_sq / (1 - beta2 ** state['step']) |
|
|
185 |
|
|
|
186 |
denom = v_t_prime.sqrt().add_(group['eps']) |
|
|
187 |
p.data.addcdiv_(-group['lr'], m_t_bar , denom) |
|
|
188 |
|
|
|
189 |
return loss |