|
a |
|
b/baselines/common/running_stat.py |
|
|
1 |
import numpy as np |
|
|
2 |
|
|
|
3 |
# http://www.johndcook.com/blog/standard_deviation/ |
|
|
4 |
class RunningStat(object): |
|
|
5 |
def __init__(self, shape): |
|
|
6 |
self._n = 0 |
|
|
7 |
self._M = np.zeros(shape) |
|
|
8 |
self._S = np.zeros(shape) |
|
|
9 |
def push(self, x): |
|
|
10 |
x = np.asarray(x) |
|
|
11 |
assert x.shape == self._M.shape |
|
|
12 |
self._n += 1 |
|
|
13 |
if self._n == 1: |
|
|
14 |
self._M[...] = x |
|
|
15 |
else: |
|
|
16 |
oldM = self._M.copy() |
|
|
17 |
self._M[...] = oldM + (x - oldM)/self._n |
|
|
18 |
self._S[...] = self._S + (x - oldM)*(x - self._M) |
|
|
19 |
@property |
|
|
20 |
def n(self): |
|
|
21 |
return self._n |
|
|
22 |
@property |
|
|
23 |
def mean(self): |
|
|
24 |
return self._M |
|
|
25 |
@property |
|
|
26 |
def var(self): |
|
|
27 |
return self._S/(self._n - 1) if self._n > 1 else np.square(self._M) |
|
|
28 |
@property |
|
|
29 |
def std(self): |
|
|
30 |
return np.sqrt(self.var) |
|
|
31 |
@property |
|
|
32 |
def shape(self): |
|
|
33 |
return self._M.shape |
|
|
34 |
|
|
|
35 |
def test_running_stat(): |
|
|
36 |
for shp in ((), (3,), (3,4)): |
|
|
37 |
li = [] |
|
|
38 |
rs = RunningStat(shp) |
|
|
39 |
for _ in range(5): |
|
|
40 |
val = np.random.randn(*shp) |
|
|
41 |
rs.push(val) |
|
|
42 |
li.append(val) |
|
|
43 |
m = np.mean(li, axis=0) |
|
|
44 |
assert np.allclose(rs.mean, m) |
|
|
45 |
v = np.square(m) if (len(li) == 1) else np.var(li, ddof=1, axis=0) |
|
|
46 |
assert np.allclose(rs.var, v) |