Diff of /docs/model.html [000000] .. [2c6b19]

Switch to side-by-side view

--- a
+++ b/docs/model.html
@@ -0,0 +1,1469 @@
+<!doctype html>
+<html lang="en">
+<head>
+<meta charset="utf-8">
+<meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1">
+<meta name="generator" content="pdoc3 0.11.1">
+<title>VITAE.model API documentation</title>
+<meta name="description" content="">
+<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/sanitize.min.css" integrity="sha512-y1dtMcuvtTMJc1yPgEqF0ZjQbhnc/bFhyvIyVNb9Zk5mIGtqVaAB1Ttl28su8AvFMOY0EwRbAe+HCLqj6W7/KA==" crossorigin>
+<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/typography.min.css" integrity="sha512-Y1DYSb995BAfxobCkKepB1BqJJTPrOp3zPL74AWFugHHmmdcvO+C48WLrUOlhGMc0QG7AE3f7gmvvcrmX2fDoA==" crossorigin>
+<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/default.min.css" crossorigin>
+<style>:root{--highlight-color:#fe9}.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:1.5em;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:2em 0 .50em 0}h3{font-size:1.4em;margin:1.6em 0 .7em 0}h4{margin:0;font-size:105%}h1:target,h2:target,h3:target,h4:target,h5:target,h6:target{background:var(--highlight-color);padding:.2em 0}a{color:#058;text-decoration:none;transition:color .2s ease-in-out}a:visited{color:#503}a:hover{color:#b62}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900;font-weight:bold}pre code{font-size:.8em;line-height:1.4em;padding:1em;display:block}code{background:#f3f3f3;font-family:"DejaVu Sans Mono",monospace;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}dt:target .name{background:var(--highlight-color)}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}td{padding:0 .5em}.admonition{padding:.1em 1em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style>
+<style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul ul{padding-left:1em}.toc > ul > li{margin-top:.5em}}</style>
+<style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style>
+<script type="text/x-mathjax-config">MathJax.Hub.Config({ tex2jax: { inlineMath: [ ['$','$'], ["\\(","\\)"] ], processEscapes: true } });</script>
+<script async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS_CHTML" integrity="sha256-kZafAc6mZvK3W3v1pHOcUix30OHQN6pU/NO2oFkqZVw=" crossorigin></script>
+<script defer src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js" integrity="sha512-D9gUyxqja7hBtkWpPWGt9wfbfaMGVt9gnyCvYa+jojwwPHLCzUm5i8rpk7vD7wNee9bA35eYIjobYPaQuKS1MQ==" crossorigin></script>
+<script>window.addEventListener('DOMContentLoaded', () => {
+hljs.configure({languages: ['bash', 'css', 'diff', 'graphql', 'ini', 'javascript', 'json', 'plaintext', 'python', 'python-repl', 'rust', 'shell', 'sql', 'typescript', 'xml', 'yaml']});
+hljs.highlightAll();
+})</script>
+</head>
+<body>
+<main>
+<article id="content">
+<header>
+<h1 class="title">Module <code>VITAE.model</code></h1>
+</header>
+<section id="section-intro">
+</section>
+<section>
+</section>
+<section>
+</section>
+<section>
+</section>
+<section>
+<h2 class="section-title" id="header-classes">Classes</h2>
+<dl>
+<dt id="VITAE.model.cdf_layer"><code class="flex name class">
+<span>class <span class="ident">cdf_layer</span></span>
+</code></dt>
+<dd>
+<div class="desc"><p>The Normal cdf layer with custom gradients.</p></div>
+<details class="source">
+<summary>
+<span>Expand source code</span>
+</summary>
+<pre><code class="python">class cdf_layer(Layer):
+    &#39;&#39;&#39;
+    The Normal cdf layer with custom gradients.
+    &#39;&#39;&#39;
+    def __init__(self):
+        &#39;&#39;&#39;
+        &#39;&#39;&#39;
+        super(cdf_layer, self).__init__()
+        
+    @tf.function
+    def call(self, x):
+        return self.func(x)
+        
+    @tf.custom_gradient
+    def func(self, x):
+        &#39;&#39;&#39;Return cdf(x) and pdf(x).
+
+        Parameters
+        ----------
+        x : tf.Tensor
+            The input tensor.
+        
+        Returns
+        ----------
+        f : tf.Tensor
+            cdf(x).
+        grad : tf.Tensor
+            pdf(x).
+        &#39;&#39;&#39;   
+        dist = tfp.distributions.Normal(
+            loc = tf.constant(0.0, tf.keras.backend.floatx()), 
+            scale = tf.constant(1.0, tf.keras.backend.floatx()), 
+            allow_nan_stats=False)
+        f = dist.cdf(x)
+        def grad(dy):
+            gradient = dist.prob(x)
+            return dy * gradient
+        return f, grad</code></pre>
+</details>
+<h3>Ancestors</h3>
+<ul class="hlist">
+<li>keras.src.engine.base_layer.Layer</li>
+<li>tensorflow.python.module.module.Module</li>
+<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li>
+<li>tensorflow.python.trackable.base.Trackable</li>
+<li>keras.src.utils.version_utils.LayerVersionSelector</li>
+</ul>
+<h3>Methods</h3>
+<dl>
+<dt id="VITAE.model.cdf_layer.call"><code class="name flex">
+<span>def <span class="ident">call</span></span>(<span>self, x)</span>
+</code></dt>
+<dd>
+<div class="desc"></div>
+</dd>
+<dt id="VITAE.model.cdf_layer.func"><code class="name flex">
+<span>def <span class="ident">func</span></span>(<span>self, x)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Return cdf(x) and pdf(x).</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>x</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd>The input tensor.</dd>
+</dl>
+<h2 id="returns">Returns</h2>
+<dl>
+<dt><strong><code>f</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd>cdf(x).</dd>
+<dt><strong><code>grad</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd>pdf(x).</dd>
+</dl></div>
+</dd>
+</dl>
+</dd>
+<dt id="VITAE.model.Sampling"><code class="flex name class">
+<span>class <span class="ident">Sampling</span></span>
+<span>(</span><span>seed=0, **kwargs)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Sampling latent variable <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span> from <span><span class="MathJax_Preview">N(\mu_z, \log \sigma_z^2</span><script type="math/tex">N(\mu_z, \log \sigma_z^2</script></span>).
+<br>
+Used in Encoder.</p></div>
+<details class="source">
+<summary>
+<span>Expand source code</span>
+</summary>
+<pre><code class="python">class Sampling(Layer):
+    &#34;&#34;&#34;Sampling latent variable \(z\) from \(N(\\mu_z, \\log \\sigma_z^2\)).    
+    Used in Encoder.
+    &#34;&#34;&#34;
+    def __init__(self, seed=0, **kwargs):
+        super(Sampling, self).__init__(**kwargs)
+        self.seed = seed
+
+    @tf.function
+    def call(self, z_mean, z_log_var):
+        &#39;&#39;&#39;Return cdf(x) and pdf(x).
+
+        Parameters
+        ----------
+        z_mean : tf.Tensor
+            \([B, L, d]\) The mean of \(z\).
+        z_log_var : tf.Tensor
+            \([B, L, d]\) The log-variance of \(z\).
+
+        Returns
+        ----------
+        z : tf.Tensor
+            \([B, L, d]\) The sampled \(z\).
+        &#39;&#39;&#39;   
+   #     seed = tfp.util.SeedStream(self.seed, salt=&#34;random_normal&#34;)
+   #     epsilon = tf.random.normal(shape = tf.shape(z_mean), seed=seed(), dtype=tf.keras.backend.floatx())
+        epsilon = tf.random.normal(shape = tf.shape(z_mean), dtype=tf.keras.backend.floatx())
+        z = z_mean + tf.exp(0.5 * z_log_var) * epsilon
+        z = tf.clip_by_value(z, -1e6, 1e6)
+        return z</code></pre>
+</details>
+<h3>Ancestors</h3>
+<ul class="hlist">
+<li>keras.src.engine.base_layer.Layer</li>
+<li>tensorflow.python.module.module.Module</li>
+<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li>
+<li>tensorflow.python.trackable.base.Trackable</li>
+<li>keras.src.utils.version_utils.LayerVersionSelector</li>
+</ul>
+<h3>Methods</h3>
+<dl>
+<dt id="VITAE.model.Sampling.call"><code class="name flex">
+<span>def <span class="ident">call</span></span>(<span>self, z_mean, z_log_var)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Return cdf(x) and pdf(x).</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>z_mean</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The mean of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
+<dt><strong><code>z_log_var</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The log-variance of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
+</dl>
+<h2 id="returns">Returns</h2>
+<dl>
+<dt><strong><code>z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The sampled <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
+</dl></div>
+</dd>
+</dl>
+</dd>
+<dt id="VITAE.model.Encoder"><code class="flex name class">
+<span>class <span class="ident">Encoder</span></span>
+<span>(</span><span>dimensions, dim_latent, name='encoder', **kwargs)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Encoder, model <span><span class="MathJax_Preview">p(Z_i|Y_i,X_i)</span><script type="math/tex">p(Z_i|Y_i,X_i)</script></span>.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>dimensions</code></strong> :&ensp;<code>np.array</code></dt>
+<dd>The dimensions of hidden layers of the encoder.</dd>
+<dt><strong><code>dim_latent</code></strong> :&ensp;<code>int</code></dt>
+<dd>The latent dimension of the encoder.</dd>
+<dt><strong><code>name</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd>The name of the layer.</dd>
+<dt><strong><code>**kwargs</code></strong></dt>
+<dd>Extra keyword arguments.</dd>
+</dl></div>
+<details class="source">
+<summary>
+<span>Expand source code</span>
+</summary>
+<pre><code class="python">class Encoder(Layer):
+    &#39;&#39;&#39;
+    Encoder, model \(p(Z_i|Y_i,X_i)\).
+    &#39;&#39;&#39;
+    def __init__(self, dimensions, dim_latent, name=&#39;encoder&#39;, **kwargs):
+        &#39;&#39;&#39;
+        Parameters
+        ----------
+        dimensions : np.array
+            The dimensions of hidden layers of the encoder.
+        dim_latent : int
+            The latent dimension of the encoder.
+        name : str, optional
+            The name of the layer.
+        **kwargs : 
+            Extra keyword arguments.
+        &#39;&#39;&#39; 
+        super(Encoder, self).__init__(name = name, **kwargs)
+        self.dense_layers = [Dense(dim, activation = tf.nn.leaky_relu,
+                                          name = &#39;encoder_%i&#39;%(i+1)) \
+                             for (i, dim) in enumerate(dimensions)]
+        self.batch_norm_layers = [BatchNormalization(center=False) \
+                                    for _ in range(len((dimensions)))]
+        self.batch_norm_layers.append(BatchNormalization(center=False))
+        self.latent_mean = Dense(dim_latent, name = &#39;latent_mean&#39;)
+        self.latent_log_var = Dense(dim_latent, name = &#39;latent_log_var&#39;)
+        self.sampling = Sampling()
+    
+    @tf.function
+    def call(self, x, L=1, is_training=True):
+        &#39;&#39;&#39;Encode the inputs and get the latent variables.
+
+        Parameters
+        ----------
+        x : tf.Tensor
+            \([B, L, d]\) The input.
+        L : int, optional
+            The number of MC samples.
+        is_training : boolean, optional
+            Whether in the training or inference mode.
+        
+        Returns
+        ----------
+        z_mean : tf.Tensor
+            \([B, L, d]\) The mean of \(z\).
+        z_log_var : tf.Tensor
+            \([B, L, d]\) The log-variance of \(z\).
+        z : tf.Tensor
+            \([B, L, d]\) The sampled \(z\).
+        &#39;&#39;&#39;         
+        for dense, bn in zip(self.dense_layers, self.batch_norm_layers):
+            x = dense(x)
+            x = bn(x, training=is_training)
+        z_mean = self.batch_norm_layers[-1](self.latent_mean(x), training=is_training)
+        z_log_var = self.latent_log_var(x)
+        _z_mean = tf.tile(tf.expand_dims(z_mean, 1), (1,L,1))
+        _z_log_var = tf.tile(tf.expand_dims(z_log_var, 1), (1,L,1))
+        z = self.sampling(_z_mean, _z_log_var)
+        return z_mean, z_log_var, z</code></pre>
+</details>
+<h3>Ancestors</h3>
+<ul class="hlist">
+<li>keras.src.engine.base_layer.Layer</li>
+<li>tensorflow.python.module.module.Module</li>
+<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li>
+<li>tensorflow.python.trackable.base.Trackable</li>
+<li>keras.src.utils.version_utils.LayerVersionSelector</li>
+</ul>
+<h3>Methods</h3>
+<dl>
+<dt id="VITAE.model.Encoder.call"><code class="name flex">
+<span>def <span class="ident">call</span></span>(<span>self, x, L=1, is_training=True)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Encode the inputs and get the latent variables.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>x</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The input.</dd>
+<dt><strong><code>L</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The number of MC samples.</dd>
+<dt><strong><code>is_training</code></strong> :&ensp;<code>boolean</code>, optional</dt>
+<dd>Whether in the training or inference mode.</dd>
+</dl>
+<h2 id="returns">Returns</h2>
+<dl>
+<dt><strong><code>z_mean</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The mean of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
+<dt><strong><code>z_log_var</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The log-variance of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
+<dt><strong><code>z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The sampled <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
+</dl></div>
+</dd>
+</dl>
+</dd>
+<dt id="VITAE.model.Decoder"><code class="flex name class">
+<span>class <span class="ident">Decoder</span></span>
+<span>(</span><span>dimensions, dim_origin, data_type='UMI', name='decoder', **kwargs)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Decoder, model <span><span class="MathJax_Preview">p(Y_i|Z_i,X_i)</span><script type="math/tex">p(Y_i|Z_i,X_i)</script></span>.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>dimensions</code></strong> :&ensp;<code>np.array</code></dt>
+<dd>The dimensions of hidden layers of the encoder.</dd>
+<dt><strong><code>dim_origin</code></strong> :&ensp;<code>int</code></dt>
+<dd>The output dimension of the decoder.</dd>
+<dt><strong><code>data_type</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd><code>'UMI'</code>, <code>'non-UMI'</code>, or <code>'Gaussian'</code>.</dd>
+<dt><strong><code>name</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd>The name of the layer.</dd>
+</dl></div>
+<details class="source">
+<summary>
+<span>Expand source code</span>
+</summary>
+<pre><code class="python">class Decoder(Layer):
+    &#39;&#39;&#39;
+    Decoder, model \(p(Y_i|Z_i,X_i)\).
+    &#39;&#39;&#39;
+    def __init__(self, dimensions, dim_origin, data_type = &#39;UMI&#39;, 
+                name = &#39;decoder&#39;, **kwargs):
+        &#39;&#39;&#39;
+        Parameters
+        ----------
+        dimensions : np.array
+            The dimensions of hidden layers of the encoder.
+        dim_origin : int
+            The output dimension of the decoder.
+        data_type : str, optional
+            `&#39;UMI&#39;`, `&#39;non-UMI&#39;`, or `&#39;Gaussian&#39;`.
+        name : str, optional
+            The name of the layer.
+        &#39;&#39;&#39;
+        super(Decoder, self).__init__(name = name, **kwargs)
+        self.data_type = data_type
+        self.dense_layers = [Dense(dim, activation = tf.nn.leaky_relu,
+                                          name = &#39;decoder_%i&#39;%(i+1)) \
+                             for (i,dim) in enumerate(dimensions)]
+        self.batch_norm_layers = [BatchNormalization(center=False) \
+                                    for _ in range(len((dimensions)))]
+
+        if data_type==&#39;Gaussian&#39;:
+            self.nu_z = Dense(dim_origin, name = &#39;nu_z&#39;)
+            # common variance
+            self.log_tau = tf.Variable(tf.zeros([1, dim_origin], dtype=tf.keras.backend.floatx()),
+                                 constraint = lambda t: tf.clip_by_value(t,-30.,6.),
+                                 name = &#34;log_tau&#34;)
+        else:
+            self.log_lambda_z = Dense(dim_origin, name = &#39;log_lambda_z&#39;)
+
+            # dispersion parameter
+            self.log_r = tf.Variable(tf.zeros([1, dim_origin], dtype=tf.keras.backend.floatx()),
+                                     constraint = lambda t: tf.clip_by_value(t,-30.,6.),
+                                     name = &#34;log_r&#34;)
+            
+            if self.data_type == &#39;non-UMI&#39;:
+                self.phi = Dense(dim_origin, activation = &#39;sigmoid&#39;, name = &#34;phi&#34;)
+          
+    @tf.function  
+    def call(self, z, is_training=True):
+        &#39;&#39;&#39;Decode the latent variables and get the reconstructions.
+
+        Parameters
+        ----------
+        z : tf.Tensor
+            \([B, L, d]\) the sampled \(z\).
+        is_training : boolean, optional
+            whether in the training or inference mode.
+
+        When `data_type==&#39;Gaussian&#39;`:
+
+        Returns
+        ----------
+        nu_z : tf.Tensor
+            \([B, L, G]\) The mean of \(Y_i|Z_i,X_i\).
+        tau : tf.Tensor
+            \([1, G]\) The variance of \(Y_i|Z_i,X_i\).
+
+        When `data_type==&#39;UMI&#39;`:
+
+        Returns
+        ----------
+        lambda_z : tf.Tensor
+            \([B, L, G]\) The mean of \(Y_i|Z_i,X_i\).
+        r : tf.Tensor
+            \([1, G]\) The dispersion parameters of \(Y_i|Z_i,X_i\).
+
+        When `data_type==&#39;non-UMI&#39;`:
+
+        Returns
+        ----------
+        lambda_z : tf.Tensor
+            \([B, L, G]\) The mean of \(Y_i|Z_i,X_i\).
+        r : tf.Tensor
+            \([1, G]\) The dispersion parameters of \(Y_i|Z_i,X_i\).
+        phi_z : tf.Tensor
+            \([1, G]\) The zero inflated parameters of \(Y_i|Z_i,X_i\).
+        &#39;&#39;&#39;
+        for dense, bn in zip(self.dense_layers, self.batch_norm_layers):
+            z = dense(z)
+            z = bn(z, training=is_training)
+        if self.data_type==&#39;Gaussian&#39;:
+            nu_z = self.nu_z(z)
+            tau = tf.exp(self.log_tau)
+            return nu_z, tau
+        else:
+            lambda_z = tf.math.exp(
+                tf.clip_by_value(self.log_lambda_z(z), -30., 6.)
+                )
+            r = tf.exp(self.log_r)
+            if self.data_type==&#39;UMI&#39;:
+                return lambda_z, r
+            else:
+                return lambda_z, r, self.phi(z)</code></pre>
+</details>
+<h3>Ancestors</h3>
+<ul class="hlist">
+<li>keras.src.engine.base_layer.Layer</li>
+<li>tensorflow.python.module.module.Module</li>
+<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li>
+<li>tensorflow.python.trackable.base.Trackable</li>
+<li>keras.src.utils.version_utils.LayerVersionSelector</li>
+</ul>
+<h3>Methods</h3>
+<dl>
+<dt id="VITAE.model.Decoder.call"><code class="name flex">
+<span>def <span class="ident">call</span></span>(<span>self, z, is_training=True)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Decode the latent variables and get the reconstructions.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> the sampled <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
+<dt><strong><code>is_training</code></strong> :&ensp;<code>boolean</code>, optional</dt>
+<dd>whether in the training or inference mode.</dd>
+</dl>
+<p>When <code>data_type=='Gaussian'</code>:</p>
+<h2 id="returns">Returns</h2>
+<dl>
+<dt><strong><code>nu_z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, G]</span><script type="math/tex">[B, L, G]</script></span> The mean of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
+<dt><strong><code>tau</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The variance of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
+</dl>
+<p>When <code>data_type=='UMI'</code>:</p>
+<h2 id="returns_1">Returns</h2>
+<dl>
+<dt><strong><code>lambda_z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, G]</span><script type="math/tex">[B, L, G]</script></span> The mean of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
+<dt><strong><code>r</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The dispersion parameters of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
+</dl>
+<p>When <code>data_type=='non-UMI'</code>:</p>
+<h2 id="returns_2">Returns</h2>
+<dl>
+<dt><strong><code>lambda_z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, G]</span><script type="math/tex">[B, L, G]</script></span> The mean of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
+<dt><strong><code>r</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The dispersion parameters of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
+<dt><strong><code>phi_z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The zero inflated parameters of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
+</dl></div>
+</dd>
+</dl>
+</dd>
+<dt id="VITAE.model.LatentSpace"><code class="flex name class">
+<span>class <span class="ident">LatentSpace</span></span>
+<span>(</span><span>n_clusters, dim_latent, name='LatentSpace', seed=0, **kwargs)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Layer for the Latent Space.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>n_clusters</code></strong> :&ensp;<code>int</code></dt>
+<dd>The number of vertices in the latent space.</dd>
+<dt><strong><code>dim_latent</code></strong> :&ensp;<code>int</code></dt>
+<dd>The latent dimension.</dd>
+<dt><strong><code>M</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The discretized number of uniform(0,1).</dd>
+<dt><strong><code>name</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd>The name of the layer.</dd>
+<dt><strong><code>**kwargs</code></strong></dt>
+<dd>Extra keyword arguments.</dd>
+</dl></div>
+<details class="source">
+<summary>
+<span>Expand source code</span>
+</summary>
+<pre><code class="python">class LatentSpace(Layer):
+    &#39;&#39;&#39;
+    Layer for the Latent Space.
+    &#39;&#39;&#39;
+    def __init__(self, n_clusters, dim_latent,
+            name = &#39;LatentSpace&#39;, seed=0, **kwargs):
+        &#39;&#39;&#39;
+        Parameters
+        ----------
+        n_clusters : int
+            The number of vertices in the latent space.
+        dim_latent : int
+            The latent dimension.
+        M : int, optional
+            The discretized number of uniform(0,1).
+        name : str, optional
+            The name of the layer.
+        **kwargs : 
+            Extra keyword arguments.
+        &#39;&#39;&#39;
+        super(LatentSpace, self).__init__(name=name, **kwargs)
+        self.dim_latent = dim_latent
+        self.n_states = n_clusters
+        self.n_categories = int(n_clusters*(n_clusters+1)/2)
+
+        # nonzero indexes
+        # A = [0,0,...,0  , 1,1,...,1,   ...]
+        # B = [0,1,...,k-1, 1,2,...,k-1, ...]
+        self.A, self.B = np.nonzero(np.triu(np.ones(n_clusters)))
+        self.A = tf.convert_to_tensor(self.A, tf.int32)
+        self.B = tf.convert_to_tensor(self.B, tf.int32)
+        self.clusters_ind = tf.boolean_mask(
+            tf.range(0,self.n_categories,1), self.A==self.B)
+
+        # [pi_1, ... , pi_K] in R^(n_categories)
+        self.pi = tf.Variable(tf.ones([1, self.n_categories], dtype=tf.keras.backend.floatx()) / self.n_categories,
+                                name = &#39;pi&#39;)
+        
+        # [mu_1, ... , mu_K] in R^(dim_latent * n_clusters)
+        self.mu = tf.Variable(tf.random.uniform([self.dim_latent, self.n_states],
+                                                minval = -1, maxval = 1, seed=seed, dtype=tf.keras.backend.floatx()),
+                                name = &#39;mu&#39;)
+        self.cdf_layer = cdf_layer()       
+        
+    def initialize(self, mu, log_pi):
+        &#39;&#39;&#39;Initialize the latent space.
+
+        Parameters
+        ----------
+        mu : np.array
+            \([d, k]\) The position matrix.
+        log_pi : np.array
+            \([1, K]\) \(\\log\\pi\).
+        &#39;&#39;&#39;
+        # Initialize parameters of the latent space
+        if mu is not None:
+            self.mu.assign(mu)
+        if log_pi is not None:
+            self.pi.assign(log_pi)
+
+    def normalize(self):
+        &#39;&#39;&#39;Normalize \(\\pi\).
+        &#39;&#39;&#39;
+        self.pi = tf.nn.softmax(self.pi)
+
+    @tf.function
+    def _get_normal_params(self, z, pi):
+        batch_size = tf.shape(z)[0]
+        L = tf.shape(z)[1]
+        
+        # [batch_size, L, n_categories]
+        if pi is None:
+            # [batch_size, L, n_states]
+            temp_pi = tf.tile(
+                tf.expand_dims(tf.nn.softmax(self.pi), 1),
+                (batch_size,L,1))
+        else:
+            temp_pi = tf.expand_dims(tf.nn.softmax(pi), 1)
+
+        # [batch_size, L, d, n_categories]
+        alpha_zc = tf.expand_dims(tf.expand_dims(
+            tf.gather(self.mu, self.B, axis=1) - tf.gather(self.mu, self.A, axis=1), 0), 0)
+        beta_zc = tf.expand_dims(z,-1) - \
+            tf.expand_dims(tf.expand_dims(
+            tf.gather(self.mu, self.B, axis=1), 0), 0)
+            
+        # [batch_size, L, n_categories]
+        inv_sig = tf.reduce_sum(alpha_zc * alpha_zc, axis=2)
+        nu = - tf.reduce_sum(alpha_zc * beta_zc, axis=2) * tf.math.reciprocal_no_nan(inv_sig)
+        _t = - tf.reduce_sum(beta_zc * beta_zc, axis=2) + nu**2*inv_sig
+        return temp_pi, beta_zc, inv_sig, nu, _t
+    
+    @tf.function
+    def _get_pz(self, temp_pi, inv_sig, beta_zc, log_p_z_c_L):
+        # [batch_size, L, n_categories]
+        log_p_zc_L = - 0.5 * self.dim_latent * tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + \
+            tf.math.log(temp_pi) + \
+            tf.where(inv_sig==0, 
+                    - 0.5 * tf.reduce_sum(beta_zc**2, axis=2), 
+                    log_p_z_c_L)
+        
+        # [batch_size, L, 1]
+        log_p_z_L = tf.reduce_logsumexp(log_p_zc_L, axis=-1, keepdims=True)
+        
+        # [1, ]
+        log_p_z = tf.reduce_mean(log_p_z_L)
+        return log_p_zc_L, log_p_z_L, log_p_z
+    
+    @tf.function
+    def _get_posterior_c(self, log_p_zc_L, log_p_z_L):
+        L = tf.shape(log_p_z_L)[1]
+
+        # log_p_c_x     -   predicted probability distribution
+        # [batch_size, n_categories]
+        log_p_c_x = tf.reduce_logsumexp(
+                        log_p_zc_L - log_p_z_L,
+                    axis=1) - tf.math.log(tf.cast(L, tf.keras.backend.floatx()))
+        return log_p_c_x
+
+    @tf.function
+    def _get_inference(self, z, log_p_z_L, temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2):
+        batch_size = tf.shape(z)[0]
+        L = tf.shape(z)[1]
+        dist = tfp.distributions.Normal(
+            loc = tf.constant(0.0, tf.keras.backend.floatx()), 
+            scale = tf.constant(1.0, tf.keras.backend.floatx()), 
+            allow_nan_stats=False)
+        
+        # [batch_size, L, n_categories, n_clusters]
+        inv_sig = tf.expand_dims(inv_sig, -1)
+        _sig = tf.tile(tf.clip_by_value(tf.math.reciprocal_no_nan(inv_sig), 1e-12, 1e30), (1,1,1,self.n_states))
+        log_eta0 = tf.tile(tf.expand_dims(log_eta0, -1), (1,1,1,self.n_states))
+        eta1 = tf.tile(tf.expand_dims(eta1, -1), (1,1,1,self.n_states))
+        eta2 = tf.tile(tf.expand_dims(eta2, -1), (1,1,1,self.n_states))
+        nu = tf.tile(tf.expand_dims(nu, -1), (1,1,1,1))
+        A = tf.tile(tf.expand_dims(tf.expand_dims(
+            tf.one_hot(self.A, self.n_states, dtype=tf.keras.backend.floatx()), 
+            0),0), (batch_size,L,1,1))
+        B = tf.tile(tf.expand_dims(tf.expand_dims(
+            tf.one_hot(self.B, self.n_states, dtype=tf.keras.backend.floatx()), 
+            0),0), (batch_size,L,1,1))
+        temp_pi = tf.expand_dims(temp_pi, -1)
+
+        # w_tilde [batch_size, L, n_clusters]
+        w_tilde = log_eta0 + tf.math.log(
+            tf.clip_by_value(
+                (dist.cdf(eta1) - dist.cdf(eta2)) * (nu * A + (1-nu) * B)  -
+                (dist.prob(eta1) - dist.prob(eta2)) * tf.math.sqrt(_sig) * (A - B), 
+                0.0, 1e30)
+            )
+        w_tilde = - 0.5 * self.dim_latent * tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + \
+            tf.math.log(temp_pi) + \
+            tf.where(inv_sig==0, 
+                    tf.where(B==1, - 0.5 * tf.expand_dims(tf.reduce_sum(beta_zc**2, axis=2), -1), -np.inf), 
+                    w_tilde)
+        w_tilde = tf.exp(tf.reduce_logsumexp(w_tilde, 2) - log_p_z_L)
+
+        # tf.debugging.assert_greater_equal(
+        #     tf.reduce_sum(w_tilde, -1), tf.ones([batch_size, L], dtype=tf.keras.backend.floatx())*0.99, 
+        #     message=&#39;Wrong w_tilde&#39;, summarize=None, name=None
+        # )
+        
+        # var_w_tilde [batch_size, L, n_clusters]
+        var_w_tilde = log_eta0 + tf.math.log(
+            tf.clip_by_value(
+                (dist.cdf(eta1) -  dist.cdf(eta2)) * ((_sig + nu**2) * (A+B) + (1-2*nu) * B)  -
+                (dist.prob(eta1) - dist.prob(eta2)) * tf.math.sqrt(_sig) * (nu *(A+B)-B )*2 -
+                (eta1*dist.prob(eta1) - eta2*dist.prob(eta2)) * _sig *(A+B), 
+                0.0, 1e30)
+            )
+        var_w_tilde = - 0.5 * self.dim_latent * tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + \
+            tf.math.log(temp_pi) + \
+            tf.where(inv_sig==0, 
+                    tf.where(B==1, - 0.5 * tf.expand_dims(tf.reduce_sum(beta_zc**2, axis=2), -1), -np.inf), 
+                    var_w_tilde) 
+        var_w_tilde = tf.exp(tf.reduce_logsumexp(var_w_tilde, 2) - log_p_z_L) - w_tilde**2  
+
+
+        w_tilde = tf.reduce_mean(w_tilde, 1)
+        var_w_tilde = tf.reduce_mean(var_w_tilde, 1)
+        return w_tilde, var_w_tilde
+
+    def get_pz(self, z, eps, pi):
+        &#39;&#39;&#39;Get \(\\log p(Z_i|Y_i,X_i)\).
+
+        Parameters
+        ----------
+        z : tf.Tensor
+            \([B, L, d]\) The latent variables.
+
+        Returns
+        ----------
+        temp_pi : tf.Tensor
+            \([B, L, K]\) \(\\pi\).
+        inv_sig : tf.Tensor
+            \([B, L, K]\) \(\\sigma_{Z_ic_i}^{-1}\).
+        nu : tf.Tensor
+            \([B, L, K]\) \(\\nu_{Z_ic_i}\).
+        beta_zc : tf.Tensor
+            \([B, L, d, K]\) \(\\beta_{Z_ic_i}\).
+        log_eta0 : tf.Tensor
+            \([B, L, K]\) \(\\log\\eta_{Z_ic_i,0}\).
+        eta1 : tf.Tensor
+            \([B, L, K]\) \(\\eta_{Z_ic_i,1}\).
+        eta2 : tf.Tensor
+            \([B, L, K]\) \(\\eta_{Z_ic_i,2}\).
+        log_p_zc_L : tf.Tensor
+            \([B, L, K]\) \(\\log p(Z_i,c_i|Y_i,X_i)\).
+        log_p_z_L : tf.Tensor
+            \([B, L]\) \(\\log p(Z_i|Y_i,X_i)\).
+        log_p_z : tf.Tensor
+            \([B, 1]\) The estimated \(\\log p(Z_i|Y_i,X_i)\). 
+        &#39;&#39;&#39;        
+        temp_pi, beta_zc, inv_sig, nu, _t = self._get_normal_params(z, pi)
+        temp_pi = tf.clip_by_value(temp_pi, eps, 1.0)
+
+        log_eta0 = 0.5 * (tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) - \
+                    tf.math.log(tf.clip_by_value(inv_sig, 1e-12, 1e30)) + _t)
+        eta1 = (1-nu) * tf.math.sqrt(tf.clip_by_value(inv_sig, 1e-12, 1e30))
+        eta2 = -nu * tf.math.sqrt(tf.clip_by_value(inv_sig, 1e-12, 1e30))
+
+        log_p_z_c_L =  log_eta0 + tf.math.log(tf.clip_by_value(
+            self.cdf_layer(eta1) - self.cdf_layer(eta2),
+            eps, 1e30))
+        
+        log_p_zc_L, log_p_z_L, log_p_z = self._get_pz(temp_pi, inv_sig, beta_zc, log_p_z_c_L)
+        return temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2, log_p_zc_L, log_p_z_L, log_p_z
+
+    def get_posterior_c(self, z):
+        &#39;&#39;&#39;Get \(p(c_i|Y_i,X_i)\).
+
+        Parameters
+        ----------
+        z : tf.Tensor
+            \([B, L, d]\) The latent variables.
+
+        Returns
+        ----------
+        p_c_x : np.array
+            \([B, K]\) \(p(c_i|Y_i,X_i)\).
+        &#39;&#39;&#39;  
+        _,_,_,_,_,_,_, log_p_zc_L, log_p_z_L, _ = self.get_pz(z)
+        log_p_c_x = self._get_posterior_c(log_p_zc_L, log_p_z_L)
+        p_c_x = tf.exp(log_p_c_x).numpy()
+        return p_c_x
+
+    def call(self, z, pi=None, inference=False):
+        &#39;&#39;&#39;Get posterior estimations.
+
+        Parameters
+        ----------
+        z : tf.Tensor
+            \([B, L, d]\) The latent variables.
+        inference : boolean
+            Whether in training or inference mode.
+
+        When `inference=False`:
+
+        Returns
+        ----------
+        log_p_z_L : tf.Tensor
+            \([B, 1]\) The estimated \(\\log p(Z_i|Y_i,X_i)\).
+
+        When `inference=True`:
+
+        Returns
+        ----------
+        res : dict
+            The dict of posterior estimations - \(p(c_i|Y_i,X_i)\), \(c\), \(E(\\tilde{w}_i|Y_i,X_i)\), \(Var(\\tilde{w}_i|Y_i,X_i)\), \(D_{JS}\).
+        &#39;&#39;&#39;                 
+        eps = 1e-16 if not inference else 0.
+        temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2, log_p_zc_L, log_p_z_L, log_p_z = self.get_pz(z, eps, pi)
+
+        if not inference:
+            return log_p_z
+        else:
+            log_p_c_x = self._get_posterior_c(log_p_zc_L, log_p_z_L)
+            w_tilde, var_w_tilde = self._get_inference(z, log_p_z_L, temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2)
+            
+            res = {}
+            res[&#39;p_c_x&#39;] = tf.exp(log_p_c_x).numpy()
+            res[&#39;w_tilde&#39;] = w_tilde.numpy()
+            res[&#39;var_w_tilde&#39;] = var_w_tilde.numpy()
+            return res</code></pre>
+</details>
+<h3>Ancestors</h3>
+<ul class="hlist">
+<li>keras.src.engine.base_layer.Layer</li>
+<li>tensorflow.python.module.module.Module</li>
+<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li>
+<li>tensorflow.python.trackable.base.Trackable</li>
+<li>keras.src.utils.version_utils.LayerVersionSelector</li>
+</ul>
+<h3>Methods</h3>
+<dl>
+<dt id="VITAE.model.LatentSpace.initialize"><code class="name flex">
+<span>def <span class="ident">initialize</span></span>(<span>self, mu, log_pi)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Initialize the latent space.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>mu</code></strong> :&ensp;<code>np.array</code></dt>
+<dd><span><span class="MathJax_Preview">[d, k]</span><script type="math/tex">[d, k]</script></span> The position matrix.</dd>
+<dt><strong><code>log_pi</code></strong> :&ensp;<code>np.array</code></dt>
+<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> <span><span class="MathJax_Preview">\log\pi</span><script type="math/tex">\log\pi</script></span>.</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.model.LatentSpace.normalize"><code class="name flex">
+<span>def <span class="ident">normalize</span></span>(<span>self)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Normalize <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</p></div>
+</dd>
+<dt id="VITAE.model.LatentSpace.get_pz"><code class="name flex">
+<span>def <span class="ident">get_pz</span></span>(<span>self, z, eps, pi)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Get <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The latent variables.</dd>
+</dl>
+<h2 id="returns">Returns</h2>
+<dl>
+<dt><strong><code>temp_pi</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</dd>
+<dt><strong><code>inv_sig</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\sigma_{Z_ic_i}^{-1}</span><script type="math/tex">\sigma_{Z_ic_i}^{-1}</script></span>.</dd>
+<dt><strong><code>nu</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\nu_{Z_ic_i}</span><script type="math/tex">\nu_{Z_ic_i}</script></span>.</dd>
+<dt><strong><code>beta_zc</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, d, K]</span><script type="math/tex">[B, L, d, K]</script></span> <span><span class="MathJax_Preview">\beta_{Z_ic_i}</span><script type="math/tex">\beta_{Z_ic_i}</script></span>.</dd>
+<dt><strong><code>log_eta0</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\log\eta_{Z_ic_i,0}</span><script type="math/tex">\log\eta_{Z_ic_i,0}</script></span>.</dd>
+<dt><strong><code>eta1</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\eta_{Z_ic_i,1}</span><script type="math/tex">\eta_{Z_ic_i,1}</script></span>.</dd>
+<dt><strong><code>eta2</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\eta_{Z_ic_i,2}</span><script type="math/tex">\eta_{Z_ic_i,2}</script></span>.</dd>
+<dt><strong><code>log_p_zc_L</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\log p(Z_i,c_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i,c_i|Y_i,X_i)</script></span>.</dd>
+<dt><strong><code>log_p_z_L</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L]</span><script type="math/tex">[B, L]</script></span> <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</dd>
+<dt><strong><code>log_p_z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, 1]</span><script type="math/tex">[B, 1]</script></span> The estimated <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.model.LatentSpace.get_posterior_c"><code class="name flex">
+<span>def <span class="ident">get_posterior_c</span></span>(<span>self, z)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Get <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The latent variables.</dd>
+</dl>
+<h2 id="returns">Returns</h2>
+<dl>
+<dt><strong><code>p_c_x</code></strong> :&ensp;<code>np.array</code></dt>
+<dd><span><span class="MathJax_Preview">[B, K]</span><script type="math/tex">[B, K]</script></span> <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.model.LatentSpace.call"><code class="name flex">
+<span>def <span class="ident">call</span></span>(<span>self, z, pi=None, inference=False)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Get posterior estimations.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>z</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The latent variables.</dd>
+<dt><strong><code>inference</code></strong> :&ensp;<code>boolean</code></dt>
+<dd>Whether in training or inference mode.</dd>
+</dl>
+<p>When <code>inference=False</code>:</p>
+<h2 id="returns">Returns</h2>
+<dl>
+<dt><strong><code>log_p_z_L</code></strong> :&ensp;<code>tf.Tensor</code></dt>
+<dd><span><span class="MathJax_Preview">[B, 1]</span><script type="math/tex">[B, 1]</script></span> The estimated <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</dd>
+</dl>
+<p>When <code>inference=True</code>:</p>
+<h2 id="returns_1">Returns</h2>
+<dl>
+<dt><strong><code>res</code></strong> :&ensp;<code>dict</code></dt>
+<dd>The dict of posterior estimations - <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>, <span><span class="MathJax_Preview">c</span><script type="math/tex">c</script></span>, <span><span class="MathJax_Preview">E(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">E(\tilde{w}_i|Y_i,X_i)</script></span>, <span><span class="MathJax_Preview">Var(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">Var(\tilde{w}_i|Y_i,X_i)</script></span>, <span><span class="MathJax_Preview">D_{JS}</span><script type="math/tex">D_{JS}</script></span>.</dd>
+</dl></div>
+</dd>
+</dl>
+</dd>
+<dt id="VITAE.model.VariationalAutoEncoder"><code class="flex name class">
+<span>class <span class="ident">VariationalAutoEncoder</span></span>
+<span>(</span><span>dim_origin, dimensions, dim_latent, data_type='UMI', has_cov=False, name='autoencoder', **kwargs)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Combines the encoder, decoder and LatentSpace into an end-to-end model for training and inference.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>dim_origin</code></strong> :&ensp;<code>int</code></dt>
+<dd>The output dimension of the decoder.</dd>
+<dt><strong><code>dimensions</code></strong> :&ensp;<code>np.array</code></dt>
+<dd>The dimensions of hidden layers of the encoder.</dd>
+<dt><strong><code>dim_latent</code></strong> :&ensp;<code>int</code></dt>
+<dd>The latent dimension.</dd>
+<dt><strong><code>data_type</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd><code>'UMI'</code>, <code>'non-UMI'</code>, or <code>'Gaussian'</code>.</dd>
+<dt><strong><code>has_cov</code></strong> :&ensp;<code>boolean</code></dt>
+<dd>Whether has covariates or not.</dd>
+<dt><strong><code>gamma</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The weights of the MMD loss</dd>
+<dt><strong><code>name</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd>The name of the layer.</dd>
+<dt><strong><code>**kwargs</code></strong></dt>
+<dd>Extra keyword arguments.</dd>
+</dl></div>
+<details class="source">
+<summary>
+<span>Expand source code</span>
+</summary>
+<pre><code class="python">class VariationalAutoEncoder(tf.keras.Model):
+    &#34;&#34;&#34;
+    Combines the encoder, decoder and LatentSpace into an end-to-end model for training and inference.
+    &#34;&#34;&#34;
+    def __init__(self, dim_origin, dimensions, dim_latent,
+                 data_type = &#39;UMI&#39;, has_cov=False,
+                 name = &#39;autoencoder&#39;, **kwargs):
+        &#39;&#39;&#39;
+        Parameters
+        ----------
+        dim_origin : int
+            The output dimension of the decoder.        
+        dimensions : np.array
+            The dimensions of hidden layers of the encoder.
+        dim_latent : int
+            The latent dimension.
+        data_type : str, optional
+            `&#39;UMI&#39;`, `&#39;non-UMI&#39;`, or `&#39;Gaussian&#39;`.
+        has_cov : boolean
+            Whether has covariates or not.
+        gamma : float, optional
+            The weights of the MMD loss
+        name : str, optional
+            The name of the layer.
+        **kwargs : 
+            Extra keyword arguments.
+        &#39;&#39;&#39;
+        super(VariationalAutoEncoder, self).__init__(name = name, **kwargs)
+        self.data_type = data_type
+        self.dim_origin = dim_origin
+        self.dim_latent = dim_latent
+        self.encoder = Encoder(dimensions, dim_latent)
+        self.decoder = Decoder(dimensions[::-1], dim_origin, data_type, data_type)        
+        self.has_cov = has_cov
+        
+    def init_latent_space(self, n_clusters, mu, log_pi=None):
+        &#39;&#39;&#39;Initialze the latent space.
+
+        Parameters
+        ----------
+        n_clusters : int
+            The number of vertices in the latent space.
+        mu : np.array
+            \([d, k]\) The position matrix.
+        log_pi : np.array, optional
+            \([1, K]\) \(\\log\\pi\).
+        &#39;&#39;&#39;
+        self.n_states = n_clusters
+        self.latent_space = LatentSpace(self.n_states, self.dim_latent)
+        self.latent_space.initialize(mu, log_pi)
+        self.pilayer = None
+
+    def create_pilayer(self):
+        self.pilayer = Dense(self.latent_space.n_categories, name = &#39;pi_layer&#39;)
+
+    def call(self, x_normalized, c_score, x = None, scale_factor = 1,
+             pre_train = False, L=1, alpha=0.0, gamma = 0.0, phi = 1.0, conditions = None, pi_cov = None):
+        &#39;&#39;&#39;Feed forward through encoder, LatentSpace layer and decoder.
+
+        Parameters
+        ----------
+        x_normalized : np.array
+            \([B, G]\) The preprocessed data.
+        c_score : np.array
+            \([B, s]\) The covariates \(X_i\), only used when `has_cov=True`.
+        x : np.array, optional
+            \([B, G]\) The original count data \(Y_i\), only used when data_type is not `&#39;Gaussian&#39;`.
+        scale_factor : np.array, optional
+            \([B, ]\) The scale factors, only used when data_type is not `&#39;Gaussian&#39;`.
+        pre_train : boolean, optional
+            Whether in the pre-training phare or not.
+        L : int, optional
+            The number of MC samples.
+        alpha : float, optional
+            The penalty parameter for covariates adjustment.
+        gamma : float, optional
+            The weight of mmd loss
+        phi : float, optional
+            The weight of Jacob norm of the encoder.
+        conditions: str or list, optional
+            The conditions of different cells from the selected batch
+
+        Returns
+        ----------
+        losses : float
+            the loss.
+        &#39;&#39;&#39;
+
+        if not pre_train and self.latent_space is None:
+            raise ReferenceError(&#39;Have not initialized the latent space.&#39;)
+                    
+        if self.has_cov:
+            x_normalized = tf.concat([x_normalized, c_score], -1)
+        else:
+            x_normalized
+        _, z_log_var, z = self.encoder(x_normalized, L)
+
+        if gamma == 0:
+            mmd_loss = 0.0
+        else:
+            mmd_loss = self._get_total_mmd_loss(conditions,z,gamma)
+
+        z_in = tf.concat([z, tf.tile(tf.expand_dims(c_score,1), (1,L,1))], -1) if self.has_cov else z
+        
+        x = tf.tile(tf.expand_dims(x, 1), (1,L,1))
+        reconstruction_z_loss = self._get_reconstruction_loss(x, z_in, scale_factor, L)
+        
+        if self.has_cov and alpha&gt;0.0:
+            zero_in = tf.concat([tf.zeros([z.shape[0],1,z.shape[2]], dtype=tf.keras.backend.floatx()), 
+                                tf.tile(tf.expand_dims(c_score,1), (1,1,1))], -1)
+            reconstruction_zero_loss = self._get_reconstruction_loss(x, zero_in, scale_factor, 1)
+            reconstruction_z_loss = (1-alpha)*reconstruction_z_loss + alpha*reconstruction_zero_loss
+
+        self.add_loss(reconstruction_z_loss)
+        J_norm = self._get_Jacob(x_normalized, L)
+        self.add_loss((phi * J_norm))
+        # gamma weight has been used when call _mmd_loss function.
+        self.add_loss(mmd_loss)
+
+        if not pre_train:
+            pi = self.pilayer(pi_cov) if self.pilayer is not None else None
+            log_p_z = self.latent_space(z, pi, inference=False)
+
+            # - E_q[log p(z)]
+            self.add_loss(- log_p_z)
+
+            # - Eq[log q(z|x)]
+            E_qzx = - tf.reduce_mean(
+                            0.5 * self.dim_latent *
+                            (tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + 1.0) +
+                            0.5 * tf.reduce_sum(z_log_var, axis=-1)
+                            )
+            self.add_loss(E_qzx)
+        return self.losses
+    
+    @tf.function
+    def _get_reconstruction_loss(self, x, z_in, scale_factor, L):
+        if self.data_type==&#39;Gaussian&#39;:
+            # Gaussian Log-Likelihood Loss function
+            nu_z, tau = self.decoder(z_in)
+            neg_E_Gaus = 0.5 * tf.math.log(tf.clip_by_value(tau, 1e-12, 1e30)) + 0.5 * tf.math.square(x - nu_z) / tau
+            neg_E_Gaus = tf.reduce_mean(tf.reduce_sum(neg_E_Gaus, axis=-1))
+
+            return neg_E_Gaus
+        else:
+            if self.data_type == &#39;UMI&#39;:
+                x_hat, r = self.decoder(z_in)
+            else:
+                x_hat, r, phi = self.decoder(z_in)
+
+            x_hat = x_hat*tf.expand_dims(scale_factor, -1)
+
+            # Negative Log-Likelihood Loss function
+
+            # Ref for NB &amp; ZINB loss functions:
+            # https://github.com/gokceneraslan/neuralnet_countmodels/blob/master/Count%20models%20with%20neuralnets.ipynb
+            # Negative Binomial loss
+
+            neg_E_nb = tf.math.lgamma(r) + tf.math.lgamma(x+1.0) \
+                        - tf.math.lgamma(x+r) + \
+                        (r+x) * tf.math.log(1.0 + (x_hat/r)) + \
+                        x * (tf.math.log(r) - tf.math.log(tf.clip_by_value(x_hat, 1e-12, 1e30)))
+            
+            if self.data_type == &#39;non-UMI&#39;:
+                # Zero-Inflated Negative Binomial loss
+                nb_case = neg_E_nb - tf.math.log(tf.clip_by_value(1.0-phi, 1e-12, 1e30))
+                zero_case = - tf.math.log(tf.clip_by_value(
+                    phi + (1.0-phi) * tf.pow(r * tf.math.reciprocal_no_nan(r + x_hat), r),
+                    1e-12, 1e30))
+                neg_E_nb = tf.where(tf.less(x, 1e-8), zero_case, nb_case)
+
+            neg_E_nb = tf.reduce_mean(tf.reduce_sum(neg_E_nb, axis=-1))
+            return neg_E_nb
+
+    def _get_total_mmd_loss(self,conditions,z,gamma):
+        mmd_loss = 0.0
+        conditions = tf.cast(conditions,tf.int32)
+        n_group = conditions.shape[1]
+
+        for i in range(n_group):
+            sub_conditions = conditions[:, i]
+            # 0 means not participant in mmd
+            z_cond = z[sub_conditions != 0]
+            sub_conditions = sub_conditions[sub_conditions != 0]
+            n_sub_group = tf.unique(sub_conditions)[0].shape[0]
+            real_labels = K.reshape(sub_conditions, (-1,)).numpy()
+            unique_set = list(set(real_labels))
+            reindex_dict = dict(zip(unique_set, range(n_sub_group)))
+            real_labels = [reindex_dict[x] for x in real_labels]
+            real_labels = tf.convert_to_tensor(real_labels,dtype=tf.int32)
+
+            if (n_sub_group == 1) | (n_sub_group == 0):
+                _loss = 0
+            else:
+                _loss = self._mmd_loss(real_labels=real_labels, y_pred=z_cond, gamma=gamma,
+                                       n_conditions=n_sub_group,
+                                       kernel_method=&#39;multi-scale-rbf&#39;,
+                                       computation_method=&#34;general&#34;)
+            mmd_loss = mmd_loss + _loss
+        return mmd_loss
+
+    # each loop the inputed shape is changed. Can not use @tf.function
+    # tf graph requires static shape and tensor dtype
+    def _mmd_loss(self, real_labels, y_pred, gamma, n_conditions, kernel_method=&#39;multi-scale-rbf&#39;,
+                  computation_method=&#34;general&#34;):
+        conditions_mmd = tf.dynamic_partition(y_pred, real_labels, num_partitions=n_conditions)
+        loss = 0.0
+        if computation_method.isdigit():
+            boundary = int(computation_method)
+            ## every pair of groups will calculate a distance
+            for i in range(boundary):
+                for j in range(boundary, n_conditions):
+                    loss += _nan2zero(compute_mmd(conditions_mmd[i], conditions_mmd[j], kernel_method))
+        else:
+            for i in range(len(conditions_mmd)):
+                for j in range(i):
+                    loss += _nan2zero(compute_mmd(conditions_mmd[i], conditions_mmd[j], kernel_method))
+
+        # print(&#34;The loss is &#34;, loss)
+        return gamma * loss
+
+    @tf.function
+    def _get_Jacob(self, x, L):
+        with tf.GradientTape() as g:
+            g.watch(x)
+            z_mean, z_log_var, z = self.encoder(x, L)
+            # y_mean, y_log_var = self.decoder(z)
+        ## just jacobian will cause shape (batch,16,batch,64) matrix
+        J = g.batch_jacobian(z, x)
+        J_norm = tf.norm(J)
+        # tf.print(J_norm)
+
+        return J_norm
+    
+    def get_z(self, x_normalized, c_score):    
+        &#39;&#39;&#39;Get \(q(Z_i|Y_i,X_i)\).
+
+        Parameters
+        ----------
+        x_normalized : int
+            \([B, G]\) The preprocessed data.
+        c_score : np.array
+            \([B, s]\) The covariates \(X_i\), only used when `has_cov=True`.
+
+        Returns
+        ----------
+        z_mean : np.array
+            \([B, d]\) The latent mean.
+        &#39;&#39;&#39;    
+        x_normalized = x_normalized if (not self.has_cov or c_score is None) else tf.concat([x_normalized, c_score], -1)
+        z_mean, _, _ = self.encoder(x_normalized, 1, False)
+        return z_mean.numpy()
+
+    def get_pc_x(self, test_dataset):
+        &#39;&#39;&#39;Get \(p(c_i|Y_i,X_i)\).
+
+        Parameters
+        ----------
+        test_dataset : tf.Dataset
+            the dataset object.
+
+        Returns
+        ----------
+        pi_norm : np.array
+            \([1, K]\) The estimated \(\\pi\).
+        p_c_x : np.array
+            \([N, ]\) The estimated \(p(c_i|Y_i,X_i)\).
+        &#39;&#39;&#39;    
+        if self.latent_space is None:
+            raise ReferenceError(&#39;Have not initialized the latent space.&#39;)
+        
+        pi_norm = tf.nn.softmax(self.latent_space.pi).numpy()
+        p_c_x = []
+        for x,c_score in test_dataset:
+            x = tf.concat([x, c_score], -1) if self.has_cov else x
+            _, _, z = self.encoder(x, 1, False)
+            _p_c_x = self.latent_space.get_posterior_c(z)            
+            p_c_x.append(_p_c_x)
+        p_c_x = np.concatenate(p_c_x)         
+        return pi_norm, p_c_x
+
+    def inference(self, test_dataset, L=1):
+        &#39;&#39;&#39;Get \(p(c_i|Y_i,X_i)\).
+
+        Parameters
+        ----------
+        test_dataset : tf.Dataset
+            The dataset object.
+        L : int
+            The number of MC samples.
+
+        Returns
+        ----------
+        pi_norm  : np.array
+            \([1, K]\) The estimated \(\\pi\).
+        mu : np.array
+            \([d, k]\) The estimated \(\\mu\).
+        p_c_x : np.array
+            \([N, ]\) The estimated \(p(c_i|Y_i,X_i)\).
+        w_tilde : np.array
+            \([N, k]\) The estimated \(E(\\tilde{w}_i|Y_i,X_i)\).
+        var_w_tilde  : np.array 
+            \([N, k]\) The estimated \(Var(\\tilde{w}_i|Y_i,X_i)\).
+        z_mean : np.array
+            \([N, d]\) The estimated latent mean.
+        &#39;&#39;&#39;   
+        if self.latent_space is None:
+            raise ReferenceError(&#39;Have not initialized the latent space.&#39;)
+        
+        print(&#39;Computing posterior estimations over mini-batches.&#39;)
+        progbar = Progbar(test_dataset.cardinality().numpy())
+        pi_norm = tf.nn.softmax(self.latent_space.pi).numpy()
+        mu = self.latent_space.mu.numpy()
+        z_mean = []
+        p_c_x = []
+        w_tilde = []
+        var_w_tilde = []
+        for step, (x,c_score, _, _) in enumerate(test_dataset):
+            x = tf.concat([x, c_score], -1) if self.has_cov else x
+            _z_mean, _, z = self.encoder(x, L, False)
+            res = self.latent_space(z, inference=True)
+            
+            z_mean.append(_z_mean.numpy())
+            p_c_x.append(res[&#39;p_c_x&#39;])            
+            w_tilde.append(res[&#39;w_tilde&#39;])
+            var_w_tilde.append(res[&#39;var_w_tilde&#39;])
+            progbar.update(step+1)
+
+        z_mean = np.concatenate(z_mean)
+        p_c_x = np.concatenate(p_c_x)
+        w_tilde = np.concatenate(w_tilde)
+        w_tilde /= np.sum(w_tilde, axis=1, keepdims=True)
+        var_w_tilde = np.concatenate(var_w_tilde)
+        return pi_norm, mu, p_c_x, w_tilde, var_w_tilde, z_mean</code></pre>
+</details>
+<h3>Ancestors</h3>
+<ul class="hlist">
+<li>keras.src.engine.training.Model</li>
+<li>keras.src.engine.base_layer.Layer</li>
+<li>tensorflow.python.module.module.Module</li>
+<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li>
+<li>tensorflow.python.trackable.base.Trackable</li>
+<li>keras.src.utils.version_utils.LayerVersionSelector</li>
+<li>keras.src.utils.version_utils.ModelVersionSelector</li>
+</ul>
+<h3>Methods</h3>
+<dl>
+<dt id="VITAE.model.VariationalAutoEncoder.init_latent_space"><code class="name flex">
+<span>def <span class="ident">init_latent_space</span></span>(<span>self, n_clusters, mu, log_pi=None)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Initialze the latent space.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>n_clusters</code></strong> :&ensp;<code>int</code></dt>
+<dd>The number of vertices in the latent space.</dd>
+<dt><strong><code>mu</code></strong> :&ensp;<code>np.array</code></dt>
+<dd><span><span class="MathJax_Preview">[d, k]</span><script type="math/tex">[d, k]</script></span> The position matrix.</dd>
+<dt><strong><code>log_pi</code></strong> :&ensp;<code>np.array</code>, optional</dt>
+<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> <span><span class="MathJax_Preview">\log\pi</span><script type="math/tex">\log\pi</script></span>.</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.model.VariationalAutoEncoder.create_pilayer"><code class="name flex">
+<span>def <span class="ident">create_pilayer</span></span>(<span>self)</span>
+</code></dt>
+<dd>
+<div class="desc"></div>
+</dd>
+<dt id="VITAE.model.VariationalAutoEncoder.call"><code class="name flex">
+<span>def <span class="ident">call</span></span>(<span>self, x_normalized, c_score, x=None, scale_factor=1, pre_train=False, L=1, alpha=0.0, gamma=0.0, phi=1.0, conditions=None, pi_cov=None)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Feed forward through encoder, LatentSpace layer and decoder.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>x_normalized</code></strong> :&ensp;<code>np.array</code></dt>
+<dd><span><span class="MathJax_Preview">[B, G]</span><script type="math/tex">[B, G]</script></span> The preprocessed data.</dd>
+<dt><strong><code>c_score</code></strong> :&ensp;<code>np.array</code></dt>
+<dd><span><span class="MathJax_Preview">[B, s]</span><script type="math/tex">[B, s]</script></span> The covariates <span><span class="MathJax_Preview">X_i</span><script type="math/tex">X_i</script></span>, only used when <code>has_cov=True</code>.</dd>
+<dt><strong><code>x</code></strong> :&ensp;<code>np.array</code>, optional</dt>
+<dd><span><span class="MathJax_Preview">[B, G]</span><script type="math/tex">[B, G]</script></span> The original count data <span><span class="MathJax_Preview">Y_i</span><script type="math/tex">Y_i</script></span>, only used when data_type is not <code>'Gaussian'</code>.</dd>
+<dt><strong><code>scale_factor</code></strong> :&ensp;<code>np.array</code>, optional</dt>
+<dd><span><span class="MathJax_Preview">[B, ]</span><script type="math/tex">[B, ]</script></span> The scale factors, only used when data_type is not <code>'Gaussian'</code>.</dd>
+<dt><strong><code>pre_train</code></strong> :&ensp;<code>boolean</code>, optional</dt>
+<dd>Whether in the pre-training phare or not.</dd>
+<dt><strong><code>L</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The number of MC samples.</dd>
+<dt><strong><code>alpha</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The penalty parameter for covariates adjustment.</dd>
+<dt><strong><code>gamma</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The weight of mmd loss</dd>
+<dt><strong><code>phi</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The weight of Jacob norm of the encoder.</dd>
+<dt><strong><code>conditions</code></strong> :&ensp;<code>str</code> or <code>list</code>, optional</dt>
+<dd>The conditions of different cells from the selected batch</dd>
+</dl>
+<h2 id="returns">Returns</h2>
+<dl>
+<dt><strong><code>losses</code></strong> :&ensp;<code>float</code></dt>
+<dd>the loss.</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.model.VariationalAutoEncoder.get_z"><code class="name flex">
+<span>def <span class="ident">get_z</span></span>(<span>self, x_normalized, c_score)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Get <span><span class="MathJax_Preview">q(Z_i|Y_i,X_i)</span><script type="math/tex">q(Z_i|Y_i,X_i)</script></span>.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>x_normalized</code></strong> :&ensp;<code>int</code></dt>
+<dd><span><span class="MathJax_Preview">[B, G]</span><script type="math/tex">[B, G]</script></span> The preprocessed data.</dd>
+<dt><strong><code>c_score</code></strong> :&ensp;<code>np.array</code></dt>
+<dd><span><span class="MathJax_Preview">[B, s]</span><script type="math/tex">[B, s]</script></span> The covariates <span><span class="MathJax_Preview">X_i</span><script type="math/tex">X_i</script></span>, only used when <code>has_cov=True</code>.</dd>
+</dl>
+<h2 id="returns">Returns</h2>
+<dl>
+<dt><strong><code>z_mean</code></strong> :&ensp;<code>np.array</code></dt>
+<dd><span><span class="MathJax_Preview">[B, d]</span><script type="math/tex">[B, d]</script></span> The latent mean.</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.model.VariationalAutoEncoder.get_pc_x"><code class="name flex">
+<span>def <span class="ident">get_pc_x</span></span>(<span>self, test_dataset)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Get <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>test_dataset</code></strong> :&ensp;<code>tf.Dataset</code></dt>
+<dd>the dataset object.</dd>
+</dl>
+<h2 id="returns">Returns</h2>
+<dl>
+<dt><strong><code>pi_norm</code></strong> :&ensp;<code>np.array</code></dt>
+<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> The estimated <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</dd>
+<dt><strong><code>p_c_x</code></strong> :&ensp;<code>np.array</code></dt>
+<dd><span><span class="MathJax_Preview">[N, ]</span><script type="math/tex">[N, ]</script></span> The estimated <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.model.VariationalAutoEncoder.inference"><code class="name flex">
+<span>def <span class="ident">inference</span></span>(<span>self, test_dataset, L=1)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Get <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>test_dataset</code></strong> :&ensp;<code>tf.Dataset</code></dt>
+<dd>The dataset object.</dd>
+<dt><strong><code>L</code></strong> :&ensp;<code>int</code></dt>
+<dd>The number of MC samples.</dd>
+</dl>
+<h2 id="returns">Returns</h2>
+<dl>
+<dt><code>pi_norm
+: np.array</code></dt>
+<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> The estimated <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</dd>
+<dt><strong><code>mu</code></strong> :&ensp;<code>np.array</code></dt>
+<dd><span><span class="MathJax_Preview">[d, k]</span><script type="math/tex">[d, k]</script></span> The estimated <span><span class="MathJax_Preview">\mu</span><script type="math/tex">\mu</script></span>.</dd>
+<dt><strong><code>p_c_x</code></strong> :&ensp;<code>np.array</code></dt>
+<dd><span><span class="MathJax_Preview">[N, ]</span><script type="math/tex">[N, ]</script></span> The estimated <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</dd>
+<dt><strong><code>w_tilde</code></strong> :&ensp;<code>np.array</code></dt>
+<dd><span><span class="MathJax_Preview">[N, k]</span><script type="math/tex">[N, k]</script></span> The estimated <span><span class="MathJax_Preview">E(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">E(\tilde{w}_i|Y_i,X_i)</script></span>.</dd>
+<dt><code>var_w_tilde
+: np.array</code></dt>
+<dd><span><span class="MathJax_Preview">[N, k]</span><script type="math/tex">[N, k]</script></span> The estimated <span><span class="MathJax_Preview">Var(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">Var(\tilde{w}_i|Y_i,X_i)</script></span>.</dd>
+<dt><strong><code>z_mean</code></strong> :&ensp;<code>np.array</code></dt>
+<dd><span><span class="MathJax_Preview">[N, d]</span><script type="math/tex">[N, d]</script></span> The estimated latent mean.</dd>
+</dl></div>
+</dd>
+</dl>
+</dd>
+</dl>
+</section>
+</article>
+<nav id="sidebar">
+<div class="toc">
+<ul></ul>
+</div>
+<ul id="index">
+<li><h3>Super-module</h3>
+<ul>
+<li><code><a title="VITAE" href="index.html">VITAE</a></code></li>
+</ul>
+</li>
+<li><h3><a href="#header-classes">Classes</a></h3>
+<ul>
+<li>
+<h4><code><a title="VITAE.model.cdf_layer" href="#VITAE.model.cdf_layer">cdf_layer</a></code></h4>
+<ul class="">
+<li><code><a title="VITAE.model.cdf_layer.call" href="#VITAE.model.cdf_layer.call">call</a></code></li>
+<li><code><a title="VITAE.model.cdf_layer.func" href="#VITAE.model.cdf_layer.func">func</a></code></li>
+</ul>
+</li>
+<li>
+<h4><code><a title="VITAE.model.Sampling" href="#VITAE.model.Sampling">Sampling</a></code></h4>
+<ul class="">
+<li><code><a title="VITAE.model.Sampling.call" href="#VITAE.model.Sampling.call">call</a></code></li>
+</ul>
+</li>
+<li>
+<h4><code><a title="VITAE.model.Encoder" href="#VITAE.model.Encoder">Encoder</a></code></h4>
+<ul class="">
+<li><code><a title="VITAE.model.Encoder.call" href="#VITAE.model.Encoder.call">call</a></code></li>
+</ul>
+</li>
+<li>
+<h4><code><a title="VITAE.model.Decoder" href="#VITAE.model.Decoder">Decoder</a></code></h4>
+<ul class="">
+<li><code><a title="VITAE.model.Decoder.call" href="#VITAE.model.Decoder.call">call</a></code></li>
+</ul>
+</li>
+<li>
+<h4><code><a title="VITAE.model.LatentSpace" href="#VITAE.model.LatentSpace">LatentSpace</a></code></h4>
+<ul class="">
+<li><code><a title="VITAE.model.LatentSpace.initialize" href="#VITAE.model.LatentSpace.initialize">initialize</a></code></li>
+<li><code><a title="VITAE.model.LatentSpace.normalize" href="#VITAE.model.LatentSpace.normalize">normalize</a></code></li>
+<li><code><a title="VITAE.model.LatentSpace.get_pz" href="#VITAE.model.LatentSpace.get_pz">get_pz</a></code></li>
+<li><code><a title="VITAE.model.LatentSpace.get_posterior_c" href="#VITAE.model.LatentSpace.get_posterior_c">get_posterior_c</a></code></li>
+<li><code><a title="VITAE.model.LatentSpace.call" href="#VITAE.model.LatentSpace.call">call</a></code></li>
+</ul>
+</li>
+<li>
+<h4><code><a title="VITAE.model.VariationalAutoEncoder" href="#VITAE.model.VariationalAutoEncoder">VariationalAutoEncoder</a></code></h4>
+<ul class="two-column">
+<li><code><a title="VITAE.model.VariationalAutoEncoder.init_latent_space" href="#VITAE.model.VariationalAutoEncoder.init_latent_space">init_latent_space</a></code></li>
+<li><code><a title="VITAE.model.VariationalAutoEncoder.create_pilayer" href="#VITAE.model.VariationalAutoEncoder.create_pilayer">create_pilayer</a></code></li>
+<li><code><a title="VITAE.model.VariationalAutoEncoder.call" href="#VITAE.model.VariationalAutoEncoder.call">call</a></code></li>
+<li><code><a title="VITAE.model.VariationalAutoEncoder.get_z" href="#VITAE.model.VariationalAutoEncoder.get_z">get_z</a></code></li>
+<li><code><a title="VITAE.model.VariationalAutoEncoder.get_pc_x" href="#VITAE.model.VariationalAutoEncoder.get_pc_x">get_pc_x</a></code></li>
+<li><code><a title="VITAE.model.VariationalAutoEncoder.inference" href="#VITAE.model.VariationalAutoEncoder.inference">inference</a></code></li>
+</ul>
+</li>
+</ul>
+</li>
+</ul>
+</nav>
+</main>
+<footer id="footer">
+<p>Generated by <a href="https://pdoc3.github.io/pdoc" title="pdoc: Python API documentation generator"><cite>pdoc</cite> 0.11.1</a>.</p>
+</footer>
+</body>
+</html>