--- a +++ b/docs/model.html @@ -0,0 +1,1469 @@ +<!doctype html> +<html lang="en"> +<head> +<meta charset="utf-8"> +<meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1"> +<meta name="generator" content="pdoc3 0.11.1"> +<title>VITAE.model API documentation</title> +<meta name="description" content=""> +<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/sanitize.min.css" integrity="sha512-y1dtMcuvtTMJc1yPgEqF0ZjQbhnc/bFhyvIyVNb9Zk5mIGtqVaAB1Ttl28su8AvFMOY0EwRbAe+HCLqj6W7/KA==" crossorigin> +<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/typography.min.css" integrity="sha512-Y1DYSb995BAfxobCkKepB1BqJJTPrOp3zPL74AWFugHHmmdcvO+C48WLrUOlhGMc0QG7AE3f7gmvvcrmX2fDoA==" crossorigin> +<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/default.min.css" crossorigin> +<style>:root{--highlight-color:#fe9}.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:1.5em;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:2em 0 .50em 0}h3{font-size:1.4em;margin:1.6em 0 .7em 0}h4{margin:0;font-size:105%}h1:target,h2:target,h3:target,h4:target,h5:target,h6:target{background:var(--highlight-color);padding:.2em 0}a{color:#058;text-decoration:none;transition:color .2s ease-in-out}a:visited{color:#503}a:hover{color:#b62}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900;font-weight:bold}pre code{font-size:.8em;line-height:1.4em;padding:1em;display:block}code{background:#f3f3f3;font-family:"DejaVu Sans Mono",monospace;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}dt:target .name{background:var(--highlight-color)}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}td{padding:0 .5em}.admonition{padding:.1em 1em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style> +<style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul ul{padding-left:1em}.toc > ul > li{margin-top:.5em}}</style> +<style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style> +<script type="text/x-mathjax-config">MathJax.Hub.Config({ tex2jax: { inlineMath: [ ['$','$'], ["\\(","\\)"] ], processEscapes: true } });</script> +<script async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS_CHTML" integrity="sha256-kZafAc6mZvK3W3v1pHOcUix30OHQN6pU/NO2oFkqZVw=" crossorigin></script> +<script defer src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js" integrity="sha512-D9gUyxqja7hBtkWpPWGt9wfbfaMGVt9gnyCvYa+jojwwPHLCzUm5i8rpk7vD7wNee9bA35eYIjobYPaQuKS1MQ==" crossorigin></script> +<script>window.addEventListener('DOMContentLoaded', () => { +hljs.configure({languages: ['bash', 'css', 'diff', 'graphql', 'ini', 'javascript', 'json', 'plaintext', 'python', 'python-repl', 'rust', 'shell', 'sql', 'typescript', 'xml', 'yaml']}); +hljs.highlightAll(); +})</script> +</head> +<body> +<main> +<article id="content"> +<header> +<h1 class="title">Module <code>VITAE.model</code></h1> +</header> +<section id="section-intro"> +</section> +<section> +</section> +<section> +</section> +<section> +</section> +<section> +<h2 class="section-title" id="header-classes">Classes</h2> +<dl> +<dt id="VITAE.model.cdf_layer"><code class="flex name class"> +<span>class <span class="ident">cdf_layer</span></span> +</code></dt> +<dd> +<div class="desc"><p>The Normal cdf layer with custom gradients.</p></div> +<details class="source"> +<summary> +<span>Expand source code</span> +</summary> +<pre><code class="python">class cdf_layer(Layer): + ''' + The Normal cdf layer with custom gradients. + ''' + def __init__(self): + ''' + ''' + super(cdf_layer, self).__init__() + + @tf.function + def call(self, x): + return self.func(x) + + @tf.custom_gradient + def func(self, x): + '''Return cdf(x) and pdf(x). + + Parameters + ---------- + x : tf.Tensor + The input tensor. + + Returns + ---------- + f : tf.Tensor + cdf(x). + grad : tf.Tensor + pdf(x). + ''' + dist = tfp.distributions.Normal( + loc = tf.constant(0.0, tf.keras.backend.floatx()), + scale = tf.constant(1.0, tf.keras.backend.floatx()), + allow_nan_stats=False) + f = dist.cdf(x) + def grad(dy): + gradient = dist.prob(x) + return dy * gradient + return f, grad</code></pre> +</details> +<h3>Ancestors</h3> +<ul class="hlist"> +<li>keras.src.engine.base_layer.Layer</li> +<li>tensorflow.python.module.module.Module</li> +<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li> +<li>tensorflow.python.trackable.base.Trackable</li> +<li>keras.src.utils.version_utils.LayerVersionSelector</li> +</ul> +<h3>Methods</h3> +<dl> +<dt id="VITAE.model.cdf_layer.call"><code class="name flex"> +<span>def <span class="ident">call</span></span>(<span>self, x)</span> +</code></dt> +<dd> +<div class="desc"></div> +</dd> +<dt id="VITAE.model.cdf_layer.func"><code class="name flex"> +<span>def <span class="ident">func</span></span>(<span>self, x)</span> +</code></dt> +<dd> +<div class="desc"><p>Return cdf(x) and pdf(x).</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>x</code></strong> : <code>tf.Tensor</code></dt> +<dd>The input tensor.</dd> +</dl> +<h2 id="returns">Returns</h2> +<dl> +<dt><strong><code>f</code></strong> : <code>tf.Tensor</code></dt> +<dd>cdf(x).</dd> +<dt><strong><code>grad</code></strong> : <code>tf.Tensor</code></dt> +<dd>pdf(x).</dd> +</dl></div> +</dd> +</dl> +</dd> +<dt id="VITAE.model.Sampling"><code class="flex name class"> +<span>class <span class="ident">Sampling</span></span> +<span>(</span><span>seed=0, **kwargs)</span> +</code></dt> +<dd> +<div class="desc"><p>Sampling latent variable <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span> from <span><span class="MathJax_Preview">N(\mu_z, \log \sigma_z^2</span><script type="math/tex">N(\mu_z, \log \sigma_z^2</script></span>). +<br> +Used in Encoder.</p></div> +<details class="source"> +<summary> +<span>Expand source code</span> +</summary> +<pre><code class="python">class Sampling(Layer): + """Sampling latent variable \(z\) from \(N(\\mu_z, \\log \\sigma_z^2\)). + Used in Encoder. + """ + def __init__(self, seed=0, **kwargs): + super(Sampling, self).__init__(**kwargs) + self.seed = seed + + @tf.function + def call(self, z_mean, z_log_var): + '''Return cdf(x) and pdf(x). + + Parameters + ---------- + z_mean : tf.Tensor + \([B, L, d]\) The mean of \(z\). + z_log_var : tf.Tensor + \([B, L, d]\) The log-variance of \(z\). + + Returns + ---------- + z : tf.Tensor + \([B, L, d]\) The sampled \(z\). + ''' + # seed = tfp.util.SeedStream(self.seed, salt="random_normal") + # epsilon = tf.random.normal(shape = tf.shape(z_mean), seed=seed(), dtype=tf.keras.backend.floatx()) + epsilon = tf.random.normal(shape = tf.shape(z_mean), dtype=tf.keras.backend.floatx()) + z = z_mean + tf.exp(0.5 * z_log_var) * epsilon + z = tf.clip_by_value(z, -1e6, 1e6) + return z</code></pre> +</details> +<h3>Ancestors</h3> +<ul class="hlist"> +<li>keras.src.engine.base_layer.Layer</li> +<li>tensorflow.python.module.module.Module</li> +<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li> +<li>tensorflow.python.trackable.base.Trackable</li> +<li>keras.src.utils.version_utils.LayerVersionSelector</li> +</ul> +<h3>Methods</h3> +<dl> +<dt id="VITAE.model.Sampling.call"><code class="name flex"> +<span>def <span class="ident">call</span></span>(<span>self, z_mean, z_log_var)</span> +</code></dt> +<dd> +<div class="desc"><p>Return cdf(x) and pdf(x).</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>z_mean</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The mean of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd> +<dt><strong><code>z_log_var</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The log-variance of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd> +</dl> +<h2 id="returns">Returns</h2> +<dl> +<dt><strong><code>z</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The sampled <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd> +</dl></div> +</dd> +</dl> +</dd> +<dt id="VITAE.model.Encoder"><code class="flex name class"> +<span>class <span class="ident">Encoder</span></span> +<span>(</span><span>dimensions, dim_latent, name='encoder', **kwargs)</span> +</code></dt> +<dd> +<div class="desc"><p>Encoder, model <span><span class="MathJax_Preview">p(Z_i|Y_i,X_i)</span><script type="math/tex">p(Z_i|Y_i,X_i)</script></span>.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>dimensions</code></strong> : <code>np.array</code></dt> +<dd>The dimensions of hidden layers of the encoder.</dd> +<dt><strong><code>dim_latent</code></strong> : <code>int</code></dt> +<dd>The latent dimension of the encoder.</dd> +<dt><strong><code>name</code></strong> : <code>str</code>, optional</dt> +<dd>The name of the layer.</dd> +<dt><strong><code>**kwargs</code></strong></dt> +<dd>Extra keyword arguments.</dd> +</dl></div> +<details class="source"> +<summary> +<span>Expand source code</span> +</summary> +<pre><code class="python">class Encoder(Layer): + ''' + Encoder, model \(p(Z_i|Y_i,X_i)\). + ''' + def __init__(self, dimensions, dim_latent, name='encoder', **kwargs): + ''' + Parameters + ---------- + dimensions : np.array + The dimensions of hidden layers of the encoder. + dim_latent : int + The latent dimension of the encoder. + name : str, optional + The name of the layer. + **kwargs : + Extra keyword arguments. + ''' + super(Encoder, self).__init__(name = name, **kwargs) + self.dense_layers = [Dense(dim, activation = tf.nn.leaky_relu, + name = 'encoder_%i'%(i+1)) \ + for (i, dim) in enumerate(dimensions)] + self.batch_norm_layers = [BatchNormalization(center=False) \ + for _ in range(len((dimensions)))] + self.batch_norm_layers.append(BatchNormalization(center=False)) + self.latent_mean = Dense(dim_latent, name = 'latent_mean') + self.latent_log_var = Dense(dim_latent, name = 'latent_log_var') + self.sampling = Sampling() + + @tf.function + def call(self, x, L=1, is_training=True): + '''Encode the inputs and get the latent variables. + + Parameters + ---------- + x : tf.Tensor + \([B, L, d]\) The input. + L : int, optional + The number of MC samples. + is_training : boolean, optional + Whether in the training or inference mode. + + Returns + ---------- + z_mean : tf.Tensor + \([B, L, d]\) The mean of \(z\). + z_log_var : tf.Tensor + \([B, L, d]\) The log-variance of \(z\). + z : tf.Tensor + \([B, L, d]\) The sampled \(z\). + ''' + for dense, bn in zip(self.dense_layers, self.batch_norm_layers): + x = dense(x) + x = bn(x, training=is_training) + z_mean = self.batch_norm_layers[-1](self.latent_mean(x), training=is_training) + z_log_var = self.latent_log_var(x) + _z_mean = tf.tile(tf.expand_dims(z_mean, 1), (1,L,1)) + _z_log_var = tf.tile(tf.expand_dims(z_log_var, 1), (1,L,1)) + z = self.sampling(_z_mean, _z_log_var) + return z_mean, z_log_var, z</code></pre> +</details> +<h3>Ancestors</h3> +<ul class="hlist"> +<li>keras.src.engine.base_layer.Layer</li> +<li>tensorflow.python.module.module.Module</li> +<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li> +<li>tensorflow.python.trackable.base.Trackable</li> +<li>keras.src.utils.version_utils.LayerVersionSelector</li> +</ul> +<h3>Methods</h3> +<dl> +<dt id="VITAE.model.Encoder.call"><code class="name flex"> +<span>def <span class="ident">call</span></span>(<span>self, x, L=1, is_training=True)</span> +</code></dt> +<dd> +<div class="desc"><p>Encode the inputs and get the latent variables.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>x</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The input.</dd> +<dt><strong><code>L</code></strong> : <code>int</code>, optional</dt> +<dd>The number of MC samples.</dd> +<dt><strong><code>is_training</code></strong> : <code>boolean</code>, optional</dt> +<dd>Whether in the training or inference mode.</dd> +</dl> +<h2 id="returns">Returns</h2> +<dl> +<dt><strong><code>z_mean</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The mean of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd> +<dt><strong><code>z_log_var</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The log-variance of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd> +<dt><strong><code>z</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The sampled <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd> +</dl></div> +</dd> +</dl> +</dd> +<dt id="VITAE.model.Decoder"><code class="flex name class"> +<span>class <span class="ident">Decoder</span></span> +<span>(</span><span>dimensions, dim_origin, data_type='UMI', name='decoder', **kwargs)</span> +</code></dt> +<dd> +<div class="desc"><p>Decoder, model <span><span class="MathJax_Preview">p(Y_i|Z_i,X_i)</span><script type="math/tex">p(Y_i|Z_i,X_i)</script></span>.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>dimensions</code></strong> : <code>np.array</code></dt> +<dd>The dimensions of hidden layers of the encoder.</dd> +<dt><strong><code>dim_origin</code></strong> : <code>int</code></dt> +<dd>The output dimension of the decoder.</dd> +<dt><strong><code>data_type</code></strong> : <code>str</code>, optional</dt> +<dd><code>'UMI'</code>, <code>'non-UMI'</code>, or <code>'Gaussian'</code>.</dd> +<dt><strong><code>name</code></strong> : <code>str</code>, optional</dt> +<dd>The name of the layer.</dd> +</dl></div> +<details class="source"> +<summary> +<span>Expand source code</span> +</summary> +<pre><code class="python">class Decoder(Layer): + ''' + Decoder, model \(p(Y_i|Z_i,X_i)\). + ''' + def __init__(self, dimensions, dim_origin, data_type = 'UMI', + name = 'decoder', **kwargs): + ''' + Parameters + ---------- + dimensions : np.array + The dimensions of hidden layers of the encoder. + dim_origin : int + The output dimension of the decoder. + data_type : str, optional + `'UMI'`, `'non-UMI'`, or `'Gaussian'`. + name : str, optional + The name of the layer. + ''' + super(Decoder, self).__init__(name = name, **kwargs) + self.data_type = data_type + self.dense_layers = [Dense(dim, activation = tf.nn.leaky_relu, + name = 'decoder_%i'%(i+1)) \ + for (i,dim) in enumerate(dimensions)] + self.batch_norm_layers = [BatchNormalization(center=False) \ + for _ in range(len((dimensions)))] + + if data_type=='Gaussian': + self.nu_z = Dense(dim_origin, name = 'nu_z') + # common variance + self.log_tau = tf.Variable(tf.zeros([1, dim_origin], dtype=tf.keras.backend.floatx()), + constraint = lambda t: tf.clip_by_value(t,-30.,6.), + name = "log_tau") + else: + self.log_lambda_z = Dense(dim_origin, name = 'log_lambda_z') + + # dispersion parameter + self.log_r = tf.Variable(tf.zeros([1, dim_origin], dtype=tf.keras.backend.floatx()), + constraint = lambda t: tf.clip_by_value(t,-30.,6.), + name = "log_r") + + if self.data_type == 'non-UMI': + self.phi = Dense(dim_origin, activation = 'sigmoid', name = "phi") + + @tf.function + def call(self, z, is_training=True): + '''Decode the latent variables and get the reconstructions. + + Parameters + ---------- + z : tf.Tensor + \([B, L, d]\) the sampled \(z\). + is_training : boolean, optional + whether in the training or inference mode. + + When `data_type=='Gaussian'`: + + Returns + ---------- + nu_z : tf.Tensor + \([B, L, G]\) The mean of \(Y_i|Z_i,X_i\). + tau : tf.Tensor + \([1, G]\) The variance of \(Y_i|Z_i,X_i\). + + When `data_type=='UMI'`: + + Returns + ---------- + lambda_z : tf.Tensor + \([B, L, G]\) The mean of \(Y_i|Z_i,X_i\). + r : tf.Tensor + \([1, G]\) The dispersion parameters of \(Y_i|Z_i,X_i\). + + When `data_type=='non-UMI'`: + + Returns + ---------- + lambda_z : tf.Tensor + \([B, L, G]\) The mean of \(Y_i|Z_i,X_i\). + r : tf.Tensor + \([1, G]\) The dispersion parameters of \(Y_i|Z_i,X_i\). + phi_z : tf.Tensor + \([1, G]\) The zero inflated parameters of \(Y_i|Z_i,X_i\). + ''' + for dense, bn in zip(self.dense_layers, self.batch_norm_layers): + z = dense(z) + z = bn(z, training=is_training) + if self.data_type=='Gaussian': + nu_z = self.nu_z(z) + tau = tf.exp(self.log_tau) + return nu_z, tau + else: + lambda_z = tf.math.exp( + tf.clip_by_value(self.log_lambda_z(z), -30., 6.) + ) + r = tf.exp(self.log_r) + if self.data_type=='UMI': + return lambda_z, r + else: + return lambda_z, r, self.phi(z)</code></pre> +</details> +<h3>Ancestors</h3> +<ul class="hlist"> +<li>keras.src.engine.base_layer.Layer</li> +<li>tensorflow.python.module.module.Module</li> +<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li> +<li>tensorflow.python.trackable.base.Trackable</li> +<li>keras.src.utils.version_utils.LayerVersionSelector</li> +</ul> +<h3>Methods</h3> +<dl> +<dt id="VITAE.model.Decoder.call"><code class="name flex"> +<span>def <span class="ident">call</span></span>(<span>self, z, is_training=True)</span> +</code></dt> +<dd> +<div class="desc"><p>Decode the latent variables and get the reconstructions.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>z</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> the sampled <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd> +<dt><strong><code>is_training</code></strong> : <code>boolean</code>, optional</dt> +<dd>whether in the training or inference mode.</dd> +</dl> +<p>When <code>data_type=='Gaussian'</code>:</p> +<h2 id="returns">Returns</h2> +<dl> +<dt><strong><code>nu_z</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, G]</span><script type="math/tex">[B, L, G]</script></span> The mean of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd> +<dt><strong><code>tau</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The variance of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd> +</dl> +<p>When <code>data_type=='UMI'</code>:</p> +<h2 id="returns_1">Returns</h2> +<dl> +<dt><strong><code>lambda_z</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, G]</span><script type="math/tex">[B, L, G]</script></span> The mean of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd> +<dt><strong><code>r</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The dispersion parameters of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd> +</dl> +<p>When <code>data_type=='non-UMI'</code>:</p> +<h2 id="returns_2">Returns</h2> +<dl> +<dt><strong><code>lambda_z</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, G]</span><script type="math/tex">[B, L, G]</script></span> The mean of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd> +<dt><strong><code>r</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The dispersion parameters of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd> +<dt><strong><code>phi_z</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The zero inflated parameters of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd> +</dl></div> +</dd> +</dl> +</dd> +<dt id="VITAE.model.LatentSpace"><code class="flex name class"> +<span>class <span class="ident">LatentSpace</span></span> +<span>(</span><span>n_clusters, dim_latent, name='LatentSpace', seed=0, **kwargs)</span> +</code></dt> +<dd> +<div class="desc"><p>Layer for the Latent Space.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>n_clusters</code></strong> : <code>int</code></dt> +<dd>The number of vertices in the latent space.</dd> +<dt><strong><code>dim_latent</code></strong> : <code>int</code></dt> +<dd>The latent dimension.</dd> +<dt><strong><code>M</code></strong> : <code>int</code>, optional</dt> +<dd>The discretized number of uniform(0,1).</dd> +<dt><strong><code>name</code></strong> : <code>str</code>, optional</dt> +<dd>The name of the layer.</dd> +<dt><strong><code>**kwargs</code></strong></dt> +<dd>Extra keyword arguments.</dd> +</dl></div> +<details class="source"> +<summary> +<span>Expand source code</span> +</summary> +<pre><code class="python">class LatentSpace(Layer): + ''' + Layer for the Latent Space. + ''' + def __init__(self, n_clusters, dim_latent, + name = 'LatentSpace', seed=0, **kwargs): + ''' + Parameters + ---------- + n_clusters : int + The number of vertices in the latent space. + dim_latent : int + The latent dimension. + M : int, optional + The discretized number of uniform(0,1). + name : str, optional + The name of the layer. + **kwargs : + Extra keyword arguments. + ''' + super(LatentSpace, self).__init__(name=name, **kwargs) + self.dim_latent = dim_latent + self.n_states = n_clusters + self.n_categories = int(n_clusters*(n_clusters+1)/2) + + # nonzero indexes + # A = [0,0,...,0 , 1,1,...,1, ...] + # B = [0,1,...,k-1, 1,2,...,k-1, ...] + self.A, self.B = np.nonzero(np.triu(np.ones(n_clusters))) + self.A = tf.convert_to_tensor(self.A, tf.int32) + self.B = tf.convert_to_tensor(self.B, tf.int32) + self.clusters_ind = tf.boolean_mask( + tf.range(0,self.n_categories,1), self.A==self.B) + + # [pi_1, ... , pi_K] in R^(n_categories) + self.pi = tf.Variable(tf.ones([1, self.n_categories], dtype=tf.keras.backend.floatx()) / self.n_categories, + name = 'pi') + + # [mu_1, ... , mu_K] in R^(dim_latent * n_clusters) + self.mu = tf.Variable(tf.random.uniform([self.dim_latent, self.n_states], + minval = -1, maxval = 1, seed=seed, dtype=tf.keras.backend.floatx()), + name = 'mu') + self.cdf_layer = cdf_layer() + + def initialize(self, mu, log_pi): + '''Initialize the latent space. + + Parameters + ---------- + mu : np.array + \([d, k]\) The position matrix. + log_pi : np.array + \([1, K]\) \(\\log\\pi\). + ''' + # Initialize parameters of the latent space + if mu is not None: + self.mu.assign(mu) + if log_pi is not None: + self.pi.assign(log_pi) + + def normalize(self): + '''Normalize \(\\pi\). + ''' + self.pi = tf.nn.softmax(self.pi) + + @tf.function + def _get_normal_params(self, z, pi): + batch_size = tf.shape(z)[0] + L = tf.shape(z)[1] + + # [batch_size, L, n_categories] + if pi is None: + # [batch_size, L, n_states] + temp_pi = tf.tile( + tf.expand_dims(tf.nn.softmax(self.pi), 1), + (batch_size,L,1)) + else: + temp_pi = tf.expand_dims(tf.nn.softmax(pi), 1) + + # [batch_size, L, d, n_categories] + alpha_zc = tf.expand_dims(tf.expand_dims( + tf.gather(self.mu, self.B, axis=1) - tf.gather(self.mu, self.A, axis=1), 0), 0) + beta_zc = tf.expand_dims(z,-1) - \ + tf.expand_dims(tf.expand_dims( + tf.gather(self.mu, self.B, axis=1), 0), 0) + + # [batch_size, L, n_categories] + inv_sig = tf.reduce_sum(alpha_zc * alpha_zc, axis=2) + nu = - tf.reduce_sum(alpha_zc * beta_zc, axis=2) * tf.math.reciprocal_no_nan(inv_sig) + _t = - tf.reduce_sum(beta_zc * beta_zc, axis=2) + nu**2*inv_sig + return temp_pi, beta_zc, inv_sig, nu, _t + + @tf.function + def _get_pz(self, temp_pi, inv_sig, beta_zc, log_p_z_c_L): + # [batch_size, L, n_categories] + log_p_zc_L = - 0.5 * self.dim_latent * tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + \ + tf.math.log(temp_pi) + \ + tf.where(inv_sig==0, + - 0.5 * tf.reduce_sum(beta_zc**2, axis=2), + log_p_z_c_L) + + # [batch_size, L, 1] + log_p_z_L = tf.reduce_logsumexp(log_p_zc_L, axis=-1, keepdims=True) + + # [1, ] + log_p_z = tf.reduce_mean(log_p_z_L) + return log_p_zc_L, log_p_z_L, log_p_z + + @tf.function + def _get_posterior_c(self, log_p_zc_L, log_p_z_L): + L = tf.shape(log_p_z_L)[1] + + # log_p_c_x - predicted probability distribution + # [batch_size, n_categories] + log_p_c_x = tf.reduce_logsumexp( + log_p_zc_L - log_p_z_L, + axis=1) - tf.math.log(tf.cast(L, tf.keras.backend.floatx())) + return log_p_c_x + + @tf.function + def _get_inference(self, z, log_p_z_L, temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2): + batch_size = tf.shape(z)[0] + L = tf.shape(z)[1] + dist = tfp.distributions.Normal( + loc = tf.constant(0.0, tf.keras.backend.floatx()), + scale = tf.constant(1.0, tf.keras.backend.floatx()), + allow_nan_stats=False) + + # [batch_size, L, n_categories, n_clusters] + inv_sig = tf.expand_dims(inv_sig, -1) + _sig = tf.tile(tf.clip_by_value(tf.math.reciprocal_no_nan(inv_sig), 1e-12, 1e30), (1,1,1,self.n_states)) + log_eta0 = tf.tile(tf.expand_dims(log_eta0, -1), (1,1,1,self.n_states)) + eta1 = tf.tile(tf.expand_dims(eta1, -1), (1,1,1,self.n_states)) + eta2 = tf.tile(tf.expand_dims(eta2, -1), (1,1,1,self.n_states)) + nu = tf.tile(tf.expand_dims(nu, -1), (1,1,1,1)) + A = tf.tile(tf.expand_dims(tf.expand_dims( + tf.one_hot(self.A, self.n_states, dtype=tf.keras.backend.floatx()), + 0),0), (batch_size,L,1,1)) + B = tf.tile(tf.expand_dims(tf.expand_dims( + tf.one_hot(self.B, self.n_states, dtype=tf.keras.backend.floatx()), + 0),0), (batch_size,L,1,1)) + temp_pi = tf.expand_dims(temp_pi, -1) + + # w_tilde [batch_size, L, n_clusters] + w_tilde = log_eta0 + tf.math.log( + tf.clip_by_value( + (dist.cdf(eta1) - dist.cdf(eta2)) * (nu * A + (1-nu) * B) - + (dist.prob(eta1) - dist.prob(eta2)) * tf.math.sqrt(_sig) * (A - B), + 0.0, 1e30) + ) + w_tilde = - 0.5 * self.dim_latent * tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + \ + tf.math.log(temp_pi) + \ + tf.where(inv_sig==0, + tf.where(B==1, - 0.5 * tf.expand_dims(tf.reduce_sum(beta_zc**2, axis=2), -1), -np.inf), + w_tilde) + w_tilde = tf.exp(tf.reduce_logsumexp(w_tilde, 2) - log_p_z_L) + + # tf.debugging.assert_greater_equal( + # tf.reduce_sum(w_tilde, -1), tf.ones([batch_size, L], dtype=tf.keras.backend.floatx())*0.99, + # message='Wrong w_tilde', summarize=None, name=None + # ) + + # var_w_tilde [batch_size, L, n_clusters] + var_w_tilde = log_eta0 + tf.math.log( + tf.clip_by_value( + (dist.cdf(eta1) - dist.cdf(eta2)) * ((_sig + nu**2) * (A+B) + (1-2*nu) * B) - + (dist.prob(eta1) - dist.prob(eta2)) * tf.math.sqrt(_sig) * (nu *(A+B)-B )*2 - + (eta1*dist.prob(eta1) - eta2*dist.prob(eta2)) * _sig *(A+B), + 0.0, 1e30) + ) + var_w_tilde = - 0.5 * self.dim_latent * tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + \ + tf.math.log(temp_pi) + \ + tf.where(inv_sig==0, + tf.where(B==1, - 0.5 * tf.expand_dims(tf.reduce_sum(beta_zc**2, axis=2), -1), -np.inf), + var_w_tilde) + var_w_tilde = tf.exp(tf.reduce_logsumexp(var_w_tilde, 2) - log_p_z_L) - w_tilde**2 + + + w_tilde = tf.reduce_mean(w_tilde, 1) + var_w_tilde = tf.reduce_mean(var_w_tilde, 1) + return w_tilde, var_w_tilde + + def get_pz(self, z, eps, pi): + '''Get \(\\log p(Z_i|Y_i,X_i)\). + + Parameters + ---------- + z : tf.Tensor + \([B, L, d]\) The latent variables. + + Returns + ---------- + temp_pi : tf.Tensor + \([B, L, K]\) \(\\pi\). + inv_sig : tf.Tensor + \([B, L, K]\) \(\\sigma_{Z_ic_i}^{-1}\). + nu : tf.Tensor + \([B, L, K]\) \(\\nu_{Z_ic_i}\). + beta_zc : tf.Tensor + \([B, L, d, K]\) \(\\beta_{Z_ic_i}\). + log_eta0 : tf.Tensor + \([B, L, K]\) \(\\log\\eta_{Z_ic_i,0}\). + eta1 : tf.Tensor + \([B, L, K]\) \(\\eta_{Z_ic_i,1}\). + eta2 : tf.Tensor + \([B, L, K]\) \(\\eta_{Z_ic_i,2}\). + log_p_zc_L : tf.Tensor + \([B, L, K]\) \(\\log p(Z_i,c_i|Y_i,X_i)\). + log_p_z_L : tf.Tensor + \([B, L]\) \(\\log p(Z_i|Y_i,X_i)\). + log_p_z : tf.Tensor + \([B, 1]\) The estimated \(\\log p(Z_i|Y_i,X_i)\). + ''' + temp_pi, beta_zc, inv_sig, nu, _t = self._get_normal_params(z, pi) + temp_pi = tf.clip_by_value(temp_pi, eps, 1.0) + + log_eta0 = 0.5 * (tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) - \ + tf.math.log(tf.clip_by_value(inv_sig, 1e-12, 1e30)) + _t) + eta1 = (1-nu) * tf.math.sqrt(tf.clip_by_value(inv_sig, 1e-12, 1e30)) + eta2 = -nu * tf.math.sqrt(tf.clip_by_value(inv_sig, 1e-12, 1e30)) + + log_p_z_c_L = log_eta0 + tf.math.log(tf.clip_by_value( + self.cdf_layer(eta1) - self.cdf_layer(eta2), + eps, 1e30)) + + log_p_zc_L, log_p_z_L, log_p_z = self._get_pz(temp_pi, inv_sig, beta_zc, log_p_z_c_L) + return temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2, log_p_zc_L, log_p_z_L, log_p_z + + def get_posterior_c(self, z): + '''Get \(p(c_i|Y_i,X_i)\). + + Parameters + ---------- + z : tf.Tensor + \([B, L, d]\) The latent variables. + + Returns + ---------- + p_c_x : np.array + \([B, K]\) \(p(c_i|Y_i,X_i)\). + ''' + _,_,_,_,_,_,_, log_p_zc_L, log_p_z_L, _ = self.get_pz(z) + log_p_c_x = self._get_posterior_c(log_p_zc_L, log_p_z_L) + p_c_x = tf.exp(log_p_c_x).numpy() + return p_c_x + + def call(self, z, pi=None, inference=False): + '''Get posterior estimations. + + Parameters + ---------- + z : tf.Tensor + \([B, L, d]\) The latent variables. + inference : boolean + Whether in training or inference mode. + + When `inference=False`: + + Returns + ---------- + log_p_z_L : tf.Tensor + \([B, 1]\) The estimated \(\\log p(Z_i|Y_i,X_i)\). + + When `inference=True`: + + Returns + ---------- + res : dict + The dict of posterior estimations - \(p(c_i|Y_i,X_i)\), \(c\), \(E(\\tilde{w}_i|Y_i,X_i)\), \(Var(\\tilde{w}_i|Y_i,X_i)\), \(D_{JS}\). + ''' + eps = 1e-16 if not inference else 0. + temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2, log_p_zc_L, log_p_z_L, log_p_z = self.get_pz(z, eps, pi) + + if not inference: + return log_p_z + else: + log_p_c_x = self._get_posterior_c(log_p_zc_L, log_p_z_L) + w_tilde, var_w_tilde = self._get_inference(z, log_p_z_L, temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2) + + res = {} + res['p_c_x'] = tf.exp(log_p_c_x).numpy() + res['w_tilde'] = w_tilde.numpy() + res['var_w_tilde'] = var_w_tilde.numpy() + return res</code></pre> +</details> +<h3>Ancestors</h3> +<ul class="hlist"> +<li>keras.src.engine.base_layer.Layer</li> +<li>tensorflow.python.module.module.Module</li> +<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li> +<li>tensorflow.python.trackable.base.Trackable</li> +<li>keras.src.utils.version_utils.LayerVersionSelector</li> +</ul> +<h3>Methods</h3> +<dl> +<dt id="VITAE.model.LatentSpace.initialize"><code class="name flex"> +<span>def <span class="ident">initialize</span></span>(<span>self, mu, log_pi)</span> +</code></dt> +<dd> +<div class="desc"><p>Initialize the latent space.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>mu</code></strong> : <code>np.array</code></dt> +<dd><span><span class="MathJax_Preview">[d, k]</span><script type="math/tex">[d, k]</script></span> The position matrix.</dd> +<dt><strong><code>log_pi</code></strong> : <code>np.array</code></dt> +<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> <span><span class="MathJax_Preview">\log\pi</span><script type="math/tex">\log\pi</script></span>.</dd> +</dl></div> +</dd> +<dt id="VITAE.model.LatentSpace.normalize"><code class="name flex"> +<span>def <span class="ident">normalize</span></span>(<span>self)</span> +</code></dt> +<dd> +<div class="desc"><p>Normalize <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</p></div> +</dd> +<dt id="VITAE.model.LatentSpace.get_pz"><code class="name flex"> +<span>def <span class="ident">get_pz</span></span>(<span>self, z, eps, pi)</span> +</code></dt> +<dd> +<div class="desc"><p>Get <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>z</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The latent variables.</dd> +</dl> +<h2 id="returns">Returns</h2> +<dl> +<dt><strong><code>temp_pi</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</dd> +<dt><strong><code>inv_sig</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\sigma_{Z_ic_i}^{-1}</span><script type="math/tex">\sigma_{Z_ic_i}^{-1}</script></span>.</dd> +<dt><strong><code>nu</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\nu_{Z_ic_i}</span><script type="math/tex">\nu_{Z_ic_i}</script></span>.</dd> +<dt><strong><code>beta_zc</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, d, K]</span><script type="math/tex">[B, L, d, K]</script></span> <span><span class="MathJax_Preview">\beta_{Z_ic_i}</span><script type="math/tex">\beta_{Z_ic_i}</script></span>.</dd> +<dt><strong><code>log_eta0</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\log\eta_{Z_ic_i,0}</span><script type="math/tex">\log\eta_{Z_ic_i,0}</script></span>.</dd> +<dt><strong><code>eta1</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\eta_{Z_ic_i,1}</span><script type="math/tex">\eta_{Z_ic_i,1}</script></span>.</dd> +<dt><strong><code>eta2</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\eta_{Z_ic_i,2}</span><script type="math/tex">\eta_{Z_ic_i,2}</script></span>.</dd> +<dt><strong><code>log_p_zc_L</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\log p(Z_i,c_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i,c_i|Y_i,X_i)</script></span>.</dd> +<dt><strong><code>log_p_z_L</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L]</span><script type="math/tex">[B, L]</script></span> <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</dd> +<dt><strong><code>log_p_z</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, 1]</span><script type="math/tex">[B, 1]</script></span> The estimated <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</dd> +</dl></div> +</dd> +<dt id="VITAE.model.LatentSpace.get_posterior_c"><code class="name flex"> +<span>def <span class="ident">get_posterior_c</span></span>(<span>self, z)</span> +</code></dt> +<dd> +<div class="desc"><p>Get <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>z</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The latent variables.</dd> +</dl> +<h2 id="returns">Returns</h2> +<dl> +<dt><strong><code>p_c_x</code></strong> : <code>np.array</code></dt> +<dd><span><span class="MathJax_Preview">[B, K]</span><script type="math/tex">[B, K]</script></span> <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</dd> +</dl></div> +</dd> +<dt id="VITAE.model.LatentSpace.call"><code class="name flex"> +<span>def <span class="ident">call</span></span>(<span>self, z, pi=None, inference=False)</span> +</code></dt> +<dd> +<div class="desc"><p>Get posterior estimations.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>z</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The latent variables.</dd> +<dt><strong><code>inference</code></strong> : <code>boolean</code></dt> +<dd>Whether in training or inference mode.</dd> +</dl> +<p>When <code>inference=False</code>:</p> +<h2 id="returns">Returns</h2> +<dl> +<dt><strong><code>log_p_z_L</code></strong> : <code>tf.Tensor</code></dt> +<dd><span><span class="MathJax_Preview">[B, 1]</span><script type="math/tex">[B, 1]</script></span> The estimated <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</dd> +</dl> +<p>When <code>inference=True</code>:</p> +<h2 id="returns_1">Returns</h2> +<dl> +<dt><strong><code>res</code></strong> : <code>dict</code></dt> +<dd>The dict of posterior estimations - <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>, <span><span class="MathJax_Preview">c</span><script type="math/tex">c</script></span>, <span><span class="MathJax_Preview">E(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">E(\tilde{w}_i|Y_i,X_i)</script></span>, <span><span class="MathJax_Preview">Var(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">Var(\tilde{w}_i|Y_i,X_i)</script></span>, <span><span class="MathJax_Preview">D_{JS}</span><script type="math/tex">D_{JS}</script></span>.</dd> +</dl></div> +</dd> +</dl> +</dd> +<dt id="VITAE.model.VariationalAutoEncoder"><code class="flex name class"> +<span>class <span class="ident">VariationalAutoEncoder</span></span> +<span>(</span><span>dim_origin, dimensions, dim_latent, data_type='UMI', has_cov=False, name='autoencoder', **kwargs)</span> +</code></dt> +<dd> +<div class="desc"><p>Combines the encoder, decoder and LatentSpace into an end-to-end model for training and inference.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>dim_origin</code></strong> : <code>int</code></dt> +<dd>The output dimension of the decoder.</dd> +<dt><strong><code>dimensions</code></strong> : <code>np.array</code></dt> +<dd>The dimensions of hidden layers of the encoder.</dd> +<dt><strong><code>dim_latent</code></strong> : <code>int</code></dt> +<dd>The latent dimension.</dd> +<dt><strong><code>data_type</code></strong> : <code>str</code>, optional</dt> +<dd><code>'UMI'</code>, <code>'non-UMI'</code>, or <code>'Gaussian'</code>.</dd> +<dt><strong><code>has_cov</code></strong> : <code>boolean</code></dt> +<dd>Whether has covariates or not.</dd> +<dt><strong><code>gamma</code></strong> : <code>float</code>, optional</dt> +<dd>The weights of the MMD loss</dd> +<dt><strong><code>name</code></strong> : <code>str</code>, optional</dt> +<dd>The name of the layer.</dd> +<dt><strong><code>**kwargs</code></strong></dt> +<dd>Extra keyword arguments.</dd> +</dl></div> +<details class="source"> +<summary> +<span>Expand source code</span> +</summary> +<pre><code class="python">class VariationalAutoEncoder(tf.keras.Model): + """ + Combines the encoder, decoder and LatentSpace into an end-to-end model for training and inference. + """ + def __init__(self, dim_origin, dimensions, dim_latent, + data_type = 'UMI', has_cov=False, + name = 'autoencoder', **kwargs): + ''' + Parameters + ---------- + dim_origin : int + The output dimension of the decoder. + dimensions : np.array + The dimensions of hidden layers of the encoder. + dim_latent : int + The latent dimension. + data_type : str, optional + `'UMI'`, `'non-UMI'`, or `'Gaussian'`. + has_cov : boolean + Whether has covariates or not. + gamma : float, optional + The weights of the MMD loss + name : str, optional + The name of the layer. + **kwargs : + Extra keyword arguments. + ''' + super(VariationalAutoEncoder, self).__init__(name = name, **kwargs) + self.data_type = data_type + self.dim_origin = dim_origin + self.dim_latent = dim_latent + self.encoder = Encoder(dimensions, dim_latent) + self.decoder = Decoder(dimensions[::-1], dim_origin, data_type, data_type) + self.has_cov = has_cov + + def init_latent_space(self, n_clusters, mu, log_pi=None): + '''Initialze the latent space. + + Parameters + ---------- + n_clusters : int + The number of vertices in the latent space. + mu : np.array + \([d, k]\) The position matrix. + log_pi : np.array, optional + \([1, K]\) \(\\log\\pi\). + ''' + self.n_states = n_clusters + self.latent_space = LatentSpace(self.n_states, self.dim_latent) + self.latent_space.initialize(mu, log_pi) + self.pilayer = None + + def create_pilayer(self): + self.pilayer = Dense(self.latent_space.n_categories, name = 'pi_layer') + + def call(self, x_normalized, c_score, x = None, scale_factor = 1, + pre_train = False, L=1, alpha=0.0, gamma = 0.0, phi = 1.0, conditions = None, pi_cov = None): + '''Feed forward through encoder, LatentSpace layer and decoder. + + Parameters + ---------- + x_normalized : np.array + \([B, G]\) The preprocessed data. + c_score : np.array + \([B, s]\) The covariates \(X_i\), only used when `has_cov=True`. + x : np.array, optional + \([B, G]\) The original count data \(Y_i\), only used when data_type is not `'Gaussian'`. + scale_factor : np.array, optional + \([B, ]\) The scale factors, only used when data_type is not `'Gaussian'`. + pre_train : boolean, optional + Whether in the pre-training phare or not. + L : int, optional + The number of MC samples. + alpha : float, optional + The penalty parameter for covariates adjustment. + gamma : float, optional + The weight of mmd loss + phi : float, optional + The weight of Jacob norm of the encoder. + conditions: str or list, optional + The conditions of different cells from the selected batch + + Returns + ---------- + losses : float + the loss. + ''' + + if not pre_train and self.latent_space is None: + raise ReferenceError('Have not initialized the latent space.') + + if self.has_cov: + x_normalized = tf.concat([x_normalized, c_score], -1) + else: + x_normalized + _, z_log_var, z = self.encoder(x_normalized, L) + + if gamma == 0: + mmd_loss = 0.0 + else: + mmd_loss = self._get_total_mmd_loss(conditions,z,gamma) + + z_in = tf.concat([z, tf.tile(tf.expand_dims(c_score,1), (1,L,1))], -1) if self.has_cov else z + + x = tf.tile(tf.expand_dims(x, 1), (1,L,1)) + reconstruction_z_loss = self._get_reconstruction_loss(x, z_in, scale_factor, L) + + if self.has_cov and alpha>0.0: + zero_in = tf.concat([tf.zeros([z.shape[0],1,z.shape[2]], dtype=tf.keras.backend.floatx()), + tf.tile(tf.expand_dims(c_score,1), (1,1,1))], -1) + reconstruction_zero_loss = self._get_reconstruction_loss(x, zero_in, scale_factor, 1) + reconstruction_z_loss = (1-alpha)*reconstruction_z_loss + alpha*reconstruction_zero_loss + + self.add_loss(reconstruction_z_loss) + J_norm = self._get_Jacob(x_normalized, L) + self.add_loss((phi * J_norm)) + # gamma weight has been used when call _mmd_loss function. + self.add_loss(mmd_loss) + + if not pre_train: + pi = self.pilayer(pi_cov) if self.pilayer is not None else None + log_p_z = self.latent_space(z, pi, inference=False) + + # - E_q[log p(z)] + self.add_loss(- log_p_z) + + # - Eq[log q(z|x)] + E_qzx = - tf.reduce_mean( + 0.5 * self.dim_latent * + (tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + 1.0) + + 0.5 * tf.reduce_sum(z_log_var, axis=-1) + ) + self.add_loss(E_qzx) + return self.losses + + @tf.function + def _get_reconstruction_loss(self, x, z_in, scale_factor, L): + if self.data_type=='Gaussian': + # Gaussian Log-Likelihood Loss function + nu_z, tau = self.decoder(z_in) + neg_E_Gaus = 0.5 * tf.math.log(tf.clip_by_value(tau, 1e-12, 1e30)) + 0.5 * tf.math.square(x - nu_z) / tau + neg_E_Gaus = tf.reduce_mean(tf.reduce_sum(neg_E_Gaus, axis=-1)) + + return neg_E_Gaus + else: + if self.data_type == 'UMI': + x_hat, r = self.decoder(z_in) + else: + x_hat, r, phi = self.decoder(z_in) + + x_hat = x_hat*tf.expand_dims(scale_factor, -1) + + # Negative Log-Likelihood Loss function + + # Ref for NB & ZINB loss functions: + # https://github.com/gokceneraslan/neuralnet_countmodels/blob/master/Count%20models%20with%20neuralnets.ipynb + # Negative Binomial loss + + neg_E_nb = tf.math.lgamma(r) + tf.math.lgamma(x+1.0) \ + - tf.math.lgamma(x+r) + \ + (r+x) * tf.math.log(1.0 + (x_hat/r)) + \ + x * (tf.math.log(r) - tf.math.log(tf.clip_by_value(x_hat, 1e-12, 1e30))) + + if self.data_type == 'non-UMI': + # Zero-Inflated Negative Binomial loss + nb_case = neg_E_nb - tf.math.log(tf.clip_by_value(1.0-phi, 1e-12, 1e30)) + zero_case = - tf.math.log(tf.clip_by_value( + phi + (1.0-phi) * tf.pow(r * tf.math.reciprocal_no_nan(r + x_hat), r), + 1e-12, 1e30)) + neg_E_nb = tf.where(tf.less(x, 1e-8), zero_case, nb_case) + + neg_E_nb = tf.reduce_mean(tf.reduce_sum(neg_E_nb, axis=-1)) + return neg_E_nb + + def _get_total_mmd_loss(self,conditions,z,gamma): + mmd_loss = 0.0 + conditions = tf.cast(conditions,tf.int32) + n_group = conditions.shape[1] + + for i in range(n_group): + sub_conditions = conditions[:, i] + # 0 means not participant in mmd + z_cond = z[sub_conditions != 0] + sub_conditions = sub_conditions[sub_conditions != 0] + n_sub_group = tf.unique(sub_conditions)[0].shape[0] + real_labels = K.reshape(sub_conditions, (-1,)).numpy() + unique_set = list(set(real_labels)) + reindex_dict = dict(zip(unique_set, range(n_sub_group))) + real_labels = [reindex_dict[x] for x in real_labels] + real_labels = tf.convert_to_tensor(real_labels,dtype=tf.int32) + + if (n_sub_group == 1) | (n_sub_group == 0): + _loss = 0 + else: + _loss = self._mmd_loss(real_labels=real_labels, y_pred=z_cond, gamma=gamma, + n_conditions=n_sub_group, + kernel_method='multi-scale-rbf', + computation_method="general") + mmd_loss = mmd_loss + _loss + return mmd_loss + + # each loop the inputed shape is changed. Can not use @tf.function + # tf graph requires static shape and tensor dtype + def _mmd_loss(self, real_labels, y_pred, gamma, n_conditions, kernel_method='multi-scale-rbf', + computation_method="general"): + conditions_mmd = tf.dynamic_partition(y_pred, real_labels, num_partitions=n_conditions) + loss = 0.0 + if computation_method.isdigit(): + boundary = int(computation_method) + ## every pair of groups will calculate a distance + for i in range(boundary): + for j in range(boundary, n_conditions): + loss += _nan2zero(compute_mmd(conditions_mmd[i], conditions_mmd[j], kernel_method)) + else: + for i in range(len(conditions_mmd)): + for j in range(i): + loss += _nan2zero(compute_mmd(conditions_mmd[i], conditions_mmd[j], kernel_method)) + + # print("The loss is ", loss) + return gamma * loss + + @tf.function + def _get_Jacob(self, x, L): + with tf.GradientTape() as g: + g.watch(x) + z_mean, z_log_var, z = self.encoder(x, L) + # y_mean, y_log_var = self.decoder(z) + ## just jacobian will cause shape (batch,16,batch,64) matrix + J = g.batch_jacobian(z, x) + J_norm = tf.norm(J) + # tf.print(J_norm) + + return J_norm + + def get_z(self, x_normalized, c_score): + '''Get \(q(Z_i|Y_i,X_i)\). + + Parameters + ---------- + x_normalized : int + \([B, G]\) The preprocessed data. + c_score : np.array + \([B, s]\) The covariates \(X_i\), only used when `has_cov=True`. + + Returns + ---------- + z_mean : np.array + \([B, d]\) The latent mean. + ''' + x_normalized = x_normalized if (not self.has_cov or c_score is None) else tf.concat([x_normalized, c_score], -1) + z_mean, _, _ = self.encoder(x_normalized, 1, False) + return z_mean.numpy() + + def get_pc_x(self, test_dataset): + '''Get \(p(c_i|Y_i,X_i)\). + + Parameters + ---------- + test_dataset : tf.Dataset + the dataset object. + + Returns + ---------- + pi_norm : np.array + \([1, K]\) The estimated \(\\pi\). + p_c_x : np.array + \([N, ]\) The estimated \(p(c_i|Y_i,X_i)\). + ''' + if self.latent_space is None: + raise ReferenceError('Have not initialized the latent space.') + + pi_norm = tf.nn.softmax(self.latent_space.pi).numpy() + p_c_x = [] + for x,c_score in test_dataset: + x = tf.concat([x, c_score], -1) if self.has_cov else x + _, _, z = self.encoder(x, 1, False) + _p_c_x = self.latent_space.get_posterior_c(z) + p_c_x.append(_p_c_x) + p_c_x = np.concatenate(p_c_x) + return pi_norm, p_c_x + + def inference(self, test_dataset, L=1): + '''Get \(p(c_i|Y_i,X_i)\). + + Parameters + ---------- + test_dataset : tf.Dataset + The dataset object. + L : int + The number of MC samples. + + Returns + ---------- + pi_norm : np.array + \([1, K]\) The estimated \(\\pi\). + mu : np.array + \([d, k]\) The estimated \(\\mu\). + p_c_x : np.array + \([N, ]\) The estimated \(p(c_i|Y_i,X_i)\). + w_tilde : np.array + \([N, k]\) The estimated \(E(\\tilde{w}_i|Y_i,X_i)\). + var_w_tilde : np.array + \([N, k]\) The estimated \(Var(\\tilde{w}_i|Y_i,X_i)\). + z_mean : np.array + \([N, d]\) The estimated latent mean. + ''' + if self.latent_space is None: + raise ReferenceError('Have not initialized the latent space.') + + print('Computing posterior estimations over mini-batches.') + progbar = Progbar(test_dataset.cardinality().numpy()) + pi_norm = tf.nn.softmax(self.latent_space.pi).numpy() + mu = self.latent_space.mu.numpy() + z_mean = [] + p_c_x = [] + w_tilde = [] + var_w_tilde = [] + for step, (x,c_score, _, _) in enumerate(test_dataset): + x = tf.concat([x, c_score], -1) if self.has_cov else x + _z_mean, _, z = self.encoder(x, L, False) + res = self.latent_space(z, inference=True) + + z_mean.append(_z_mean.numpy()) + p_c_x.append(res['p_c_x']) + w_tilde.append(res['w_tilde']) + var_w_tilde.append(res['var_w_tilde']) + progbar.update(step+1) + + z_mean = np.concatenate(z_mean) + p_c_x = np.concatenate(p_c_x) + w_tilde = np.concatenate(w_tilde) + w_tilde /= np.sum(w_tilde, axis=1, keepdims=True) + var_w_tilde = np.concatenate(var_w_tilde) + return pi_norm, mu, p_c_x, w_tilde, var_w_tilde, z_mean</code></pre> +</details> +<h3>Ancestors</h3> +<ul class="hlist"> +<li>keras.src.engine.training.Model</li> +<li>keras.src.engine.base_layer.Layer</li> +<li>tensorflow.python.module.module.Module</li> +<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li> +<li>tensorflow.python.trackable.base.Trackable</li> +<li>keras.src.utils.version_utils.LayerVersionSelector</li> +<li>keras.src.utils.version_utils.ModelVersionSelector</li> +</ul> +<h3>Methods</h3> +<dl> +<dt id="VITAE.model.VariationalAutoEncoder.init_latent_space"><code class="name flex"> +<span>def <span class="ident">init_latent_space</span></span>(<span>self, n_clusters, mu, log_pi=None)</span> +</code></dt> +<dd> +<div class="desc"><p>Initialze the latent space.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>n_clusters</code></strong> : <code>int</code></dt> +<dd>The number of vertices in the latent space.</dd> +<dt><strong><code>mu</code></strong> : <code>np.array</code></dt> +<dd><span><span class="MathJax_Preview">[d, k]</span><script type="math/tex">[d, k]</script></span> The position matrix.</dd> +<dt><strong><code>log_pi</code></strong> : <code>np.array</code>, optional</dt> +<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> <span><span class="MathJax_Preview">\log\pi</span><script type="math/tex">\log\pi</script></span>.</dd> +</dl></div> +</dd> +<dt id="VITAE.model.VariationalAutoEncoder.create_pilayer"><code class="name flex"> +<span>def <span class="ident">create_pilayer</span></span>(<span>self)</span> +</code></dt> +<dd> +<div class="desc"></div> +</dd> +<dt id="VITAE.model.VariationalAutoEncoder.call"><code class="name flex"> +<span>def <span class="ident">call</span></span>(<span>self, x_normalized, c_score, x=None, scale_factor=1, pre_train=False, L=1, alpha=0.0, gamma=0.0, phi=1.0, conditions=None, pi_cov=None)</span> +</code></dt> +<dd> +<div class="desc"><p>Feed forward through encoder, LatentSpace layer and decoder.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>x_normalized</code></strong> : <code>np.array</code></dt> +<dd><span><span class="MathJax_Preview">[B, G]</span><script type="math/tex">[B, G]</script></span> The preprocessed data.</dd> +<dt><strong><code>c_score</code></strong> : <code>np.array</code></dt> +<dd><span><span class="MathJax_Preview">[B, s]</span><script type="math/tex">[B, s]</script></span> The covariates <span><span class="MathJax_Preview">X_i</span><script type="math/tex">X_i</script></span>, only used when <code>has_cov=True</code>.</dd> +<dt><strong><code>x</code></strong> : <code>np.array</code>, optional</dt> +<dd><span><span class="MathJax_Preview">[B, G]</span><script type="math/tex">[B, G]</script></span> The original count data <span><span class="MathJax_Preview">Y_i</span><script type="math/tex">Y_i</script></span>, only used when data_type is not <code>'Gaussian'</code>.</dd> +<dt><strong><code>scale_factor</code></strong> : <code>np.array</code>, optional</dt> +<dd><span><span class="MathJax_Preview">[B, ]</span><script type="math/tex">[B, ]</script></span> The scale factors, only used when data_type is not <code>'Gaussian'</code>.</dd> +<dt><strong><code>pre_train</code></strong> : <code>boolean</code>, optional</dt> +<dd>Whether in the pre-training phare or not.</dd> +<dt><strong><code>L</code></strong> : <code>int</code>, optional</dt> +<dd>The number of MC samples.</dd> +<dt><strong><code>alpha</code></strong> : <code>float</code>, optional</dt> +<dd>The penalty parameter for covariates adjustment.</dd> +<dt><strong><code>gamma</code></strong> : <code>float</code>, optional</dt> +<dd>The weight of mmd loss</dd> +<dt><strong><code>phi</code></strong> : <code>float</code>, optional</dt> +<dd>The weight of Jacob norm of the encoder.</dd> +<dt><strong><code>conditions</code></strong> : <code>str</code> or <code>list</code>, optional</dt> +<dd>The conditions of different cells from the selected batch</dd> +</dl> +<h2 id="returns">Returns</h2> +<dl> +<dt><strong><code>losses</code></strong> : <code>float</code></dt> +<dd>the loss.</dd> +</dl></div> +</dd> +<dt id="VITAE.model.VariationalAutoEncoder.get_z"><code class="name flex"> +<span>def <span class="ident">get_z</span></span>(<span>self, x_normalized, c_score)</span> +</code></dt> +<dd> +<div class="desc"><p>Get <span><span class="MathJax_Preview">q(Z_i|Y_i,X_i)</span><script type="math/tex">q(Z_i|Y_i,X_i)</script></span>.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>x_normalized</code></strong> : <code>int</code></dt> +<dd><span><span class="MathJax_Preview">[B, G]</span><script type="math/tex">[B, G]</script></span> The preprocessed data.</dd> +<dt><strong><code>c_score</code></strong> : <code>np.array</code></dt> +<dd><span><span class="MathJax_Preview">[B, s]</span><script type="math/tex">[B, s]</script></span> The covariates <span><span class="MathJax_Preview">X_i</span><script type="math/tex">X_i</script></span>, only used when <code>has_cov=True</code>.</dd> +</dl> +<h2 id="returns">Returns</h2> +<dl> +<dt><strong><code>z_mean</code></strong> : <code>np.array</code></dt> +<dd><span><span class="MathJax_Preview">[B, d]</span><script type="math/tex">[B, d]</script></span> The latent mean.</dd> +</dl></div> +</dd> +<dt id="VITAE.model.VariationalAutoEncoder.get_pc_x"><code class="name flex"> +<span>def <span class="ident">get_pc_x</span></span>(<span>self, test_dataset)</span> +</code></dt> +<dd> +<div class="desc"><p>Get <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>test_dataset</code></strong> : <code>tf.Dataset</code></dt> +<dd>the dataset object.</dd> +</dl> +<h2 id="returns">Returns</h2> +<dl> +<dt><strong><code>pi_norm</code></strong> : <code>np.array</code></dt> +<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> The estimated <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</dd> +<dt><strong><code>p_c_x</code></strong> : <code>np.array</code></dt> +<dd><span><span class="MathJax_Preview">[N, ]</span><script type="math/tex">[N, ]</script></span> The estimated <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</dd> +</dl></div> +</dd> +<dt id="VITAE.model.VariationalAutoEncoder.inference"><code class="name flex"> +<span>def <span class="ident">inference</span></span>(<span>self, test_dataset, L=1)</span> +</code></dt> +<dd> +<div class="desc"><p>Get <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>test_dataset</code></strong> : <code>tf.Dataset</code></dt> +<dd>The dataset object.</dd> +<dt><strong><code>L</code></strong> : <code>int</code></dt> +<dd>The number of MC samples.</dd> +</dl> +<h2 id="returns">Returns</h2> +<dl> +<dt><code>pi_norm +: np.array</code></dt> +<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> The estimated <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</dd> +<dt><strong><code>mu</code></strong> : <code>np.array</code></dt> +<dd><span><span class="MathJax_Preview">[d, k]</span><script type="math/tex">[d, k]</script></span> The estimated <span><span class="MathJax_Preview">\mu</span><script type="math/tex">\mu</script></span>.</dd> +<dt><strong><code>p_c_x</code></strong> : <code>np.array</code></dt> +<dd><span><span class="MathJax_Preview">[N, ]</span><script type="math/tex">[N, ]</script></span> The estimated <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</dd> +<dt><strong><code>w_tilde</code></strong> : <code>np.array</code></dt> +<dd><span><span class="MathJax_Preview">[N, k]</span><script type="math/tex">[N, k]</script></span> The estimated <span><span class="MathJax_Preview">E(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">E(\tilde{w}_i|Y_i,X_i)</script></span>.</dd> +<dt><code>var_w_tilde +: np.array</code></dt> +<dd><span><span class="MathJax_Preview">[N, k]</span><script type="math/tex">[N, k]</script></span> The estimated <span><span class="MathJax_Preview">Var(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">Var(\tilde{w}_i|Y_i,X_i)</script></span>.</dd> +<dt><strong><code>z_mean</code></strong> : <code>np.array</code></dt> +<dd><span><span class="MathJax_Preview">[N, d]</span><script type="math/tex">[N, d]</script></span> The estimated latent mean.</dd> +</dl></div> +</dd> +</dl> +</dd> +</dl> +</section> +</article> +<nav id="sidebar"> +<div class="toc"> +<ul></ul> +</div> +<ul id="index"> +<li><h3>Super-module</h3> +<ul> +<li><code><a title="VITAE" href="index.html">VITAE</a></code></li> +</ul> +</li> +<li><h3><a href="#header-classes">Classes</a></h3> +<ul> +<li> +<h4><code><a title="VITAE.model.cdf_layer" href="#VITAE.model.cdf_layer">cdf_layer</a></code></h4> +<ul class=""> +<li><code><a title="VITAE.model.cdf_layer.call" href="#VITAE.model.cdf_layer.call">call</a></code></li> +<li><code><a title="VITAE.model.cdf_layer.func" href="#VITAE.model.cdf_layer.func">func</a></code></li> +</ul> +</li> +<li> +<h4><code><a title="VITAE.model.Sampling" href="#VITAE.model.Sampling">Sampling</a></code></h4> +<ul class=""> +<li><code><a title="VITAE.model.Sampling.call" href="#VITAE.model.Sampling.call">call</a></code></li> +</ul> +</li> +<li> +<h4><code><a title="VITAE.model.Encoder" href="#VITAE.model.Encoder">Encoder</a></code></h4> +<ul class=""> +<li><code><a title="VITAE.model.Encoder.call" href="#VITAE.model.Encoder.call">call</a></code></li> +</ul> +</li> +<li> +<h4><code><a title="VITAE.model.Decoder" href="#VITAE.model.Decoder">Decoder</a></code></h4> +<ul class=""> +<li><code><a title="VITAE.model.Decoder.call" href="#VITAE.model.Decoder.call">call</a></code></li> +</ul> +</li> +<li> +<h4><code><a title="VITAE.model.LatentSpace" href="#VITAE.model.LatentSpace">LatentSpace</a></code></h4> +<ul class=""> +<li><code><a title="VITAE.model.LatentSpace.initialize" href="#VITAE.model.LatentSpace.initialize">initialize</a></code></li> +<li><code><a title="VITAE.model.LatentSpace.normalize" href="#VITAE.model.LatentSpace.normalize">normalize</a></code></li> +<li><code><a title="VITAE.model.LatentSpace.get_pz" href="#VITAE.model.LatentSpace.get_pz">get_pz</a></code></li> +<li><code><a title="VITAE.model.LatentSpace.get_posterior_c" href="#VITAE.model.LatentSpace.get_posterior_c">get_posterior_c</a></code></li> +<li><code><a title="VITAE.model.LatentSpace.call" href="#VITAE.model.LatentSpace.call">call</a></code></li> +</ul> +</li> +<li> +<h4><code><a title="VITAE.model.VariationalAutoEncoder" href="#VITAE.model.VariationalAutoEncoder">VariationalAutoEncoder</a></code></h4> +<ul class="two-column"> +<li><code><a title="VITAE.model.VariationalAutoEncoder.init_latent_space" href="#VITAE.model.VariationalAutoEncoder.init_latent_space">init_latent_space</a></code></li> +<li><code><a title="VITAE.model.VariationalAutoEncoder.create_pilayer" href="#VITAE.model.VariationalAutoEncoder.create_pilayer">create_pilayer</a></code></li> +<li><code><a title="VITAE.model.VariationalAutoEncoder.call" href="#VITAE.model.VariationalAutoEncoder.call">call</a></code></li> +<li><code><a title="VITAE.model.VariationalAutoEncoder.get_z" href="#VITAE.model.VariationalAutoEncoder.get_z">get_z</a></code></li> +<li><code><a title="VITAE.model.VariationalAutoEncoder.get_pc_x" href="#VITAE.model.VariationalAutoEncoder.get_pc_x">get_pc_x</a></code></li> +<li><code><a title="VITAE.model.VariationalAutoEncoder.inference" href="#VITAE.model.VariationalAutoEncoder.inference">inference</a></code></li> +</ul> +</li> +</ul> +</li> +</ul> +</nav> +</main> +<footer id="footer"> +<p>Generated by <a href="https://pdoc3.github.io/pdoc" title="pdoc: Python API documentation generator"><cite>pdoc</cite> 0.11.1</a>.</p> +</footer> +</body> +</html>