Diff of /docs/model.html [000000] .. [2c6b19]

Switch to unified view

a b/docs/model.html
1
<!doctype html>
2
<html lang="en">
3
<head>
4
<meta charset="utf-8">
5
<meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1">
6
<meta name="generator" content="pdoc3 0.11.1">
7
<title>VITAE.model API documentation</title>
8
<meta name="description" content="">
9
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/sanitize.min.css" integrity="sha512-y1dtMcuvtTMJc1yPgEqF0ZjQbhnc/bFhyvIyVNb9Zk5mIGtqVaAB1Ttl28su8AvFMOY0EwRbAe+HCLqj6W7/KA==" crossorigin>
10
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/typography.min.css" integrity="sha512-Y1DYSb995BAfxobCkKepB1BqJJTPrOp3zPL74AWFugHHmmdcvO+C48WLrUOlhGMc0QG7AE3f7gmvvcrmX2fDoA==" crossorigin>
11
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/default.min.css" crossorigin>
12
<style>:root{--highlight-color:#fe9}.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:1.5em;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:2em 0 .50em 0}h3{font-size:1.4em;margin:1.6em 0 .7em 0}h4{margin:0;font-size:105%}h1:target,h2:target,h3:target,h4:target,h5:target,h6:target{background:var(--highlight-color);padding:.2em 0}a{color:#058;text-decoration:none;transition:color .2s ease-in-out}a:visited{color:#503}a:hover{color:#b62}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900;font-weight:bold}pre code{font-size:.8em;line-height:1.4em;padding:1em;display:block}code{background:#f3f3f3;font-family:"DejaVu Sans Mono",monospace;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}dt:target .name{background:var(--highlight-color)}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}td{padding:0 .5em}.admonition{padding:.1em 1em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style>
13
<style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul ul{padding-left:1em}.toc > ul > li{margin-top:.5em}}</style>
14
<style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style>
15
<script type="text/x-mathjax-config">MathJax.Hub.Config({ tex2jax: { inlineMath: [ ['$','$'], ["\\(","\\)"] ], processEscapes: true } });</script>
16
<script async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS_CHTML" integrity="sha256-kZafAc6mZvK3W3v1pHOcUix30OHQN6pU/NO2oFkqZVw=" crossorigin></script>
17
<script defer src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js" integrity="sha512-D9gUyxqja7hBtkWpPWGt9wfbfaMGVt9gnyCvYa+jojwwPHLCzUm5i8rpk7vD7wNee9bA35eYIjobYPaQuKS1MQ==" crossorigin></script>
18
<script>window.addEventListener('DOMContentLoaded', () => {
19
hljs.configure({languages: ['bash', 'css', 'diff', 'graphql', 'ini', 'javascript', 'json', 'plaintext', 'python', 'python-repl', 'rust', 'shell', 'sql', 'typescript', 'xml', 'yaml']});
20
hljs.highlightAll();
21
})</script>
22
</head>
23
<body>
24
<main>
25
<article id="content">
26
<header>
27
<h1 class="title">Module <code>VITAE.model</code></h1>
28
</header>
29
<section id="section-intro">
30
</section>
31
<section>
32
</section>
33
<section>
34
</section>
35
<section>
36
</section>
37
<section>
38
<h2 class="section-title" id="header-classes">Classes</h2>
39
<dl>
40
<dt id="VITAE.model.cdf_layer"><code class="flex name class">
41
<span>class <span class="ident">cdf_layer</span></span>
42
</code></dt>
43
<dd>
44
<div class="desc"><p>The Normal cdf layer with custom gradients.</p></div>
45
<details class="source">
46
<summary>
47
<span>Expand source code</span>
48
</summary>
49
<pre><code class="python">class cdf_layer(Layer):
50
    &#39;&#39;&#39;
51
    The Normal cdf layer with custom gradients.
52
    &#39;&#39;&#39;
53
    def __init__(self):
54
        &#39;&#39;&#39;
55
        &#39;&#39;&#39;
56
        super(cdf_layer, self).__init__()
57
        
58
    @tf.function
59
    def call(self, x):
60
        return self.func(x)
61
        
62
    @tf.custom_gradient
63
    def func(self, x):
64
        &#39;&#39;&#39;Return cdf(x) and pdf(x).
65
66
        Parameters
67
        ----------
68
        x : tf.Tensor
69
            The input tensor.
70
        
71
        Returns
72
        ----------
73
        f : tf.Tensor
74
            cdf(x).
75
        grad : tf.Tensor
76
            pdf(x).
77
        &#39;&#39;&#39;   
78
        dist = tfp.distributions.Normal(
79
            loc = tf.constant(0.0, tf.keras.backend.floatx()), 
80
            scale = tf.constant(1.0, tf.keras.backend.floatx()), 
81
            allow_nan_stats=False)
82
        f = dist.cdf(x)
83
        def grad(dy):
84
            gradient = dist.prob(x)
85
            return dy * gradient
86
        return f, grad</code></pre>
87
</details>
88
<h3>Ancestors</h3>
89
<ul class="hlist">
90
<li>keras.src.engine.base_layer.Layer</li>
91
<li>tensorflow.python.module.module.Module</li>
92
<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li>
93
<li>tensorflow.python.trackable.base.Trackable</li>
94
<li>keras.src.utils.version_utils.LayerVersionSelector</li>
95
</ul>
96
<h3>Methods</h3>
97
<dl>
98
<dt id="VITAE.model.cdf_layer.call"><code class="name flex">
99
<span>def <span class="ident">call</span></span>(<span>self, x)</span>
100
</code></dt>
101
<dd>
102
<div class="desc"></div>
103
</dd>
104
<dt id="VITAE.model.cdf_layer.func"><code class="name flex">
105
<span>def <span class="ident">func</span></span>(<span>self, x)</span>
106
</code></dt>
107
<dd>
108
<div class="desc"><p>Return cdf(x) and pdf(x).</p>
109
<h2 id="parameters">Parameters</h2>
110
<dl>
111
<dt><strong><code>x</code></strong> :&ensp;<code>tf.Tensor</code></dt>
112
<dd>The input tensor.</dd>
113
</dl>
114
<h2 id="returns">Returns</h2>
115
<dl>
116
<dt><strong><code>f</code></strong> :&ensp;<code>tf.Tensor</code></dt>
117
<dd>cdf(x).</dd>
118
<dt><strong><code>grad</code></strong> :&ensp;<code>tf.Tensor</code></dt>
119
<dd>pdf(x).</dd>
120
</dl></div>
121
</dd>
122
</dl>
123
</dd>
124
<dt id="VITAE.model.Sampling"><code class="flex name class">
125
<span>class <span class="ident">Sampling</span></span>
126
<span>(</span><span>seed=0, **kwargs)</span>
127
</code></dt>
128
<dd>
129
<div class="desc"><p>Sampling latent variable <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span> from <span><span class="MathJax_Preview">N(\mu_z, \log \sigma_z^2</span><script type="math/tex">N(\mu_z, \log \sigma_z^2</script></span>).
130
<br>
131
Used in Encoder.</p></div>
132
<details class="source">
133
<summary>
134
<span>Expand source code</span>
135
</summary>
136
<pre><code class="python">class Sampling(Layer):
137
    &#34;&#34;&#34;Sampling latent variable \(z\) from \(N(\\mu_z, \\log \\sigma_z^2\)).    
138
    Used in Encoder.
139
    &#34;&#34;&#34;
140
    def __init__(self, seed=0, **kwargs):
141
        super(Sampling, self).__init__(**kwargs)
142
        self.seed = seed
143
144
    @tf.function
145
    def call(self, z_mean, z_log_var):
146
        &#39;&#39;&#39;Return cdf(x) and pdf(x).
147
148
        Parameters
149
        ----------
150
        z_mean : tf.Tensor
151
            \([B, L, d]\) The mean of \(z\).
152
        z_log_var : tf.Tensor
153
            \([B, L, d]\) The log-variance of \(z\).
154
155
        Returns
156
        ----------
157
        z : tf.Tensor
158
            \([B, L, d]\) The sampled \(z\).
159
        &#39;&#39;&#39;   
160
   #     seed = tfp.util.SeedStream(self.seed, salt=&#34;random_normal&#34;)
161
   #     epsilon = tf.random.normal(shape = tf.shape(z_mean), seed=seed(), dtype=tf.keras.backend.floatx())
162
        epsilon = tf.random.normal(shape = tf.shape(z_mean), dtype=tf.keras.backend.floatx())
163
        z = z_mean + tf.exp(0.5 * z_log_var) * epsilon
164
        z = tf.clip_by_value(z, -1e6, 1e6)
165
        return z</code></pre>
166
</details>
167
<h3>Ancestors</h3>
168
<ul class="hlist">
169
<li>keras.src.engine.base_layer.Layer</li>
170
<li>tensorflow.python.module.module.Module</li>
171
<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li>
172
<li>tensorflow.python.trackable.base.Trackable</li>
173
<li>keras.src.utils.version_utils.LayerVersionSelector</li>
174
</ul>
175
<h3>Methods</h3>
176
<dl>
177
<dt id="VITAE.model.Sampling.call"><code class="name flex">
178
<span>def <span class="ident">call</span></span>(<span>self, z_mean, z_log_var)</span>
179
</code></dt>
180
<dd>
181
<div class="desc"><p>Return cdf(x) and pdf(x).</p>
182
<h2 id="parameters">Parameters</h2>
183
<dl>
184
<dt><strong><code>z_mean</code></strong> :&ensp;<code>tf.Tensor</code></dt>
185
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The mean of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
186
<dt><strong><code>z_log_var</code></strong> :&ensp;<code>tf.Tensor</code></dt>
187
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The log-variance of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
188
</dl>
189
<h2 id="returns">Returns</h2>
190
<dl>
191
<dt><strong><code>z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
192
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The sampled <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
193
</dl></div>
194
</dd>
195
</dl>
196
</dd>
197
<dt id="VITAE.model.Encoder"><code class="flex name class">
198
<span>class <span class="ident">Encoder</span></span>
199
<span>(</span><span>dimensions, dim_latent, name='encoder', **kwargs)</span>
200
</code></dt>
201
<dd>
202
<div class="desc"><p>Encoder, model <span><span class="MathJax_Preview">p(Z_i|Y_i,X_i)</span><script type="math/tex">p(Z_i|Y_i,X_i)</script></span>.</p>
203
<h2 id="parameters">Parameters</h2>
204
<dl>
205
<dt><strong><code>dimensions</code></strong> :&ensp;<code>np.array</code></dt>
206
<dd>The dimensions of hidden layers of the encoder.</dd>
207
<dt><strong><code>dim_latent</code></strong> :&ensp;<code>int</code></dt>
208
<dd>The latent dimension of the encoder.</dd>
209
<dt><strong><code>name</code></strong> :&ensp;<code>str</code>, optional</dt>
210
<dd>The name of the layer.</dd>
211
<dt><strong><code>**kwargs</code></strong></dt>
212
<dd>Extra keyword arguments.</dd>
213
</dl></div>
214
<details class="source">
215
<summary>
216
<span>Expand source code</span>
217
</summary>
218
<pre><code class="python">class Encoder(Layer):
219
    &#39;&#39;&#39;
220
    Encoder, model \(p(Z_i|Y_i,X_i)\).
221
    &#39;&#39;&#39;
222
    def __init__(self, dimensions, dim_latent, name=&#39;encoder&#39;, **kwargs):
223
        &#39;&#39;&#39;
224
        Parameters
225
        ----------
226
        dimensions : np.array
227
            The dimensions of hidden layers of the encoder.
228
        dim_latent : int
229
            The latent dimension of the encoder.
230
        name : str, optional
231
            The name of the layer.
232
        **kwargs : 
233
            Extra keyword arguments.
234
        &#39;&#39;&#39; 
235
        super(Encoder, self).__init__(name = name, **kwargs)
236
        self.dense_layers = [Dense(dim, activation = tf.nn.leaky_relu,
237
                                          name = &#39;encoder_%i&#39;%(i+1)) \
238
                             for (i, dim) in enumerate(dimensions)]
239
        self.batch_norm_layers = [BatchNormalization(center=False) \
240
                                    for _ in range(len((dimensions)))]
241
        self.batch_norm_layers.append(BatchNormalization(center=False))
242
        self.latent_mean = Dense(dim_latent, name = &#39;latent_mean&#39;)
243
        self.latent_log_var = Dense(dim_latent, name = &#39;latent_log_var&#39;)
244
        self.sampling = Sampling()
245
    
246
    @tf.function
247
    def call(self, x, L=1, is_training=True):
248
        &#39;&#39;&#39;Encode the inputs and get the latent variables.
249
250
        Parameters
251
        ----------
252
        x : tf.Tensor
253
            \([B, L, d]\) The input.
254
        L : int, optional
255
            The number of MC samples.
256
        is_training : boolean, optional
257
            Whether in the training or inference mode.
258
        
259
        Returns
260
        ----------
261
        z_mean : tf.Tensor
262
            \([B, L, d]\) The mean of \(z\).
263
        z_log_var : tf.Tensor
264
            \([B, L, d]\) The log-variance of \(z\).
265
        z : tf.Tensor
266
            \([B, L, d]\) The sampled \(z\).
267
        &#39;&#39;&#39;         
268
        for dense, bn in zip(self.dense_layers, self.batch_norm_layers):
269
            x = dense(x)
270
            x = bn(x, training=is_training)
271
        z_mean = self.batch_norm_layers[-1](self.latent_mean(x), training=is_training)
272
        z_log_var = self.latent_log_var(x)
273
        _z_mean = tf.tile(tf.expand_dims(z_mean, 1), (1,L,1))
274
        _z_log_var = tf.tile(tf.expand_dims(z_log_var, 1), (1,L,1))
275
        z = self.sampling(_z_mean, _z_log_var)
276
        return z_mean, z_log_var, z</code></pre>
277
</details>
278
<h3>Ancestors</h3>
279
<ul class="hlist">
280
<li>keras.src.engine.base_layer.Layer</li>
281
<li>tensorflow.python.module.module.Module</li>
282
<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li>
283
<li>tensorflow.python.trackable.base.Trackable</li>
284
<li>keras.src.utils.version_utils.LayerVersionSelector</li>
285
</ul>
286
<h3>Methods</h3>
287
<dl>
288
<dt id="VITAE.model.Encoder.call"><code class="name flex">
289
<span>def <span class="ident">call</span></span>(<span>self, x, L=1, is_training=True)</span>
290
</code></dt>
291
<dd>
292
<div class="desc"><p>Encode the inputs and get the latent variables.</p>
293
<h2 id="parameters">Parameters</h2>
294
<dl>
295
<dt><strong><code>x</code></strong> :&ensp;<code>tf.Tensor</code></dt>
296
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The input.</dd>
297
<dt><strong><code>L</code></strong> :&ensp;<code>int</code>, optional</dt>
298
<dd>The number of MC samples.</dd>
299
<dt><strong><code>is_training</code></strong> :&ensp;<code>boolean</code>, optional</dt>
300
<dd>Whether in the training or inference mode.</dd>
301
</dl>
302
<h2 id="returns">Returns</h2>
303
<dl>
304
<dt><strong><code>z_mean</code></strong> :&ensp;<code>tf.Tensor</code></dt>
305
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The mean of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
306
<dt><strong><code>z_log_var</code></strong> :&ensp;<code>tf.Tensor</code></dt>
307
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The log-variance of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
308
<dt><strong><code>z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
309
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The sampled <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
310
</dl></div>
311
</dd>
312
</dl>
313
</dd>
314
<dt id="VITAE.model.Decoder"><code class="flex name class">
315
<span>class <span class="ident">Decoder</span></span>
316
<span>(</span><span>dimensions, dim_origin, data_type='UMI', name='decoder', **kwargs)</span>
317
</code></dt>
318
<dd>
319
<div class="desc"><p>Decoder, model <span><span class="MathJax_Preview">p(Y_i|Z_i,X_i)</span><script type="math/tex">p(Y_i|Z_i,X_i)</script></span>.</p>
320
<h2 id="parameters">Parameters</h2>
321
<dl>
322
<dt><strong><code>dimensions</code></strong> :&ensp;<code>np.array</code></dt>
323
<dd>The dimensions of hidden layers of the encoder.</dd>
324
<dt><strong><code>dim_origin</code></strong> :&ensp;<code>int</code></dt>
325
<dd>The output dimension of the decoder.</dd>
326
<dt><strong><code>data_type</code></strong> :&ensp;<code>str</code>, optional</dt>
327
<dd><code>'UMI'</code>, <code>'non-UMI'</code>, or <code>'Gaussian'</code>.</dd>
328
<dt><strong><code>name</code></strong> :&ensp;<code>str</code>, optional</dt>
329
<dd>The name of the layer.</dd>
330
</dl></div>
331
<details class="source">
332
<summary>
333
<span>Expand source code</span>
334
</summary>
335
<pre><code class="python">class Decoder(Layer):
336
    &#39;&#39;&#39;
337
    Decoder, model \(p(Y_i|Z_i,X_i)\).
338
    &#39;&#39;&#39;
339
    def __init__(self, dimensions, dim_origin, data_type = &#39;UMI&#39;, 
340
                name = &#39;decoder&#39;, **kwargs):
341
        &#39;&#39;&#39;
342
        Parameters
343
        ----------
344
        dimensions : np.array
345
            The dimensions of hidden layers of the encoder.
346
        dim_origin : int
347
            The output dimension of the decoder.
348
        data_type : str, optional
349
            `&#39;UMI&#39;`, `&#39;non-UMI&#39;`, or `&#39;Gaussian&#39;`.
350
        name : str, optional
351
            The name of the layer.
352
        &#39;&#39;&#39;
353
        super(Decoder, self).__init__(name = name, **kwargs)
354
        self.data_type = data_type
355
        self.dense_layers = [Dense(dim, activation = tf.nn.leaky_relu,
356
                                          name = &#39;decoder_%i&#39;%(i+1)) \
357
                             for (i,dim) in enumerate(dimensions)]
358
        self.batch_norm_layers = [BatchNormalization(center=False) \
359
                                    for _ in range(len((dimensions)))]
360
361
        if data_type==&#39;Gaussian&#39;:
362
            self.nu_z = Dense(dim_origin, name = &#39;nu_z&#39;)
363
            # common variance
364
            self.log_tau = tf.Variable(tf.zeros([1, dim_origin], dtype=tf.keras.backend.floatx()),
365
                                 constraint = lambda t: tf.clip_by_value(t,-30.,6.),
366
                                 name = &#34;log_tau&#34;)
367
        else:
368
            self.log_lambda_z = Dense(dim_origin, name = &#39;log_lambda_z&#39;)
369
370
            # dispersion parameter
371
            self.log_r = tf.Variable(tf.zeros([1, dim_origin], dtype=tf.keras.backend.floatx()),
372
                                     constraint = lambda t: tf.clip_by_value(t,-30.,6.),
373
                                     name = &#34;log_r&#34;)
374
            
375
            if self.data_type == &#39;non-UMI&#39;:
376
                self.phi = Dense(dim_origin, activation = &#39;sigmoid&#39;, name = &#34;phi&#34;)
377
          
378
    @tf.function  
379
    def call(self, z, is_training=True):
380
        &#39;&#39;&#39;Decode the latent variables and get the reconstructions.
381
382
        Parameters
383
        ----------
384
        z : tf.Tensor
385
            \([B, L, d]\) the sampled \(z\).
386
        is_training : boolean, optional
387
            whether in the training or inference mode.
388
389
        When `data_type==&#39;Gaussian&#39;`:
390
391
        Returns
392
        ----------
393
        nu_z : tf.Tensor
394
            \([B, L, G]\) The mean of \(Y_i|Z_i,X_i\).
395
        tau : tf.Tensor
396
            \([1, G]\) The variance of \(Y_i|Z_i,X_i\).
397
398
        When `data_type==&#39;UMI&#39;`:
399
400
        Returns
401
        ----------
402
        lambda_z : tf.Tensor
403
            \([B, L, G]\) The mean of \(Y_i|Z_i,X_i\).
404
        r : tf.Tensor
405
            \([1, G]\) The dispersion parameters of \(Y_i|Z_i,X_i\).
406
407
        When `data_type==&#39;non-UMI&#39;`:
408
409
        Returns
410
        ----------
411
        lambda_z : tf.Tensor
412
            \([B, L, G]\) The mean of \(Y_i|Z_i,X_i\).
413
        r : tf.Tensor
414
            \([1, G]\) The dispersion parameters of \(Y_i|Z_i,X_i\).
415
        phi_z : tf.Tensor
416
            \([1, G]\) The zero inflated parameters of \(Y_i|Z_i,X_i\).
417
        &#39;&#39;&#39;
418
        for dense, bn in zip(self.dense_layers, self.batch_norm_layers):
419
            z = dense(z)
420
            z = bn(z, training=is_training)
421
        if self.data_type==&#39;Gaussian&#39;:
422
            nu_z = self.nu_z(z)
423
            tau = tf.exp(self.log_tau)
424
            return nu_z, tau
425
        else:
426
            lambda_z = tf.math.exp(
427
                tf.clip_by_value(self.log_lambda_z(z), -30., 6.)
428
                )
429
            r = tf.exp(self.log_r)
430
            if self.data_type==&#39;UMI&#39;:
431
                return lambda_z, r
432
            else:
433
                return lambda_z, r, self.phi(z)</code></pre>
434
</details>
435
<h3>Ancestors</h3>
436
<ul class="hlist">
437
<li>keras.src.engine.base_layer.Layer</li>
438
<li>tensorflow.python.module.module.Module</li>
439
<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li>
440
<li>tensorflow.python.trackable.base.Trackable</li>
441
<li>keras.src.utils.version_utils.LayerVersionSelector</li>
442
</ul>
443
<h3>Methods</h3>
444
<dl>
445
<dt id="VITAE.model.Decoder.call"><code class="name flex">
446
<span>def <span class="ident">call</span></span>(<span>self, z, is_training=True)</span>
447
</code></dt>
448
<dd>
449
<div class="desc"><p>Decode the latent variables and get the reconstructions.</p>
450
<h2 id="parameters">Parameters</h2>
451
<dl>
452
<dt><strong><code>z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
453
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> the sampled <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
454
<dt><strong><code>is_training</code></strong> :&ensp;<code>boolean</code>, optional</dt>
455
<dd>whether in the training or inference mode.</dd>
456
</dl>
457
<p>When <code>data_type=='Gaussian'</code>:</p>
458
<h2 id="returns">Returns</h2>
459
<dl>
460
<dt><strong><code>nu_z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
461
<dd><span><span class="MathJax_Preview">[B, L, G]</span><script type="math/tex">[B, L, G]</script></span> The mean of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
462
<dt><strong><code>tau</code></strong> :&ensp;<code>tf.Tensor</code></dt>
463
<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The variance of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
464
</dl>
465
<p>When <code>data_type=='UMI'</code>:</p>
466
<h2 id="returns_1">Returns</h2>
467
<dl>
468
<dt><strong><code>lambda_z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
469
<dd><span><span class="MathJax_Preview">[B, L, G]</span><script type="math/tex">[B, L, G]</script></span> The mean of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
470
<dt><strong><code>r</code></strong> :&ensp;<code>tf.Tensor</code></dt>
471
<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The dispersion parameters of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
472
</dl>
473
<p>When <code>data_type=='non-UMI'</code>:</p>
474
<h2 id="returns_2">Returns</h2>
475
<dl>
476
<dt><strong><code>lambda_z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
477
<dd><span><span class="MathJax_Preview">[B, L, G]</span><script type="math/tex">[B, L, G]</script></span> The mean of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
478
<dt><strong><code>r</code></strong> :&ensp;<code>tf.Tensor</code></dt>
479
<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The dispersion parameters of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
480
<dt><strong><code>phi_z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
481
<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The zero inflated parameters of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
482
</dl></div>
483
</dd>
484
</dl>
485
</dd>
486
<dt id="VITAE.model.LatentSpace"><code class="flex name class">
487
<span>class <span class="ident">LatentSpace</span></span>
488
<span>(</span><span>n_clusters, dim_latent, name='LatentSpace', seed=0, **kwargs)</span>
489
</code></dt>
490
<dd>
491
<div class="desc"><p>Layer for the Latent Space.</p>
492
<h2 id="parameters">Parameters</h2>
493
<dl>
494
<dt><strong><code>n_clusters</code></strong> :&ensp;<code>int</code></dt>
495
<dd>The number of vertices in the latent space.</dd>
496
<dt><strong><code>dim_latent</code></strong> :&ensp;<code>int</code></dt>
497
<dd>The latent dimension.</dd>
498
<dt><strong><code>M</code></strong> :&ensp;<code>int</code>, optional</dt>
499
<dd>The discretized number of uniform(0,1).</dd>
500
<dt><strong><code>name</code></strong> :&ensp;<code>str</code>, optional</dt>
501
<dd>The name of the layer.</dd>
502
<dt><strong><code>**kwargs</code></strong></dt>
503
<dd>Extra keyword arguments.</dd>
504
</dl></div>
505
<details class="source">
506
<summary>
507
<span>Expand source code</span>
508
</summary>
509
<pre><code class="python">class LatentSpace(Layer):
510
    &#39;&#39;&#39;
511
    Layer for the Latent Space.
512
    &#39;&#39;&#39;
513
    def __init__(self, n_clusters, dim_latent,
514
            name = &#39;LatentSpace&#39;, seed=0, **kwargs):
515
        &#39;&#39;&#39;
516
        Parameters
517
        ----------
518
        n_clusters : int
519
            The number of vertices in the latent space.
520
        dim_latent : int
521
            The latent dimension.
522
        M : int, optional
523
            The discretized number of uniform(0,1).
524
        name : str, optional
525
            The name of the layer.
526
        **kwargs : 
527
            Extra keyword arguments.
528
        &#39;&#39;&#39;
529
        super(LatentSpace, self).__init__(name=name, **kwargs)
530
        self.dim_latent = dim_latent
531
        self.n_states = n_clusters
532
        self.n_categories = int(n_clusters*(n_clusters+1)/2)
533
534
        # nonzero indexes
535
        # A = [0,0,...,0  , 1,1,...,1,   ...]
536
        # B = [0,1,...,k-1, 1,2,...,k-1, ...]
537
        self.A, self.B = np.nonzero(np.triu(np.ones(n_clusters)))
538
        self.A = tf.convert_to_tensor(self.A, tf.int32)
539
        self.B = tf.convert_to_tensor(self.B, tf.int32)
540
        self.clusters_ind = tf.boolean_mask(
541
            tf.range(0,self.n_categories,1), self.A==self.B)
542
543
        # [pi_1, ... , pi_K] in R^(n_categories)
544
        self.pi = tf.Variable(tf.ones([1, self.n_categories], dtype=tf.keras.backend.floatx()) / self.n_categories,
545
                                name = &#39;pi&#39;)
546
        
547
        # [mu_1, ... , mu_K] in R^(dim_latent * n_clusters)
548
        self.mu = tf.Variable(tf.random.uniform([self.dim_latent, self.n_states],
549
                                                minval = -1, maxval = 1, seed=seed, dtype=tf.keras.backend.floatx()),
550
                                name = &#39;mu&#39;)
551
        self.cdf_layer = cdf_layer()       
552
        
553
    def initialize(self, mu, log_pi):
554
        &#39;&#39;&#39;Initialize the latent space.
555
556
        Parameters
557
        ----------
558
        mu : np.array
559
            \([d, k]\) The position matrix.
560
        log_pi : np.array
561
            \([1, K]\) \(\\log\\pi\).
562
        &#39;&#39;&#39;
563
        # Initialize parameters of the latent space
564
        if mu is not None:
565
            self.mu.assign(mu)
566
        if log_pi is not None:
567
            self.pi.assign(log_pi)
568
569
    def normalize(self):
570
        &#39;&#39;&#39;Normalize \(\\pi\).
571
        &#39;&#39;&#39;
572
        self.pi = tf.nn.softmax(self.pi)
573
574
    @tf.function
575
    def _get_normal_params(self, z, pi):
576
        batch_size = tf.shape(z)[0]
577
        L = tf.shape(z)[1]
578
        
579
        # [batch_size, L, n_categories]
580
        if pi is None:
581
            # [batch_size, L, n_states]
582
            temp_pi = tf.tile(
583
                tf.expand_dims(tf.nn.softmax(self.pi), 1),
584
                (batch_size,L,1))
585
        else:
586
            temp_pi = tf.expand_dims(tf.nn.softmax(pi), 1)
587
588
        # [batch_size, L, d, n_categories]
589
        alpha_zc = tf.expand_dims(tf.expand_dims(
590
            tf.gather(self.mu, self.B, axis=1) - tf.gather(self.mu, self.A, axis=1), 0), 0)
591
        beta_zc = tf.expand_dims(z,-1) - \
592
            tf.expand_dims(tf.expand_dims(
593
            tf.gather(self.mu, self.B, axis=1), 0), 0)
594
            
595
        # [batch_size, L, n_categories]
596
        inv_sig = tf.reduce_sum(alpha_zc * alpha_zc, axis=2)
597
        nu = - tf.reduce_sum(alpha_zc * beta_zc, axis=2) * tf.math.reciprocal_no_nan(inv_sig)
598
        _t = - tf.reduce_sum(beta_zc * beta_zc, axis=2) + nu**2*inv_sig
599
        return temp_pi, beta_zc, inv_sig, nu, _t
600
    
601
    @tf.function
602
    def _get_pz(self, temp_pi, inv_sig, beta_zc, log_p_z_c_L):
603
        # [batch_size, L, n_categories]
604
        log_p_zc_L = - 0.5 * self.dim_latent * tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + \
605
            tf.math.log(temp_pi) + \
606
            tf.where(inv_sig==0, 
607
                    - 0.5 * tf.reduce_sum(beta_zc**2, axis=2), 
608
                    log_p_z_c_L)
609
        
610
        # [batch_size, L, 1]
611
        log_p_z_L = tf.reduce_logsumexp(log_p_zc_L, axis=-1, keepdims=True)
612
        
613
        # [1, ]
614
        log_p_z = tf.reduce_mean(log_p_z_L)
615
        return log_p_zc_L, log_p_z_L, log_p_z
616
    
617
    @tf.function
618
    def _get_posterior_c(self, log_p_zc_L, log_p_z_L):
619
        L = tf.shape(log_p_z_L)[1]
620
621
        # log_p_c_x     -   predicted probability distribution
622
        # [batch_size, n_categories]
623
        log_p_c_x = tf.reduce_logsumexp(
624
                        log_p_zc_L - log_p_z_L,
625
                    axis=1) - tf.math.log(tf.cast(L, tf.keras.backend.floatx()))
626
        return log_p_c_x
627
628
    @tf.function
629
    def _get_inference(self, z, log_p_z_L, temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2):
630
        batch_size = tf.shape(z)[0]
631
        L = tf.shape(z)[1]
632
        dist = tfp.distributions.Normal(
633
            loc = tf.constant(0.0, tf.keras.backend.floatx()), 
634
            scale = tf.constant(1.0, tf.keras.backend.floatx()), 
635
            allow_nan_stats=False)
636
        
637
        # [batch_size, L, n_categories, n_clusters]
638
        inv_sig = tf.expand_dims(inv_sig, -1)
639
        _sig = tf.tile(tf.clip_by_value(tf.math.reciprocal_no_nan(inv_sig), 1e-12, 1e30), (1,1,1,self.n_states))
640
        log_eta0 = tf.tile(tf.expand_dims(log_eta0, -1), (1,1,1,self.n_states))
641
        eta1 = tf.tile(tf.expand_dims(eta1, -1), (1,1,1,self.n_states))
642
        eta2 = tf.tile(tf.expand_dims(eta2, -1), (1,1,1,self.n_states))
643
        nu = tf.tile(tf.expand_dims(nu, -1), (1,1,1,1))
644
        A = tf.tile(tf.expand_dims(tf.expand_dims(
645
            tf.one_hot(self.A, self.n_states, dtype=tf.keras.backend.floatx()), 
646
            0),0), (batch_size,L,1,1))
647
        B = tf.tile(tf.expand_dims(tf.expand_dims(
648
            tf.one_hot(self.B, self.n_states, dtype=tf.keras.backend.floatx()), 
649
            0),0), (batch_size,L,1,1))
650
        temp_pi = tf.expand_dims(temp_pi, -1)
651
652
        # w_tilde [batch_size, L, n_clusters]
653
        w_tilde = log_eta0 + tf.math.log(
654
            tf.clip_by_value(
655
                (dist.cdf(eta1) - dist.cdf(eta2)) * (nu * A + (1-nu) * B)  -
656
                (dist.prob(eta1) - dist.prob(eta2)) * tf.math.sqrt(_sig) * (A - B), 
657
                0.0, 1e30)
658
            )
659
        w_tilde = - 0.5 * self.dim_latent * tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + \
660
            tf.math.log(temp_pi) + \
661
            tf.where(inv_sig==0, 
662
                    tf.where(B==1, - 0.5 * tf.expand_dims(tf.reduce_sum(beta_zc**2, axis=2), -1), -np.inf), 
663
                    w_tilde)
664
        w_tilde = tf.exp(tf.reduce_logsumexp(w_tilde, 2) - log_p_z_L)
665
666
        # tf.debugging.assert_greater_equal(
667
        #     tf.reduce_sum(w_tilde, -1), tf.ones([batch_size, L], dtype=tf.keras.backend.floatx())*0.99, 
668
        #     message=&#39;Wrong w_tilde&#39;, summarize=None, name=None
669
        # )
670
        
671
        # var_w_tilde [batch_size, L, n_clusters]
672
        var_w_tilde = log_eta0 + tf.math.log(
673
            tf.clip_by_value(
674
                (dist.cdf(eta1) -  dist.cdf(eta2)) * ((_sig + nu**2) * (A+B) + (1-2*nu) * B)  -
675
                (dist.prob(eta1) - dist.prob(eta2)) * tf.math.sqrt(_sig) * (nu *(A+B)-B )*2 -
676
                (eta1*dist.prob(eta1) - eta2*dist.prob(eta2)) * _sig *(A+B), 
677
                0.0, 1e30)
678
            )
679
        var_w_tilde = - 0.5 * self.dim_latent * tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + \
680
            tf.math.log(temp_pi) + \
681
            tf.where(inv_sig==0, 
682
                    tf.where(B==1, - 0.5 * tf.expand_dims(tf.reduce_sum(beta_zc**2, axis=2), -1), -np.inf), 
683
                    var_w_tilde) 
684
        var_w_tilde = tf.exp(tf.reduce_logsumexp(var_w_tilde, 2) - log_p_z_L) - w_tilde**2  
685
686
687
        w_tilde = tf.reduce_mean(w_tilde, 1)
688
        var_w_tilde = tf.reduce_mean(var_w_tilde, 1)
689
        return w_tilde, var_w_tilde
690
691
    def get_pz(self, z, eps, pi):
692
        &#39;&#39;&#39;Get \(\\log p(Z_i|Y_i,X_i)\).
693
694
        Parameters
695
        ----------
696
        z : tf.Tensor
697
            \([B, L, d]\) The latent variables.
698
699
        Returns
700
        ----------
701
        temp_pi : tf.Tensor
702
            \([B, L, K]\) \(\\pi\).
703
        inv_sig : tf.Tensor
704
            \([B, L, K]\) \(\\sigma_{Z_ic_i}^{-1}\).
705
        nu : tf.Tensor
706
            \([B, L, K]\) \(\\nu_{Z_ic_i}\).
707
        beta_zc : tf.Tensor
708
            \([B, L, d, K]\) \(\\beta_{Z_ic_i}\).
709
        log_eta0 : tf.Tensor
710
            \([B, L, K]\) \(\\log\\eta_{Z_ic_i,0}\).
711
        eta1 : tf.Tensor
712
            \([B, L, K]\) \(\\eta_{Z_ic_i,1}\).
713
        eta2 : tf.Tensor
714
            \([B, L, K]\) \(\\eta_{Z_ic_i,2}\).
715
        log_p_zc_L : tf.Tensor
716
            \([B, L, K]\) \(\\log p(Z_i,c_i|Y_i,X_i)\).
717
        log_p_z_L : tf.Tensor
718
            \([B, L]\) \(\\log p(Z_i|Y_i,X_i)\).
719
        log_p_z : tf.Tensor
720
            \([B, 1]\) The estimated \(\\log p(Z_i|Y_i,X_i)\). 
721
        &#39;&#39;&#39;        
722
        temp_pi, beta_zc, inv_sig, nu, _t = self._get_normal_params(z, pi)
723
        temp_pi = tf.clip_by_value(temp_pi, eps, 1.0)
724
725
        log_eta0 = 0.5 * (tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) - \
726
                    tf.math.log(tf.clip_by_value(inv_sig, 1e-12, 1e30)) + _t)
727
        eta1 = (1-nu) * tf.math.sqrt(tf.clip_by_value(inv_sig, 1e-12, 1e30))
728
        eta2 = -nu * tf.math.sqrt(tf.clip_by_value(inv_sig, 1e-12, 1e30))
729
730
        log_p_z_c_L =  log_eta0 + tf.math.log(tf.clip_by_value(
731
            self.cdf_layer(eta1) - self.cdf_layer(eta2),
732
            eps, 1e30))
733
        
734
        log_p_zc_L, log_p_z_L, log_p_z = self._get_pz(temp_pi, inv_sig, beta_zc, log_p_z_c_L)
735
        return temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2, log_p_zc_L, log_p_z_L, log_p_z
736
737
    def get_posterior_c(self, z):
738
        &#39;&#39;&#39;Get \(p(c_i|Y_i,X_i)\).
739
740
        Parameters
741
        ----------
742
        z : tf.Tensor
743
            \([B, L, d]\) The latent variables.
744
745
        Returns
746
        ----------
747
        p_c_x : np.array
748
            \([B, K]\) \(p(c_i|Y_i,X_i)\).
749
        &#39;&#39;&#39;  
750
        _,_,_,_,_,_,_, log_p_zc_L, log_p_z_L, _ = self.get_pz(z)
751
        log_p_c_x = self._get_posterior_c(log_p_zc_L, log_p_z_L)
752
        p_c_x = tf.exp(log_p_c_x).numpy()
753
        return p_c_x
754
755
    def call(self, z, pi=None, inference=False):
756
        &#39;&#39;&#39;Get posterior estimations.
757
758
        Parameters
759
        ----------
760
        z : tf.Tensor
761
            \([B, L, d]\) The latent variables.
762
        inference : boolean
763
            Whether in training or inference mode.
764
765
        When `inference=False`:
766
767
        Returns
768
        ----------
769
        log_p_z_L : tf.Tensor
770
            \([B, 1]\) The estimated \(\\log p(Z_i|Y_i,X_i)\).
771
772
        When `inference=True`:
773
774
        Returns
775
        ----------
776
        res : dict
777
            The dict of posterior estimations - \(p(c_i|Y_i,X_i)\), \(c\), \(E(\\tilde{w}_i|Y_i,X_i)\), \(Var(\\tilde{w}_i|Y_i,X_i)\), \(D_{JS}\).
778
        &#39;&#39;&#39;                 
779
        eps = 1e-16 if not inference else 0.
780
        temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2, log_p_zc_L, log_p_z_L, log_p_z = self.get_pz(z, eps, pi)
781
782
        if not inference:
783
            return log_p_z
784
        else:
785
            log_p_c_x = self._get_posterior_c(log_p_zc_L, log_p_z_L)
786
            w_tilde, var_w_tilde = self._get_inference(z, log_p_z_L, temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2)
787
            
788
            res = {}
789
            res[&#39;p_c_x&#39;] = tf.exp(log_p_c_x).numpy()
790
            res[&#39;w_tilde&#39;] = w_tilde.numpy()
791
            res[&#39;var_w_tilde&#39;] = var_w_tilde.numpy()
792
            return res</code></pre>
793
</details>
794
<h3>Ancestors</h3>
795
<ul class="hlist">
796
<li>keras.src.engine.base_layer.Layer</li>
797
<li>tensorflow.python.module.module.Module</li>
798
<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li>
799
<li>tensorflow.python.trackable.base.Trackable</li>
800
<li>keras.src.utils.version_utils.LayerVersionSelector</li>
801
</ul>
802
<h3>Methods</h3>
803
<dl>
804
<dt id="VITAE.model.LatentSpace.initialize"><code class="name flex">
805
<span>def <span class="ident">initialize</span></span>(<span>self, mu, log_pi)</span>
806
</code></dt>
807
<dd>
808
<div class="desc"><p>Initialize the latent space.</p>
809
<h2 id="parameters">Parameters</h2>
810
<dl>
811
<dt><strong><code>mu</code></strong> :&ensp;<code>np.array</code></dt>
812
<dd><span><span class="MathJax_Preview">[d, k]</span><script type="math/tex">[d, k]</script></span> The position matrix.</dd>
813
<dt><strong><code>log_pi</code></strong> :&ensp;<code>np.array</code></dt>
814
<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> <span><span class="MathJax_Preview">\log\pi</span><script type="math/tex">\log\pi</script></span>.</dd>
815
</dl></div>
816
</dd>
817
<dt id="VITAE.model.LatentSpace.normalize"><code class="name flex">
818
<span>def <span class="ident">normalize</span></span>(<span>self)</span>
819
</code></dt>
820
<dd>
821
<div class="desc"><p>Normalize <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</p></div>
822
</dd>
823
<dt id="VITAE.model.LatentSpace.get_pz"><code class="name flex">
824
<span>def <span class="ident">get_pz</span></span>(<span>self, z, eps, pi)</span>
825
</code></dt>
826
<dd>
827
<div class="desc"><p>Get <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</p>
828
<h2 id="parameters">Parameters</h2>
829
<dl>
830
<dt><strong><code>z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
831
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The latent variables.</dd>
832
</dl>
833
<h2 id="returns">Returns</h2>
834
<dl>
835
<dt><strong><code>temp_pi</code></strong> :&ensp;<code>tf.Tensor</code></dt>
836
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</dd>
837
<dt><strong><code>inv_sig</code></strong> :&ensp;<code>tf.Tensor</code></dt>
838
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\sigma_{Z_ic_i}^{-1}</span><script type="math/tex">\sigma_{Z_ic_i}^{-1}</script></span>.</dd>
839
<dt><strong><code>nu</code></strong> :&ensp;<code>tf.Tensor</code></dt>
840
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\nu_{Z_ic_i}</span><script type="math/tex">\nu_{Z_ic_i}</script></span>.</dd>
841
<dt><strong><code>beta_zc</code></strong> :&ensp;<code>tf.Tensor</code></dt>
842
<dd><span><span class="MathJax_Preview">[B, L, d, K]</span><script type="math/tex">[B, L, d, K]</script></span> <span><span class="MathJax_Preview">\beta_{Z_ic_i}</span><script type="math/tex">\beta_{Z_ic_i}</script></span>.</dd>
843
<dt><strong><code>log_eta0</code></strong> :&ensp;<code>tf.Tensor</code></dt>
844
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\log\eta_{Z_ic_i,0}</span><script type="math/tex">\log\eta_{Z_ic_i,0}</script></span>.</dd>
845
<dt><strong><code>eta1</code></strong> :&ensp;<code>tf.Tensor</code></dt>
846
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\eta_{Z_ic_i,1}</span><script type="math/tex">\eta_{Z_ic_i,1}</script></span>.</dd>
847
<dt><strong><code>eta2</code></strong> :&ensp;<code>tf.Tensor</code></dt>
848
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\eta_{Z_ic_i,2}</span><script type="math/tex">\eta_{Z_ic_i,2}</script></span>.</dd>
849
<dt><strong><code>log_p_zc_L</code></strong> :&ensp;<code>tf.Tensor</code></dt>
850
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\log p(Z_i,c_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i,c_i|Y_i,X_i)</script></span>.</dd>
851
<dt><strong><code>log_p_z_L</code></strong> :&ensp;<code>tf.Tensor</code></dt>
852
<dd><span><span class="MathJax_Preview">[B, L]</span><script type="math/tex">[B, L]</script></span> <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</dd>
853
<dt><strong><code>log_p_z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
854
<dd><span><span class="MathJax_Preview">[B, 1]</span><script type="math/tex">[B, 1]</script></span> The estimated <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</dd>
855
</dl></div>
856
</dd>
857
<dt id="VITAE.model.LatentSpace.get_posterior_c"><code class="name flex">
858
<span>def <span class="ident">get_posterior_c</span></span>(<span>self, z)</span>
859
</code></dt>
860
<dd>
861
<div class="desc"><p>Get <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</p>
862
<h2 id="parameters">Parameters</h2>
863
<dl>
864
<dt><strong><code>z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
865
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The latent variables.</dd>
866
</dl>
867
<h2 id="returns">Returns</h2>
868
<dl>
869
<dt><strong><code>p_c_x</code></strong> :&ensp;<code>np.array</code></dt>
870
<dd><span><span class="MathJax_Preview">[B, K]</span><script type="math/tex">[B, K]</script></span> <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</dd>
871
</dl></div>
872
</dd>
873
<dt id="VITAE.model.LatentSpace.call"><code class="name flex">
874
<span>def <span class="ident">call</span></span>(<span>self, z, pi=None, inference=False)</span>
875
</code></dt>
876
<dd>
877
<div class="desc"><p>Get posterior estimations.</p>
878
<h2 id="parameters">Parameters</h2>
879
<dl>
880
<dt><strong><code>z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
881
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The latent variables.</dd>
882
<dt><strong><code>inference</code></strong> :&ensp;<code>boolean</code></dt>
883
<dd>Whether in training or inference mode.</dd>
884
</dl>
885
<p>When <code>inference=False</code>:</p>
886
<h2 id="returns">Returns</h2>
887
<dl>
888
<dt><strong><code>log_p_z_L</code></strong> :&ensp;<code>tf.Tensor</code></dt>
889
<dd><span><span class="MathJax_Preview">[B, 1]</span><script type="math/tex">[B, 1]</script></span> The estimated <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</dd>
890
</dl>
891
<p>When <code>inference=True</code>:</p>
892
<h2 id="returns_1">Returns</h2>
893
<dl>
894
<dt><strong><code>res</code></strong> :&ensp;<code>dict</code></dt>
895
<dd>The dict of posterior estimations - <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>, <span><span class="MathJax_Preview">c</span><script type="math/tex">c</script></span>, <span><span class="MathJax_Preview">E(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">E(\tilde{w}_i|Y_i,X_i)</script></span>, <span><span class="MathJax_Preview">Var(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">Var(\tilde{w}_i|Y_i,X_i)</script></span>, <span><span class="MathJax_Preview">D_{JS}</span><script type="math/tex">D_{JS}</script></span>.</dd>
896
</dl></div>
897
</dd>
898
</dl>
899
</dd>
900
<dt id="VITAE.model.VariationalAutoEncoder"><code class="flex name class">
901
<span>class <span class="ident">VariationalAutoEncoder</span></span>
902
<span>(</span><span>dim_origin, dimensions, dim_latent, data_type='UMI', has_cov=False, name='autoencoder', **kwargs)</span>
903
</code></dt>
904
<dd>
905
<div class="desc"><p>Combines the encoder, decoder and LatentSpace into an end-to-end model for training and inference.</p>
906
<h2 id="parameters">Parameters</h2>
907
<dl>
908
<dt><strong><code>dim_origin</code></strong> :&ensp;<code>int</code></dt>
909
<dd>The output dimension of the decoder.</dd>
910
<dt><strong><code>dimensions</code></strong> :&ensp;<code>np.array</code></dt>
911
<dd>The dimensions of hidden layers of the encoder.</dd>
912
<dt><strong><code>dim_latent</code></strong> :&ensp;<code>int</code></dt>
913
<dd>The latent dimension.</dd>
914
<dt><strong><code>data_type</code></strong> :&ensp;<code>str</code>, optional</dt>
915
<dd><code>'UMI'</code>, <code>'non-UMI'</code>, or <code>'Gaussian'</code>.</dd>
916
<dt><strong><code>has_cov</code></strong> :&ensp;<code>boolean</code></dt>
917
<dd>Whether has covariates or not.</dd>
918
<dt><strong><code>gamma</code></strong> :&ensp;<code>float</code>, optional</dt>
919
<dd>The weights of the MMD loss</dd>
920
<dt><strong><code>name</code></strong> :&ensp;<code>str</code>, optional</dt>
921
<dd>The name of the layer.</dd>
922
<dt><strong><code>**kwargs</code></strong></dt>
923
<dd>Extra keyword arguments.</dd>
924
</dl></div>
925
<details class="source">
926
<summary>
927
<span>Expand source code</span>
928
</summary>
929
<pre><code class="python">class VariationalAutoEncoder(tf.keras.Model):
930
    &#34;&#34;&#34;
931
    Combines the encoder, decoder and LatentSpace into an end-to-end model for training and inference.
932
    &#34;&#34;&#34;
933
    def __init__(self, dim_origin, dimensions, dim_latent,
934
                 data_type = &#39;UMI&#39;, has_cov=False,
935
                 name = &#39;autoencoder&#39;, **kwargs):
936
        &#39;&#39;&#39;
937
        Parameters
938
        ----------
939
        dim_origin : int
940
            The output dimension of the decoder.        
941
        dimensions : np.array
942
            The dimensions of hidden layers of the encoder.
943
        dim_latent : int
944
            The latent dimension.
945
        data_type : str, optional
946
            `&#39;UMI&#39;`, `&#39;non-UMI&#39;`, or `&#39;Gaussian&#39;`.
947
        has_cov : boolean
948
            Whether has covariates or not.
949
        gamma : float, optional
950
            The weights of the MMD loss
951
        name : str, optional
952
            The name of the layer.
953
        **kwargs : 
954
            Extra keyword arguments.
955
        &#39;&#39;&#39;
956
        super(VariationalAutoEncoder, self).__init__(name = name, **kwargs)
957
        self.data_type = data_type
958
        self.dim_origin = dim_origin
959
        self.dim_latent = dim_latent
960
        self.encoder = Encoder(dimensions, dim_latent)
961
        self.decoder = Decoder(dimensions[::-1], dim_origin, data_type, data_type)        
962
        self.has_cov = has_cov
963
        
964
    def init_latent_space(self, n_clusters, mu, log_pi=None):
965
        &#39;&#39;&#39;Initialze the latent space.
966
967
        Parameters
968
        ----------
969
        n_clusters : int
970
            The number of vertices in the latent space.
971
        mu : np.array
972
            \([d, k]\) The position matrix.
973
        log_pi : np.array, optional
974
            \([1, K]\) \(\\log\\pi\).
975
        &#39;&#39;&#39;
976
        self.n_states = n_clusters
977
        self.latent_space = LatentSpace(self.n_states, self.dim_latent)
978
        self.latent_space.initialize(mu, log_pi)
979
        self.pilayer = None
980
981
    def create_pilayer(self):
982
        self.pilayer = Dense(self.latent_space.n_categories, name = &#39;pi_layer&#39;)
983
984
    def call(self, x_normalized, c_score, x = None, scale_factor = 1,
985
             pre_train = False, L=1, alpha=0.0, gamma = 0.0, phi = 1.0, conditions = None, pi_cov = None):
986
        &#39;&#39;&#39;Feed forward through encoder, LatentSpace layer and decoder.
987
988
        Parameters
989
        ----------
990
        x_normalized : np.array
991
            \([B, G]\) The preprocessed data.
992
        c_score : np.array
993
            \([B, s]\) The covariates \(X_i\), only used when `has_cov=True`.
994
        x : np.array, optional
995
            \([B, G]\) The original count data \(Y_i\), only used when data_type is not `&#39;Gaussian&#39;`.
996
        scale_factor : np.array, optional
997
            \([B, ]\) The scale factors, only used when data_type is not `&#39;Gaussian&#39;`.
998
        pre_train : boolean, optional
999
            Whether in the pre-training phare or not.
1000
        L : int, optional
1001
            The number of MC samples.
1002
        alpha : float, optional
1003
            The penalty parameter for covariates adjustment.
1004
        gamma : float, optional
1005
            The weight of mmd loss
1006
        phi : float, optional
1007
            The weight of Jacob norm of the encoder.
1008
        conditions: str or list, optional
1009
            The conditions of different cells from the selected batch
1010
1011
        Returns
1012
        ----------
1013
        losses : float
1014
            the loss.
1015
        &#39;&#39;&#39;
1016
1017
        if not pre_train and self.latent_space is None:
1018
            raise ReferenceError(&#39;Have not initialized the latent space.&#39;)
1019
                    
1020
        if self.has_cov:
1021
            x_normalized = tf.concat([x_normalized, c_score], -1)
1022
        else:
1023
            x_normalized
1024
        _, z_log_var, z = self.encoder(x_normalized, L)
1025
1026
        if gamma == 0:
1027
            mmd_loss = 0.0
1028
        else:
1029
            mmd_loss = self._get_total_mmd_loss(conditions,z,gamma)
1030
1031
        z_in = tf.concat([z, tf.tile(tf.expand_dims(c_score,1), (1,L,1))], -1) if self.has_cov else z
1032
        
1033
        x = tf.tile(tf.expand_dims(x, 1), (1,L,1))
1034
        reconstruction_z_loss = self._get_reconstruction_loss(x, z_in, scale_factor, L)
1035
        
1036
        if self.has_cov and alpha&gt;0.0:
1037
            zero_in = tf.concat([tf.zeros([z.shape[0],1,z.shape[2]], dtype=tf.keras.backend.floatx()), 
1038
                                tf.tile(tf.expand_dims(c_score,1), (1,1,1))], -1)
1039
            reconstruction_zero_loss = self._get_reconstruction_loss(x, zero_in, scale_factor, 1)
1040
            reconstruction_z_loss = (1-alpha)*reconstruction_z_loss + alpha*reconstruction_zero_loss
1041
1042
        self.add_loss(reconstruction_z_loss)
1043
        J_norm = self._get_Jacob(x_normalized, L)
1044
        self.add_loss((phi * J_norm))
1045
        # gamma weight has been used when call _mmd_loss function.
1046
        self.add_loss(mmd_loss)
1047
1048
        if not pre_train:
1049
            pi = self.pilayer(pi_cov) if self.pilayer is not None else None
1050
            log_p_z = self.latent_space(z, pi, inference=False)
1051
1052
            # - E_q[log p(z)]
1053
            self.add_loss(- log_p_z)
1054
1055
            # - Eq[log q(z|x)]
1056
            E_qzx = - tf.reduce_mean(
1057
                            0.5 * self.dim_latent *
1058
                            (tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + 1.0) +
1059
                            0.5 * tf.reduce_sum(z_log_var, axis=-1)
1060
                            )
1061
            self.add_loss(E_qzx)
1062
        return self.losses
1063
    
1064
    @tf.function
1065
    def _get_reconstruction_loss(self, x, z_in, scale_factor, L):
1066
        if self.data_type==&#39;Gaussian&#39;:
1067
            # Gaussian Log-Likelihood Loss function
1068
            nu_z, tau = self.decoder(z_in)
1069
            neg_E_Gaus = 0.5 * tf.math.log(tf.clip_by_value(tau, 1e-12, 1e30)) + 0.5 * tf.math.square(x - nu_z) / tau
1070
            neg_E_Gaus = tf.reduce_mean(tf.reduce_sum(neg_E_Gaus, axis=-1))
1071
1072
            return neg_E_Gaus
1073
        else:
1074
            if self.data_type == &#39;UMI&#39;:
1075
                x_hat, r = self.decoder(z_in)
1076
            else:
1077
                x_hat, r, phi = self.decoder(z_in)
1078
1079
            x_hat = x_hat*tf.expand_dims(scale_factor, -1)
1080
1081
            # Negative Log-Likelihood Loss function
1082
1083
            # Ref for NB &amp; ZINB loss functions:
1084
            # https://github.com/gokceneraslan/neuralnet_countmodels/blob/master/Count%20models%20with%20neuralnets.ipynb
1085
            # Negative Binomial loss
1086
1087
            neg_E_nb = tf.math.lgamma(r) + tf.math.lgamma(x+1.0) \
1088
                        - tf.math.lgamma(x+r) + \
1089
                        (r+x) * tf.math.log(1.0 + (x_hat/r)) + \
1090
                        x * (tf.math.log(r) - tf.math.log(tf.clip_by_value(x_hat, 1e-12, 1e30)))
1091
            
1092
            if self.data_type == &#39;non-UMI&#39;:
1093
                # Zero-Inflated Negative Binomial loss
1094
                nb_case = neg_E_nb - tf.math.log(tf.clip_by_value(1.0-phi, 1e-12, 1e30))
1095
                zero_case = - tf.math.log(tf.clip_by_value(
1096
                    phi + (1.0-phi) * tf.pow(r * tf.math.reciprocal_no_nan(r + x_hat), r),
1097
                    1e-12, 1e30))
1098
                neg_E_nb = tf.where(tf.less(x, 1e-8), zero_case, nb_case)
1099
1100
            neg_E_nb = tf.reduce_mean(tf.reduce_sum(neg_E_nb, axis=-1))
1101
            return neg_E_nb
1102
1103
    def _get_total_mmd_loss(self,conditions,z,gamma):
1104
        mmd_loss = 0.0
1105
        conditions = tf.cast(conditions,tf.int32)
1106
        n_group = conditions.shape[1]
1107
1108
        for i in range(n_group):
1109
            sub_conditions = conditions[:, i]
1110
            # 0 means not participant in mmd
1111
            z_cond = z[sub_conditions != 0]
1112
            sub_conditions = sub_conditions[sub_conditions != 0]
1113
            n_sub_group = tf.unique(sub_conditions)[0].shape[0]
1114
            real_labels = K.reshape(sub_conditions, (-1,)).numpy()
1115
            unique_set = list(set(real_labels))
1116
            reindex_dict = dict(zip(unique_set, range(n_sub_group)))
1117
            real_labels = [reindex_dict[x] for x in real_labels]
1118
            real_labels = tf.convert_to_tensor(real_labels,dtype=tf.int32)
1119
1120
            if (n_sub_group == 1) | (n_sub_group == 0):
1121
                _loss = 0
1122
            else:
1123
                _loss = self._mmd_loss(real_labels=real_labels, y_pred=z_cond, gamma=gamma,
1124
                                       n_conditions=n_sub_group,
1125
                                       kernel_method=&#39;multi-scale-rbf&#39;,
1126
                                       computation_method=&#34;general&#34;)
1127
            mmd_loss = mmd_loss + _loss
1128
        return mmd_loss
1129
1130
    # each loop the inputed shape is changed. Can not use @tf.function
1131
    # tf graph requires static shape and tensor dtype
1132
    def _mmd_loss(self, real_labels, y_pred, gamma, n_conditions, kernel_method=&#39;multi-scale-rbf&#39;,
1133
                  computation_method=&#34;general&#34;):
1134
        conditions_mmd = tf.dynamic_partition(y_pred, real_labels, num_partitions=n_conditions)
1135
        loss = 0.0
1136
        if computation_method.isdigit():
1137
            boundary = int(computation_method)
1138
            ## every pair of groups will calculate a distance
1139
            for i in range(boundary):
1140
                for j in range(boundary, n_conditions):
1141
                    loss += _nan2zero(compute_mmd(conditions_mmd[i], conditions_mmd[j], kernel_method))
1142
        else:
1143
            for i in range(len(conditions_mmd)):
1144
                for j in range(i):
1145
                    loss += _nan2zero(compute_mmd(conditions_mmd[i], conditions_mmd[j], kernel_method))
1146
1147
        # print(&#34;The loss is &#34;, loss)
1148
        return gamma * loss
1149
1150
    @tf.function
1151
    def _get_Jacob(self, x, L):
1152
        with tf.GradientTape() as g:
1153
            g.watch(x)
1154
            z_mean, z_log_var, z = self.encoder(x, L)
1155
            # y_mean, y_log_var = self.decoder(z)
1156
        ## just jacobian will cause shape (batch,16,batch,64) matrix
1157
        J = g.batch_jacobian(z, x)
1158
        J_norm = tf.norm(J)
1159
        # tf.print(J_norm)
1160
1161
        return J_norm
1162
    
1163
    def get_z(self, x_normalized, c_score):    
1164
        &#39;&#39;&#39;Get \(q(Z_i|Y_i,X_i)\).
1165
1166
        Parameters
1167
        ----------
1168
        x_normalized : int
1169
            \([B, G]\) The preprocessed data.
1170
        c_score : np.array
1171
            \([B, s]\) The covariates \(X_i\), only used when `has_cov=True`.
1172
1173
        Returns
1174
        ----------
1175
        z_mean : np.array
1176
            \([B, d]\) The latent mean.
1177
        &#39;&#39;&#39;    
1178
        x_normalized = x_normalized if (not self.has_cov or c_score is None) else tf.concat([x_normalized, c_score], -1)
1179
        z_mean, _, _ = self.encoder(x_normalized, 1, False)
1180
        return z_mean.numpy()
1181
1182
    def get_pc_x(self, test_dataset):
1183
        &#39;&#39;&#39;Get \(p(c_i|Y_i,X_i)\).
1184
1185
        Parameters
1186
        ----------
1187
        test_dataset : tf.Dataset
1188
            the dataset object.
1189
1190
        Returns
1191
        ----------
1192
        pi_norm : np.array
1193
            \([1, K]\) The estimated \(\\pi\).
1194
        p_c_x : np.array
1195
            \([N, ]\) The estimated \(p(c_i|Y_i,X_i)\).
1196
        &#39;&#39;&#39;    
1197
        if self.latent_space is None:
1198
            raise ReferenceError(&#39;Have not initialized the latent space.&#39;)
1199
        
1200
        pi_norm = tf.nn.softmax(self.latent_space.pi).numpy()
1201
        p_c_x = []
1202
        for x,c_score in test_dataset:
1203
            x = tf.concat([x, c_score], -1) if self.has_cov else x
1204
            _, _, z = self.encoder(x, 1, False)
1205
            _p_c_x = self.latent_space.get_posterior_c(z)            
1206
            p_c_x.append(_p_c_x)
1207
        p_c_x = np.concatenate(p_c_x)         
1208
        return pi_norm, p_c_x
1209
1210
    def inference(self, test_dataset, L=1):
1211
        &#39;&#39;&#39;Get \(p(c_i|Y_i,X_i)\).
1212
1213
        Parameters
1214
        ----------
1215
        test_dataset : tf.Dataset
1216
            The dataset object.
1217
        L : int
1218
            The number of MC samples.
1219
1220
        Returns
1221
        ----------
1222
        pi_norm  : np.array
1223
            \([1, K]\) The estimated \(\\pi\).
1224
        mu : np.array
1225
            \([d, k]\) The estimated \(\\mu\).
1226
        p_c_x : np.array
1227
            \([N, ]\) The estimated \(p(c_i|Y_i,X_i)\).
1228
        w_tilde : np.array
1229
            \([N, k]\) The estimated \(E(\\tilde{w}_i|Y_i,X_i)\).
1230
        var_w_tilde  : np.array 
1231
            \([N, k]\) The estimated \(Var(\\tilde{w}_i|Y_i,X_i)\).
1232
        z_mean : np.array
1233
            \([N, d]\) The estimated latent mean.
1234
        &#39;&#39;&#39;   
1235
        if self.latent_space is None:
1236
            raise ReferenceError(&#39;Have not initialized the latent space.&#39;)
1237
        
1238
        print(&#39;Computing posterior estimations over mini-batches.&#39;)
1239
        progbar = Progbar(test_dataset.cardinality().numpy())
1240
        pi_norm = tf.nn.softmax(self.latent_space.pi).numpy()
1241
        mu = self.latent_space.mu.numpy()
1242
        z_mean = []
1243
        p_c_x = []
1244
        w_tilde = []
1245
        var_w_tilde = []
1246
        for step, (x,c_score, _, _) in enumerate(test_dataset):
1247
            x = tf.concat([x, c_score], -1) if self.has_cov else x
1248
            _z_mean, _, z = self.encoder(x, L, False)
1249
            res = self.latent_space(z, inference=True)
1250
            
1251
            z_mean.append(_z_mean.numpy())
1252
            p_c_x.append(res[&#39;p_c_x&#39;])            
1253
            w_tilde.append(res[&#39;w_tilde&#39;])
1254
            var_w_tilde.append(res[&#39;var_w_tilde&#39;])
1255
            progbar.update(step+1)
1256
1257
        z_mean = np.concatenate(z_mean)
1258
        p_c_x = np.concatenate(p_c_x)
1259
        w_tilde = np.concatenate(w_tilde)
1260
        w_tilde /= np.sum(w_tilde, axis=1, keepdims=True)
1261
        var_w_tilde = np.concatenate(var_w_tilde)
1262
        return pi_norm, mu, p_c_x, w_tilde, var_w_tilde, z_mean</code></pre>
1263
</details>
1264
<h3>Ancestors</h3>
1265
<ul class="hlist">
1266
<li>keras.src.engine.training.Model</li>
1267
<li>keras.src.engine.base_layer.Layer</li>
1268
<li>tensorflow.python.module.module.Module</li>
1269
<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li>
1270
<li>tensorflow.python.trackable.base.Trackable</li>
1271
<li>keras.src.utils.version_utils.LayerVersionSelector</li>
1272
<li>keras.src.utils.version_utils.ModelVersionSelector</li>
1273
</ul>
1274
<h3>Methods</h3>
1275
<dl>
1276
<dt id="VITAE.model.VariationalAutoEncoder.init_latent_space"><code class="name flex">
1277
<span>def <span class="ident">init_latent_space</span></span>(<span>self, n_clusters, mu, log_pi=None)</span>
1278
</code></dt>
1279
<dd>
1280
<div class="desc"><p>Initialze the latent space.</p>
1281
<h2 id="parameters">Parameters</h2>
1282
<dl>
1283
<dt><strong><code>n_clusters</code></strong> :&ensp;<code>int</code></dt>
1284
<dd>The number of vertices in the latent space.</dd>
1285
<dt><strong><code>mu</code></strong> :&ensp;<code>np.array</code></dt>
1286
<dd><span><span class="MathJax_Preview">[d, k]</span><script type="math/tex">[d, k]</script></span> The position matrix.</dd>
1287
<dt><strong><code>log_pi</code></strong> :&ensp;<code>np.array</code>, optional</dt>
1288
<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> <span><span class="MathJax_Preview">\log\pi</span><script type="math/tex">\log\pi</script></span>.</dd>
1289
</dl></div>
1290
</dd>
1291
<dt id="VITAE.model.VariationalAutoEncoder.create_pilayer"><code class="name flex">
1292
<span>def <span class="ident">create_pilayer</span></span>(<span>self)</span>
1293
</code></dt>
1294
<dd>
1295
<div class="desc"></div>
1296
</dd>
1297
<dt id="VITAE.model.VariationalAutoEncoder.call"><code class="name flex">
1298
<span>def <span class="ident">call</span></span>(<span>self, x_normalized, c_score, x=None, scale_factor=1, pre_train=False, L=1, alpha=0.0, gamma=0.0, phi=1.0, conditions=None, pi_cov=None)</span>
1299
</code></dt>
1300
<dd>
1301
<div class="desc"><p>Feed forward through encoder, LatentSpace layer and decoder.</p>
1302
<h2 id="parameters">Parameters</h2>
1303
<dl>
1304
<dt><strong><code>x_normalized</code></strong> :&ensp;<code>np.array</code></dt>
1305
<dd><span><span class="MathJax_Preview">[B, G]</span><script type="math/tex">[B, G]</script></span> The preprocessed data.</dd>
1306
<dt><strong><code>c_score</code></strong> :&ensp;<code>np.array</code></dt>
1307
<dd><span><span class="MathJax_Preview">[B, s]</span><script type="math/tex">[B, s]</script></span> The covariates <span><span class="MathJax_Preview">X_i</span><script type="math/tex">X_i</script></span>, only used when <code>has_cov=True</code>.</dd>
1308
<dt><strong><code>x</code></strong> :&ensp;<code>np.array</code>, optional</dt>
1309
<dd><span><span class="MathJax_Preview">[B, G]</span><script type="math/tex">[B, G]</script></span> The original count data <span><span class="MathJax_Preview">Y_i</span><script type="math/tex">Y_i</script></span>, only used when data_type is not <code>'Gaussian'</code>.</dd>
1310
<dt><strong><code>scale_factor</code></strong> :&ensp;<code>np.array</code>, optional</dt>
1311
<dd><span><span class="MathJax_Preview">[B, ]</span><script type="math/tex">[B, ]</script></span> The scale factors, only used when data_type is not <code>'Gaussian'</code>.</dd>
1312
<dt><strong><code>pre_train</code></strong> :&ensp;<code>boolean</code>, optional</dt>
1313
<dd>Whether in the pre-training phare or not.</dd>
1314
<dt><strong><code>L</code></strong> :&ensp;<code>int</code>, optional</dt>
1315
<dd>The number of MC samples.</dd>
1316
<dt><strong><code>alpha</code></strong> :&ensp;<code>float</code>, optional</dt>
1317
<dd>The penalty parameter for covariates adjustment.</dd>
1318
<dt><strong><code>gamma</code></strong> :&ensp;<code>float</code>, optional</dt>
1319
<dd>The weight of mmd loss</dd>
1320
<dt><strong><code>phi</code></strong> :&ensp;<code>float</code>, optional</dt>
1321
<dd>The weight of Jacob norm of the encoder.</dd>
1322
<dt><strong><code>conditions</code></strong> :&ensp;<code>str</code> or <code>list</code>, optional</dt>
1323
<dd>The conditions of different cells from the selected batch</dd>
1324
</dl>
1325
<h2 id="returns">Returns</h2>
1326
<dl>
1327
<dt><strong><code>losses</code></strong> :&ensp;<code>float</code></dt>
1328
<dd>the loss.</dd>
1329
</dl></div>
1330
</dd>
1331
<dt id="VITAE.model.VariationalAutoEncoder.get_z"><code class="name flex">
1332
<span>def <span class="ident">get_z</span></span>(<span>self, x_normalized, c_score)</span>
1333
</code></dt>
1334
<dd>
1335
<div class="desc"><p>Get <span><span class="MathJax_Preview">q(Z_i|Y_i,X_i)</span><script type="math/tex">q(Z_i|Y_i,X_i)</script></span>.</p>
1336
<h2 id="parameters">Parameters</h2>
1337
<dl>
1338
<dt><strong><code>x_normalized</code></strong> :&ensp;<code>int</code></dt>
1339
<dd><span><span class="MathJax_Preview">[B, G]</span><script type="math/tex">[B, G]</script></span> The preprocessed data.</dd>
1340
<dt><strong><code>c_score</code></strong> :&ensp;<code>np.array</code></dt>
1341
<dd><span><span class="MathJax_Preview">[B, s]</span><script type="math/tex">[B, s]</script></span> The covariates <span><span class="MathJax_Preview">X_i</span><script type="math/tex">X_i</script></span>, only used when <code>has_cov=True</code>.</dd>
1342
</dl>
1343
<h2 id="returns">Returns</h2>
1344
<dl>
1345
<dt><strong><code>z_mean</code></strong> :&ensp;<code>np.array</code></dt>
1346
<dd><span><span class="MathJax_Preview">[B, d]</span><script type="math/tex">[B, d]</script></span> The latent mean.</dd>
1347
</dl></div>
1348
</dd>
1349
<dt id="VITAE.model.VariationalAutoEncoder.get_pc_x"><code class="name flex">
1350
<span>def <span class="ident">get_pc_x</span></span>(<span>self, test_dataset)</span>
1351
</code></dt>
1352
<dd>
1353
<div class="desc"><p>Get <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</p>
1354
<h2 id="parameters">Parameters</h2>
1355
<dl>
1356
<dt><strong><code>test_dataset</code></strong> :&ensp;<code>tf.Dataset</code></dt>
1357
<dd>the dataset object.</dd>
1358
</dl>
1359
<h2 id="returns">Returns</h2>
1360
<dl>
1361
<dt><strong><code>pi_norm</code></strong> :&ensp;<code>np.array</code></dt>
1362
<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> The estimated <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</dd>
1363
<dt><strong><code>p_c_x</code></strong> :&ensp;<code>np.array</code></dt>
1364
<dd><span><span class="MathJax_Preview">[N, ]</span><script type="math/tex">[N, ]</script></span> The estimated <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</dd>
1365
</dl></div>
1366
</dd>
1367
<dt id="VITAE.model.VariationalAutoEncoder.inference"><code class="name flex">
1368
<span>def <span class="ident">inference</span></span>(<span>self, test_dataset, L=1)</span>
1369
</code></dt>
1370
<dd>
1371
<div class="desc"><p>Get <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</p>
1372
<h2 id="parameters">Parameters</h2>
1373
<dl>
1374
<dt><strong><code>test_dataset</code></strong> :&ensp;<code>tf.Dataset</code></dt>
1375
<dd>The dataset object.</dd>
1376
<dt><strong><code>L</code></strong> :&ensp;<code>int</code></dt>
1377
<dd>The number of MC samples.</dd>
1378
</dl>
1379
<h2 id="returns">Returns</h2>
1380
<dl>
1381
<dt><code>pi_norm
1382
: np.array</code></dt>
1383
<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> The estimated <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</dd>
1384
<dt><strong><code>mu</code></strong> :&ensp;<code>np.array</code></dt>
1385
<dd><span><span class="MathJax_Preview">[d, k]</span><script type="math/tex">[d, k]</script></span> The estimated <span><span class="MathJax_Preview">\mu</span><script type="math/tex">\mu</script></span>.</dd>
1386
<dt><strong><code>p_c_x</code></strong> :&ensp;<code>np.array</code></dt>
1387
<dd><span><span class="MathJax_Preview">[N, ]</span><script type="math/tex">[N, ]</script></span> The estimated <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</dd>
1388
<dt><strong><code>w_tilde</code></strong> :&ensp;<code>np.array</code></dt>
1389
<dd><span><span class="MathJax_Preview">[N, k]</span><script type="math/tex">[N, k]</script></span> The estimated <span><span class="MathJax_Preview">E(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">E(\tilde{w}_i|Y_i,X_i)</script></span>.</dd>
1390
<dt><code>var_w_tilde
1391
: np.array</code></dt>
1392
<dd><span><span class="MathJax_Preview">[N, k]</span><script type="math/tex">[N, k]</script></span> The estimated <span><span class="MathJax_Preview">Var(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">Var(\tilde{w}_i|Y_i,X_i)</script></span>.</dd>
1393
<dt><strong><code>z_mean</code></strong> :&ensp;<code>np.array</code></dt>
1394
<dd><span><span class="MathJax_Preview">[N, d]</span><script type="math/tex">[N, d]</script></span> The estimated latent mean.</dd>
1395
</dl></div>
1396
</dd>
1397
</dl>
1398
</dd>
1399
</dl>
1400
</section>
1401
</article>
1402
<nav id="sidebar">
1403
<div class="toc">
1404
<ul></ul>
1405
</div>
1406
<ul id="index">
1407
<li><h3>Super-module</h3>
1408
<ul>
1409
<li><code><a title="VITAE" href="index.html">VITAE</a></code></li>
1410
</ul>
1411
</li>
1412
<li><h3><a href="#header-classes">Classes</a></h3>
1413
<ul>
1414
<li>
1415
<h4><code><a title="VITAE.model.cdf_layer" href="#VITAE.model.cdf_layer">cdf_layer</a></code></h4>
1416
<ul class="">
1417
<li><code><a title="VITAE.model.cdf_layer.call" href="#VITAE.model.cdf_layer.call">call</a></code></li>
1418
<li><code><a title="VITAE.model.cdf_layer.func" href="#VITAE.model.cdf_layer.func">func</a></code></li>
1419
</ul>
1420
</li>
1421
<li>
1422
<h4><code><a title="VITAE.model.Sampling" href="#VITAE.model.Sampling">Sampling</a></code></h4>
1423
<ul class="">
1424
<li><code><a title="VITAE.model.Sampling.call" href="#VITAE.model.Sampling.call">call</a></code></li>
1425
</ul>
1426
</li>
1427
<li>
1428
<h4><code><a title="VITAE.model.Encoder" href="#VITAE.model.Encoder">Encoder</a></code></h4>
1429
<ul class="">
1430
<li><code><a title="VITAE.model.Encoder.call" href="#VITAE.model.Encoder.call">call</a></code></li>
1431
</ul>
1432
</li>
1433
<li>
1434
<h4><code><a title="VITAE.model.Decoder" href="#VITAE.model.Decoder">Decoder</a></code></h4>
1435
<ul class="">
1436
<li><code><a title="VITAE.model.Decoder.call" href="#VITAE.model.Decoder.call">call</a></code></li>
1437
</ul>
1438
</li>
1439
<li>
1440
<h4><code><a title="VITAE.model.LatentSpace" href="#VITAE.model.LatentSpace">LatentSpace</a></code></h4>
1441
<ul class="">
1442
<li><code><a title="VITAE.model.LatentSpace.initialize" href="#VITAE.model.LatentSpace.initialize">initialize</a></code></li>
1443
<li><code><a title="VITAE.model.LatentSpace.normalize" href="#VITAE.model.LatentSpace.normalize">normalize</a></code></li>
1444
<li><code><a title="VITAE.model.LatentSpace.get_pz" href="#VITAE.model.LatentSpace.get_pz">get_pz</a></code></li>
1445
<li><code><a title="VITAE.model.LatentSpace.get_posterior_c" href="#VITAE.model.LatentSpace.get_posterior_c">get_posterior_c</a></code></li>
1446
<li><code><a title="VITAE.model.LatentSpace.call" href="#VITAE.model.LatentSpace.call">call</a></code></li>
1447
</ul>
1448
</li>
1449
<li>
1450
<h4><code><a title="VITAE.model.VariationalAutoEncoder" href="#VITAE.model.VariationalAutoEncoder">VariationalAutoEncoder</a></code></h4>
1451
<ul class="two-column">
1452
<li><code><a title="VITAE.model.VariationalAutoEncoder.init_latent_space" href="#VITAE.model.VariationalAutoEncoder.init_latent_space">init_latent_space</a></code></li>
1453
<li><code><a title="VITAE.model.VariationalAutoEncoder.create_pilayer" href="#VITAE.model.VariationalAutoEncoder.create_pilayer">create_pilayer</a></code></li>
1454
<li><code><a title="VITAE.model.VariationalAutoEncoder.call" href="#VITAE.model.VariationalAutoEncoder.call">call</a></code></li>
1455
<li><code><a title="VITAE.model.VariationalAutoEncoder.get_z" href="#VITAE.model.VariationalAutoEncoder.get_z">get_z</a></code></li>
1456
<li><code><a title="VITAE.model.VariationalAutoEncoder.get_pc_x" href="#VITAE.model.VariationalAutoEncoder.get_pc_x">get_pc_x</a></code></li>
1457
<li><code><a title="VITAE.model.VariationalAutoEncoder.inference" href="#VITAE.model.VariationalAutoEncoder.inference">inference</a></code></li>
1458
</ul>
1459
</li>
1460
</ul>
1461
</li>
1462
</ul>
1463
</nav>
1464
</main>
1465
<footer id="footer">
1466
<p>Generated by <a href="https://pdoc3.github.io/pdoc" title="pdoc: Python API documentation generator"><cite>pdoc</cite> 0.11.1</a>.</p>
1467
</footer>
1468
</body>
1469
</html>