--- a
+++ b/baselines/common/tf_util.py
@@ -0,0 +1,281 @@
+import numpy as np
+import tensorflow as tf  # pylint: ignore-module
+import copy
+import os
+import functools
+import collections
+import multiprocessing
+
+def switch(condition, then_expression, else_expression):
+    """Switches between two operations depending on a scalar value (int or bool).
+    Note that both `then_expression` and `else_expression`
+    should be symbolic tensors of the *same shape*.
+
+    # Arguments
+        condition: scalar tensor.
+        then_expression: TensorFlow operation.
+        else_expression: TensorFlow operation.
+    """
+    x_shape = copy.copy(then_expression.get_shape())
+    x = tf.cond(tf.cast(condition, 'bool'),
+                lambda: then_expression,
+                lambda: else_expression)
+    x.set_shape(x_shape)
+    return x
+
+# ================================================================
+# Extras
+# ================================================================
+
+def lrelu(x, leak=0.2):
+    f1 = 0.5 * (1 + leak)
+    f2 = 0.5 * (1 - leak)
+    return f1 * x + f2 * abs(x)
+
+# ================================================================
+# Mathematical utils
+# ================================================================
+
+def huber_loss(x, delta=1.0):
+    """Reference: https://en.wikipedia.org/wiki/Huber_loss"""
+    return tf.where(
+        tf.abs(x) < delta,
+        tf.square(x) * 0.5,
+        delta * (tf.abs(x) - 0.5 * delta)
+    )
+
+# ================================================================
+# Global session
+# ================================================================
+
+def make_session(num_cpu=None, make_default=False, graph=None):
+    """Returns a session that will use <num_cpu> CPU's only"""
+    if num_cpu is None:
+        num_cpu = int(os.getenv('RCALL_NUM_CPU', multiprocessing.cpu_count()))
+    tf_config = tf.ConfigProto(
+        inter_op_parallelism_threads=num_cpu,
+        intra_op_parallelism_threads=num_cpu)
+    tf_config.gpu_options.allocator_type = 'BFC'
+    if make_default:
+        return tf.InteractiveSession(config=tf_config, graph=graph)
+    else:
+        return tf.Session(config=tf_config, graph=graph)
+
+def single_threaded_session():
+    """Returns a session which will only use a single CPU"""
+    return make_session(num_cpu=1)
+
+def in_session(f):
+    @functools.wraps(f)
+    def newfunc(*args, **kwargs):
+        with tf.Session():
+            f(*args, **kwargs)
+    return newfunc
+
+ALREADY_INITIALIZED = set()
+
+def initialize():
+    """Initialize all the uninitialized variables in the global scope."""
+    new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED
+    tf.get_default_session().run(tf.variables_initializer(new_variables))
+    ALREADY_INITIALIZED.update(new_variables)
+
+# ================================================================
+# Model components
+# ================================================================
+
+def normc_initializer(std=1.0, axis=0):
+    def _initializer(shape, dtype=None, partition_info=None):  # pylint: disable=W0613
+        out = np.random.randn(*shape).astype(np.float32)
+        out *= std / np.sqrt(np.square(out).sum(axis=axis, keepdims=True))
+        return tf.constant(out)
+    return _initializer
+
+def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", dtype=tf.float32, collections=None,
+           summary_tag=None):
+    with tf.variable_scope(name):
+        stride_shape = [1, stride[0], stride[1], 1]
+        filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]), num_filters]
+
+        # there are "num input feature maps * filter height * filter width"
+        # inputs to each hidden unit
+        fan_in = intprod(filter_shape[:3])
+        # each unit in the lower layer receives a gradient from:
+        # "num output feature maps * filter height * filter width" /
+        #   pooling size
+        fan_out = intprod(filter_shape[:2]) * num_filters
+        # initialize weights with random weights
+        w_bound = np.sqrt(6. / (fan_in + fan_out))
+
+        w = tf.get_variable("W", filter_shape, dtype, tf.random_uniform_initializer(-w_bound, w_bound),
+                            collections=collections)
+        b = tf.get_variable("b", [1, 1, 1, num_filters], initializer=tf.zeros_initializer(),
+                            collections=collections)
+
+        if summary_tag is not None:
+            tf.summary.image(summary_tag,
+                             tf.transpose(tf.reshape(w, [filter_size[0], filter_size[1], -1, 1]),
+                                          [2, 0, 1, 3]),
+                             max_images=10)
+
+        return tf.nn.conv2d(x, w, stride_shape, pad) + b
+
+# ================================================================
+# Theano-like Function
+# ================================================================
+
+def function(inputs, outputs, updates=None, givens=None):
+    """Just like Theano function. Take a bunch of tensorflow placeholders and expressions
+    computed based on those placeholders and produces f(inputs) -> outputs. Function f takes
+    values to be fed to the input's placeholders and produces the values of the expressions
+    in outputs.
+
+    Input values can be passed in the same order as inputs or can be provided as kwargs based
+    on placeholder name (passed to constructor or accessible via placeholder.op.name).
+
+    Example:
+        x = tf.placeholder(tf.int32, (), name="x")
+        y = tf.placeholder(tf.int32, (), name="y")
+        z = 3 * x + 2 * y
+        lin = function([x, y], z, givens={y: 0})
+
+        with single_threaded_session():
+            initialize()
+
+            assert lin(2) == 6
+            assert lin(x=3) == 9
+            assert lin(2, 2) == 10
+            assert lin(x=2, y=3) == 12
+
+    Parameters
+    ----------
+    inputs: [tf.placeholder, tf.constant, or object with make_feed_dict method]
+        list of input arguments
+    outputs: [tf.Variable] or tf.Variable
+        list of outputs or a single output to be returned from function. Returned
+        value will also have the same shape.
+    """
+    if isinstance(outputs, list):
+        return _Function(inputs, outputs, updates, givens=givens)
+    elif isinstance(outputs, (dict, collections.OrderedDict)):
+        f = _Function(inputs, outputs.values(), updates, givens=givens)
+        return lambda *args, **kwargs: type(outputs)(zip(outputs.keys(), f(*args, **kwargs)))
+    else:
+        f = _Function(inputs, [outputs], updates, givens=givens)
+        return lambda *args, **kwargs: f(*args, **kwargs)[0]
+
+
+class _Function(object):
+    def __init__(self, inputs, outputs, updates, givens):
+        for inpt in inputs:
+            if not hasattr(inpt, 'make_feed_dict') and not (type(inpt) is tf.Tensor and len(inpt.op.inputs) == 0):
+                assert False, "inputs should all be placeholders, constants, or have a make_feed_dict method"
+        self.inputs = inputs
+        updates = updates or []
+        self.update_group = tf.group(*updates)
+        self.outputs_update = list(outputs) + [self.update_group]
+        self.givens = {} if givens is None else givens
+
+    def _feed_input(self, feed_dict, inpt, value):
+        if hasattr(inpt, 'make_feed_dict'):
+            feed_dict.update(inpt.make_feed_dict(value))
+        else:
+            feed_dict[inpt] = value
+
+    def __call__(self, *args):
+        assert len(args) <= len(self.inputs), "Too many arguments provided"
+        feed_dict = {}
+        # Update the args
+        for inpt, value in zip(self.inputs, args):
+            self._feed_input(feed_dict, inpt, value)
+        # Update feed dict with givens.
+        for inpt in self.givens:
+            feed_dict[inpt] = feed_dict.get(inpt, self.givens[inpt])
+        results = tf.get_default_session().run(self.outputs_update, feed_dict=feed_dict)[:-1]
+        return results
+
+# ================================================================
+# Flat vectors
+# ================================================================
+
+def var_shape(x):
+    out = x.get_shape().as_list()
+    assert all(isinstance(a, int) for a in out), \
+        "shape function assumes that shape is fully known"
+    return out
+
+def numel(x):
+    return intprod(var_shape(x))
+
+def intprod(x):
+    return int(np.prod(x))
+
+def flatgrad(loss, var_list, clip_norm=None):
+    grads = tf.gradients(loss, var_list)
+    if clip_norm is not None:
+        grads = [tf.clip_by_norm(grad, clip_norm=clip_norm) for grad in grads]
+    return tf.concat(axis=0, values=[
+        tf.reshape(grad if grad is not None else tf.zeros_like(v), [numel(v)])
+        for (v, grad) in zip(var_list, grads)
+    ])
+
+class SetFromFlat(object):
+    def __init__(self, var_list, dtype=tf.float32):
+        assigns = []
+        shapes = list(map(var_shape, var_list))
+        total_size = np.sum([intprod(shape) for shape in shapes])
+
+        self.theta = theta = tf.placeholder(dtype, [total_size])
+        start = 0
+        assigns = []
+        for (shape, v) in zip(shapes, var_list):
+            size = intprod(shape)
+            assigns.append(tf.assign(v, tf.reshape(theta[start:start + size], shape)))
+            start += size
+        self.op = tf.group(*assigns)
+
+    def __call__(self, theta):
+        tf.get_default_session().run(self.op, feed_dict={self.theta: theta})
+
+class GetFlat(object):
+    def __init__(self, var_list):
+        self.op = tf.concat(axis=0, values=[tf.reshape(v, [numel(v)]) for v in var_list])
+
+    def __call__(self):
+        return tf.get_default_session().run(self.op)
+
+_PLACEHOLDER_CACHE = {}  # name -> (placeholder, dtype, shape)
+
+def get_placeholder(name, dtype, shape):
+    if name in _PLACEHOLDER_CACHE:
+        out, dtype1, shape1 = _PLACEHOLDER_CACHE[name]
+        assert dtype1 == dtype and shape1 == shape
+        return out
+    else:
+        out = tf.placeholder(dtype=dtype, shape=shape, name=name)
+        _PLACEHOLDER_CACHE[name] = (out, dtype, shape)
+        return out
+
+def get_placeholder_cached(name):
+    return _PLACEHOLDER_CACHE[name][0]
+
+def flattenallbut0(x):
+    return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])])
+
+
+# ================================================================
+# Diagnostics 
+# ================================================================
+
+def display_var_info(vars):
+    from baselines import logger
+    count_params = 0
+    for v in vars:
+        name = v.name
+        if "/Adam" in name or "beta1_power" in name or "beta2_power" in name: continue
+        v_params = np.prod(v.shape.as_list())
+        count_params += v_params
+        if "/b:" in name or "/biases" in name: continue    # Wx+b, bias is not interesting to look at => count params, but not print
+        logger.info("   %s%s %i params %s" % (name, " "*(55-len(name)), v_params, str(v.shape)))
+
+    logger.info("Total model parameters: %0.2f million" % (count_params*1e-6))