a b/model/lavis/common/optims.py
1
"""
2
 Copyright (c) 2022, salesforce.com, inc.
3
 All rights reserved.
4
 SPDX-License-Identifier: BSD-3-Clause
5
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
"""
7
8
import math
9
10
from model.lavis.common.registry import registry
11
12
13
@registry.register_lr_scheduler("linear_warmup_step_lr")
14
class LinearWarmupStepLRScheduler:
15
    def __init__(
16
        self,
17
        optimizer,
18
        max_epoch,
19
        min_lr,
20
        init_lr,
21
        decay_rate=1,
22
        warmup_start_lr=-1,
23
        warmup_steps=0,
24
        **kwargs
25
    ):
26
        self.optimizer = optimizer
27
28
        self.max_epoch = max_epoch
29
        self.min_lr = min_lr
30
31
        self.decay_rate = decay_rate
32
33
        self.init_lr = init_lr
34
        self.warmup_steps = warmup_steps
35
        self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36
37
    def step(self, cur_epoch, cur_step):
38
        if cur_epoch == 0:
39
            warmup_lr_schedule(
40
                step=cur_step,
41
                optimizer=self.optimizer,
42
                max_step=self.warmup_steps,
43
                init_lr=self.warmup_start_lr,
44
                max_lr=self.init_lr,
45
            )
46
        else:
47
            step_lr_schedule(
48
                epoch=cur_epoch,
49
                optimizer=self.optimizer,
50
                init_lr=self.init_lr,
51
                min_lr=self.min_lr,
52
                decay_rate=self.decay_rate,
53
            )
54
55
56
@registry.register_lr_scheduler("linear_warmup_cosine_lr")
57
class LinearWarmupCosineLRScheduler:
58
    def __init__(
59
        self,
60
        optimizer,
61
        max_epoch,
62
        min_lr,
63
        init_lr,
64
        warmup_steps=0,
65
        warmup_start_lr=-1,
66
        **kwargs
67
    ):
68
        self.optimizer = optimizer
69
70
        self.max_epoch = max_epoch
71
        self.min_lr = min_lr
72
73
        self.init_lr = init_lr
74
        self.warmup_steps = warmup_steps
75
        self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
76
77
    def step(self, cur_epoch, cur_step):
78
        # assuming the warmup iters less than one epoch
79
        if cur_epoch == 0:
80
            warmup_lr_schedule(
81
                step=cur_step,
82
                optimizer=self.optimizer,
83
                max_step=self.warmup_steps,
84
                init_lr=self.warmup_start_lr,
85
                max_lr=self.init_lr,
86
            )
87
        else:
88
            cosine_lr_schedule(
89
                epoch=cur_epoch,
90
                optimizer=self.optimizer,
91
                max_epoch=self.max_epoch,
92
                init_lr=self.init_lr,
93
                min_lr=self.min_lr,
94
            )
95
96
97
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
98
    """Decay the learning rate"""
99
    lr = (init_lr - min_lr) * 0.5 * (
100
        1.0 + math.cos(math.pi * epoch / max_epoch)
101
    ) + min_lr
102
    for param_group in optimizer.param_groups:
103
        param_group["lr"] = lr
104
105
106
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
107
    """Warmup the learning rate"""
108
    lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
109
    for param_group in optimizer.param_groups:
110
        param_group["lr"] = lr
111
112
113
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
114
    """Decay the learning rate"""
115
    lr = max(min_lr, init_lr * (decay_rate**epoch))
116
    for param_group in optimizer.param_groups:
117
        param_group["lr"] = lr