<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1">
<meta name="generator" content="pdoc3 0.11.1">
<title>VITAE.model API documentation</title>
<meta name="description" content="">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/sanitize.min.css" integrity="sha512-y1dtMcuvtTMJc1yPgEqF0ZjQbhnc/bFhyvIyVNb9Zk5mIGtqVaAB1Ttl28su8AvFMOY0EwRbAe+HCLqj6W7/KA==" crossorigin>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/typography.min.css" integrity="sha512-Y1DYSb995BAfxobCkKepB1BqJJTPrOp3zPL74AWFugHHmmdcvO+C48WLrUOlhGMc0QG7AE3f7gmvvcrmX2fDoA==" crossorigin>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/default.min.css" crossorigin>
<style>:root{--highlight-color:#fe9}.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:1.5em;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:2em 0 .50em 0}h3{font-size:1.4em;margin:1.6em 0 .7em 0}h4{margin:0;font-size:105%}h1:target,h2:target,h3:target,h4:target,h5:target,h6:target{background:var(--highlight-color);padding:.2em 0}a{color:#058;text-decoration:none;transition:color .2s ease-in-out}a:visited{color:#503}a:hover{color:#b62}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900;font-weight:bold}pre code{font-size:.8em;line-height:1.4em;padding:1em;display:block}code{background:#f3f3f3;font-family:"DejaVu Sans Mono",monospace;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}dt:target .name{background:var(--highlight-color)}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}td{padding:0 .5em}.admonition{padding:.1em 1em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style>
<style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul ul{padding-left:1em}.toc > ul > li{margin-top:.5em}}</style>
<style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style>
<script type="text/x-mathjax-config">MathJax.Hub.Config({ tex2jax: { inlineMath: [ ['$','$'], ["\\(","\\)"] ], processEscapes: true } });</script>
<script async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS_CHTML" integrity="sha256-kZafAc6mZvK3W3v1pHOcUix30OHQN6pU/NO2oFkqZVw=" crossorigin></script>
<script defer src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js" integrity="sha512-D9gUyxqja7hBtkWpPWGt9wfbfaMGVt9gnyCvYa+jojwwPHLCzUm5i8rpk7vD7wNee9bA35eYIjobYPaQuKS1MQ==" crossorigin></script>
<script>window.addEventListener('DOMContentLoaded', () => {
hljs.configure({languages: ['bash', 'css', 'diff', 'graphql', 'ini', 'javascript', 'json', 'plaintext', 'python', 'python-repl', 'rust', 'shell', 'sql', 'typescript', 'xml', 'yaml']});
hljs.highlightAll();
})</script>
</head>
<body>
<main>
<article id="content">
<header>
<h1 class="title">Module <code>VITAE.model</code></h1>
</header>
<section id="section-intro">
</section>
<section>
</section>
<section>
</section>
<section>
</section>
<section>
<h2 class="section-title" id="header-classes">Classes</h2>
<dl>
<dt id="VITAE.model.cdf_layer"><code class="flex name class">
<span>class <span class="ident">cdf_layer</span></span>
</code></dt>
<dd>
<div class="desc"><p>The Normal cdf layer with custom gradients.</p></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">class cdf_layer(Layer):
'''
The Normal cdf layer with custom gradients.
'''
def __init__(self):
'''
'''
super(cdf_layer, self).__init__()
@tf.function
def call(self, x):
return self.func(x)
@tf.custom_gradient
def func(self, x):
'''Return cdf(x) and pdf(x).
Parameters
----------
x : tf.Tensor
The input tensor.
Returns
----------
f : tf.Tensor
cdf(x).
grad : tf.Tensor
pdf(x).
'''
dist = tfp.distributions.Normal(
loc = tf.constant(0.0, tf.keras.backend.floatx()),
scale = tf.constant(1.0, tf.keras.backend.floatx()),
allow_nan_stats=False)
f = dist.cdf(x)
def grad(dy):
gradient = dist.prob(x)
return dy * gradient
return f, grad</code></pre>
</details>
<h3>Ancestors</h3>
<ul class="hlist">
<li>keras.src.engine.base_layer.Layer</li>
<li>tensorflow.python.module.module.Module</li>
<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li>
<li>tensorflow.python.trackable.base.Trackable</li>
<li>keras.src.utils.version_utils.LayerVersionSelector</li>
</ul>
<h3>Methods</h3>
<dl>
<dt id="VITAE.model.cdf_layer.call"><code class="name flex">
<span>def <span class="ident">call</span></span>(<span>self, x)</span>
</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="VITAE.model.cdf_layer.func"><code class="name flex">
<span>def <span class="ident">func</span></span>(<span>self, x)</span>
</code></dt>
<dd>
<div class="desc"><p>Return cdf(x) and pdf(x).</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>x</code></strong> : <code>tf.Tensor</code></dt>
<dd>The input tensor.</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>f</code></strong> : <code>tf.Tensor</code></dt>
<dd>cdf(x).</dd>
<dt><strong><code>grad</code></strong> : <code>tf.Tensor</code></dt>
<dd>pdf(x).</dd>
</dl></div>
</dd>
</dl>
</dd>
<dt id="VITAE.model.Sampling"><code class="flex name class">
<span>class <span class="ident">Sampling</span></span>
<span>(</span><span>seed=0, **kwargs)</span>
</code></dt>
<dd>
<div class="desc"><p>Sampling latent variable <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span> from <span><span class="MathJax_Preview">N(\mu_z, \log \sigma_z^2</span><script type="math/tex">N(\mu_z, \log \sigma_z^2</script></span>).
<br>
Used in Encoder.</p></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">class Sampling(Layer):
"""Sampling latent variable \(z\) from \(N(\\mu_z, \\log \\sigma_z^2\)).
Used in Encoder.
"""
def __init__(self, seed=0, **kwargs):
super(Sampling, self).__init__(**kwargs)
self.seed = seed
@tf.function
def call(self, z_mean, z_log_var):
'''Return cdf(x) and pdf(x).
Parameters
----------
z_mean : tf.Tensor
\([B, L, d]\) The mean of \(z\).
z_log_var : tf.Tensor
\([B, L, d]\) The log-variance of \(z\).
Returns
----------
z : tf.Tensor
\([B, L, d]\) The sampled \(z\).
'''
# seed = tfp.util.SeedStream(self.seed, salt="random_normal")
# epsilon = tf.random.normal(shape = tf.shape(z_mean), seed=seed(), dtype=tf.keras.backend.floatx())
epsilon = tf.random.normal(shape = tf.shape(z_mean), dtype=tf.keras.backend.floatx())
z = z_mean + tf.exp(0.5 * z_log_var) * epsilon
z = tf.clip_by_value(z, -1e6, 1e6)
return z</code></pre>
</details>
<h3>Ancestors</h3>
<ul class="hlist">
<li>keras.src.engine.base_layer.Layer</li>
<li>tensorflow.python.module.module.Module</li>
<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li>
<li>tensorflow.python.trackable.base.Trackable</li>
<li>keras.src.utils.version_utils.LayerVersionSelector</li>
</ul>
<h3>Methods</h3>
<dl>
<dt id="VITAE.model.Sampling.call"><code class="name flex">
<span>def <span class="ident">call</span></span>(<span>self, z_mean, z_log_var)</span>
</code></dt>
<dd>
<div class="desc"><p>Return cdf(x) and pdf(x).</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>z_mean</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The mean of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
<dt><strong><code>z_log_var</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The log-variance of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>z</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The sampled <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
</dl></div>
</dd>
</dl>
</dd>
<dt id="VITAE.model.Encoder"><code class="flex name class">
<span>class <span class="ident">Encoder</span></span>
<span>(</span><span>dimensions, dim_latent, name='encoder', **kwargs)</span>
</code></dt>
<dd>
<div class="desc"><p>Encoder, model <span><span class="MathJax_Preview">p(Z_i|Y_i,X_i)</span><script type="math/tex">p(Z_i|Y_i,X_i)</script></span>.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>dimensions</code></strong> : <code>np.array</code></dt>
<dd>The dimensions of hidden layers of the encoder.</dd>
<dt><strong><code>dim_latent</code></strong> : <code>int</code></dt>
<dd>The latent dimension of the encoder.</dd>
<dt><strong><code>name</code></strong> : <code>str</code>, optional</dt>
<dd>The name of the layer.</dd>
<dt><strong><code>**kwargs</code></strong></dt>
<dd>Extra keyword arguments.</dd>
</dl></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">class Encoder(Layer):
'''
Encoder, model \(p(Z_i|Y_i,X_i)\).
'''
def __init__(self, dimensions, dim_latent, name='encoder', **kwargs):
'''
Parameters
----------
dimensions : np.array
The dimensions of hidden layers of the encoder.
dim_latent : int
The latent dimension of the encoder.
name : str, optional
The name of the layer.
**kwargs :
Extra keyword arguments.
'''
super(Encoder, self).__init__(name = name, **kwargs)
self.dense_layers = [Dense(dim, activation = tf.nn.leaky_relu,
name = 'encoder_%i'%(i+1)) \
for (i, dim) in enumerate(dimensions)]
self.batch_norm_layers = [BatchNormalization(center=False) \
for _ in range(len((dimensions)))]
self.batch_norm_layers.append(BatchNormalization(center=False))
self.latent_mean = Dense(dim_latent, name = 'latent_mean')
self.latent_log_var = Dense(dim_latent, name = 'latent_log_var')
self.sampling = Sampling()
@tf.function
def call(self, x, L=1, is_training=True):
'''Encode the inputs and get the latent variables.
Parameters
----------
x : tf.Tensor
\([B, L, d]\) The input.
L : int, optional
The number of MC samples.
is_training : boolean, optional
Whether in the training or inference mode.
Returns
----------
z_mean : tf.Tensor
\([B, L, d]\) The mean of \(z\).
z_log_var : tf.Tensor
\([B, L, d]\) The log-variance of \(z\).
z : tf.Tensor
\([B, L, d]\) The sampled \(z\).
'''
for dense, bn in zip(self.dense_layers, self.batch_norm_layers):
x = dense(x)
x = bn(x, training=is_training)
z_mean = self.batch_norm_layers[-1](self.latent_mean(x), training=is_training)
z_log_var = self.latent_log_var(x)
_z_mean = tf.tile(tf.expand_dims(z_mean, 1), (1,L,1))
_z_log_var = tf.tile(tf.expand_dims(z_log_var, 1), (1,L,1))
z = self.sampling(_z_mean, _z_log_var)
return z_mean, z_log_var, z</code></pre>
</details>
<h3>Ancestors</h3>
<ul class="hlist">
<li>keras.src.engine.base_layer.Layer</li>
<li>tensorflow.python.module.module.Module</li>
<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li>
<li>tensorflow.python.trackable.base.Trackable</li>
<li>keras.src.utils.version_utils.LayerVersionSelector</li>
</ul>
<h3>Methods</h3>
<dl>
<dt id="VITAE.model.Encoder.call"><code class="name flex">
<span>def <span class="ident">call</span></span>(<span>self, x, L=1, is_training=True)</span>
</code></dt>
<dd>
<div class="desc"><p>Encode the inputs and get the latent variables.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>x</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The input.</dd>
<dt><strong><code>L</code></strong> : <code>int</code>, optional</dt>
<dd>The number of MC samples.</dd>
<dt><strong><code>is_training</code></strong> : <code>boolean</code>, optional</dt>
<dd>Whether in the training or inference mode.</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>z_mean</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The mean of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
<dt><strong><code>z_log_var</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The log-variance of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
<dt><strong><code>z</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The sampled <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
</dl></div>
</dd>
</dl>
</dd>
<dt id="VITAE.model.Decoder"><code class="flex name class">
<span>class <span class="ident">Decoder</span></span>
<span>(</span><span>dimensions, dim_origin, data_type='UMI', name='decoder', **kwargs)</span>
</code></dt>
<dd>
<div class="desc"><p>Decoder, model <span><span class="MathJax_Preview">p(Y_i|Z_i,X_i)</span><script type="math/tex">p(Y_i|Z_i,X_i)</script></span>.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>dimensions</code></strong> : <code>np.array</code></dt>
<dd>The dimensions of hidden layers of the encoder.</dd>
<dt><strong><code>dim_origin</code></strong> : <code>int</code></dt>
<dd>The output dimension of the decoder.</dd>
<dt><strong><code>data_type</code></strong> : <code>str</code>, optional</dt>
<dd><code>'UMI'</code>, <code>'non-UMI'</code>, or <code>'Gaussian'</code>.</dd>
<dt><strong><code>name</code></strong> : <code>str</code>, optional</dt>
<dd>The name of the layer.</dd>
</dl></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">class Decoder(Layer):
'''
Decoder, model \(p(Y_i|Z_i,X_i)\).
'''
def __init__(self, dimensions, dim_origin, data_type = 'UMI',
name = 'decoder', **kwargs):
'''
Parameters
----------
dimensions : np.array
The dimensions of hidden layers of the encoder.
dim_origin : int
The output dimension of the decoder.
data_type : str, optional
`'UMI'`, `'non-UMI'`, or `'Gaussian'`.
name : str, optional
The name of the layer.
'''
super(Decoder, self).__init__(name = name, **kwargs)
self.data_type = data_type
self.dense_layers = [Dense(dim, activation = tf.nn.leaky_relu,
name = 'decoder_%i'%(i+1)) \
for (i,dim) in enumerate(dimensions)]
self.batch_norm_layers = [BatchNormalization(center=False) \
for _ in range(len((dimensions)))]
if data_type=='Gaussian':
self.nu_z = Dense(dim_origin, name = 'nu_z')
# common variance
self.log_tau = tf.Variable(tf.zeros([1, dim_origin], dtype=tf.keras.backend.floatx()),
constraint = lambda t: tf.clip_by_value(t,-30.,6.),
name = "log_tau")
else:
self.log_lambda_z = Dense(dim_origin, name = 'log_lambda_z')
# dispersion parameter
self.log_r = tf.Variable(tf.zeros([1, dim_origin], dtype=tf.keras.backend.floatx()),
constraint = lambda t: tf.clip_by_value(t,-30.,6.),
name = "log_r")
if self.data_type == 'non-UMI':
self.phi = Dense(dim_origin, activation = 'sigmoid', name = "phi")
@tf.function
def call(self, z, is_training=True):
'''Decode the latent variables and get the reconstructions.
Parameters
----------
z : tf.Tensor
\([B, L, d]\) the sampled \(z\).
is_training : boolean, optional
whether in the training or inference mode.
When `data_type=='Gaussian'`:
Returns
----------
nu_z : tf.Tensor
\([B, L, G]\) The mean of \(Y_i|Z_i,X_i\).
tau : tf.Tensor
\([1, G]\) The variance of \(Y_i|Z_i,X_i\).
When `data_type=='UMI'`:
Returns
----------
lambda_z : tf.Tensor
\([B, L, G]\) The mean of \(Y_i|Z_i,X_i\).
r : tf.Tensor
\([1, G]\) The dispersion parameters of \(Y_i|Z_i,X_i\).
When `data_type=='non-UMI'`:
Returns
----------
lambda_z : tf.Tensor
\([B, L, G]\) The mean of \(Y_i|Z_i,X_i\).
r : tf.Tensor
\([1, G]\) The dispersion parameters of \(Y_i|Z_i,X_i\).
phi_z : tf.Tensor
\([1, G]\) The zero inflated parameters of \(Y_i|Z_i,X_i\).
'''
for dense, bn in zip(self.dense_layers, self.batch_norm_layers):
z = dense(z)
z = bn(z, training=is_training)
if self.data_type=='Gaussian':
nu_z = self.nu_z(z)
tau = tf.exp(self.log_tau)
return nu_z, tau
else:
lambda_z = tf.math.exp(
tf.clip_by_value(self.log_lambda_z(z), -30., 6.)
)
r = tf.exp(self.log_r)
if self.data_type=='UMI':
return lambda_z, r
else:
return lambda_z, r, self.phi(z)</code></pre>
</details>
<h3>Ancestors</h3>
<ul class="hlist">
<li>keras.src.engine.base_layer.Layer</li>
<li>tensorflow.python.module.module.Module</li>
<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li>
<li>tensorflow.python.trackable.base.Trackable</li>
<li>keras.src.utils.version_utils.LayerVersionSelector</li>
</ul>
<h3>Methods</h3>
<dl>
<dt id="VITAE.model.Decoder.call"><code class="name flex">
<span>def <span class="ident">call</span></span>(<span>self, z, is_training=True)</span>
</code></dt>
<dd>
<div class="desc"><p>Decode the latent variables and get the reconstructions.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>z</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> the sampled <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd>
<dt><strong><code>is_training</code></strong> : <code>boolean</code>, optional</dt>
<dd>whether in the training or inference mode.</dd>
</dl>
<p>When <code>data_type=='Gaussian'</code>:</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>nu_z</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, G]</span><script type="math/tex">[B, L, G]</script></span> The mean of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
<dt><strong><code>tau</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The variance of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
</dl>
<p>When <code>data_type=='UMI'</code>:</p>
<h2 id="returns_1">Returns</h2>
<dl>
<dt><strong><code>lambda_z</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, G]</span><script type="math/tex">[B, L, G]</script></span> The mean of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
<dt><strong><code>r</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The dispersion parameters of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
</dl>
<p>When <code>data_type=='non-UMI'</code>:</p>
<h2 id="returns_2">Returns</h2>
<dl>
<dt><strong><code>lambda_z</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, G]</span><script type="math/tex">[B, L, G]</script></span> The mean of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
<dt><strong><code>r</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The dispersion parameters of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
<dt><strong><code>phi_z</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The zero inflated parameters of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd>
</dl></div>
</dd>
</dl>
</dd>
<dt id="VITAE.model.LatentSpace"><code class="flex name class">
<span>class <span class="ident">LatentSpace</span></span>
<span>(</span><span>n_clusters, dim_latent, name='LatentSpace', seed=0, **kwargs)</span>
</code></dt>
<dd>
<div class="desc"><p>Layer for the Latent Space.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>n_clusters</code></strong> : <code>int</code></dt>
<dd>The number of vertices in the latent space.</dd>
<dt><strong><code>dim_latent</code></strong> : <code>int</code></dt>
<dd>The latent dimension.</dd>
<dt><strong><code>M</code></strong> : <code>int</code>, optional</dt>
<dd>The discretized number of uniform(0,1).</dd>
<dt><strong><code>name</code></strong> : <code>str</code>, optional</dt>
<dd>The name of the layer.</dd>
<dt><strong><code>**kwargs</code></strong></dt>
<dd>Extra keyword arguments.</dd>
</dl></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">class LatentSpace(Layer):
'''
Layer for the Latent Space.
'''
def __init__(self, n_clusters, dim_latent,
name = 'LatentSpace', seed=0, **kwargs):
'''
Parameters
----------
n_clusters : int
The number of vertices in the latent space.
dim_latent : int
The latent dimension.
M : int, optional
The discretized number of uniform(0,1).
name : str, optional
The name of the layer.
**kwargs :
Extra keyword arguments.
'''
super(LatentSpace, self).__init__(name=name, **kwargs)
self.dim_latent = dim_latent
self.n_states = n_clusters
self.n_categories = int(n_clusters*(n_clusters+1)/2)
# nonzero indexes
# A = [0,0,...,0 , 1,1,...,1, ...]
# B = [0,1,...,k-1, 1,2,...,k-1, ...]
self.A, self.B = np.nonzero(np.triu(np.ones(n_clusters)))
self.A = tf.convert_to_tensor(self.A, tf.int32)
self.B = tf.convert_to_tensor(self.B, tf.int32)
self.clusters_ind = tf.boolean_mask(
tf.range(0,self.n_categories,1), self.A==self.B)
# [pi_1, ... , pi_K] in R^(n_categories)
self.pi = tf.Variable(tf.ones([1, self.n_categories], dtype=tf.keras.backend.floatx()) / self.n_categories,
name = 'pi')
# [mu_1, ... , mu_K] in R^(dim_latent * n_clusters)
self.mu = tf.Variable(tf.random.uniform([self.dim_latent, self.n_states],
minval = -1, maxval = 1, seed=seed, dtype=tf.keras.backend.floatx()),
name = 'mu')
self.cdf_layer = cdf_layer()
def initialize(self, mu, log_pi):
'''Initialize the latent space.
Parameters
----------
mu : np.array
\([d, k]\) The position matrix.
log_pi : np.array
\([1, K]\) \(\\log\\pi\).
'''
# Initialize parameters of the latent space
if mu is not None:
self.mu.assign(mu)
if log_pi is not None:
self.pi.assign(log_pi)
def normalize(self):
'''Normalize \(\\pi\).
'''
self.pi = tf.nn.softmax(self.pi)
@tf.function
def _get_normal_params(self, z, pi):
batch_size = tf.shape(z)[0]
L = tf.shape(z)[1]
# [batch_size, L, n_categories]
if pi is None:
# [batch_size, L, n_states]
temp_pi = tf.tile(
tf.expand_dims(tf.nn.softmax(self.pi), 1),
(batch_size,L,1))
else:
temp_pi = tf.expand_dims(tf.nn.softmax(pi), 1)
# [batch_size, L, d, n_categories]
alpha_zc = tf.expand_dims(tf.expand_dims(
tf.gather(self.mu, self.B, axis=1) - tf.gather(self.mu, self.A, axis=1), 0), 0)
beta_zc = tf.expand_dims(z,-1) - \
tf.expand_dims(tf.expand_dims(
tf.gather(self.mu, self.B, axis=1), 0), 0)
# [batch_size, L, n_categories]
inv_sig = tf.reduce_sum(alpha_zc * alpha_zc, axis=2)
nu = - tf.reduce_sum(alpha_zc * beta_zc, axis=2) * tf.math.reciprocal_no_nan(inv_sig)
_t = - tf.reduce_sum(beta_zc * beta_zc, axis=2) + nu**2*inv_sig
return temp_pi, beta_zc, inv_sig, nu, _t
@tf.function
def _get_pz(self, temp_pi, inv_sig, beta_zc, log_p_z_c_L):
# [batch_size, L, n_categories]
log_p_zc_L = - 0.5 * self.dim_latent * tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + \
tf.math.log(temp_pi) + \
tf.where(inv_sig==0,
- 0.5 * tf.reduce_sum(beta_zc**2, axis=2),
log_p_z_c_L)
# [batch_size, L, 1]
log_p_z_L = tf.reduce_logsumexp(log_p_zc_L, axis=-1, keepdims=True)
# [1, ]
log_p_z = tf.reduce_mean(log_p_z_L)
return log_p_zc_L, log_p_z_L, log_p_z
@tf.function
def _get_posterior_c(self, log_p_zc_L, log_p_z_L):
L = tf.shape(log_p_z_L)[1]
# log_p_c_x - predicted probability distribution
# [batch_size, n_categories]
log_p_c_x = tf.reduce_logsumexp(
log_p_zc_L - log_p_z_L,
axis=1) - tf.math.log(tf.cast(L, tf.keras.backend.floatx()))
return log_p_c_x
@tf.function
def _get_inference(self, z, log_p_z_L, temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2):
batch_size = tf.shape(z)[0]
L = tf.shape(z)[1]
dist = tfp.distributions.Normal(
loc = tf.constant(0.0, tf.keras.backend.floatx()),
scale = tf.constant(1.0, tf.keras.backend.floatx()),
allow_nan_stats=False)
# [batch_size, L, n_categories, n_clusters]
inv_sig = tf.expand_dims(inv_sig, -1)
_sig = tf.tile(tf.clip_by_value(tf.math.reciprocal_no_nan(inv_sig), 1e-12, 1e30), (1,1,1,self.n_states))
log_eta0 = tf.tile(tf.expand_dims(log_eta0, -1), (1,1,1,self.n_states))
eta1 = tf.tile(tf.expand_dims(eta1, -1), (1,1,1,self.n_states))
eta2 = tf.tile(tf.expand_dims(eta2, -1), (1,1,1,self.n_states))
nu = tf.tile(tf.expand_dims(nu, -1), (1,1,1,1))
A = tf.tile(tf.expand_dims(tf.expand_dims(
tf.one_hot(self.A, self.n_states, dtype=tf.keras.backend.floatx()),
0),0), (batch_size,L,1,1))
B = tf.tile(tf.expand_dims(tf.expand_dims(
tf.one_hot(self.B, self.n_states, dtype=tf.keras.backend.floatx()),
0),0), (batch_size,L,1,1))
temp_pi = tf.expand_dims(temp_pi, -1)
# w_tilde [batch_size, L, n_clusters]
w_tilde = log_eta0 + tf.math.log(
tf.clip_by_value(
(dist.cdf(eta1) - dist.cdf(eta2)) * (nu * A + (1-nu) * B) -
(dist.prob(eta1) - dist.prob(eta2)) * tf.math.sqrt(_sig) * (A - B),
0.0, 1e30)
)
w_tilde = - 0.5 * self.dim_latent * tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + \
tf.math.log(temp_pi) + \
tf.where(inv_sig==0,
tf.where(B==1, - 0.5 * tf.expand_dims(tf.reduce_sum(beta_zc**2, axis=2), -1), -np.inf),
w_tilde)
w_tilde = tf.exp(tf.reduce_logsumexp(w_tilde, 2) - log_p_z_L)
# tf.debugging.assert_greater_equal(
# tf.reduce_sum(w_tilde, -1), tf.ones([batch_size, L], dtype=tf.keras.backend.floatx())*0.99,
# message='Wrong w_tilde', summarize=None, name=None
# )
# var_w_tilde [batch_size, L, n_clusters]
var_w_tilde = log_eta0 + tf.math.log(
tf.clip_by_value(
(dist.cdf(eta1) - dist.cdf(eta2)) * ((_sig + nu**2) * (A+B) + (1-2*nu) * B) -
(dist.prob(eta1) - dist.prob(eta2)) * tf.math.sqrt(_sig) * (nu *(A+B)-B )*2 -
(eta1*dist.prob(eta1) - eta2*dist.prob(eta2)) * _sig *(A+B),
0.0, 1e30)
)
var_w_tilde = - 0.5 * self.dim_latent * tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + \
tf.math.log(temp_pi) + \
tf.where(inv_sig==0,
tf.where(B==1, - 0.5 * tf.expand_dims(tf.reduce_sum(beta_zc**2, axis=2), -1), -np.inf),
var_w_tilde)
var_w_tilde = tf.exp(tf.reduce_logsumexp(var_w_tilde, 2) - log_p_z_L) - w_tilde**2
w_tilde = tf.reduce_mean(w_tilde, 1)
var_w_tilde = tf.reduce_mean(var_w_tilde, 1)
return w_tilde, var_w_tilde
def get_pz(self, z, eps, pi):
'''Get \(\\log p(Z_i|Y_i,X_i)\).
Parameters
----------
z : tf.Tensor
\([B, L, d]\) The latent variables.
Returns
----------
temp_pi : tf.Tensor
\([B, L, K]\) \(\\pi\).
inv_sig : tf.Tensor
\([B, L, K]\) \(\\sigma_{Z_ic_i}^{-1}\).
nu : tf.Tensor
\([B, L, K]\) \(\\nu_{Z_ic_i}\).
beta_zc : tf.Tensor
\([B, L, d, K]\) \(\\beta_{Z_ic_i}\).
log_eta0 : tf.Tensor
\([B, L, K]\) \(\\log\\eta_{Z_ic_i,0}\).
eta1 : tf.Tensor
\([B, L, K]\) \(\\eta_{Z_ic_i,1}\).
eta2 : tf.Tensor
\([B, L, K]\) \(\\eta_{Z_ic_i,2}\).
log_p_zc_L : tf.Tensor
\([B, L, K]\) \(\\log p(Z_i,c_i|Y_i,X_i)\).
log_p_z_L : tf.Tensor
\([B, L]\) \(\\log p(Z_i|Y_i,X_i)\).
log_p_z : tf.Tensor
\([B, 1]\) The estimated \(\\log p(Z_i|Y_i,X_i)\).
'''
temp_pi, beta_zc, inv_sig, nu, _t = self._get_normal_params(z, pi)
temp_pi = tf.clip_by_value(temp_pi, eps, 1.0)
log_eta0 = 0.5 * (tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) - \
tf.math.log(tf.clip_by_value(inv_sig, 1e-12, 1e30)) + _t)
eta1 = (1-nu) * tf.math.sqrt(tf.clip_by_value(inv_sig, 1e-12, 1e30))
eta2 = -nu * tf.math.sqrt(tf.clip_by_value(inv_sig, 1e-12, 1e30))
log_p_z_c_L = log_eta0 + tf.math.log(tf.clip_by_value(
self.cdf_layer(eta1) - self.cdf_layer(eta2),
eps, 1e30))
log_p_zc_L, log_p_z_L, log_p_z = self._get_pz(temp_pi, inv_sig, beta_zc, log_p_z_c_L)
return temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2, log_p_zc_L, log_p_z_L, log_p_z
def get_posterior_c(self, z):
'''Get \(p(c_i|Y_i,X_i)\).
Parameters
----------
z : tf.Tensor
\([B, L, d]\) The latent variables.
Returns
----------
p_c_x : np.array
\([B, K]\) \(p(c_i|Y_i,X_i)\).
'''
_,_,_,_,_,_,_, log_p_zc_L, log_p_z_L, _ = self.get_pz(z)
log_p_c_x = self._get_posterior_c(log_p_zc_L, log_p_z_L)
p_c_x = tf.exp(log_p_c_x).numpy()
return p_c_x
def call(self, z, pi=None, inference=False):
'''Get posterior estimations.
Parameters
----------
z : tf.Tensor
\([B, L, d]\) The latent variables.
inference : boolean
Whether in training or inference mode.
When `inference=False`:
Returns
----------
log_p_z_L : tf.Tensor
\([B, 1]\) The estimated \(\\log p(Z_i|Y_i,X_i)\).
When `inference=True`:
Returns
----------
res : dict
The dict of posterior estimations - \(p(c_i|Y_i,X_i)\), \(c\), \(E(\\tilde{w}_i|Y_i,X_i)\), \(Var(\\tilde{w}_i|Y_i,X_i)\), \(D_{JS}\).
'''
eps = 1e-16 if not inference else 0.
temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2, log_p_zc_L, log_p_z_L, log_p_z = self.get_pz(z, eps, pi)
if not inference:
return log_p_z
else:
log_p_c_x = self._get_posterior_c(log_p_zc_L, log_p_z_L)
w_tilde, var_w_tilde = self._get_inference(z, log_p_z_L, temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2)
res = {}
res['p_c_x'] = tf.exp(log_p_c_x).numpy()
res['w_tilde'] = w_tilde.numpy()
res['var_w_tilde'] = var_w_tilde.numpy()
return res</code></pre>
</details>
<h3>Ancestors</h3>
<ul class="hlist">
<li>keras.src.engine.base_layer.Layer</li>
<li>tensorflow.python.module.module.Module</li>
<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li>
<li>tensorflow.python.trackable.base.Trackable</li>
<li>keras.src.utils.version_utils.LayerVersionSelector</li>
</ul>
<h3>Methods</h3>
<dl>
<dt id="VITAE.model.LatentSpace.initialize"><code class="name flex">
<span>def <span class="ident">initialize</span></span>(<span>self, mu, log_pi)</span>
</code></dt>
<dd>
<div class="desc"><p>Initialize the latent space.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>mu</code></strong> : <code>np.array</code></dt>
<dd><span><span class="MathJax_Preview">[d, k]</span><script type="math/tex">[d, k]</script></span> The position matrix.</dd>
<dt><strong><code>log_pi</code></strong> : <code>np.array</code></dt>
<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> <span><span class="MathJax_Preview">\log\pi</span><script type="math/tex">\log\pi</script></span>.</dd>
</dl></div>
</dd>
<dt id="VITAE.model.LatentSpace.normalize"><code class="name flex">
<span>def <span class="ident">normalize</span></span>(<span>self)</span>
</code></dt>
<dd>
<div class="desc"><p>Normalize <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</p></div>
</dd>
<dt id="VITAE.model.LatentSpace.get_pz"><code class="name flex">
<span>def <span class="ident">get_pz</span></span>(<span>self, z, eps, pi)</span>
</code></dt>
<dd>
<div class="desc"><p>Get <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>z</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The latent variables.</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>temp_pi</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</dd>
<dt><strong><code>inv_sig</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\sigma_{Z_ic_i}^{-1}</span><script type="math/tex">\sigma_{Z_ic_i}^{-1}</script></span>.</dd>
<dt><strong><code>nu</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\nu_{Z_ic_i}</span><script type="math/tex">\nu_{Z_ic_i}</script></span>.</dd>
<dt><strong><code>beta_zc</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, d, K]</span><script type="math/tex">[B, L, d, K]</script></span> <span><span class="MathJax_Preview">\beta_{Z_ic_i}</span><script type="math/tex">\beta_{Z_ic_i}</script></span>.</dd>
<dt><strong><code>log_eta0</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\log\eta_{Z_ic_i,0}</span><script type="math/tex">\log\eta_{Z_ic_i,0}</script></span>.</dd>
<dt><strong><code>eta1</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\eta_{Z_ic_i,1}</span><script type="math/tex">\eta_{Z_ic_i,1}</script></span>.</dd>
<dt><strong><code>eta2</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\eta_{Z_ic_i,2}</span><script type="math/tex">\eta_{Z_ic_i,2}</script></span>.</dd>
<dt><strong><code>log_p_zc_L</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\log p(Z_i,c_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i,c_i|Y_i,X_i)</script></span>.</dd>
<dt><strong><code>log_p_z_L</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L]</span><script type="math/tex">[B, L]</script></span> <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</dd>
<dt><strong><code>log_p_z</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, 1]</span><script type="math/tex">[B, 1]</script></span> The estimated <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</dd>
</dl></div>
</dd>
<dt id="VITAE.model.LatentSpace.get_posterior_c"><code class="name flex">
<span>def <span class="ident">get_posterior_c</span></span>(<span>self, z)</span>
</code></dt>
<dd>
<div class="desc"><p>Get <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>z</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The latent variables.</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>p_c_x</code></strong> : <code>np.array</code></dt>
<dd><span><span class="MathJax_Preview">[B, K]</span><script type="math/tex">[B, K]</script></span> <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</dd>
</dl></div>
</dd>
<dt id="VITAE.model.LatentSpace.call"><code class="name flex">
<span>def <span class="ident">call</span></span>(<span>self, z, pi=None, inference=False)</span>
</code></dt>
<dd>
<div class="desc"><p>Get posterior estimations.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>z</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The latent variables.</dd>
<dt><strong><code>inference</code></strong> : <code>boolean</code></dt>
<dd>Whether in training or inference mode.</dd>
</dl>
<p>When <code>inference=False</code>:</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>log_p_z_L</code></strong> : <code>tf.Tensor</code></dt>
<dd><span><span class="MathJax_Preview">[B, 1]</span><script type="math/tex">[B, 1]</script></span> The estimated <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</dd>
</dl>
<p>When <code>inference=True</code>:</p>
<h2 id="returns_1">Returns</h2>
<dl>
<dt><strong><code>res</code></strong> : <code>dict</code></dt>
<dd>The dict of posterior estimations - <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>, <span><span class="MathJax_Preview">c</span><script type="math/tex">c</script></span>, <span><span class="MathJax_Preview">E(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">E(\tilde{w}_i|Y_i,X_i)</script></span>, <span><span class="MathJax_Preview">Var(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">Var(\tilde{w}_i|Y_i,X_i)</script></span>, <span><span class="MathJax_Preview">D_{JS}</span><script type="math/tex">D_{JS}</script></span>.</dd>
</dl></div>
</dd>
</dl>
</dd>
<dt id="VITAE.model.VariationalAutoEncoder"><code class="flex name class">
<span>class <span class="ident">VariationalAutoEncoder</span></span>
<span>(</span><span>dim_origin, dimensions, dim_latent, data_type='UMI', has_cov=False, name='autoencoder', **kwargs)</span>
</code></dt>
<dd>
<div class="desc"><p>Combines the encoder, decoder and LatentSpace into an end-to-end model for training and inference.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>dim_origin</code></strong> : <code>int</code></dt>
<dd>The output dimension of the decoder.</dd>
<dt><strong><code>dimensions</code></strong> : <code>np.array</code></dt>
<dd>The dimensions of hidden layers of the encoder.</dd>
<dt><strong><code>dim_latent</code></strong> : <code>int</code></dt>
<dd>The latent dimension.</dd>
<dt><strong><code>data_type</code></strong> : <code>str</code>, optional</dt>
<dd><code>'UMI'</code>, <code>'non-UMI'</code>, or <code>'Gaussian'</code>.</dd>
<dt><strong><code>has_cov</code></strong> : <code>boolean</code></dt>
<dd>Whether has covariates or not.</dd>
<dt><strong><code>gamma</code></strong> : <code>float</code>, optional</dt>
<dd>The weights of the MMD loss</dd>
<dt><strong><code>name</code></strong> : <code>str</code>, optional</dt>
<dd>The name of the layer.</dd>
<dt><strong><code>**kwargs</code></strong></dt>
<dd>Extra keyword arguments.</dd>
</dl></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">class VariationalAutoEncoder(tf.keras.Model):
"""
Combines the encoder, decoder and LatentSpace into an end-to-end model for training and inference.
"""
def __init__(self, dim_origin, dimensions, dim_latent,
data_type = 'UMI', has_cov=False,
name = 'autoencoder', **kwargs):
'''
Parameters
----------
dim_origin : int
The output dimension of the decoder.
dimensions : np.array
The dimensions of hidden layers of the encoder.
dim_latent : int
The latent dimension.
data_type : str, optional
`'UMI'`, `'non-UMI'`, or `'Gaussian'`.
has_cov : boolean
Whether has covariates or not.
gamma : float, optional
The weights of the MMD loss
name : str, optional
The name of the layer.
**kwargs :
Extra keyword arguments.
'''
super(VariationalAutoEncoder, self).__init__(name = name, **kwargs)
self.data_type = data_type
self.dim_origin = dim_origin
self.dim_latent = dim_latent
self.encoder = Encoder(dimensions, dim_latent)
self.decoder = Decoder(dimensions[::-1], dim_origin, data_type, data_type)
self.has_cov = has_cov
def init_latent_space(self, n_clusters, mu, log_pi=None):
'''Initialze the latent space.
Parameters
----------
n_clusters : int
The number of vertices in the latent space.
mu : np.array
\([d, k]\) The position matrix.
log_pi : np.array, optional
\([1, K]\) \(\\log\\pi\).
'''
self.n_states = n_clusters
self.latent_space = LatentSpace(self.n_states, self.dim_latent)
self.latent_space.initialize(mu, log_pi)
self.pilayer = None
def create_pilayer(self):
self.pilayer = Dense(self.latent_space.n_categories, name = 'pi_layer')
def call(self, x_normalized, c_score, x = None, scale_factor = 1,
pre_train = False, L=1, alpha=0.0, gamma = 0.0, phi = 1.0, conditions = None, pi_cov = None):
'''Feed forward through encoder, LatentSpace layer and decoder.
Parameters
----------
x_normalized : np.array
\([B, G]\) The preprocessed data.
c_score : np.array
\([B, s]\) The covariates \(X_i\), only used when `has_cov=True`.
x : np.array, optional
\([B, G]\) The original count data \(Y_i\), only used when data_type is not `'Gaussian'`.
scale_factor : np.array, optional
\([B, ]\) The scale factors, only used when data_type is not `'Gaussian'`.
pre_train : boolean, optional
Whether in the pre-training phare or not.
L : int, optional
The number of MC samples.
alpha : float, optional
The penalty parameter for covariates adjustment.
gamma : float, optional
The weight of mmd loss
phi : float, optional
The weight of Jacob norm of the encoder.
conditions: str or list, optional
The conditions of different cells from the selected batch
Returns
----------
losses : float
the loss.
'''
if not pre_train and self.latent_space is None:
raise ReferenceError('Have not initialized the latent space.')
if self.has_cov:
x_normalized = tf.concat([x_normalized, c_score], -1)
else:
x_normalized
_, z_log_var, z = self.encoder(x_normalized, L)
if gamma == 0:
mmd_loss = 0.0
else:
mmd_loss = self._get_total_mmd_loss(conditions,z,gamma)
z_in = tf.concat([z, tf.tile(tf.expand_dims(c_score,1), (1,L,1))], -1) if self.has_cov else z
x = tf.tile(tf.expand_dims(x, 1), (1,L,1))
reconstruction_z_loss = self._get_reconstruction_loss(x, z_in, scale_factor, L)
if self.has_cov and alpha>0.0:
zero_in = tf.concat([tf.zeros([z.shape[0],1,z.shape[2]], dtype=tf.keras.backend.floatx()),
tf.tile(tf.expand_dims(c_score,1), (1,1,1))], -1)
reconstruction_zero_loss = self._get_reconstruction_loss(x, zero_in, scale_factor, 1)
reconstruction_z_loss = (1-alpha)*reconstruction_z_loss + alpha*reconstruction_zero_loss
self.add_loss(reconstruction_z_loss)
J_norm = self._get_Jacob(x_normalized, L)
self.add_loss((phi * J_norm))
# gamma weight has been used when call _mmd_loss function.
self.add_loss(mmd_loss)
if not pre_train:
pi = self.pilayer(pi_cov) if self.pilayer is not None else None
log_p_z = self.latent_space(z, pi, inference=False)
# - E_q[log p(z)]
self.add_loss(- log_p_z)
# - Eq[log q(z|x)]
E_qzx = - tf.reduce_mean(
0.5 * self.dim_latent *
(tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + 1.0) +
0.5 * tf.reduce_sum(z_log_var, axis=-1)
)
self.add_loss(E_qzx)
return self.losses
@tf.function
def _get_reconstruction_loss(self, x, z_in, scale_factor, L):
if self.data_type=='Gaussian':
# Gaussian Log-Likelihood Loss function
nu_z, tau = self.decoder(z_in)
neg_E_Gaus = 0.5 * tf.math.log(tf.clip_by_value(tau, 1e-12, 1e30)) + 0.5 * tf.math.square(x - nu_z) / tau
neg_E_Gaus = tf.reduce_mean(tf.reduce_sum(neg_E_Gaus, axis=-1))
return neg_E_Gaus
else:
if self.data_type == 'UMI':
x_hat, r = self.decoder(z_in)
else:
x_hat, r, phi = self.decoder(z_in)
x_hat = x_hat*tf.expand_dims(scale_factor, -1)
# Negative Log-Likelihood Loss function
# Ref for NB & ZINB loss functions:
# https://github.com/gokceneraslan/neuralnet_countmodels/blob/master/Count%20models%20with%20neuralnets.ipynb
# Negative Binomial loss
neg_E_nb = tf.math.lgamma(r) + tf.math.lgamma(x+1.0) \
- tf.math.lgamma(x+r) + \
(r+x) * tf.math.log(1.0 + (x_hat/r)) + \
x * (tf.math.log(r) - tf.math.log(tf.clip_by_value(x_hat, 1e-12, 1e30)))
if self.data_type == 'non-UMI':
# Zero-Inflated Negative Binomial loss
nb_case = neg_E_nb - tf.math.log(tf.clip_by_value(1.0-phi, 1e-12, 1e30))
zero_case = - tf.math.log(tf.clip_by_value(
phi + (1.0-phi) * tf.pow(r * tf.math.reciprocal_no_nan(r + x_hat), r),
1e-12, 1e30))
neg_E_nb = tf.where(tf.less(x, 1e-8), zero_case, nb_case)
neg_E_nb = tf.reduce_mean(tf.reduce_sum(neg_E_nb, axis=-1))
return neg_E_nb
def _get_total_mmd_loss(self,conditions,z,gamma):
mmd_loss = 0.0
conditions = tf.cast(conditions,tf.int32)
n_group = conditions.shape[1]
for i in range(n_group):
sub_conditions = conditions[:, i]
# 0 means not participant in mmd
z_cond = z[sub_conditions != 0]
sub_conditions = sub_conditions[sub_conditions != 0]
n_sub_group = tf.unique(sub_conditions)[0].shape[0]
real_labels = K.reshape(sub_conditions, (-1,)).numpy()
unique_set = list(set(real_labels))
reindex_dict = dict(zip(unique_set, range(n_sub_group)))
real_labels = [reindex_dict[x] for x in real_labels]
real_labels = tf.convert_to_tensor(real_labels,dtype=tf.int32)
if (n_sub_group == 1) | (n_sub_group == 0):
_loss = 0
else:
_loss = self._mmd_loss(real_labels=real_labels, y_pred=z_cond, gamma=gamma,
n_conditions=n_sub_group,
kernel_method='multi-scale-rbf',
computation_method="general")
mmd_loss = mmd_loss + _loss
return mmd_loss
# each loop the inputed shape is changed. Can not use @tf.function
# tf graph requires static shape and tensor dtype
def _mmd_loss(self, real_labels, y_pred, gamma, n_conditions, kernel_method='multi-scale-rbf',
computation_method="general"):
conditions_mmd = tf.dynamic_partition(y_pred, real_labels, num_partitions=n_conditions)
loss = 0.0
if computation_method.isdigit():
boundary = int(computation_method)
## every pair of groups will calculate a distance
for i in range(boundary):
for j in range(boundary, n_conditions):
loss += _nan2zero(compute_mmd(conditions_mmd[i], conditions_mmd[j], kernel_method))
else:
for i in range(len(conditions_mmd)):
for j in range(i):
loss += _nan2zero(compute_mmd(conditions_mmd[i], conditions_mmd[j], kernel_method))
# print("The loss is ", loss)
return gamma * loss
@tf.function
def _get_Jacob(self, x, L):
with tf.GradientTape() as g:
g.watch(x)
z_mean, z_log_var, z = self.encoder(x, L)
# y_mean, y_log_var = self.decoder(z)
## just jacobian will cause shape (batch,16,batch,64) matrix
J = g.batch_jacobian(z, x)
J_norm = tf.norm(J)
# tf.print(J_norm)
return J_norm
def get_z(self, x_normalized, c_score):
'''Get \(q(Z_i|Y_i,X_i)\).
Parameters
----------
x_normalized : int
\([B, G]\) The preprocessed data.
c_score : np.array
\([B, s]\) The covariates \(X_i\), only used when `has_cov=True`.
Returns
----------
z_mean : np.array
\([B, d]\) The latent mean.
'''
x_normalized = x_normalized if (not self.has_cov or c_score is None) else tf.concat([x_normalized, c_score], -1)
z_mean, _, _ = self.encoder(x_normalized, 1, False)
return z_mean.numpy()
def get_pc_x(self, test_dataset):
'''Get \(p(c_i|Y_i,X_i)\).
Parameters
----------
test_dataset : tf.Dataset
the dataset object.
Returns
----------
pi_norm : np.array
\([1, K]\) The estimated \(\\pi\).
p_c_x : np.array
\([N, ]\) The estimated \(p(c_i|Y_i,X_i)\).
'''
if self.latent_space is None:
raise ReferenceError('Have not initialized the latent space.')
pi_norm = tf.nn.softmax(self.latent_space.pi).numpy()
p_c_x = []
for x,c_score in test_dataset:
x = tf.concat([x, c_score], -1) if self.has_cov else x
_, _, z = self.encoder(x, 1, False)
_p_c_x = self.latent_space.get_posterior_c(z)
p_c_x.append(_p_c_x)
p_c_x = np.concatenate(p_c_x)
return pi_norm, p_c_x
def inference(self, test_dataset, L=1):
'''Get \(p(c_i|Y_i,X_i)\).
Parameters
----------
test_dataset : tf.Dataset
The dataset object.
L : int
The number of MC samples.
Returns
----------
pi_norm : np.array
\([1, K]\) The estimated \(\\pi\).
mu : np.array
\([d, k]\) The estimated \(\\mu\).
p_c_x : np.array
\([N, ]\) The estimated \(p(c_i|Y_i,X_i)\).
w_tilde : np.array
\([N, k]\) The estimated \(E(\\tilde{w}_i|Y_i,X_i)\).
var_w_tilde : np.array
\([N, k]\) The estimated \(Var(\\tilde{w}_i|Y_i,X_i)\).
z_mean : np.array
\([N, d]\) The estimated latent mean.
'''
if self.latent_space is None:
raise ReferenceError('Have not initialized the latent space.')
print('Computing posterior estimations over mini-batches.')
progbar = Progbar(test_dataset.cardinality().numpy())
pi_norm = tf.nn.softmax(self.latent_space.pi).numpy()
mu = self.latent_space.mu.numpy()
z_mean = []
p_c_x = []
w_tilde = []
var_w_tilde = []
for step, (x,c_score, _, _) in enumerate(test_dataset):
x = tf.concat([x, c_score], -1) if self.has_cov else x
_z_mean, _, z = self.encoder(x, L, False)
res = self.latent_space(z, inference=True)
z_mean.append(_z_mean.numpy())
p_c_x.append(res['p_c_x'])
w_tilde.append(res['w_tilde'])
var_w_tilde.append(res['var_w_tilde'])
progbar.update(step+1)
z_mean = np.concatenate(z_mean)
p_c_x = np.concatenate(p_c_x)
w_tilde = np.concatenate(w_tilde)
w_tilde /= np.sum(w_tilde, axis=1, keepdims=True)
var_w_tilde = np.concatenate(var_w_tilde)
return pi_norm, mu, p_c_x, w_tilde, var_w_tilde, z_mean</code></pre>
</details>
<h3>Ancestors</h3>
<ul class="hlist">
<li>keras.src.engine.training.Model</li>
<li>keras.src.engine.base_layer.Layer</li>
<li>tensorflow.python.module.module.Module</li>
<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li>
<li>tensorflow.python.trackable.base.Trackable</li>
<li>keras.src.utils.version_utils.LayerVersionSelector</li>
<li>keras.src.utils.version_utils.ModelVersionSelector</li>
</ul>
<h3>Methods</h3>
<dl>
<dt id="VITAE.model.VariationalAutoEncoder.init_latent_space"><code class="name flex">
<span>def <span class="ident">init_latent_space</span></span>(<span>self, n_clusters, mu, log_pi=None)</span>
</code></dt>
<dd>
<div class="desc"><p>Initialze the latent space.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>n_clusters</code></strong> : <code>int</code></dt>
<dd>The number of vertices in the latent space.</dd>
<dt><strong><code>mu</code></strong> : <code>np.array</code></dt>
<dd><span><span class="MathJax_Preview">[d, k]</span><script type="math/tex">[d, k]</script></span> The position matrix.</dd>
<dt><strong><code>log_pi</code></strong> : <code>np.array</code>, optional</dt>
<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> <span><span class="MathJax_Preview">\log\pi</span><script type="math/tex">\log\pi</script></span>.</dd>
</dl></div>
</dd>
<dt id="VITAE.model.VariationalAutoEncoder.create_pilayer"><code class="name flex">
<span>def <span class="ident">create_pilayer</span></span>(<span>self)</span>
</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="VITAE.model.VariationalAutoEncoder.call"><code class="name flex">
<span>def <span class="ident">call</span></span>(<span>self, x_normalized, c_score, x=None, scale_factor=1, pre_train=False, L=1, alpha=0.0, gamma=0.0, phi=1.0, conditions=None, pi_cov=None)</span>
</code></dt>
<dd>
<div class="desc"><p>Feed forward through encoder, LatentSpace layer and decoder.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>x_normalized</code></strong> : <code>np.array</code></dt>
<dd><span><span class="MathJax_Preview">[B, G]</span><script type="math/tex">[B, G]</script></span> The preprocessed data.</dd>
<dt><strong><code>c_score</code></strong> : <code>np.array</code></dt>
<dd><span><span class="MathJax_Preview">[B, s]</span><script type="math/tex">[B, s]</script></span> The covariates <span><span class="MathJax_Preview">X_i</span><script type="math/tex">X_i</script></span>, only used when <code>has_cov=True</code>.</dd>
<dt><strong><code>x</code></strong> : <code>np.array</code>, optional</dt>
<dd><span><span class="MathJax_Preview">[B, G]</span><script type="math/tex">[B, G]</script></span> The original count data <span><span class="MathJax_Preview">Y_i</span><script type="math/tex">Y_i</script></span>, only used when data_type is not <code>'Gaussian'</code>.</dd>
<dt><strong><code>scale_factor</code></strong> : <code>np.array</code>, optional</dt>
<dd><span><span class="MathJax_Preview">[B, ]</span><script type="math/tex">[B, ]</script></span> The scale factors, only used when data_type is not <code>'Gaussian'</code>.</dd>
<dt><strong><code>pre_train</code></strong> : <code>boolean</code>, optional</dt>
<dd>Whether in the pre-training phare or not.</dd>
<dt><strong><code>L</code></strong> : <code>int</code>, optional</dt>
<dd>The number of MC samples.</dd>
<dt><strong><code>alpha</code></strong> : <code>float</code>, optional</dt>
<dd>The penalty parameter for covariates adjustment.</dd>
<dt><strong><code>gamma</code></strong> : <code>float</code>, optional</dt>
<dd>The weight of mmd loss</dd>
<dt><strong><code>phi</code></strong> : <code>float</code>, optional</dt>
<dd>The weight of Jacob norm of the encoder.</dd>
<dt><strong><code>conditions</code></strong> : <code>str</code> or <code>list</code>, optional</dt>
<dd>The conditions of different cells from the selected batch</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>losses</code></strong> : <code>float</code></dt>
<dd>the loss.</dd>
</dl></div>
</dd>
<dt id="VITAE.model.VariationalAutoEncoder.get_z"><code class="name flex">
<span>def <span class="ident">get_z</span></span>(<span>self, x_normalized, c_score)</span>
</code></dt>
<dd>
<div class="desc"><p>Get <span><span class="MathJax_Preview">q(Z_i|Y_i,X_i)</span><script type="math/tex">q(Z_i|Y_i,X_i)</script></span>.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>x_normalized</code></strong> : <code>int</code></dt>
<dd><span><span class="MathJax_Preview">[B, G]</span><script type="math/tex">[B, G]</script></span> The preprocessed data.</dd>
<dt><strong><code>c_score</code></strong> : <code>np.array</code></dt>
<dd><span><span class="MathJax_Preview">[B, s]</span><script type="math/tex">[B, s]</script></span> The covariates <span><span class="MathJax_Preview">X_i</span><script type="math/tex">X_i</script></span>, only used when <code>has_cov=True</code>.</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>z_mean</code></strong> : <code>np.array</code></dt>
<dd><span><span class="MathJax_Preview">[B, d]</span><script type="math/tex">[B, d]</script></span> The latent mean.</dd>
</dl></div>
</dd>
<dt id="VITAE.model.VariationalAutoEncoder.get_pc_x"><code class="name flex">
<span>def <span class="ident">get_pc_x</span></span>(<span>self, test_dataset)</span>
</code></dt>
<dd>
<div class="desc"><p>Get <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>test_dataset</code></strong> : <code>tf.Dataset</code></dt>
<dd>the dataset object.</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>pi_norm</code></strong> : <code>np.array</code></dt>
<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> The estimated <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</dd>
<dt><strong><code>p_c_x</code></strong> : <code>np.array</code></dt>
<dd><span><span class="MathJax_Preview">[N, ]</span><script type="math/tex">[N, ]</script></span> The estimated <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</dd>
</dl></div>
</dd>
<dt id="VITAE.model.VariationalAutoEncoder.inference"><code class="name flex">
<span>def <span class="ident">inference</span></span>(<span>self, test_dataset, L=1)</span>
</code></dt>
<dd>
<div class="desc"><p>Get <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>test_dataset</code></strong> : <code>tf.Dataset</code></dt>
<dd>The dataset object.</dd>
<dt><strong><code>L</code></strong> : <code>int</code></dt>
<dd>The number of MC samples.</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><code>pi_norm
: np.array</code></dt>
<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> The estimated <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</dd>
<dt><strong><code>mu</code></strong> : <code>np.array</code></dt>
<dd><span><span class="MathJax_Preview">[d, k]</span><script type="math/tex">[d, k]</script></span> The estimated <span><span class="MathJax_Preview">\mu</span><script type="math/tex">\mu</script></span>.</dd>
<dt><strong><code>p_c_x</code></strong> : <code>np.array</code></dt>
<dd><span><span class="MathJax_Preview">[N, ]</span><script type="math/tex">[N, ]</script></span> The estimated <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</dd>
<dt><strong><code>w_tilde</code></strong> : <code>np.array</code></dt>
<dd><span><span class="MathJax_Preview">[N, k]</span><script type="math/tex">[N, k]</script></span> The estimated <span><span class="MathJax_Preview">E(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">E(\tilde{w}_i|Y_i,X_i)</script></span>.</dd>
<dt><code>var_w_tilde
: np.array</code></dt>
<dd><span><span class="MathJax_Preview">[N, k]</span><script type="math/tex">[N, k]</script></span> The estimated <span><span class="MathJax_Preview">Var(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">Var(\tilde{w}_i|Y_i,X_i)</script></span>.</dd>
<dt><strong><code>z_mean</code></strong> : <code>np.array</code></dt>
<dd><span><span class="MathJax_Preview">[N, d]</span><script type="math/tex">[N, d]</script></span> The estimated latent mean.</dd>
</dl></div>
</dd>
</dl>
</dd>
</dl>
</section>
</article>
<nav id="sidebar">
<div class="toc">
<ul></ul>
</div>
<ul id="index">
<li><h3>Super-module</h3>
<ul>
<li><code><a title="VITAE" href="index.html">VITAE</a></code></li>
</ul>
</li>
<li><h3><a href="#header-classes">Classes</a></h3>
<ul>
<li>
<h4><code><a title="VITAE.model.cdf_layer" href="#VITAE.model.cdf_layer">cdf_layer</a></code></h4>
<ul class="">
<li><code><a title="VITAE.model.cdf_layer.call" href="#VITAE.model.cdf_layer.call">call</a></code></li>
<li><code><a title="VITAE.model.cdf_layer.func" href="#VITAE.model.cdf_layer.func">func</a></code></li>
</ul>
</li>
<li>
<h4><code><a title="VITAE.model.Sampling" href="#VITAE.model.Sampling">Sampling</a></code></h4>
<ul class="">
<li><code><a title="VITAE.model.Sampling.call" href="#VITAE.model.Sampling.call">call</a></code></li>
</ul>
</li>
<li>
<h4><code><a title="VITAE.model.Encoder" href="#VITAE.model.Encoder">Encoder</a></code></h4>
<ul class="">
<li><code><a title="VITAE.model.Encoder.call" href="#VITAE.model.Encoder.call">call</a></code></li>
</ul>
</li>
<li>
<h4><code><a title="VITAE.model.Decoder" href="#VITAE.model.Decoder">Decoder</a></code></h4>
<ul class="">
<li><code><a title="VITAE.model.Decoder.call" href="#VITAE.model.Decoder.call">call</a></code></li>
</ul>
</li>
<li>
<h4><code><a title="VITAE.model.LatentSpace" href="#VITAE.model.LatentSpace">LatentSpace</a></code></h4>
<ul class="">
<li><code><a title="VITAE.model.LatentSpace.initialize" href="#VITAE.model.LatentSpace.initialize">initialize</a></code></li>
<li><code><a title="VITAE.model.LatentSpace.normalize" href="#VITAE.model.LatentSpace.normalize">normalize</a></code></li>
<li><code><a title="VITAE.model.LatentSpace.get_pz" href="#VITAE.model.LatentSpace.get_pz">get_pz</a></code></li>
<li><code><a title="VITAE.model.LatentSpace.get_posterior_c" href="#VITAE.model.LatentSpace.get_posterior_c">get_posterior_c</a></code></li>
<li><code><a title="VITAE.model.LatentSpace.call" href="#VITAE.model.LatentSpace.call">call</a></code></li>
</ul>
</li>
<li>
<h4><code><a title="VITAE.model.VariationalAutoEncoder" href="#VITAE.model.VariationalAutoEncoder">VariationalAutoEncoder</a></code></h4>
<ul class="two-column">
<li><code><a title="VITAE.model.VariationalAutoEncoder.init_latent_space" href="#VITAE.model.VariationalAutoEncoder.init_latent_space">init_latent_space</a></code></li>
<li><code><a title="VITAE.model.VariationalAutoEncoder.create_pilayer" href="#VITAE.model.VariationalAutoEncoder.create_pilayer">create_pilayer</a></code></li>
<li><code><a title="VITAE.model.VariationalAutoEncoder.call" href="#VITAE.model.VariationalAutoEncoder.call">call</a></code></li>
<li><code><a title="VITAE.model.VariationalAutoEncoder.get_z" href="#VITAE.model.VariationalAutoEncoder.get_z">get_z</a></code></li>
<li><code><a title="VITAE.model.VariationalAutoEncoder.get_pc_x" href="#VITAE.model.VariationalAutoEncoder.get_pc_x">get_pc_x</a></code></li>
<li><code><a title="VITAE.model.VariationalAutoEncoder.inference" href="#VITAE.model.VariationalAutoEncoder.inference">inference</a></code></li>
</ul>
</li>
</ul>
</li>
</ul>
</nav>
</main>
<footer id="footer">
<p>Generated by <a href="https://pdoc3.github.io/pdoc" title="pdoc: Python API documentation generator"><cite>pdoc</cite> 0.11.1</a>.</p>
</footer>
</body>
</html>