[f9c9f2]: / baselines / common / mpi_running_mean_std.py

Download this file

108 lines (87 with data), 3.7 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
from mpi4py import MPI
import tensorflow as tf, baselines.common.tf_util as U, numpy as np
class RunningMeanStd(object):
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
def __init__(self, epsilon=1e-2, shape=()):
self._sum = tf.get_variable(
dtype=tf.float64,
shape=shape,
initializer=tf.constant_initializer(0.0),
name="runningsum", trainable=False)
self._sumsq = tf.get_variable(
dtype=tf.float64,
shape=shape,
initializer=tf.constant_initializer(epsilon),
name="runningsumsq", trainable=False)
self._count = tf.get_variable(
dtype=tf.float64,
shape=(),
initializer=tf.constant_initializer(epsilon),
name="count", trainable=False)
self.shape = shape
self.mean = tf.to_float(self._sum / self._count)
self.std = tf.sqrt( tf.maximum( tf.to_float(self._sumsq / self._count) - tf.square(self.mean) , 1e-2 ))
newsum = tf.placeholder(shape=self.shape, dtype=tf.float64, name='sum')
newsumsq = tf.placeholder(shape=self.shape, dtype=tf.float64, name='var')
newcount = tf.placeholder(shape=[], dtype=tf.float64, name='count')
self.incfiltparams = U.function([newsum, newsumsq, newcount], [],
updates=[tf.assign_add(self._sum, newsum),
tf.assign_add(self._sumsq, newsumsq),
tf.assign_add(self._count, newcount)])
def update(self, x):
x = x.astype('float64')
n = int(np.prod(self.shape))
totalvec = np.zeros(n*2+1, 'float64')
addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(), np.array([len(x)],dtype='float64')])
MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2*n].reshape(self.shape), totalvec[2*n])
@U.in_session
def test_runningmeanstd():
for (x1, x2, x3) in [
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
(np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),
]:
rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:])
U.initialize()
x = np.concatenate([x1, x2, x3], axis=0)
ms1 = [x.mean(axis=0), x.std(axis=0)]
rms.update(x1)
rms.update(x2)
rms.update(x3)
ms2 = [rms.mean.eval(), rms.std.eval()]
assert np.allclose(ms1, ms2)
@U.in_session
def test_dist():
np.random.seed(0)
p1,p2,p3=(np.random.randn(3,1), np.random.randn(4,1), np.random.randn(5,1))
q1,q2,q3=(np.random.randn(6,1), np.random.randn(7,1), np.random.randn(8,1))
# p1,p2,p3=(np.random.randn(3), np.random.randn(4), np.random.randn(5))
# q1,q2,q3=(np.random.randn(6), np.random.randn(7), np.random.randn(8))
comm = MPI.COMM_WORLD
assert comm.Get_size()==2
if comm.Get_rank()==0:
x1,x2,x3 = p1,p2,p3
elif comm.Get_rank()==1:
x1,x2,x3 = q1,q2,q3
else:
assert False
rms = RunningMeanStd(epsilon=0.0, shape=(1,))
U.initialize()
rms.update(x1)
rms.update(x2)
rms.update(x3)
bigvec = np.concatenate([p1,p2,p3,q1,q2,q3])
def checkallclose(x,y):
print(x,y)
return np.allclose(x,y)
assert checkallclose(
bigvec.mean(axis=0),
rms.mean.eval(),
)
assert checkallclose(
bigvec.std(axis=0),
rms.std.eval(),
)
if __name__ == "__main__":
# Run with mpirun -np 2 python <filename>
test_dist()