|
a |
|
b/model/lavis/common/optims.py |
|
|
1 |
""" |
|
|
2 |
Copyright (c) 2022, salesforce.com, inc. |
|
|
3 |
All rights reserved. |
|
|
4 |
SPDX-License-Identifier: BSD-3-Clause |
|
|
5 |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
|
|
6 |
""" |
|
|
7 |
|
|
|
8 |
import math |
|
|
9 |
|
|
|
10 |
from model.lavis.common.registry import registry |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
@registry.register_lr_scheduler("linear_warmup_step_lr") |
|
|
14 |
class LinearWarmupStepLRScheduler: |
|
|
15 |
def __init__( |
|
|
16 |
self, |
|
|
17 |
optimizer, |
|
|
18 |
max_epoch, |
|
|
19 |
min_lr, |
|
|
20 |
init_lr, |
|
|
21 |
decay_rate=1, |
|
|
22 |
warmup_start_lr=-1, |
|
|
23 |
warmup_steps=0, |
|
|
24 |
**kwargs |
|
|
25 |
): |
|
|
26 |
self.optimizer = optimizer |
|
|
27 |
|
|
|
28 |
self.max_epoch = max_epoch |
|
|
29 |
self.min_lr = min_lr |
|
|
30 |
|
|
|
31 |
self.decay_rate = decay_rate |
|
|
32 |
|
|
|
33 |
self.init_lr = init_lr |
|
|
34 |
self.warmup_steps = warmup_steps |
|
|
35 |
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr |
|
|
36 |
|
|
|
37 |
def step(self, cur_epoch, cur_step): |
|
|
38 |
if cur_epoch == 0: |
|
|
39 |
warmup_lr_schedule( |
|
|
40 |
step=cur_step, |
|
|
41 |
optimizer=self.optimizer, |
|
|
42 |
max_step=self.warmup_steps, |
|
|
43 |
init_lr=self.warmup_start_lr, |
|
|
44 |
max_lr=self.init_lr, |
|
|
45 |
) |
|
|
46 |
else: |
|
|
47 |
step_lr_schedule( |
|
|
48 |
epoch=cur_epoch, |
|
|
49 |
optimizer=self.optimizer, |
|
|
50 |
init_lr=self.init_lr, |
|
|
51 |
min_lr=self.min_lr, |
|
|
52 |
decay_rate=self.decay_rate, |
|
|
53 |
) |
|
|
54 |
|
|
|
55 |
|
|
|
56 |
@registry.register_lr_scheduler("linear_warmup_cosine_lr") |
|
|
57 |
class LinearWarmupCosineLRScheduler: |
|
|
58 |
def __init__( |
|
|
59 |
self, |
|
|
60 |
optimizer, |
|
|
61 |
max_epoch, |
|
|
62 |
min_lr, |
|
|
63 |
init_lr, |
|
|
64 |
warmup_steps=0, |
|
|
65 |
warmup_start_lr=-1, |
|
|
66 |
**kwargs |
|
|
67 |
): |
|
|
68 |
self.optimizer = optimizer |
|
|
69 |
|
|
|
70 |
self.max_epoch = max_epoch |
|
|
71 |
self.min_lr = min_lr |
|
|
72 |
|
|
|
73 |
self.init_lr = init_lr |
|
|
74 |
self.warmup_steps = warmup_steps |
|
|
75 |
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr |
|
|
76 |
|
|
|
77 |
def step(self, cur_epoch, cur_step): |
|
|
78 |
# assuming the warmup iters less than one epoch |
|
|
79 |
if cur_epoch == 0: |
|
|
80 |
warmup_lr_schedule( |
|
|
81 |
step=cur_step, |
|
|
82 |
optimizer=self.optimizer, |
|
|
83 |
max_step=self.warmup_steps, |
|
|
84 |
init_lr=self.warmup_start_lr, |
|
|
85 |
max_lr=self.init_lr, |
|
|
86 |
) |
|
|
87 |
else: |
|
|
88 |
cosine_lr_schedule( |
|
|
89 |
epoch=cur_epoch, |
|
|
90 |
optimizer=self.optimizer, |
|
|
91 |
max_epoch=self.max_epoch, |
|
|
92 |
init_lr=self.init_lr, |
|
|
93 |
min_lr=self.min_lr, |
|
|
94 |
) |
|
|
95 |
|
|
|
96 |
|
|
|
97 |
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): |
|
|
98 |
"""Decay the learning rate""" |
|
|
99 |
lr = (init_lr - min_lr) * 0.5 * ( |
|
|
100 |
1.0 + math.cos(math.pi * epoch / max_epoch) |
|
|
101 |
) + min_lr |
|
|
102 |
for param_group in optimizer.param_groups: |
|
|
103 |
param_group["lr"] = lr |
|
|
104 |
|
|
|
105 |
|
|
|
106 |
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): |
|
|
107 |
"""Warmup the learning rate""" |
|
|
108 |
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) |
|
|
109 |
for param_group in optimizer.param_groups: |
|
|
110 |
param_group["lr"] = lr |
|
|
111 |
|
|
|
112 |
|
|
|
113 |
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): |
|
|
114 |
"""Decay the learning rate""" |
|
|
115 |
lr = max(min_lr, init_lr * (decay_rate**epoch)) |
|
|
116 |
for param_group in optimizer.param_groups: |
|
|
117 |
param_group["lr"] = lr |