--- a +++ b/baselines/common/distributions.py @@ -0,0 +1,310 @@ +import tensorflow as tf +import numpy as np +import baselines.common.tf_util as U +from baselines.a2c.utils import fc +from tensorflow.python.ops import math_ops + +class Pd(object): + """ + A particular probability distribution + """ + def flatparam(self): + raise NotImplementedError + def mode(self): + raise NotImplementedError + def neglogp(self, x): + # Usually it's easier to define the negative logprob + raise NotImplementedError + def kl(self, other): + raise NotImplementedError + def entropy(self): + raise NotImplementedError + def sample(self): + raise NotImplementedError + def logp(self, x): + return - self.neglogp(x) + +class PdType(object): + """ + Parametrized family of probability distributions + """ + def pdclass(self): + raise NotImplementedError + def pdfromflat(self, flat): + return self.pdclass()(flat) + def pdfromlatent(self, latent_vector): + raise NotImplementedError + def param_shape(self): + raise NotImplementedError + def sample_shape(self): + raise NotImplementedError + def sample_dtype(self): + raise NotImplementedError + + def param_placeholder(self, prepend_shape, name=None): + return tf.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name) + def sample_placeholder(self, prepend_shape, name=None): + return tf.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name) + +class CategoricalPdType(PdType): + def __init__(self, ncat): + self.ncat = ncat + def pdclass(self): + return CategoricalPd + def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0): + pdparam = fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias) + return self.pdfromflat(pdparam), pdparam + + def param_shape(self): + return [self.ncat] + def sample_shape(self): + return [] + def sample_dtype(self): + return tf.int32 + + +class MultiCategoricalPdType(PdType): + def __init__(self, nvec): + self.ncats = nvec + def pdclass(self): + return MultiCategoricalPd + def pdfromflat(self, flat): + return MultiCategoricalPd(self.ncats, flat) + def param_shape(self): + return [sum(self.ncats)] + def sample_shape(self): + return [len(self.ncats)] + def sample_dtype(self): + return tf.int32 + +class DiagGaussianPdType(PdType): + def __init__(self, size): + self.size = size + def pdclass(self): + return DiagGaussianPd + + def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0): + mean = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias) + logstd = tf.get_variable(name='logstd', shape=[1, self.size], initializer=tf.zeros_initializer()) + pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1) + return self.pdfromflat(pdparam), mean + + def param_shape(self): + return [2*self.size] + def sample_shape(self): + return [self.size] + def sample_dtype(self): + return tf.float32 + +class BernoulliPdType(PdType): + def __init__(self, size): + self.size = size + def pdclass(self): + return BernoulliPd + def param_shape(self): + return [self.size] + def sample_shape(self): + return [self.size] + def sample_dtype(self): + return tf.int32 + +# WRONG SECOND DERIVATIVES +# class CategoricalPd(Pd): +# def __init__(self, logits): +# self.logits = logits +# self.ps = tf.nn.softmax(logits) +# @classmethod +# def fromflat(cls, flat): +# return cls(flat) +# def flatparam(self): +# return self.logits +# def mode(self): +# return U.argmax(self.logits, axis=-1) +# def logp(self, x): +# return -tf.nn.sparse_softmax_cross_entropy_with_logits(self.logits, x) +# def kl(self, other): +# return tf.nn.softmax_cross_entropy_with_logits(other.logits, self.ps) \ +# - tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps) +# def entropy(self): +# return tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps) +# def sample(self): +# u = tf.random_uniform(tf.shape(self.logits)) +# return U.argmax(self.logits - tf.log(-tf.log(u)), axis=-1) + +class CategoricalPd(Pd): + def __init__(self, logits): + self.logits = logits + def flatparam(self): + return self.logits + def mode(self): + return tf.argmax(self.logits, axis=-1) + def neglogp(self, x): + # return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x) + # Note: we can't use sparse_softmax_cross_entropy_with_logits because + # the implementation does not allow second-order derivatives... + one_hot_actions = tf.one_hot(x, self.logits.get_shape().as_list()[-1]) + return tf.nn.softmax_cross_entropy_with_logits( + logits=self.logits, + labels=one_hot_actions) + def kl(self, other): + a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keep_dims=True) + a1 = other.logits - tf.reduce_max(other.logits, axis=-1, keep_dims=True) + ea0 = tf.exp(a0) + ea1 = tf.exp(a1) + z0 = tf.reduce_sum(ea0, axis=-1, keep_dims=True) + z1 = tf.reduce_sum(ea1, axis=-1, keep_dims=True) + p0 = ea0 / z0 + return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=-1) + def entropy(self): + a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keep_dims=True) + ea0 = tf.exp(a0) + z0 = tf.reduce_sum(ea0, axis=-1, keep_dims=True) + p0 = ea0 / z0 + return tf.reduce_sum(p0 * (tf.log(z0) - a0), axis=-1) + def sample(self): + u = tf.random_uniform(tf.shape(self.logits)) + return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1) + @classmethod + def fromflat(cls, flat): + return cls(flat) + +class MultiCategoricalPd(Pd): + def __init__(self, nvec, flat): + self.flat = flat + self.categoricals = list(map(CategoricalPd, tf.split(flat, nvec, axis=-1))) + def flatparam(self): + return self.flat + def mode(self): + return tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32) + def neglogp(self, x): + return tf.add_n([p.neglogp(px) for p, px in zip(self.categoricals, tf.unstack(x, axis=-1))]) + def kl(self, other): + return tf.add_n([p.kl(q) for p, q in zip(self.categoricals, other.categoricals)]) + def entropy(self): + return tf.add_n([p.entropy() for p in self.categoricals]) + def sample(self): + return tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32) + @classmethod + def fromflat(cls, flat): + raise NotImplementedError + +class DiagGaussianPd(Pd): + def __init__(self, flat): + self.flat = flat + mean, logstd = tf.split(axis=len(flat.shape)-1, num_or_size_splits=2, value=flat) + self.mean = mean + self.logstd = logstd + self.std = tf.exp(logstd) + def flatparam(self): + return self.flat + def mode(self): + return self.mean + def neglogp(self, x): + return 0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std), axis=-1) \ + + 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[-1]) \ + + tf.reduce_sum(self.logstd, axis=-1) + def kl(self, other): + assert isinstance(other, DiagGaussianPd) + return tf.reduce_sum(other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, axis=-1) + def entropy(self): + return tf.reduce_sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), axis=-1) + def sample(self): + # if don't need noise, comment the random_normal + return self.mean + tf.random_normal(tf.shape(self.mean), stddev=self.std) + @classmethod + def fromflat(cls, flat): + return cls(flat) + +class BernoulliPd(Pd): + def __init__(self, logits): + self.logits = logits + self.ps = tf.sigmoid(logits) + def flatparam(self): + return self.logits + def mode(self): + return tf.round(self.ps) + def neglogp(self, x): + return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.to_float(x)), axis=-1) + def kl(self, other): + return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=-1) - tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1) + def entropy(self): + return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1) + def sample(self): + u = tf.random_uniform(tf.shape(self.ps)) + return tf.to_float(math_ops.less(u, self.ps)) + @classmethod + def fromflat(cls, flat): + return cls(flat) + +def make_pdtype(ac_space): + from gym import spaces + if isinstance(ac_space, spaces.Box): + assert len(ac_space.shape) == 1 + return DiagGaussianPdType(ac_space.shape[0]) + elif isinstance(ac_space, spaces.Discrete): + return CategoricalPdType(ac_space.n) + elif isinstance(ac_space, spaces.MultiDiscrete): + return MultiCategoricalPdType(ac_space.nvec) + elif isinstance(ac_space, spaces.MultiBinary): + return BernoulliPdType(ac_space.n) + else: + raise NotImplementedError + +def shape_el(v, i): + maybe = v.get_shape()[i] + if maybe is not None: + return maybe + else: + return tf.shape(v)[i] + +@U.in_session +def test_probtypes(): + np.random.seed(0) + + pdparam_diag_gauss = np.array([-.2, .3, .4, -.5, .1, -.5, .1, 0.8]) + diag_gauss = DiagGaussianPdType(pdparam_diag_gauss.size // 2) #pylint: disable=E1101 + validate_probtype(diag_gauss, pdparam_diag_gauss) + + pdparam_categorical = np.array([-.2, .3, .5]) + categorical = CategoricalPdType(pdparam_categorical.size) #pylint: disable=E1101 + validate_probtype(categorical, pdparam_categorical) + + nvec = [1,2,3] + pdparam_multicategorical = np.array([-.2, .3, .5, .1, 1, -.1]) + multicategorical = MultiCategoricalPdType(nvec) #pylint: disable=E1101 + validate_probtype(multicategorical, pdparam_multicategorical) + + pdparam_bernoulli = np.array([-.2, .3, .5]) + bernoulli = BernoulliPdType(pdparam_bernoulli.size) #pylint: disable=E1101 + validate_probtype(bernoulli, pdparam_bernoulli) + + +def validate_probtype(probtype, pdparam): + N = 100000 + # Check to see if mean negative log likelihood == differential entropy + Mval = np.repeat(pdparam[None, :], N, axis=0) + M = probtype.param_placeholder([N]) + X = probtype.sample_placeholder([N]) + pd = probtype.pdfromflat(M) + calcloglik = U.function([X, M], pd.logp(X)) + calcent = U.function([M], pd.entropy()) + Xval = tf.get_default_session().run(pd.sample(), feed_dict={M:Mval}) + logliks = calcloglik(Xval, Mval) + entval_ll = - logliks.mean() #pylint: disable=E1101 + entval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101 + entval = calcent(Mval).mean() #pylint: disable=E1101 + assert np.abs(entval - entval_ll) < 3 * entval_ll_stderr # within 3 sigmas + + # Check to see if kldiv[p,q] = - ent[p] - E_p[log q] + M2 = probtype.param_placeholder([N]) + pd2 = probtype.pdfromflat(M2) + q = pdparam + np.random.randn(pdparam.size) * 0.1 + Mval2 = np.repeat(q[None, :], N, axis=0) + calckl = U.function([M, M2], pd.kl(pd2)) + klval = calckl(Mval, Mval2).mean() #pylint: disable=E1101 + logliks = calcloglik(Xval, Mval2) + klval_ll = - entval - logliks.mean() #pylint: disable=E1101 + klval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101 + assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas + print('ok on', probtype, pdparam) +