--- a
+++ b/baselines/common/mpi_adam.py
@@ -0,0 +1,79 @@
+from mpi4py import MPI
+import baselines.common.tf_util as U
+import tensorflow as tf
+import numpy as np
+
+class MpiAdam(object):
+    def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None):
+        self.var_list = var_list
+        self.beta1 = beta1
+        self.beta2 = beta2
+        self.epsilon = epsilon
+        self.scale_grad_by_procs = scale_grad_by_procs
+        size = sum(U.numel(v) for v in var_list)
+        self.m = np.zeros(size, 'float32')
+        self.v = np.zeros(size, 'float32')
+        self.t = 0
+        self.setfromflat = U.SetFromFlat(var_list)
+        self.getflat = U.GetFlat(var_list)
+        self.comm = MPI.COMM_WORLD if comm is None else comm
+
+    def update(self, localg, stepsize):
+        if self.t % 100 == 0:
+            self.check_synced()
+        localg = localg.astype('float32')
+        globalg = np.zeros_like(localg)
+        self.comm.Allreduce(localg, globalg, op=MPI.SUM)
+        if self.scale_grad_by_procs:
+            globalg /= self.comm.Get_size()
+
+        self.t += 1
+        a = stepsize * np.sqrt(1 - self.beta2**self.t)/(1 - self.beta1**self.t)
+        self.m = self.beta1 * self.m + (1 - self.beta1) * globalg
+        self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg)
+        step = (- a) * self.m / (np.sqrt(self.v) + self.epsilon)
+        self.setfromflat(self.getflat() + step)
+
+    def sync(self):
+        theta = self.getflat()
+        self.comm.Bcast(theta, root=0)
+        self.setfromflat(theta)
+
+    def check_synced(self):
+        if self.comm.Get_rank() == 0: # this is root
+            theta = self.getflat()
+            self.comm.Bcast(theta, root=0)
+        else:
+            thetalocal = self.getflat()
+            thetaroot = np.empty_like(thetalocal)
+            self.comm.Bcast(thetaroot, root=0)
+            assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal)
+
+@U.in_session
+def test_MpiAdam():
+    np.random.seed(0)
+    tf.set_random_seed(0)
+
+    a = tf.Variable(np.random.randn(3).astype('float32'))
+    b = tf.Variable(np.random.randn(2,5).astype('float32'))
+    loss = tf.reduce_sum(tf.square(a)) + tf.reduce_sum(tf.sin(b))
+
+    stepsize = 1e-2
+    update_op = tf.train.AdamOptimizer(stepsize).minimize(loss)
+    do_update = U.function([], loss, updates=[update_op])
+
+    tf.get_default_session().run(tf.global_variables_initializer())
+    for i in range(10):
+        print(i,do_update())
+
+    tf.set_random_seed(0)
+    tf.get_default_session().run(tf.global_variables_initializer())
+
+    var_list = [a,b]
+    lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)], updates=[update_op])
+    adam = MpiAdam(var_list)
+
+    for i in range(10):
+        l,g = lossandgrad()
+        adam.update(g, stepsize)
+        print(i,l)
\ No newline at end of file