--- a +++ b/baselines/common/mpi_moments.py @@ -0,0 +1,60 @@ +from mpi4py import MPI +import numpy as np +from baselines.common import zipsame + + +def mpi_mean(x, axis=0, comm=None, keepdims=False): + x = np.asarray(x) + assert x.ndim > 0 + if comm is None: comm = MPI.COMM_WORLD + xsum = x.sum(axis=axis, keepdims=keepdims) + n = xsum.size + localsum = np.zeros(n+1, x.dtype) + localsum[:n] = xsum.ravel() + localsum[n] = x.shape[axis] + globalsum = np.zeros_like(localsum) + comm.Allreduce(localsum, globalsum, op=MPI.SUM) + return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n] + +def mpi_moments(x, axis=0, comm=None, keepdims=False): + x = np.asarray(x) + assert x.ndim > 0 + mean, count = mpi_mean(x, axis=axis, comm=comm, keepdims=True) + sqdiffs = np.square(x - mean) + meansqdiff, count1 = mpi_mean(sqdiffs, axis=axis, comm=comm, keepdims=True) + assert count1 == count + std = np.sqrt(meansqdiff) + if not keepdims: + newshape = mean.shape[:axis] + mean.shape[axis+1:] + mean = mean.reshape(newshape) + std = std.reshape(newshape) + return mean, std, count + + +def test_runningmeanstd(): + import subprocess + subprocess.check_call(['mpirun', '-np', '3', + 'python','-c', + 'from baselines.common.mpi_moments import _helper_runningmeanstd; _helper_runningmeanstd()']) + +def _helper_runningmeanstd(): + comm = MPI.COMM_WORLD + np.random.seed(0) + for (triple,axis) in [ + ((np.random.randn(3), np.random.randn(4), np.random.randn(5)),0), + ((np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),0), + ((np.random.randn(2,3), np.random.randn(2,4), np.random.randn(2,4)),1), + ]: + + + x = np.concatenate(triple, axis=axis) + ms1 = [x.mean(axis=axis), x.std(axis=axis), x.shape[axis]] + + + ms2 = mpi_moments(triple[comm.Get_rank()],axis=axis) + + for (a1,a2) in zipsame(ms1, ms2): + print(a1, a2) + assert np.allclose(a1, a2) + print("ok!") +