a b/submission/baselines/common/schedules.py
1
"""This file is used for specifying various schedules that evolve over
2
time throughout the execution of the algorithm, such as:
3
 - learning rate for the optimizer
4
 - exploration epsilon for the epsilon greedy exploration strategy
5
 - beta parameter for beta parameter in prioritized replay
6
7
Each schedule has a function `value(t)` which returns the current value
8
of the parameter given the timestep t of the optimization procedure.
9
"""
10
11
12
class Schedule(object):
13
    def value(self, t):
14
        """Value of the schedule at time t"""
15
        raise NotImplementedError()
16
17
18
class ConstantSchedule(object):
19
    def __init__(self, value):
20
        """Value remains constant over time.
21
22
        Parameters
23
        ----------
24
        value: float
25
            Constant value of the schedule
26
        """
27
        self._v = value
28
29
    def value(self, t):
30
        """See Schedule.value"""
31
        return self._v
32
33
34
def linear_interpolation(l, r, alpha):
35
    return l + alpha * (r - l)
36
37
38
class PiecewiseSchedule(object):
39
    def __init__(self, endpoints, interpolation=linear_interpolation, outside_value=None):
40
        """Piecewise schedule.
41
42
        endpoints: [(int, int)]
43
            list of pairs `(time, value)` meanining that schedule should output
44
            `value` when `t==time`. All the values for time must be sorted in
45
            an increasing order. When t is between two times, e.g. `(time_a, value_a)`
46
            and `(time_b, value_b)`, such that `time_a <= t < time_b` then value outputs
47
            `interpolation(value_a, value_b, alpha)` where alpha is a fraction of
48
            time passed between `time_a` and `time_b` for time `t`.
49
        interpolation: lambda float, float, float: float
50
            a function that takes value to the left and to the right of t according
51
            to the `endpoints`. Alpha is the fraction of distance from left endpoint to
52
            right endpoint that t has covered. See linear_interpolation for example.
53
        outside_value: float
54
            if the value is requested outside of all the intervals sepecified in
55
            `endpoints` this value is returned. If None then AssertionError is
56
            raised when outside value is requested.
57
        """
58
        idxes = [e[0] for e in endpoints]
59
        assert idxes == sorted(idxes)
60
        self._interpolation = interpolation
61
        self._outside_value = outside_value
62
        self._endpoints = endpoints
63
64
    def value(self, t):
65
        """See Schedule.value"""
66
        for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]):
67
            if l_t <= t and t < r_t:
68
                alpha = float(t - l_t) / (r_t - l_t)
69
                return self._interpolation(l, r, alpha)
70
71
        # t does not belong to any of the pieces, so doom.
72
        assert self._outside_value is not None
73
        return self._outside_value
74
75
76
class LinearSchedule(object):
77
    def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
78
        """Linear interpolation between initial_p and final_p over
79
        schedule_timesteps. After this many timesteps pass final_p is
80
        returned.
81
82
        Parameters
83
        ----------
84
        schedule_timesteps: int
85
            Number of timesteps for which to linearly anneal initial_p
86
            to final_p
87
        initial_p: float
88
            initial output value
89
        final_p: float
90
            final output value
91
        """
92
        self.schedule_timesteps = schedule_timesteps
93
        self.final_p = final_p
94
        self.initial_p = initial_p
95
96
    def value(self, t):
97
        """See Schedule.value"""
98
        fraction = min(float(t) / self.schedule_timesteps, 1.0)
99
        return self.initial_p + fraction * (self.final_p - self.initial_p)