--- a +++ b/submission/baselines/common/input.py @@ -0,0 +1,30 @@ +import tensorflow as tf +from gym.spaces import Discrete, Box + +def observation_input(ob_space, batch_size=None, name='Ob'): + ''' + Build observation input with encoding depending on the + observation space type + Params: + + ob_space: observation space (should be one of gym.spaces) + batch_size: batch size for input (default is None, so that resulting input placeholder can take tensors with any batch size) + name: tensorflow variable name for input placeholder + + returns: tuple (input_placeholder, processed_input_tensor) + ''' + if isinstance(ob_space, Discrete): + input_x = tf.placeholder(shape=(batch_size,), dtype=tf.int32, name=name) + processed_x = tf.to_float(tf.one_hot(input_x, ob_space.n)) + return input_x, processed_x + + elif isinstance(ob_space, Box): + input_shape = (batch_size,) + ob_space.shape + input_x = tf.placeholder(shape=input_shape, dtype=ob_space.dtype, name=name) + processed_x = tf.to_float(input_x) + return input_x, processed_x + + else: + raise NotImplementedError + +