|
a |
|
b/submission/baselines/common/input.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
from gym.spaces import Discrete, Box |
|
|
3 |
|
|
|
4 |
def observation_input(ob_space, batch_size=None, name='Ob'): |
|
|
5 |
''' |
|
|
6 |
Build observation input with encoding depending on the |
|
|
7 |
observation space type |
|
|
8 |
Params: |
|
|
9 |
|
|
|
10 |
ob_space: observation space (should be one of gym.spaces) |
|
|
11 |
batch_size: batch size for input (default is None, so that resulting input placeholder can take tensors with any batch size) |
|
|
12 |
name: tensorflow variable name for input placeholder |
|
|
13 |
|
|
|
14 |
returns: tuple (input_placeholder, processed_input_tensor) |
|
|
15 |
''' |
|
|
16 |
if isinstance(ob_space, Discrete): |
|
|
17 |
input_x = tf.placeholder(shape=(batch_size,), dtype=tf.int32, name=name) |
|
|
18 |
processed_x = tf.to_float(tf.one_hot(input_x, ob_space.n)) |
|
|
19 |
return input_x, processed_x |
|
|
20 |
|
|
|
21 |
elif isinstance(ob_space, Box): |
|
|
22 |
input_shape = (batch_size,) + ob_space.shape |
|
|
23 |
input_x = tf.placeholder(shape=input_shape, dtype=ob_space.dtype, name=name) |
|
|
24 |
processed_x = tf.to_float(input_x) |
|
|
25 |
return input_x, processed_x |
|
|
26 |
|
|
|
27 |
else: |
|
|
28 |
raise NotImplementedError |
|
|
29 |
|
|
|
30 |
|