a b/submission/baselines/common/input.py
1
import tensorflow as tf
2
from gym.spaces import Discrete, Box
3
4
def observation_input(ob_space, batch_size=None, name='Ob'):
5
    '''
6
    Build observation input with encoding depending on the 
7
    observation space type
8
    Params:
9
    
10
    ob_space: observation space (should be one of gym.spaces)
11
    batch_size: batch size for input (default is None, so that resulting input placeholder can take tensors with any batch size)
12
    name: tensorflow variable name for input placeholder
13
14
    returns: tuple (input_placeholder, processed_input_tensor)
15
    '''
16
    if isinstance(ob_space, Discrete):
17
        input_x  = tf.placeholder(shape=(batch_size,), dtype=tf.int32, name=name)
18
        processed_x = tf.to_float(tf.one_hot(input_x, ob_space.n))
19
        return input_x, processed_x
20
21
    elif isinstance(ob_space, Box):
22
        input_shape = (batch_size,) + ob_space.shape
23
        input_x = tf.placeholder(shape=input_shape, dtype=ob_space.dtype, name=name)
24
        processed_x = tf.to_float(input_x)
25
        return input_x, processed_x
26
27
    else:
28
        raise NotImplementedError
29
30