Switch to unified view

a b/baselines/common/running_stat.py
1
import numpy as np
2
3
# http://www.johndcook.com/blog/standard_deviation/
4
class RunningStat(object):
5
    def __init__(self, shape):
6
        self._n = 0
7
        self._M = np.zeros(shape)
8
        self._S = np.zeros(shape)
9
    def push(self, x):
10
        x = np.asarray(x)
11
        assert x.shape == self._M.shape
12
        self._n += 1
13
        if self._n == 1:
14
            self._M[...] = x
15
        else:
16
            oldM = self._M.copy()
17
            self._M[...] = oldM + (x - oldM)/self._n
18
            self._S[...] = self._S + (x - oldM)*(x - self._M)
19
    @property
20
    def n(self):
21
        return self._n
22
    @property
23
    def mean(self):
24
        return self._M
25
    @property
26
    def var(self):
27
        return self._S/(self._n - 1) if self._n > 1 else np.square(self._M)
28
    @property
29
    def std(self):
30
        return np.sqrt(self.var)
31
    @property
32
    def shape(self):
33
        return self._M.shape
34
35
def test_running_stat():
36
    for shp in ((), (3,), (3,4)):
37
        li = []
38
        rs = RunningStat(shp)
39
        for _ in range(5):
40
            val = np.random.randn(*shp)
41
            rs.push(val)
42
            li.append(val)
43
            m = np.mean(li, axis=0)
44
            assert np.allclose(rs.mean, m)
45
            v = np.square(m) if (len(li) == 1) else np.var(li, ddof=1, axis=0)
46
            assert np.allclose(rs.var, v)