a b/scvae/distributions/zero_inflated.py
1
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#         http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
# ==========================================================================
15
16
# ======================================================================== #
17
#
18
# Copyright (c) 2017 - 2020 scVAE authors
19
#
20
# Licensed under the Apache License, Version 2.0 (the "License");
21
# you may not use this file except in compliance with the License.
22
# You may obtain a copy of the License at
23
#
24
#        http://www.apache.org/licenses/LICENSE-2.0
25
#
26
# Unless required by applicable law or agreed to in writing, software
27
# distributed under the License is distributed on an "AS IS" BASIS,
28
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29
# See the License for the specific language governing permissions and
30
# limitations under the License.
31
#
32
# ======================================================================== #
33
34
"""The ZeroInflated distribution class."""
35
36
from __future__ import absolute_import
37
from __future__ import division
38
from __future__ import print_function
39
40
from tensorflow import where
41
from tensorflow.python.framework import ops
42
from tensorflow.python.ops import array_ops
43
from tensorflow.python.ops import check_ops
44
from tensorflow.python.ops import math_ops
45
from tensorflow_probability.python.distributions import distribution
46
from tensorflow_probability.python.internal import reparameterization
47
48
49
class ZeroInflated(distribution.Distribution):
50
    """zero-inflated distribution.
51
52
    The `zero-inflated` object implements batched zero-inflated distributions.
53
    The zero-inflated model is defined by a zero-inflation rate
54
    and a python list of `Distribution` objects.
55
56
    Methods supported include `log_prob`, `prob`, `mean`, `sample`, and
57
    `entropy_lower_bound`.
58
    """
59
60
    def __init__(self,
61
                 dist,
62
                 pi,
63
                 validate_args=False,
64
                 allow_nan_stats=True,
65
                 name="ZeroInflated"):
66
        """Initialise a zero-inflated distribution.
67
68
        A `ZeroInflated` is defined by a zero-inflation rate (`pi`,
69
        representing the probabilities of excess zeroes) and a `Distribution`
70
        object having matching dtype, batch shape, event shape, and continuity
71
        properties (the dist).
72
73
        Args:
74
            pi: A zero-inflation rate, representing the probabilities of excess
75
                zeroes.
76
            dist: A `Distribution` instance.
77
                The instance must have `batch_shape` matching the
78
                zero-inflation rate.
79
            validate_args: Python `bool`, default `False`. If `True`, raise a
80
                runtime error if batch or event ranks are inconsistent between
81
                pi and any of the distributions. This is only checked if the
82
                ranks cannot be determined statically at graph construction
83
                time.
84
            allow_nan_stats: Boolean, default `True`. If `False`, raise an
85
                 exception if a statistic (e.g. mean/mode/etc...) is undefined
86
                for any batch member. If `True`, batch members with valid
87
                parameters leading to undefined statistics will return NaN for
88
                this statistic.
89
            name: A name for this distribution (optional).
90
91
        Raises:
92
            TypeError: If pi is not a zero-inflation rate, or `dist` is not
93
                `Distibution` are not instances of `Distribution`, or do not
94
                 have matching `dtype`.
95
            ValueError: If `dist` is an empty list or tuple, or its
96
                elements do not have a statically known event rank.
97
                If `pi.num_classes` cannot be inferred at graph creation time,
98
                or the constant value of `pi.num_classes` is not equal to
99
                `len(dist)`, or all `dist` and `pi` do not have
100
                matching static batch shapes, or all dist do not
101
                have matching static event shapes.
102
        """
103
        parameters = locals()
104
        if not dist:
105
            raise ValueError("dist must be non-empty")
106
107
        if not isinstance(dist, distribution.Distribution):
108
            raise TypeError(
109
                "dist must be a Distribution instance"
110
                " but saw: %s" % dist)
111
112
        dtype = dist.dtype
113
        static_event_shape = dist.event_shape
114
        static_batch_shape = pi.get_shape()
115
116
        if static_event_shape.ndims is None:
117
            raise ValueError(
118
                "Expected to know rank(event_shape) from dist, but "
119
                "the distribution does not provide a static number of ndims")
120
121
        # Ensure that all batch and event ndims are consistent.
122
        with ops.name_scope(name, values=[pi]):
123
            with ops.control_dependencies([check_ops.assert_positive(pi)] if
124
                                          validate_args else []):
125
                pi_batch_shape = array_ops.shape(pi)
126
                pi_batch_rank = array_ops.size(pi_batch_shape)
127
                if validate_args:
128
                    dist_batch_shape = dist.batch_shape_tensor()
129
                    dist_batch_rank = array_ops.size(dist_batch_shape)
130
                    check_message = (
131
                        "dist batch shape must match pi batch shape")
132
                    self._assertions = [check_ops.assert_equal(
133
                        pi_batch_rank, dist_batch_rank, message=check_message)]
134
                    self._assertions += [
135
                        check_ops.assert_equal(
136
                            pi_batch_shape, dist_batch_shape,
137
                            message=check_message)]
138
                else:
139
                    self._assertions = []
140
141
                self._pi = pi
142
                self._dist = dist
143
                self._static_event_shape = static_event_shape
144
                self._static_batch_shape = static_batch_shape
145
146
        # We let the zero-inflated distribution access _graph_parents since its
147
        # arguably more like a baseclass.
148
        graph_parents = [self._pi]
149
        graph_parents += self._dist._graph_parents
150
151
        super(ZeroInflated, self).__init__(
152
                dtype=dtype,
153
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
154
                validate_args=validate_args,
155
                allow_nan_stats=allow_nan_stats,
156
                parameters=parameters,
157
                graph_parents=graph_parents,
158
                name=name)
159
160
    @property
161
    def pi(self):
162
        return self._pi
163
164
    @property
165
    def dist(self):
166
        return self._dist
167
168
    def _batch_shape_tensor(self):
169
        return array_ops.shape(self._pi)
170
171
    def _batch_shape(self):
172
        return self._static_batch_shape
173
174
    def _event_shape_tensor(self):
175
        return self._dist.event_shape_tensor()
176
177
    def _event_shape(self):
178
        return self._static_event_shape
179
180
    def _mean(self):
181
        with ops.control_dependencies(self._assertions):
182
            # These should all be the same shape by virtue of matching
183
            # batch_shape and event_shape.
184
            return (1-self._pi) * self._dist.mean()
185
186
    def _variance(self):
187
        with ops.control_dependencies(self._assertions):
188
            # These should all be the same shape by virtue of matching
189
            # batch_shape and event_shape.
190
            return ((1-self._pi) * (self._dist.variance()
191
                                    + math_ops.square(self._dist.mean()))
192
                    - math_ops.square(self._mean()))
193
194
    def _log_prob(self, x):
195
        with ops.control_dependencies(self._assertions):
196
            x = ops.convert_to_tensor(x, name="x")
197
            y_0 = math_ops.log(self.pi + (1 - self.pi) * self._dist.prob(x))
198
            y_1 = math_ops.log(1 - self.pi) + self._dist.log_prob(x)
199
            return where(x > 0, y_1, y_0)
200
201
    def _prob(self, x):
202
        return math_ops.exp(self._log_prob(x))