Switch to side-by-side view

--- a
+++ b/model/lavis/common/optims.py
@@ -0,0 +1,117 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import math
+
+from model.lavis.common.registry import registry
+
+
+@registry.register_lr_scheduler("linear_warmup_step_lr")
+class LinearWarmupStepLRScheduler:
+    def __init__(
+        self,
+        optimizer,
+        max_epoch,
+        min_lr,
+        init_lr,
+        decay_rate=1,
+        warmup_start_lr=-1,
+        warmup_steps=0,
+        **kwargs
+    ):
+        self.optimizer = optimizer
+
+        self.max_epoch = max_epoch
+        self.min_lr = min_lr
+
+        self.decay_rate = decay_rate
+
+        self.init_lr = init_lr
+        self.warmup_steps = warmup_steps
+        self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
+
+    def step(self, cur_epoch, cur_step):
+        if cur_epoch == 0:
+            warmup_lr_schedule(
+                step=cur_step,
+                optimizer=self.optimizer,
+                max_step=self.warmup_steps,
+                init_lr=self.warmup_start_lr,
+                max_lr=self.init_lr,
+            )
+        else:
+            step_lr_schedule(
+                epoch=cur_epoch,
+                optimizer=self.optimizer,
+                init_lr=self.init_lr,
+                min_lr=self.min_lr,
+                decay_rate=self.decay_rate,
+            )
+
+
+@registry.register_lr_scheduler("linear_warmup_cosine_lr")
+class LinearWarmupCosineLRScheduler:
+    def __init__(
+        self,
+        optimizer,
+        max_epoch,
+        min_lr,
+        init_lr,
+        warmup_steps=0,
+        warmup_start_lr=-1,
+        **kwargs
+    ):
+        self.optimizer = optimizer
+
+        self.max_epoch = max_epoch
+        self.min_lr = min_lr
+
+        self.init_lr = init_lr
+        self.warmup_steps = warmup_steps
+        self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
+
+    def step(self, cur_epoch, cur_step):
+        # assuming the warmup iters less than one epoch
+        if cur_epoch == 0:
+            warmup_lr_schedule(
+                step=cur_step,
+                optimizer=self.optimizer,
+                max_step=self.warmup_steps,
+                init_lr=self.warmup_start_lr,
+                max_lr=self.init_lr,
+            )
+        else:
+            cosine_lr_schedule(
+                epoch=cur_epoch,
+                optimizer=self.optimizer,
+                max_epoch=self.max_epoch,
+                init_lr=self.init_lr,
+                min_lr=self.min_lr,
+            )
+
+
+def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
+    """Decay the learning rate"""
+    lr = (init_lr - min_lr) * 0.5 * (
+        1.0 + math.cos(math.pi * epoch / max_epoch)
+    ) + min_lr
+    for param_group in optimizer.param_groups:
+        param_group["lr"] = lr
+
+
+def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
+    """Warmup the learning rate"""
+    lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
+    for param_group in optimizer.param_groups:
+        param_group["lr"] = lr
+
+
+def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
+    """Decay the learning rate"""
+    lr = max(min_lr, init_lr * (decay_rate**epoch))
+    for param_group in optimizer.param_groups:
+        param_group["lr"] = lr