Switch to side-by-side view

--- a
+++ b/baselines/common/schedules.py
@@ -0,0 +1,99 @@
+"""This file is used for specifying various schedules that evolve over
+time throughout the execution of the algorithm, such as:
+ - learning rate for the optimizer
+ - exploration epsilon for the epsilon greedy exploration strategy
+ - beta parameter for beta parameter in prioritized replay
+
+Each schedule has a function `value(t)` which returns the current value
+of the parameter given the timestep t of the optimization procedure.
+"""
+
+
+class Schedule(object):
+    def value(self, t):
+        """Value of the schedule at time t"""
+        raise NotImplementedError()
+
+
+class ConstantSchedule(object):
+    def __init__(self, value):
+        """Value remains constant over time.
+
+        Parameters
+        ----------
+        value: float
+            Constant value of the schedule
+        """
+        self._v = value
+
+    def value(self, t):
+        """See Schedule.value"""
+        return self._v
+
+
+def linear_interpolation(l, r, alpha):
+    return l + alpha * (r - l)
+
+
+class PiecewiseSchedule(object):
+    def __init__(self, endpoints, interpolation=linear_interpolation, outside_value=None):
+        """Piecewise schedule.
+
+        endpoints: [(int, int)]
+            list of pairs `(time, value)` meanining that schedule should output
+            `value` when `t==time`. All the values for time must be sorted in
+            an increasing order. When t is between two times, e.g. `(time_a, value_a)`
+            and `(time_b, value_b)`, such that `time_a <= t < time_b` then value outputs
+            `interpolation(value_a, value_b, alpha)` where alpha is a fraction of
+            time passed between `time_a` and `time_b` for time `t`.
+        interpolation: lambda float, float, float: float
+            a function that takes value to the left and to the right of t according
+            to the `endpoints`. Alpha is the fraction of distance from left endpoint to
+            right endpoint that t has covered. See linear_interpolation for example.
+        outside_value: float
+            if the value is requested outside of all the intervals sepecified in
+            `endpoints` this value is returned. If None then AssertionError is
+            raised when outside value is requested.
+        """
+        idxes = [e[0] for e in endpoints]
+        assert idxes == sorted(idxes)
+        self._interpolation = interpolation
+        self._outside_value = outside_value
+        self._endpoints = endpoints
+
+    def value(self, t):
+        """See Schedule.value"""
+        for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]):
+            if l_t <= t and t < r_t:
+                alpha = float(t - l_t) / (r_t - l_t)
+                return self._interpolation(l, r, alpha)
+
+        # t does not belong to any of the pieces, so doom.
+        assert self._outside_value is not None
+        return self._outside_value
+
+
+class LinearSchedule(object):
+    def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
+        """Linear interpolation between initial_p and final_p over
+        schedule_timesteps. After this many timesteps pass final_p is
+        returned.
+
+        Parameters
+        ----------
+        schedule_timesteps: int
+            Number of timesteps for which to linearly anneal initial_p
+            to final_p
+        initial_p: float
+            initial output value
+        final_p: float
+            final output value
+        """
+        self.schedule_timesteps = schedule_timesteps
+        self.final_p = final_p
+        self.initial_p = initial_p
+
+    def value(self, t):
+        """See Schedule.value"""
+        fraction = min(float(t) / self.schedule_timesteps, 1.0)
+        return self.initial_p + fraction * (self.final_p - self.initial_p)