Diff of /pathflowai/schedulers.py [000000] .. [e9500f]

Switch to unified view

a b/pathflowai/schedulers.py
1
"""
2
schedulers.py
3
=======================
4
Modulates the learning rate during the training process.
5
"""
6
import torch
7
import math
8
from torch.optim.lr_scheduler import ExponentialLR
9
10
class CosineAnnealingWithRestartsLR(torch.optim.lr_scheduler._LRScheduler):
11
    r"""Set the learning rate of each parameter group using a cosine annealing
12
    schedule, where :math:`\eta_{max}` is set to the initial lr and
13
    :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
14
     .. math::
15
         \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
16
        \cos(\frac{T_{cur}}{T_{max}}\pi))
17
     When last_epoch=-1, sets initial lr as lr.
18
     It has been proposed in
19
    `SGDR: Stochastic Gradient Descent with Warm Restarts`_. This implements
20
    the cosine annealing part of SGDR, the restarts and number of iterations multiplier.
21
     Args:
22
        optimizer (Optimizer): Wrapped optimizer.
23
        T_max (int): Maximum number of iterations.
24
        T_mult (float): Multiply T_max by this number after each restart. Default: 1.
25
        eta_min (float): Minimum learning rate. Default: 0.
26
        last_epoch (int): The index of last epoch. Default: -1.
27
     .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
28
        https://arxiv.org/abs/1608.03983
29
    """
30
    def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, T_mult=1., alpha_decay=1.0):
31
        self.T_max = T_max
32
        self.T_mult = T_mult
33
        self.restart_every = T_max
34
        self.eta_min = eta_min
35
        self.restarts = 0
36
        self.restarted_at = 0
37
        self.alpha = alpha_decay
38
        super().__init__(optimizer, last_epoch)
39
40
    def restart(self):
41
        self.restarts += 1
42
        self.restart_every = int(round(self.restart_every * self.T_mult))
43
        self.restarted_at = self.last_epoch
44
45
    def cosine(self, base_lr):
46
        return self.eta_min + self.alpha**self.restarts * (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.step_n / self.restart_every)) / 2
47
48
    @property
49
    def step_n(self):
50
        return self.last_epoch - self.restarted_at
51
52
    def get_lr(self):
53
        if self.step_n >= self.restart_every:
54
            self.restart()
55
        return [self.cosine(base_lr) for base_lr in self.base_lrs]
56
57
class Scheduler:
58
    """Scheduler class that modulates learning rate of torch optimizers over epochs.
59
60
    Parameters
61
    ----------
62
    optimizer : type
63
        torch.Optimizer object
64
    opts : type
65
        Options of setting the learning rate scheduler, see default.
66
67
    Attributes
68
    ----------
69
    schedulers : type
70
        Different types of schedulers to choose from.
71
    scheduler_step_fn : type
72
        How scheduler updates learning rate.
73
    initial_lr : type
74
        Initial set learning rate.
75
    scheduler_choice : type
76
        What scheduler type was chosen.
77
    scheduler : type
78
        Scheduler object chosen that will more directly update optimizer LR.
79
80
    """
81
    def __init__(self, optimizer=None, opts=dict(scheduler='null',lr_scheduler_decay=0.5,T_max=10,eta_min=5e-8,T_mult=2)):
82
        self.schedulers = {'exp':(lambda optimizer: ExponentialLR(optimizer, opts["lr_scheduler_decay"])),
83
                            'null':(lambda optimizer: None),
84
                            'warm_restarts':(lambda optimizer: CosineAnnealingWithRestartsLR(optimizer, T_max=opts['T_max'], eta_min=opts['eta_min'], last_epoch=-1, T_mult=opts['T_mult']))}
85
        self.scheduler_step_fn = {'exp':(lambda scheduler: scheduler.step()),
86
                                  'warm_restarts':(lambda scheduler: scheduler.step()),
87
                                  'null':(lambda scheduler: None)}
88
        self.initial_lr = optimizer.param_groups[0]['lr']
89
        self.scheduler_choice = opts['scheduler']
90
        self.scheduler = self.schedulers[self.scheduler_choice](optimizer) if optimizer is not None else None
91
92
    def step(self):
93
        """Update optimizer learning rate"""
94
        self.scheduler_step_fn[self.scheduler_choice](self.scheduler)
95
96
    def get_lr(self):
97
        """Return current learning rate.
98
99
        Returns
100
        -------
101
        float
102
            Current learning rate.
103
104
        """
105
        lr = (self.initial_lr if self.scheduler_choice == 'null' else self.scheduler.optimizer.param_groups[0]['lr'])
106
        return lr