|
a |
|
b/docs/model.html |
|
|
1 |
<!doctype html> |
|
|
2 |
<html lang="en"> |
|
|
3 |
<head> |
|
|
4 |
<meta charset="utf-8"> |
|
|
5 |
<meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1"> |
|
|
6 |
<meta name="generator" content="pdoc3 0.11.1"> |
|
|
7 |
<title>VITAE.model API documentation</title> |
|
|
8 |
<meta name="description" content=""> |
|
|
9 |
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/sanitize.min.css" integrity="sha512-y1dtMcuvtTMJc1yPgEqF0ZjQbhnc/bFhyvIyVNb9Zk5mIGtqVaAB1Ttl28su8AvFMOY0EwRbAe+HCLqj6W7/KA==" crossorigin> |
|
|
10 |
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/typography.min.css" integrity="sha512-Y1DYSb995BAfxobCkKepB1BqJJTPrOp3zPL74AWFugHHmmdcvO+C48WLrUOlhGMc0QG7AE3f7gmvvcrmX2fDoA==" crossorigin> |
|
|
11 |
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/default.min.css" crossorigin> |
|
|
12 |
<style>:root{--highlight-color:#fe9}.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:1.5em;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:2em 0 .50em 0}h3{font-size:1.4em;margin:1.6em 0 .7em 0}h4{margin:0;font-size:105%}h1:target,h2:target,h3:target,h4:target,h5:target,h6:target{background:var(--highlight-color);padding:.2em 0}a{color:#058;text-decoration:none;transition:color .2s ease-in-out}a:visited{color:#503}a:hover{color:#b62}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900;font-weight:bold}pre code{font-size:.8em;line-height:1.4em;padding:1em;display:block}code{background:#f3f3f3;font-family:"DejaVu Sans Mono",monospace;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}dt:target .name{background:var(--highlight-color)}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}td{padding:0 .5em}.admonition{padding:.1em 1em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style> |
|
|
13 |
<style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul ul{padding-left:1em}.toc > ul > li{margin-top:.5em}}</style> |
|
|
14 |
<style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style> |
|
|
15 |
<script type="text/x-mathjax-config">MathJax.Hub.Config({ tex2jax: { inlineMath: [ ['$','$'], ["\\(","\\)"] ], processEscapes: true } });</script> |
|
|
16 |
<script async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS_CHTML" integrity="sha256-kZafAc6mZvK3W3v1pHOcUix30OHQN6pU/NO2oFkqZVw=" crossorigin></script> |
|
|
17 |
<script defer src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js" integrity="sha512-D9gUyxqja7hBtkWpPWGt9wfbfaMGVt9gnyCvYa+jojwwPHLCzUm5i8rpk7vD7wNee9bA35eYIjobYPaQuKS1MQ==" crossorigin></script> |
|
|
18 |
<script>window.addEventListener('DOMContentLoaded', () => { |
|
|
19 |
hljs.configure({languages: ['bash', 'css', 'diff', 'graphql', 'ini', 'javascript', 'json', 'plaintext', 'python', 'python-repl', 'rust', 'shell', 'sql', 'typescript', 'xml', 'yaml']}); |
|
|
20 |
hljs.highlightAll(); |
|
|
21 |
})</script> |
|
|
22 |
</head> |
|
|
23 |
<body> |
|
|
24 |
<main> |
|
|
25 |
<article id="content"> |
|
|
26 |
<header> |
|
|
27 |
<h1 class="title">Module <code>VITAE.model</code></h1> |
|
|
28 |
</header> |
|
|
29 |
<section id="section-intro"> |
|
|
30 |
</section> |
|
|
31 |
<section> |
|
|
32 |
</section> |
|
|
33 |
<section> |
|
|
34 |
</section> |
|
|
35 |
<section> |
|
|
36 |
</section> |
|
|
37 |
<section> |
|
|
38 |
<h2 class="section-title" id="header-classes">Classes</h2> |
|
|
39 |
<dl> |
|
|
40 |
<dt id="VITAE.model.cdf_layer"><code class="flex name class"> |
|
|
41 |
<span>class <span class="ident">cdf_layer</span></span> |
|
|
42 |
</code></dt> |
|
|
43 |
<dd> |
|
|
44 |
<div class="desc"><p>The Normal cdf layer with custom gradients.</p></div> |
|
|
45 |
<details class="source"> |
|
|
46 |
<summary> |
|
|
47 |
<span>Expand source code</span> |
|
|
48 |
</summary> |
|
|
49 |
<pre><code class="python">class cdf_layer(Layer): |
|
|
50 |
''' |
|
|
51 |
The Normal cdf layer with custom gradients. |
|
|
52 |
''' |
|
|
53 |
def __init__(self): |
|
|
54 |
''' |
|
|
55 |
''' |
|
|
56 |
super(cdf_layer, self).__init__() |
|
|
57 |
|
|
|
58 |
@tf.function |
|
|
59 |
def call(self, x): |
|
|
60 |
return self.func(x) |
|
|
61 |
|
|
|
62 |
@tf.custom_gradient |
|
|
63 |
def func(self, x): |
|
|
64 |
'''Return cdf(x) and pdf(x). |
|
|
65 |
|
|
|
66 |
Parameters |
|
|
67 |
---------- |
|
|
68 |
x : tf.Tensor |
|
|
69 |
The input tensor. |
|
|
70 |
|
|
|
71 |
Returns |
|
|
72 |
---------- |
|
|
73 |
f : tf.Tensor |
|
|
74 |
cdf(x). |
|
|
75 |
grad : tf.Tensor |
|
|
76 |
pdf(x). |
|
|
77 |
''' |
|
|
78 |
dist = tfp.distributions.Normal( |
|
|
79 |
loc = tf.constant(0.0, tf.keras.backend.floatx()), |
|
|
80 |
scale = tf.constant(1.0, tf.keras.backend.floatx()), |
|
|
81 |
allow_nan_stats=False) |
|
|
82 |
f = dist.cdf(x) |
|
|
83 |
def grad(dy): |
|
|
84 |
gradient = dist.prob(x) |
|
|
85 |
return dy * gradient |
|
|
86 |
return f, grad</code></pre> |
|
|
87 |
</details> |
|
|
88 |
<h3>Ancestors</h3> |
|
|
89 |
<ul class="hlist"> |
|
|
90 |
<li>keras.src.engine.base_layer.Layer</li> |
|
|
91 |
<li>tensorflow.python.module.module.Module</li> |
|
|
92 |
<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li> |
|
|
93 |
<li>tensorflow.python.trackable.base.Trackable</li> |
|
|
94 |
<li>keras.src.utils.version_utils.LayerVersionSelector</li> |
|
|
95 |
</ul> |
|
|
96 |
<h3>Methods</h3> |
|
|
97 |
<dl> |
|
|
98 |
<dt id="VITAE.model.cdf_layer.call"><code class="name flex"> |
|
|
99 |
<span>def <span class="ident">call</span></span>(<span>self, x)</span> |
|
|
100 |
</code></dt> |
|
|
101 |
<dd> |
|
|
102 |
<div class="desc"></div> |
|
|
103 |
</dd> |
|
|
104 |
<dt id="VITAE.model.cdf_layer.func"><code class="name flex"> |
|
|
105 |
<span>def <span class="ident">func</span></span>(<span>self, x)</span> |
|
|
106 |
</code></dt> |
|
|
107 |
<dd> |
|
|
108 |
<div class="desc"><p>Return cdf(x) and pdf(x).</p> |
|
|
109 |
<h2 id="parameters">Parameters</h2> |
|
|
110 |
<dl> |
|
|
111 |
<dt><strong><code>x</code></strong> : <code>tf.Tensor</code></dt> |
|
|
112 |
<dd>The input tensor.</dd> |
|
|
113 |
</dl> |
|
|
114 |
<h2 id="returns">Returns</h2> |
|
|
115 |
<dl> |
|
|
116 |
<dt><strong><code>f</code></strong> : <code>tf.Tensor</code></dt> |
|
|
117 |
<dd>cdf(x).</dd> |
|
|
118 |
<dt><strong><code>grad</code></strong> : <code>tf.Tensor</code></dt> |
|
|
119 |
<dd>pdf(x).</dd> |
|
|
120 |
</dl></div> |
|
|
121 |
</dd> |
|
|
122 |
</dl> |
|
|
123 |
</dd> |
|
|
124 |
<dt id="VITAE.model.Sampling"><code class="flex name class"> |
|
|
125 |
<span>class <span class="ident">Sampling</span></span> |
|
|
126 |
<span>(</span><span>seed=0, **kwargs)</span> |
|
|
127 |
</code></dt> |
|
|
128 |
<dd> |
|
|
129 |
<div class="desc"><p>Sampling latent variable <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span> from <span><span class="MathJax_Preview">N(\mu_z, \log \sigma_z^2</span><script type="math/tex">N(\mu_z, \log \sigma_z^2</script></span>). |
|
|
130 |
<br> |
|
|
131 |
Used in Encoder.</p></div> |
|
|
132 |
<details class="source"> |
|
|
133 |
<summary> |
|
|
134 |
<span>Expand source code</span> |
|
|
135 |
</summary> |
|
|
136 |
<pre><code class="python">class Sampling(Layer): |
|
|
137 |
"""Sampling latent variable \(z\) from \(N(\\mu_z, \\log \\sigma_z^2\)). |
|
|
138 |
Used in Encoder. |
|
|
139 |
""" |
|
|
140 |
def __init__(self, seed=0, **kwargs): |
|
|
141 |
super(Sampling, self).__init__(**kwargs) |
|
|
142 |
self.seed = seed |
|
|
143 |
|
|
|
144 |
@tf.function |
|
|
145 |
def call(self, z_mean, z_log_var): |
|
|
146 |
'''Return cdf(x) and pdf(x). |
|
|
147 |
|
|
|
148 |
Parameters |
|
|
149 |
---------- |
|
|
150 |
z_mean : tf.Tensor |
|
|
151 |
\([B, L, d]\) The mean of \(z\). |
|
|
152 |
z_log_var : tf.Tensor |
|
|
153 |
\([B, L, d]\) The log-variance of \(z\). |
|
|
154 |
|
|
|
155 |
Returns |
|
|
156 |
---------- |
|
|
157 |
z : tf.Tensor |
|
|
158 |
\([B, L, d]\) The sampled \(z\). |
|
|
159 |
''' |
|
|
160 |
# seed = tfp.util.SeedStream(self.seed, salt="random_normal") |
|
|
161 |
# epsilon = tf.random.normal(shape = tf.shape(z_mean), seed=seed(), dtype=tf.keras.backend.floatx()) |
|
|
162 |
epsilon = tf.random.normal(shape = tf.shape(z_mean), dtype=tf.keras.backend.floatx()) |
|
|
163 |
z = z_mean + tf.exp(0.5 * z_log_var) * epsilon |
|
|
164 |
z = tf.clip_by_value(z, -1e6, 1e6) |
|
|
165 |
return z</code></pre> |
|
|
166 |
</details> |
|
|
167 |
<h3>Ancestors</h3> |
|
|
168 |
<ul class="hlist"> |
|
|
169 |
<li>keras.src.engine.base_layer.Layer</li> |
|
|
170 |
<li>tensorflow.python.module.module.Module</li> |
|
|
171 |
<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li> |
|
|
172 |
<li>tensorflow.python.trackable.base.Trackable</li> |
|
|
173 |
<li>keras.src.utils.version_utils.LayerVersionSelector</li> |
|
|
174 |
</ul> |
|
|
175 |
<h3>Methods</h3> |
|
|
176 |
<dl> |
|
|
177 |
<dt id="VITAE.model.Sampling.call"><code class="name flex"> |
|
|
178 |
<span>def <span class="ident">call</span></span>(<span>self, z_mean, z_log_var)</span> |
|
|
179 |
</code></dt> |
|
|
180 |
<dd> |
|
|
181 |
<div class="desc"><p>Return cdf(x) and pdf(x).</p> |
|
|
182 |
<h2 id="parameters">Parameters</h2> |
|
|
183 |
<dl> |
|
|
184 |
<dt><strong><code>z_mean</code></strong> : <code>tf.Tensor</code></dt> |
|
|
185 |
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The mean of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd> |
|
|
186 |
<dt><strong><code>z_log_var</code></strong> : <code>tf.Tensor</code></dt> |
|
|
187 |
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The log-variance of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd> |
|
|
188 |
</dl> |
|
|
189 |
<h2 id="returns">Returns</h2> |
|
|
190 |
<dl> |
|
|
191 |
<dt><strong><code>z</code></strong> : <code>tf.Tensor</code></dt> |
|
|
192 |
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The sampled <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd> |
|
|
193 |
</dl></div> |
|
|
194 |
</dd> |
|
|
195 |
</dl> |
|
|
196 |
</dd> |
|
|
197 |
<dt id="VITAE.model.Encoder"><code class="flex name class"> |
|
|
198 |
<span>class <span class="ident">Encoder</span></span> |
|
|
199 |
<span>(</span><span>dimensions, dim_latent, name='encoder', **kwargs)</span> |
|
|
200 |
</code></dt> |
|
|
201 |
<dd> |
|
|
202 |
<div class="desc"><p>Encoder, model <span><span class="MathJax_Preview">p(Z_i|Y_i,X_i)</span><script type="math/tex">p(Z_i|Y_i,X_i)</script></span>.</p> |
|
|
203 |
<h2 id="parameters">Parameters</h2> |
|
|
204 |
<dl> |
|
|
205 |
<dt><strong><code>dimensions</code></strong> : <code>np.array</code></dt> |
|
|
206 |
<dd>The dimensions of hidden layers of the encoder.</dd> |
|
|
207 |
<dt><strong><code>dim_latent</code></strong> : <code>int</code></dt> |
|
|
208 |
<dd>The latent dimension of the encoder.</dd> |
|
|
209 |
<dt><strong><code>name</code></strong> : <code>str</code>, optional</dt> |
|
|
210 |
<dd>The name of the layer.</dd> |
|
|
211 |
<dt><strong><code>**kwargs</code></strong></dt> |
|
|
212 |
<dd>Extra keyword arguments.</dd> |
|
|
213 |
</dl></div> |
|
|
214 |
<details class="source"> |
|
|
215 |
<summary> |
|
|
216 |
<span>Expand source code</span> |
|
|
217 |
</summary> |
|
|
218 |
<pre><code class="python">class Encoder(Layer): |
|
|
219 |
''' |
|
|
220 |
Encoder, model \(p(Z_i|Y_i,X_i)\). |
|
|
221 |
''' |
|
|
222 |
def __init__(self, dimensions, dim_latent, name='encoder', **kwargs): |
|
|
223 |
''' |
|
|
224 |
Parameters |
|
|
225 |
---------- |
|
|
226 |
dimensions : np.array |
|
|
227 |
The dimensions of hidden layers of the encoder. |
|
|
228 |
dim_latent : int |
|
|
229 |
The latent dimension of the encoder. |
|
|
230 |
name : str, optional |
|
|
231 |
The name of the layer. |
|
|
232 |
**kwargs : |
|
|
233 |
Extra keyword arguments. |
|
|
234 |
''' |
|
|
235 |
super(Encoder, self).__init__(name = name, **kwargs) |
|
|
236 |
self.dense_layers = [Dense(dim, activation = tf.nn.leaky_relu, |
|
|
237 |
name = 'encoder_%i'%(i+1)) \ |
|
|
238 |
for (i, dim) in enumerate(dimensions)] |
|
|
239 |
self.batch_norm_layers = [BatchNormalization(center=False) \ |
|
|
240 |
for _ in range(len((dimensions)))] |
|
|
241 |
self.batch_norm_layers.append(BatchNormalization(center=False)) |
|
|
242 |
self.latent_mean = Dense(dim_latent, name = 'latent_mean') |
|
|
243 |
self.latent_log_var = Dense(dim_latent, name = 'latent_log_var') |
|
|
244 |
self.sampling = Sampling() |
|
|
245 |
|
|
|
246 |
@tf.function |
|
|
247 |
def call(self, x, L=1, is_training=True): |
|
|
248 |
'''Encode the inputs and get the latent variables. |
|
|
249 |
|
|
|
250 |
Parameters |
|
|
251 |
---------- |
|
|
252 |
x : tf.Tensor |
|
|
253 |
\([B, L, d]\) The input. |
|
|
254 |
L : int, optional |
|
|
255 |
The number of MC samples. |
|
|
256 |
is_training : boolean, optional |
|
|
257 |
Whether in the training or inference mode. |
|
|
258 |
|
|
|
259 |
Returns |
|
|
260 |
---------- |
|
|
261 |
z_mean : tf.Tensor |
|
|
262 |
\([B, L, d]\) The mean of \(z\). |
|
|
263 |
z_log_var : tf.Tensor |
|
|
264 |
\([B, L, d]\) The log-variance of \(z\). |
|
|
265 |
z : tf.Tensor |
|
|
266 |
\([B, L, d]\) The sampled \(z\). |
|
|
267 |
''' |
|
|
268 |
for dense, bn in zip(self.dense_layers, self.batch_norm_layers): |
|
|
269 |
x = dense(x) |
|
|
270 |
x = bn(x, training=is_training) |
|
|
271 |
z_mean = self.batch_norm_layers[-1](self.latent_mean(x), training=is_training) |
|
|
272 |
z_log_var = self.latent_log_var(x) |
|
|
273 |
_z_mean = tf.tile(tf.expand_dims(z_mean, 1), (1,L,1)) |
|
|
274 |
_z_log_var = tf.tile(tf.expand_dims(z_log_var, 1), (1,L,1)) |
|
|
275 |
z = self.sampling(_z_mean, _z_log_var) |
|
|
276 |
return z_mean, z_log_var, z</code></pre> |
|
|
277 |
</details> |
|
|
278 |
<h3>Ancestors</h3> |
|
|
279 |
<ul class="hlist"> |
|
|
280 |
<li>keras.src.engine.base_layer.Layer</li> |
|
|
281 |
<li>tensorflow.python.module.module.Module</li> |
|
|
282 |
<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li> |
|
|
283 |
<li>tensorflow.python.trackable.base.Trackable</li> |
|
|
284 |
<li>keras.src.utils.version_utils.LayerVersionSelector</li> |
|
|
285 |
</ul> |
|
|
286 |
<h3>Methods</h3> |
|
|
287 |
<dl> |
|
|
288 |
<dt id="VITAE.model.Encoder.call"><code class="name flex"> |
|
|
289 |
<span>def <span class="ident">call</span></span>(<span>self, x, L=1, is_training=True)</span> |
|
|
290 |
</code></dt> |
|
|
291 |
<dd> |
|
|
292 |
<div class="desc"><p>Encode the inputs and get the latent variables.</p> |
|
|
293 |
<h2 id="parameters">Parameters</h2> |
|
|
294 |
<dl> |
|
|
295 |
<dt><strong><code>x</code></strong> : <code>tf.Tensor</code></dt> |
|
|
296 |
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The input.</dd> |
|
|
297 |
<dt><strong><code>L</code></strong> : <code>int</code>, optional</dt> |
|
|
298 |
<dd>The number of MC samples.</dd> |
|
|
299 |
<dt><strong><code>is_training</code></strong> : <code>boolean</code>, optional</dt> |
|
|
300 |
<dd>Whether in the training or inference mode.</dd> |
|
|
301 |
</dl> |
|
|
302 |
<h2 id="returns">Returns</h2> |
|
|
303 |
<dl> |
|
|
304 |
<dt><strong><code>z_mean</code></strong> : <code>tf.Tensor</code></dt> |
|
|
305 |
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The mean of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd> |
|
|
306 |
<dt><strong><code>z_log_var</code></strong> : <code>tf.Tensor</code></dt> |
|
|
307 |
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The log-variance of <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd> |
|
|
308 |
<dt><strong><code>z</code></strong> : <code>tf.Tensor</code></dt> |
|
|
309 |
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The sampled <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd> |
|
|
310 |
</dl></div> |
|
|
311 |
</dd> |
|
|
312 |
</dl> |
|
|
313 |
</dd> |
|
|
314 |
<dt id="VITAE.model.Decoder"><code class="flex name class"> |
|
|
315 |
<span>class <span class="ident">Decoder</span></span> |
|
|
316 |
<span>(</span><span>dimensions, dim_origin, data_type='UMI', name='decoder', **kwargs)</span> |
|
|
317 |
</code></dt> |
|
|
318 |
<dd> |
|
|
319 |
<div class="desc"><p>Decoder, model <span><span class="MathJax_Preview">p(Y_i|Z_i,X_i)</span><script type="math/tex">p(Y_i|Z_i,X_i)</script></span>.</p> |
|
|
320 |
<h2 id="parameters">Parameters</h2> |
|
|
321 |
<dl> |
|
|
322 |
<dt><strong><code>dimensions</code></strong> : <code>np.array</code></dt> |
|
|
323 |
<dd>The dimensions of hidden layers of the encoder.</dd> |
|
|
324 |
<dt><strong><code>dim_origin</code></strong> : <code>int</code></dt> |
|
|
325 |
<dd>The output dimension of the decoder.</dd> |
|
|
326 |
<dt><strong><code>data_type</code></strong> : <code>str</code>, optional</dt> |
|
|
327 |
<dd><code>'UMI'</code>, <code>'non-UMI'</code>, or <code>'Gaussian'</code>.</dd> |
|
|
328 |
<dt><strong><code>name</code></strong> : <code>str</code>, optional</dt> |
|
|
329 |
<dd>The name of the layer.</dd> |
|
|
330 |
</dl></div> |
|
|
331 |
<details class="source"> |
|
|
332 |
<summary> |
|
|
333 |
<span>Expand source code</span> |
|
|
334 |
</summary> |
|
|
335 |
<pre><code class="python">class Decoder(Layer): |
|
|
336 |
''' |
|
|
337 |
Decoder, model \(p(Y_i|Z_i,X_i)\). |
|
|
338 |
''' |
|
|
339 |
def __init__(self, dimensions, dim_origin, data_type = 'UMI', |
|
|
340 |
name = 'decoder', **kwargs): |
|
|
341 |
''' |
|
|
342 |
Parameters |
|
|
343 |
---------- |
|
|
344 |
dimensions : np.array |
|
|
345 |
The dimensions of hidden layers of the encoder. |
|
|
346 |
dim_origin : int |
|
|
347 |
The output dimension of the decoder. |
|
|
348 |
data_type : str, optional |
|
|
349 |
`'UMI'`, `'non-UMI'`, or `'Gaussian'`. |
|
|
350 |
name : str, optional |
|
|
351 |
The name of the layer. |
|
|
352 |
''' |
|
|
353 |
super(Decoder, self).__init__(name = name, **kwargs) |
|
|
354 |
self.data_type = data_type |
|
|
355 |
self.dense_layers = [Dense(dim, activation = tf.nn.leaky_relu, |
|
|
356 |
name = 'decoder_%i'%(i+1)) \ |
|
|
357 |
for (i,dim) in enumerate(dimensions)] |
|
|
358 |
self.batch_norm_layers = [BatchNormalization(center=False) \ |
|
|
359 |
for _ in range(len((dimensions)))] |
|
|
360 |
|
|
|
361 |
if data_type=='Gaussian': |
|
|
362 |
self.nu_z = Dense(dim_origin, name = 'nu_z') |
|
|
363 |
# common variance |
|
|
364 |
self.log_tau = tf.Variable(tf.zeros([1, dim_origin], dtype=tf.keras.backend.floatx()), |
|
|
365 |
constraint = lambda t: tf.clip_by_value(t,-30.,6.), |
|
|
366 |
name = "log_tau") |
|
|
367 |
else: |
|
|
368 |
self.log_lambda_z = Dense(dim_origin, name = 'log_lambda_z') |
|
|
369 |
|
|
|
370 |
# dispersion parameter |
|
|
371 |
self.log_r = tf.Variable(tf.zeros([1, dim_origin], dtype=tf.keras.backend.floatx()), |
|
|
372 |
constraint = lambda t: tf.clip_by_value(t,-30.,6.), |
|
|
373 |
name = "log_r") |
|
|
374 |
|
|
|
375 |
if self.data_type == 'non-UMI': |
|
|
376 |
self.phi = Dense(dim_origin, activation = 'sigmoid', name = "phi") |
|
|
377 |
|
|
|
378 |
@tf.function |
|
|
379 |
def call(self, z, is_training=True): |
|
|
380 |
'''Decode the latent variables and get the reconstructions. |
|
|
381 |
|
|
|
382 |
Parameters |
|
|
383 |
---------- |
|
|
384 |
z : tf.Tensor |
|
|
385 |
\([B, L, d]\) the sampled \(z\). |
|
|
386 |
is_training : boolean, optional |
|
|
387 |
whether in the training or inference mode. |
|
|
388 |
|
|
|
389 |
When `data_type=='Gaussian'`: |
|
|
390 |
|
|
|
391 |
Returns |
|
|
392 |
---------- |
|
|
393 |
nu_z : tf.Tensor |
|
|
394 |
\([B, L, G]\) The mean of \(Y_i|Z_i,X_i\). |
|
|
395 |
tau : tf.Tensor |
|
|
396 |
\([1, G]\) The variance of \(Y_i|Z_i,X_i\). |
|
|
397 |
|
|
|
398 |
When `data_type=='UMI'`: |
|
|
399 |
|
|
|
400 |
Returns |
|
|
401 |
---------- |
|
|
402 |
lambda_z : tf.Tensor |
|
|
403 |
\([B, L, G]\) The mean of \(Y_i|Z_i,X_i\). |
|
|
404 |
r : tf.Tensor |
|
|
405 |
\([1, G]\) The dispersion parameters of \(Y_i|Z_i,X_i\). |
|
|
406 |
|
|
|
407 |
When `data_type=='non-UMI'`: |
|
|
408 |
|
|
|
409 |
Returns |
|
|
410 |
---------- |
|
|
411 |
lambda_z : tf.Tensor |
|
|
412 |
\([B, L, G]\) The mean of \(Y_i|Z_i,X_i\). |
|
|
413 |
r : tf.Tensor |
|
|
414 |
\([1, G]\) The dispersion parameters of \(Y_i|Z_i,X_i\). |
|
|
415 |
phi_z : tf.Tensor |
|
|
416 |
\([1, G]\) The zero inflated parameters of \(Y_i|Z_i,X_i\). |
|
|
417 |
''' |
|
|
418 |
for dense, bn in zip(self.dense_layers, self.batch_norm_layers): |
|
|
419 |
z = dense(z) |
|
|
420 |
z = bn(z, training=is_training) |
|
|
421 |
if self.data_type=='Gaussian': |
|
|
422 |
nu_z = self.nu_z(z) |
|
|
423 |
tau = tf.exp(self.log_tau) |
|
|
424 |
return nu_z, tau |
|
|
425 |
else: |
|
|
426 |
lambda_z = tf.math.exp( |
|
|
427 |
tf.clip_by_value(self.log_lambda_z(z), -30., 6.) |
|
|
428 |
) |
|
|
429 |
r = tf.exp(self.log_r) |
|
|
430 |
if self.data_type=='UMI': |
|
|
431 |
return lambda_z, r |
|
|
432 |
else: |
|
|
433 |
return lambda_z, r, self.phi(z)</code></pre> |
|
|
434 |
</details> |
|
|
435 |
<h3>Ancestors</h3> |
|
|
436 |
<ul class="hlist"> |
|
|
437 |
<li>keras.src.engine.base_layer.Layer</li> |
|
|
438 |
<li>tensorflow.python.module.module.Module</li> |
|
|
439 |
<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li> |
|
|
440 |
<li>tensorflow.python.trackable.base.Trackable</li> |
|
|
441 |
<li>keras.src.utils.version_utils.LayerVersionSelector</li> |
|
|
442 |
</ul> |
|
|
443 |
<h3>Methods</h3> |
|
|
444 |
<dl> |
|
|
445 |
<dt id="VITAE.model.Decoder.call"><code class="name flex"> |
|
|
446 |
<span>def <span class="ident">call</span></span>(<span>self, z, is_training=True)</span> |
|
|
447 |
</code></dt> |
|
|
448 |
<dd> |
|
|
449 |
<div class="desc"><p>Decode the latent variables and get the reconstructions.</p> |
|
|
450 |
<h2 id="parameters">Parameters</h2> |
|
|
451 |
<dl> |
|
|
452 |
<dt><strong><code>z</code></strong> : <code>tf.Tensor</code></dt> |
|
|
453 |
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> the sampled <span><span class="MathJax_Preview">z</span><script type="math/tex">z</script></span>.</dd> |
|
|
454 |
<dt><strong><code>is_training</code></strong> : <code>boolean</code>, optional</dt> |
|
|
455 |
<dd>whether in the training or inference mode.</dd> |
|
|
456 |
</dl> |
|
|
457 |
<p>When <code>data_type=='Gaussian'</code>:</p> |
|
|
458 |
<h2 id="returns">Returns</h2> |
|
|
459 |
<dl> |
|
|
460 |
<dt><strong><code>nu_z</code></strong> : <code>tf.Tensor</code></dt> |
|
|
461 |
<dd><span><span class="MathJax_Preview">[B, L, G]</span><script type="math/tex">[B, L, G]</script></span> The mean of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd> |
|
|
462 |
<dt><strong><code>tau</code></strong> : <code>tf.Tensor</code></dt> |
|
|
463 |
<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The variance of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd> |
|
|
464 |
</dl> |
|
|
465 |
<p>When <code>data_type=='UMI'</code>:</p> |
|
|
466 |
<h2 id="returns_1">Returns</h2> |
|
|
467 |
<dl> |
|
|
468 |
<dt><strong><code>lambda_z</code></strong> : <code>tf.Tensor</code></dt> |
|
|
469 |
<dd><span><span class="MathJax_Preview">[B, L, G]</span><script type="math/tex">[B, L, G]</script></span> The mean of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd> |
|
|
470 |
<dt><strong><code>r</code></strong> : <code>tf.Tensor</code></dt> |
|
|
471 |
<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The dispersion parameters of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd> |
|
|
472 |
</dl> |
|
|
473 |
<p>When <code>data_type=='non-UMI'</code>:</p> |
|
|
474 |
<h2 id="returns_2">Returns</h2> |
|
|
475 |
<dl> |
|
|
476 |
<dt><strong><code>lambda_z</code></strong> : <code>tf.Tensor</code></dt> |
|
|
477 |
<dd><span><span class="MathJax_Preview">[B, L, G]</span><script type="math/tex">[B, L, G]</script></span> The mean of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd> |
|
|
478 |
<dt><strong><code>r</code></strong> : <code>tf.Tensor</code></dt> |
|
|
479 |
<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The dispersion parameters of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd> |
|
|
480 |
<dt><strong><code>phi_z</code></strong> : <code>tf.Tensor</code></dt> |
|
|
481 |
<dd><span><span class="MathJax_Preview">[1, G]</span><script type="math/tex">[1, G]</script></span> The zero inflated parameters of <span><span class="MathJax_Preview">Y_i|Z_i,X_i</span><script type="math/tex">Y_i|Z_i,X_i</script></span>.</dd> |
|
|
482 |
</dl></div> |
|
|
483 |
</dd> |
|
|
484 |
</dl> |
|
|
485 |
</dd> |
|
|
486 |
<dt id="VITAE.model.LatentSpace"><code class="flex name class"> |
|
|
487 |
<span>class <span class="ident">LatentSpace</span></span> |
|
|
488 |
<span>(</span><span>n_clusters, dim_latent, name='LatentSpace', seed=0, **kwargs)</span> |
|
|
489 |
</code></dt> |
|
|
490 |
<dd> |
|
|
491 |
<div class="desc"><p>Layer for the Latent Space.</p> |
|
|
492 |
<h2 id="parameters">Parameters</h2> |
|
|
493 |
<dl> |
|
|
494 |
<dt><strong><code>n_clusters</code></strong> : <code>int</code></dt> |
|
|
495 |
<dd>The number of vertices in the latent space.</dd> |
|
|
496 |
<dt><strong><code>dim_latent</code></strong> : <code>int</code></dt> |
|
|
497 |
<dd>The latent dimension.</dd> |
|
|
498 |
<dt><strong><code>M</code></strong> : <code>int</code>, optional</dt> |
|
|
499 |
<dd>The discretized number of uniform(0,1).</dd> |
|
|
500 |
<dt><strong><code>name</code></strong> : <code>str</code>, optional</dt> |
|
|
501 |
<dd>The name of the layer.</dd> |
|
|
502 |
<dt><strong><code>**kwargs</code></strong></dt> |
|
|
503 |
<dd>Extra keyword arguments.</dd> |
|
|
504 |
</dl></div> |
|
|
505 |
<details class="source"> |
|
|
506 |
<summary> |
|
|
507 |
<span>Expand source code</span> |
|
|
508 |
</summary> |
|
|
509 |
<pre><code class="python">class LatentSpace(Layer): |
|
|
510 |
''' |
|
|
511 |
Layer for the Latent Space. |
|
|
512 |
''' |
|
|
513 |
def __init__(self, n_clusters, dim_latent, |
|
|
514 |
name = 'LatentSpace', seed=0, **kwargs): |
|
|
515 |
''' |
|
|
516 |
Parameters |
|
|
517 |
---------- |
|
|
518 |
n_clusters : int |
|
|
519 |
The number of vertices in the latent space. |
|
|
520 |
dim_latent : int |
|
|
521 |
The latent dimension. |
|
|
522 |
M : int, optional |
|
|
523 |
The discretized number of uniform(0,1). |
|
|
524 |
name : str, optional |
|
|
525 |
The name of the layer. |
|
|
526 |
**kwargs : |
|
|
527 |
Extra keyword arguments. |
|
|
528 |
''' |
|
|
529 |
super(LatentSpace, self).__init__(name=name, **kwargs) |
|
|
530 |
self.dim_latent = dim_latent |
|
|
531 |
self.n_states = n_clusters |
|
|
532 |
self.n_categories = int(n_clusters*(n_clusters+1)/2) |
|
|
533 |
|
|
|
534 |
# nonzero indexes |
|
|
535 |
# A = [0,0,...,0 , 1,1,...,1, ...] |
|
|
536 |
# B = [0,1,...,k-1, 1,2,...,k-1, ...] |
|
|
537 |
self.A, self.B = np.nonzero(np.triu(np.ones(n_clusters))) |
|
|
538 |
self.A = tf.convert_to_tensor(self.A, tf.int32) |
|
|
539 |
self.B = tf.convert_to_tensor(self.B, tf.int32) |
|
|
540 |
self.clusters_ind = tf.boolean_mask( |
|
|
541 |
tf.range(0,self.n_categories,1), self.A==self.B) |
|
|
542 |
|
|
|
543 |
# [pi_1, ... , pi_K] in R^(n_categories) |
|
|
544 |
self.pi = tf.Variable(tf.ones([1, self.n_categories], dtype=tf.keras.backend.floatx()) / self.n_categories, |
|
|
545 |
name = 'pi') |
|
|
546 |
|
|
|
547 |
# [mu_1, ... , mu_K] in R^(dim_latent * n_clusters) |
|
|
548 |
self.mu = tf.Variable(tf.random.uniform([self.dim_latent, self.n_states], |
|
|
549 |
minval = -1, maxval = 1, seed=seed, dtype=tf.keras.backend.floatx()), |
|
|
550 |
name = 'mu') |
|
|
551 |
self.cdf_layer = cdf_layer() |
|
|
552 |
|
|
|
553 |
def initialize(self, mu, log_pi): |
|
|
554 |
'''Initialize the latent space. |
|
|
555 |
|
|
|
556 |
Parameters |
|
|
557 |
---------- |
|
|
558 |
mu : np.array |
|
|
559 |
\([d, k]\) The position matrix. |
|
|
560 |
log_pi : np.array |
|
|
561 |
\([1, K]\) \(\\log\\pi\). |
|
|
562 |
''' |
|
|
563 |
# Initialize parameters of the latent space |
|
|
564 |
if mu is not None: |
|
|
565 |
self.mu.assign(mu) |
|
|
566 |
if log_pi is not None: |
|
|
567 |
self.pi.assign(log_pi) |
|
|
568 |
|
|
|
569 |
def normalize(self): |
|
|
570 |
'''Normalize \(\\pi\). |
|
|
571 |
''' |
|
|
572 |
self.pi = tf.nn.softmax(self.pi) |
|
|
573 |
|
|
|
574 |
@tf.function |
|
|
575 |
def _get_normal_params(self, z, pi): |
|
|
576 |
batch_size = tf.shape(z)[0] |
|
|
577 |
L = tf.shape(z)[1] |
|
|
578 |
|
|
|
579 |
# [batch_size, L, n_categories] |
|
|
580 |
if pi is None: |
|
|
581 |
# [batch_size, L, n_states] |
|
|
582 |
temp_pi = tf.tile( |
|
|
583 |
tf.expand_dims(tf.nn.softmax(self.pi), 1), |
|
|
584 |
(batch_size,L,1)) |
|
|
585 |
else: |
|
|
586 |
temp_pi = tf.expand_dims(tf.nn.softmax(pi), 1) |
|
|
587 |
|
|
|
588 |
# [batch_size, L, d, n_categories] |
|
|
589 |
alpha_zc = tf.expand_dims(tf.expand_dims( |
|
|
590 |
tf.gather(self.mu, self.B, axis=1) - tf.gather(self.mu, self.A, axis=1), 0), 0) |
|
|
591 |
beta_zc = tf.expand_dims(z,-1) - \ |
|
|
592 |
tf.expand_dims(tf.expand_dims( |
|
|
593 |
tf.gather(self.mu, self.B, axis=1), 0), 0) |
|
|
594 |
|
|
|
595 |
# [batch_size, L, n_categories] |
|
|
596 |
inv_sig = tf.reduce_sum(alpha_zc * alpha_zc, axis=2) |
|
|
597 |
nu = - tf.reduce_sum(alpha_zc * beta_zc, axis=2) * tf.math.reciprocal_no_nan(inv_sig) |
|
|
598 |
_t = - tf.reduce_sum(beta_zc * beta_zc, axis=2) + nu**2*inv_sig |
|
|
599 |
return temp_pi, beta_zc, inv_sig, nu, _t |
|
|
600 |
|
|
|
601 |
@tf.function |
|
|
602 |
def _get_pz(self, temp_pi, inv_sig, beta_zc, log_p_z_c_L): |
|
|
603 |
# [batch_size, L, n_categories] |
|
|
604 |
log_p_zc_L = - 0.5 * self.dim_latent * tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + \ |
|
|
605 |
tf.math.log(temp_pi) + \ |
|
|
606 |
tf.where(inv_sig==0, |
|
|
607 |
- 0.5 * tf.reduce_sum(beta_zc**2, axis=2), |
|
|
608 |
log_p_z_c_L) |
|
|
609 |
|
|
|
610 |
# [batch_size, L, 1] |
|
|
611 |
log_p_z_L = tf.reduce_logsumexp(log_p_zc_L, axis=-1, keepdims=True) |
|
|
612 |
|
|
|
613 |
# [1, ] |
|
|
614 |
log_p_z = tf.reduce_mean(log_p_z_L) |
|
|
615 |
return log_p_zc_L, log_p_z_L, log_p_z |
|
|
616 |
|
|
|
617 |
@tf.function |
|
|
618 |
def _get_posterior_c(self, log_p_zc_L, log_p_z_L): |
|
|
619 |
L = tf.shape(log_p_z_L)[1] |
|
|
620 |
|
|
|
621 |
# log_p_c_x - predicted probability distribution |
|
|
622 |
# [batch_size, n_categories] |
|
|
623 |
log_p_c_x = tf.reduce_logsumexp( |
|
|
624 |
log_p_zc_L - log_p_z_L, |
|
|
625 |
axis=1) - tf.math.log(tf.cast(L, tf.keras.backend.floatx())) |
|
|
626 |
return log_p_c_x |
|
|
627 |
|
|
|
628 |
@tf.function |
|
|
629 |
def _get_inference(self, z, log_p_z_L, temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2): |
|
|
630 |
batch_size = tf.shape(z)[0] |
|
|
631 |
L = tf.shape(z)[1] |
|
|
632 |
dist = tfp.distributions.Normal( |
|
|
633 |
loc = tf.constant(0.0, tf.keras.backend.floatx()), |
|
|
634 |
scale = tf.constant(1.0, tf.keras.backend.floatx()), |
|
|
635 |
allow_nan_stats=False) |
|
|
636 |
|
|
|
637 |
# [batch_size, L, n_categories, n_clusters] |
|
|
638 |
inv_sig = tf.expand_dims(inv_sig, -1) |
|
|
639 |
_sig = tf.tile(tf.clip_by_value(tf.math.reciprocal_no_nan(inv_sig), 1e-12, 1e30), (1,1,1,self.n_states)) |
|
|
640 |
log_eta0 = tf.tile(tf.expand_dims(log_eta0, -1), (1,1,1,self.n_states)) |
|
|
641 |
eta1 = tf.tile(tf.expand_dims(eta1, -1), (1,1,1,self.n_states)) |
|
|
642 |
eta2 = tf.tile(tf.expand_dims(eta2, -1), (1,1,1,self.n_states)) |
|
|
643 |
nu = tf.tile(tf.expand_dims(nu, -1), (1,1,1,1)) |
|
|
644 |
A = tf.tile(tf.expand_dims(tf.expand_dims( |
|
|
645 |
tf.one_hot(self.A, self.n_states, dtype=tf.keras.backend.floatx()), |
|
|
646 |
0),0), (batch_size,L,1,1)) |
|
|
647 |
B = tf.tile(tf.expand_dims(tf.expand_dims( |
|
|
648 |
tf.one_hot(self.B, self.n_states, dtype=tf.keras.backend.floatx()), |
|
|
649 |
0),0), (batch_size,L,1,1)) |
|
|
650 |
temp_pi = tf.expand_dims(temp_pi, -1) |
|
|
651 |
|
|
|
652 |
# w_tilde [batch_size, L, n_clusters] |
|
|
653 |
w_tilde = log_eta0 + tf.math.log( |
|
|
654 |
tf.clip_by_value( |
|
|
655 |
(dist.cdf(eta1) - dist.cdf(eta2)) * (nu * A + (1-nu) * B) - |
|
|
656 |
(dist.prob(eta1) - dist.prob(eta2)) * tf.math.sqrt(_sig) * (A - B), |
|
|
657 |
0.0, 1e30) |
|
|
658 |
) |
|
|
659 |
w_tilde = - 0.5 * self.dim_latent * tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + \ |
|
|
660 |
tf.math.log(temp_pi) + \ |
|
|
661 |
tf.where(inv_sig==0, |
|
|
662 |
tf.where(B==1, - 0.5 * tf.expand_dims(tf.reduce_sum(beta_zc**2, axis=2), -1), -np.inf), |
|
|
663 |
w_tilde) |
|
|
664 |
w_tilde = tf.exp(tf.reduce_logsumexp(w_tilde, 2) - log_p_z_L) |
|
|
665 |
|
|
|
666 |
# tf.debugging.assert_greater_equal( |
|
|
667 |
# tf.reduce_sum(w_tilde, -1), tf.ones([batch_size, L], dtype=tf.keras.backend.floatx())*0.99, |
|
|
668 |
# message='Wrong w_tilde', summarize=None, name=None |
|
|
669 |
# ) |
|
|
670 |
|
|
|
671 |
# var_w_tilde [batch_size, L, n_clusters] |
|
|
672 |
var_w_tilde = log_eta0 + tf.math.log( |
|
|
673 |
tf.clip_by_value( |
|
|
674 |
(dist.cdf(eta1) - dist.cdf(eta2)) * ((_sig + nu**2) * (A+B) + (1-2*nu) * B) - |
|
|
675 |
(dist.prob(eta1) - dist.prob(eta2)) * tf.math.sqrt(_sig) * (nu *(A+B)-B )*2 - |
|
|
676 |
(eta1*dist.prob(eta1) - eta2*dist.prob(eta2)) * _sig *(A+B), |
|
|
677 |
0.0, 1e30) |
|
|
678 |
) |
|
|
679 |
var_w_tilde = - 0.5 * self.dim_latent * tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + \ |
|
|
680 |
tf.math.log(temp_pi) + \ |
|
|
681 |
tf.where(inv_sig==0, |
|
|
682 |
tf.where(B==1, - 0.5 * tf.expand_dims(tf.reduce_sum(beta_zc**2, axis=2), -1), -np.inf), |
|
|
683 |
var_w_tilde) |
|
|
684 |
var_w_tilde = tf.exp(tf.reduce_logsumexp(var_w_tilde, 2) - log_p_z_L) - w_tilde**2 |
|
|
685 |
|
|
|
686 |
|
|
|
687 |
w_tilde = tf.reduce_mean(w_tilde, 1) |
|
|
688 |
var_w_tilde = tf.reduce_mean(var_w_tilde, 1) |
|
|
689 |
return w_tilde, var_w_tilde |
|
|
690 |
|
|
|
691 |
def get_pz(self, z, eps, pi): |
|
|
692 |
'''Get \(\\log p(Z_i|Y_i,X_i)\). |
|
|
693 |
|
|
|
694 |
Parameters |
|
|
695 |
---------- |
|
|
696 |
z : tf.Tensor |
|
|
697 |
\([B, L, d]\) The latent variables. |
|
|
698 |
|
|
|
699 |
Returns |
|
|
700 |
---------- |
|
|
701 |
temp_pi : tf.Tensor |
|
|
702 |
\([B, L, K]\) \(\\pi\). |
|
|
703 |
inv_sig : tf.Tensor |
|
|
704 |
\([B, L, K]\) \(\\sigma_{Z_ic_i}^{-1}\). |
|
|
705 |
nu : tf.Tensor |
|
|
706 |
\([B, L, K]\) \(\\nu_{Z_ic_i}\). |
|
|
707 |
beta_zc : tf.Tensor |
|
|
708 |
\([B, L, d, K]\) \(\\beta_{Z_ic_i}\). |
|
|
709 |
log_eta0 : tf.Tensor |
|
|
710 |
\([B, L, K]\) \(\\log\\eta_{Z_ic_i,0}\). |
|
|
711 |
eta1 : tf.Tensor |
|
|
712 |
\([B, L, K]\) \(\\eta_{Z_ic_i,1}\). |
|
|
713 |
eta2 : tf.Tensor |
|
|
714 |
\([B, L, K]\) \(\\eta_{Z_ic_i,2}\). |
|
|
715 |
log_p_zc_L : tf.Tensor |
|
|
716 |
\([B, L, K]\) \(\\log p(Z_i,c_i|Y_i,X_i)\). |
|
|
717 |
log_p_z_L : tf.Tensor |
|
|
718 |
\([B, L]\) \(\\log p(Z_i|Y_i,X_i)\). |
|
|
719 |
log_p_z : tf.Tensor |
|
|
720 |
\([B, 1]\) The estimated \(\\log p(Z_i|Y_i,X_i)\). |
|
|
721 |
''' |
|
|
722 |
temp_pi, beta_zc, inv_sig, nu, _t = self._get_normal_params(z, pi) |
|
|
723 |
temp_pi = tf.clip_by_value(temp_pi, eps, 1.0) |
|
|
724 |
|
|
|
725 |
log_eta0 = 0.5 * (tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) - \ |
|
|
726 |
tf.math.log(tf.clip_by_value(inv_sig, 1e-12, 1e30)) + _t) |
|
|
727 |
eta1 = (1-nu) * tf.math.sqrt(tf.clip_by_value(inv_sig, 1e-12, 1e30)) |
|
|
728 |
eta2 = -nu * tf.math.sqrt(tf.clip_by_value(inv_sig, 1e-12, 1e30)) |
|
|
729 |
|
|
|
730 |
log_p_z_c_L = log_eta0 + tf.math.log(tf.clip_by_value( |
|
|
731 |
self.cdf_layer(eta1) - self.cdf_layer(eta2), |
|
|
732 |
eps, 1e30)) |
|
|
733 |
|
|
|
734 |
log_p_zc_L, log_p_z_L, log_p_z = self._get_pz(temp_pi, inv_sig, beta_zc, log_p_z_c_L) |
|
|
735 |
return temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2, log_p_zc_L, log_p_z_L, log_p_z |
|
|
736 |
|
|
|
737 |
def get_posterior_c(self, z): |
|
|
738 |
'''Get \(p(c_i|Y_i,X_i)\). |
|
|
739 |
|
|
|
740 |
Parameters |
|
|
741 |
---------- |
|
|
742 |
z : tf.Tensor |
|
|
743 |
\([B, L, d]\) The latent variables. |
|
|
744 |
|
|
|
745 |
Returns |
|
|
746 |
---------- |
|
|
747 |
p_c_x : np.array |
|
|
748 |
\([B, K]\) \(p(c_i|Y_i,X_i)\). |
|
|
749 |
''' |
|
|
750 |
_,_,_,_,_,_,_, log_p_zc_L, log_p_z_L, _ = self.get_pz(z) |
|
|
751 |
log_p_c_x = self._get_posterior_c(log_p_zc_L, log_p_z_L) |
|
|
752 |
p_c_x = tf.exp(log_p_c_x).numpy() |
|
|
753 |
return p_c_x |
|
|
754 |
|
|
|
755 |
def call(self, z, pi=None, inference=False): |
|
|
756 |
'''Get posterior estimations. |
|
|
757 |
|
|
|
758 |
Parameters |
|
|
759 |
---------- |
|
|
760 |
z : tf.Tensor |
|
|
761 |
\([B, L, d]\) The latent variables. |
|
|
762 |
inference : boolean |
|
|
763 |
Whether in training or inference mode. |
|
|
764 |
|
|
|
765 |
When `inference=False`: |
|
|
766 |
|
|
|
767 |
Returns |
|
|
768 |
---------- |
|
|
769 |
log_p_z_L : tf.Tensor |
|
|
770 |
\([B, 1]\) The estimated \(\\log p(Z_i|Y_i,X_i)\). |
|
|
771 |
|
|
|
772 |
When `inference=True`: |
|
|
773 |
|
|
|
774 |
Returns |
|
|
775 |
---------- |
|
|
776 |
res : dict |
|
|
777 |
The dict of posterior estimations - \(p(c_i|Y_i,X_i)\), \(c\), \(E(\\tilde{w}_i|Y_i,X_i)\), \(Var(\\tilde{w}_i|Y_i,X_i)\), \(D_{JS}\). |
|
|
778 |
''' |
|
|
779 |
eps = 1e-16 if not inference else 0. |
|
|
780 |
temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2, log_p_zc_L, log_p_z_L, log_p_z = self.get_pz(z, eps, pi) |
|
|
781 |
|
|
|
782 |
if not inference: |
|
|
783 |
return log_p_z |
|
|
784 |
else: |
|
|
785 |
log_p_c_x = self._get_posterior_c(log_p_zc_L, log_p_z_L) |
|
|
786 |
w_tilde, var_w_tilde = self._get_inference(z, log_p_z_L, temp_pi, inv_sig, nu, beta_zc, log_eta0, eta1, eta2) |
|
|
787 |
|
|
|
788 |
res = {} |
|
|
789 |
res['p_c_x'] = tf.exp(log_p_c_x).numpy() |
|
|
790 |
res['w_tilde'] = w_tilde.numpy() |
|
|
791 |
res['var_w_tilde'] = var_w_tilde.numpy() |
|
|
792 |
return res</code></pre> |
|
|
793 |
</details> |
|
|
794 |
<h3>Ancestors</h3> |
|
|
795 |
<ul class="hlist"> |
|
|
796 |
<li>keras.src.engine.base_layer.Layer</li> |
|
|
797 |
<li>tensorflow.python.module.module.Module</li> |
|
|
798 |
<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li> |
|
|
799 |
<li>tensorflow.python.trackable.base.Trackable</li> |
|
|
800 |
<li>keras.src.utils.version_utils.LayerVersionSelector</li> |
|
|
801 |
</ul> |
|
|
802 |
<h3>Methods</h3> |
|
|
803 |
<dl> |
|
|
804 |
<dt id="VITAE.model.LatentSpace.initialize"><code class="name flex"> |
|
|
805 |
<span>def <span class="ident">initialize</span></span>(<span>self, mu, log_pi)</span> |
|
|
806 |
</code></dt> |
|
|
807 |
<dd> |
|
|
808 |
<div class="desc"><p>Initialize the latent space.</p> |
|
|
809 |
<h2 id="parameters">Parameters</h2> |
|
|
810 |
<dl> |
|
|
811 |
<dt><strong><code>mu</code></strong> : <code>np.array</code></dt> |
|
|
812 |
<dd><span><span class="MathJax_Preview">[d, k]</span><script type="math/tex">[d, k]</script></span> The position matrix.</dd> |
|
|
813 |
<dt><strong><code>log_pi</code></strong> : <code>np.array</code></dt> |
|
|
814 |
<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> <span><span class="MathJax_Preview">\log\pi</span><script type="math/tex">\log\pi</script></span>.</dd> |
|
|
815 |
</dl></div> |
|
|
816 |
</dd> |
|
|
817 |
<dt id="VITAE.model.LatentSpace.normalize"><code class="name flex"> |
|
|
818 |
<span>def <span class="ident">normalize</span></span>(<span>self)</span> |
|
|
819 |
</code></dt> |
|
|
820 |
<dd> |
|
|
821 |
<div class="desc"><p>Normalize <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</p></div> |
|
|
822 |
</dd> |
|
|
823 |
<dt id="VITAE.model.LatentSpace.get_pz"><code class="name flex"> |
|
|
824 |
<span>def <span class="ident">get_pz</span></span>(<span>self, z, eps, pi)</span> |
|
|
825 |
</code></dt> |
|
|
826 |
<dd> |
|
|
827 |
<div class="desc"><p>Get <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</p> |
|
|
828 |
<h2 id="parameters">Parameters</h2> |
|
|
829 |
<dl> |
|
|
830 |
<dt><strong><code>z</code></strong> : <code>tf.Tensor</code></dt> |
|
|
831 |
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The latent variables.</dd> |
|
|
832 |
</dl> |
|
|
833 |
<h2 id="returns">Returns</h2> |
|
|
834 |
<dl> |
|
|
835 |
<dt><strong><code>temp_pi</code></strong> : <code>tf.Tensor</code></dt> |
|
|
836 |
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</dd> |
|
|
837 |
<dt><strong><code>inv_sig</code></strong> : <code>tf.Tensor</code></dt> |
|
|
838 |
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\sigma_{Z_ic_i}^{-1}</span><script type="math/tex">\sigma_{Z_ic_i}^{-1}</script></span>.</dd> |
|
|
839 |
<dt><strong><code>nu</code></strong> : <code>tf.Tensor</code></dt> |
|
|
840 |
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\nu_{Z_ic_i}</span><script type="math/tex">\nu_{Z_ic_i}</script></span>.</dd> |
|
|
841 |
<dt><strong><code>beta_zc</code></strong> : <code>tf.Tensor</code></dt> |
|
|
842 |
<dd><span><span class="MathJax_Preview">[B, L, d, K]</span><script type="math/tex">[B, L, d, K]</script></span> <span><span class="MathJax_Preview">\beta_{Z_ic_i}</span><script type="math/tex">\beta_{Z_ic_i}</script></span>.</dd> |
|
|
843 |
<dt><strong><code>log_eta0</code></strong> : <code>tf.Tensor</code></dt> |
|
|
844 |
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\log\eta_{Z_ic_i,0}</span><script type="math/tex">\log\eta_{Z_ic_i,0}</script></span>.</dd> |
|
|
845 |
<dt><strong><code>eta1</code></strong> : <code>tf.Tensor</code></dt> |
|
|
846 |
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\eta_{Z_ic_i,1}</span><script type="math/tex">\eta_{Z_ic_i,1}</script></span>.</dd> |
|
|
847 |
<dt><strong><code>eta2</code></strong> : <code>tf.Tensor</code></dt> |
|
|
848 |
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\eta_{Z_ic_i,2}</span><script type="math/tex">\eta_{Z_ic_i,2}</script></span>.</dd> |
|
|
849 |
<dt><strong><code>log_p_zc_L</code></strong> : <code>tf.Tensor</code></dt> |
|
|
850 |
<dd><span><span class="MathJax_Preview">[B, L, K]</span><script type="math/tex">[B, L, K]</script></span> <span><span class="MathJax_Preview">\log p(Z_i,c_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i,c_i|Y_i,X_i)</script></span>.</dd> |
|
|
851 |
<dt><strong><code>log_p_z_L</code></strong> : <code>tf.Tensor</code></dt> |
|
|
852 |
<dd><span><span class="MathJax_Preview">[B, L]</span><script type="math/tex">[B, L]</script></span> <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</dd> |
|
|
853 |
<dt><strong><code>log_p_z</code></strong> : <code>tf.Tensor</code></dt> |
|
|
854 |
<dd><span><span class="MathJax_Preview">[B, 1]</span><script type="math/tex">[B, 1]</script></span> The estimated <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</dd> |
|
|
855 |
</dl></div> |
|
|
856 |
</dd> |
|
|
857 |
<dt id="VITAE.model.LatentSpace.get_posterior_c"><code class="name flex"> |
|
|
858 |
<span>def <span class="ident">get_posterior_c</span></span>(<span>self, z)</span> |
|
|
859 |
</code></dt> |
|
|
860 |
<dd> |
|
|
861 |
<div class="desc"><p>Get <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</p> |
|
|
862 |
<h2 id="parameters">Parameters</h2> |
|
|
863 |
<dl> |
|
|
864 |
<dt><strong><code>z</code></strong> : <code>tf.Tensor</code></dt> |
|
|
865 |
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The latent variables.</dd> |
|
|
866 |
</dl> |
|
|
867 |
<h2 id="returns">Returns</h2> |
|
|
868 |
<dl> |
|
|
869 |
<dt><strong><code>p_c_x</code></strong> : <code>np.array</code></dt> |
|
|
870 |
<dd><span><span class="MathJax_Preview">[B, K]</span><script type="math/tex">[B, K]</script></span> <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</dd> |
|
|
871 |
</dl></div> |
|
|
872 |
</dd> |
|
|
873 |
<dt id="VITAE.model.LatentSpace.call"><code class="name flex"> |
|
|
874 |
<span>def <span class="ident">call</span></span>(<span>self, z, pi=None, inference=False)</span> |
|
|
875 |
</code></dt> |
|
|
876 |
<dd> |
|
|
877 |
<div class="desc"><p>Get posterior estimations.</p> |
|
|
878 |
<h2 id="parameters">Parameters</h2> |
|
|
879 |
<dl> |
|
|
880 |
<dt><strong><code>z</code></strong> : <code>tf.Tensor</code></dt> |
|
|
881 |
<dd><span><span class="MathJax_Preview">[B, L, d]</span><script type="math/tex">[B, L, d]</script></span> The latent variables.</dd> |
|
|
882 |
<dt><strong><code>inference</code></strong> : <code>boolean</code></dt> |
|
|
883 |
<dd>Whether in training or inference mode.</dd> |
|
|
884 |
</dl> |
|
|
885 |
<p>When <code>inference=False</code>:</p> |
|
|
886 |
<h2 id="returns">Returns</h2> |
|
|
887 |
<dl> |
|
|
888 |
<dt><strong><code>log_p_z_L</code></strong> : <code>tf.Tensor</code></dt> |
|
|
889 |
<dd><span><span class="MathJax_Preview">[B, 1]</span><script type="math/tex">[B, 1]</script></span> The estimated <span><span class="MathJax_Preview">\log p(Z_i|Y_i,X_i)</span><script type="math/tex">\log p(Z_i|Y_i,X_i)</script></span>.</dd> |
|
|
890 |
</dl> |
|
|
891 |
<p>When <code>inference=True</code>:</p> |
|
|
892 |
<h2 id="returns_1">Returns</h2> |
|
|
893 |
<dl> |
|
|
894 |
<dt><strong><code>res</code></strong> : <code>dict</code></dt> |
|
|
895 |
<dd>The dict of posterior estimations - <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>, <span><span class="MathJax_Preview">c</span><script type="math/tex">c</script></span>, <span><span class="MathJax_Preview">E(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">E(\tilde{w}_i|Y_i,X_i)</script></span>, <span><span class="MathJax_Preview">Var(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">Var(\tilde{w}_i|Y_i,X_i)</script></span>, <span><span class="MathJax_Preview">D_{JS}</span><script type="math/tex">D_{JS}</script></span>.</dd> |
|
|
896 |
</dl></div> |
|
|
897 |
</dd> |
|
|
898 |
</dl> |
|
|
899 |
</dd> |
|
|
900 |
<dt id="VITAE.model.VariationalAutoEncoder"><code class="flex name class"> |
|
|
901 |
<span>class <span class="ident">VariationalAutoEncoder</span></span> |
|
|
902 |
<span>(</span><span>dim_origin, dimensions, dim_latent, data_type='UMI', has_cov=False, name='autoencoder', **kwargs)</span> |
|
|
903 |
</code></dt> |
|
|
904 |
<dd> |
|
|
905 |
<div class="desc"><p>Combines the encoder, decoder and LatentSpace into an end-to-end model for training and inference.</p> |
|
|
906 |
<h2 id="parameters">Parameters</h2> |
|
|
907 |
<dl> |
|
|
908 |
<dt><strong><code>dim_origin</code></strong> : <code>int</code></dt> |
|
|
909 |
<dd>The output dimension of the decoder.</dd> |
|
|
910 |
<dt><strong><code>dimensions</code></strong> : <code>np.array</code></dt> |
|
|
911 |
<dd>The dimensions of hidden layers of the encoder.</dd> |
|
|
912 |
<dt><strong><code>dim_latent</code></strong> : <code>int</code></dt> |
|
|
913 |
<dd>The latent dimension.</dd> |
|
|
914 |
<dt><strong><code>data_type</code></strong> : <code>str</code>, optional</dt> |
|
|
915 |
<dd><code>'UMI'</code>, <code>'non-UMI'</code>, or <code>'Gaussian'</code>.</dd> |
|
|
916 |
<dt><strong><code>has_cov</code></strong> : <code>boolean</code></dt> |
|
|
917 |
<dd>Whether has covariates or not.</dd> |
|
|
918 |
<dt><strong><code>gamma</code></strong> : <code>float</code>, optional</dt> |
|
|
919 |
<dd>The weights of the MMD loss</dd> |
|
|
920 |
<dt><strong><code>name</code></strong> : <code>str</code>, optional</dt> |
|
|
921 |
<dd>The name of the layer.</dd> |
|
|
922 |
<dt><strong><code>**kwargs</code></strong></dt> |
|
|
923 |
<dd>Extra keyword arguments.</dd> |
|
|
924 |
</dl></div> |
|
|
925 |
<details class="source"> |
|
|
926 |
<summary> |
|
|
927 |
<span>Expand source code</span> |
|
|
928 |
</summary> |
|
|
929 |
<pre><code class="python">class VariationalAutoEncoder(tf.keras.Model): |
|
|
930 |
""" |
|
|
931 |
Combines the encoder, decoder and LatentSpace into an end-to-end model for training and inference. |
|
|
932 |
""" |
|
|
933 |
def __init__(self, dim_origin, dimensions, dim_latent, |
|
|
934 |
data_type = 'UMI', has_cov=False, |
|
|
935 |
name = 'autoencoder', **kwargs): |
|
|
936 |
''' |
|
|
937 |
Parameters |
|
|
938 |
---------- |
|
|
939 |
dim_origin : int |
|
|
940 |
The output dimension of the decoder. |
|
|
941 |
dimensions : np.array |
|
|
942 |
The dimensions of hidden layers of the encoder. |
|
|
943 |
dim_latent : int |
|
|
944 |
The latent dimension. |
|
|
945 |
data_type : str, optional |
|
|
946 |
`'UMI'`, `'non-UMI'`, or `'Gaussian'`. |
|
|
947 |
has_cov : boolean |
|
|
948 |
Whether has covariates or not. |
|
|
949 |
gamma : float, optional |
|
|
950 |
The weights of the MMD loss |
|
|
951 |
name : str, optional |
|
|
952 |
The name of the layer. |
|
|
953 |
**kwargs : |
|
|
954 |
Extra keyword arguments. |
|
|
955 |
''' |
|
|
956 |
super(VariationalAutoEncoder, self).__init__(name = name, **kwargs) |
|
|
957 |
self.data_type = data_type |
|
|
958 |
self.dim_origin = dim_origin |
|
|
959 |
self.dim_latent = dim_latent |
|
|
960 |
self.encoder = Encoder(dimensions, dim_latent) |
|
|
961 |
self.decoder = Decoder(dimensions[::-1], dim_origin, data_type, data_type) |
|
|
962 |
self.has_cov = has_cov |
|
|
963 |
|
|
|
964 |
def init_latent_space(self, n_clusters, mu, log_pi=None): |
|
|
965 |
'''Initialze the latent space. |
|
|
966 |
|
|
|
967 |
Parameters |
|
|
968 |
---------- |
|
|
969 |
n_clusters : int |
|
|
970 |
The number of vertices in the latent space. |
|
|
971 |
mu : np.array |
|
|
972 |
\([d, k]\) The position matrix. |
|
|
973 |
log_pi : np.array, optional |
|
|
974 |
\([1, K]\) \(\\log\\pi\). |
|
|
975 |
''' |
|
|
976 |
self.n_states = n_clusters |
|
|
977 |
self.latent_space = LatentSpace(self.n_states, self.dim_latent) |
|
|
978 |
self.latent_space.initialize(mu, log_pi) |
|
|
979 |
self.pilayer = None |
|
|
980 |
|
|
|
981 |
def create_pilayer(self): |
|
|
982 |
self.pilayer = Dense(self.latent_space.n_categories, name = 'pi_layer') |
|
|
983 |
|
|
|
984 |
def call(self, x_normalized, c_score, x = None, scale_factor = 1, |
|
|
985 |
pre_train = False, L=1, alpha=0.0, gamma = 0.0, phi = 1.0, conditions = None, pi_cov = None): |
|
|
986 |
'''Feed forward through encoder, LatentSpace layer and decoder. |
|
|
987 |
|
|
|
988 |
Parameters |
|
|
989 |
---------- |
|
|
990 |
x_normalized : np.array |
|
|
991 |
\([B, G]\) The preprocessed data. |
|
|
992 |
c_score : np.array |
|
|
993 |
\([B, s]\) The covariates \(X_i\), only used when `has_cov=True`. |
|
|
994 |
x : np.array, optional |
|
|
995 |
\([B, G]\) The original count data \(Y_i\), only used when data_type is not `'Gaussian'`. |
|
|
996 |
scale_factor : np.array, optional |
|
|
997 |
\([B, ]\) The scale factors, only used when data_type is not `'Gaussian'`. |
|
|
998 |
pre_train : boolean, optional |
|
|
999 |
Whether in the pre-training phare or not. |
|
|
1000 |
L : int, optional |
|
|
1001 |
The number of MC samples. |
|
|
1002 |
alpha : float, optional |
|
|
1003 |
The penalty parameter for covariates adjustment. |
|
|
1004 |
gamma : float, optional |
|
|
1005 |
The weight of mmd loss |
|
|
1006 |
phi : float, optional |
|
|
1007 |
The weight of Jacob norm of the encoder. |
|
|
1008 |
conditions: str or list, optional |
|
|
1009 |
The conditions of different cells from the selected batch |
|
|
1010 |
|
|
|
1011 |
Returns |
|
|
1012 |
---------- |
|
|
1013 |
losses : float |
|
|
1014 |
the loss. |
|
|
1015 |
''' |
|
|
1016 |
|
|
|
1017 |
if not pre_train and self.latent_space is None: |
|
|
1018 |
raise ReferenceError('Have not initialized the latent space.') |
|
|
1019 |
|
|
|
1020 |
if self.has_cov: |
|
|
1021 |
x_normalized = tf.concat([x_normalized, c_score], -1) |
|
|
1022 |
else: |
|
|
1023 |
x_normalized |
|
|
1024 |
_, z_log_var, z = self.encoder(x_normalized, L) |
|
|
1025 |
|
|
|
1026 |
if gamma == 0: |
|
|
1027 |
mmd_loss = 0.0 |
|
|
1028 |
else: |
|
|
1029 |
mmd_loss = self._get_total_mmd_loss(conditions,z,gamma) |
|
|
1030 |
|
|
|
1031 |
z_in = tf.concat([z, tf.tile(tf.expand_dims(c_score,1), (1,L,1))], -1) if self.has_cov else z |
|
|
1032 |
|
|
|
1033 |
x = tf.tile(tf.expand_dims(x, 1), (1,L,1)) |
|
|
1034 |
reconstruction_z_loss = self._get_reconstruction_loss(x, z_in, scale_factor, L) |
|
|
1035 |
|
|
|
1036 |
if self.has_cov and alpha>0.0: |
|
|
1037 |
zero_in = tf.concat([tf.zeros([z.shape[0],1,z.shape[2]], dtype=tf.keras.backend.floatx()), |
|
|
1038 |
tf.tile(tf.expand_dims(c_score,1), (1,1,1))], -1) |
|
|
1039 |
reconstruction_zero_loss = self._get_reconstruction_loss(x, zero_in, scale_factor, 1) |
|
|
1040 |
reconstruction_z_loss = (1-alpha)*reconstruction_z_loss + alpha*reconstruction_zero_loss |
|
|
1041 |
|
|
|
1042 |
self.add_loss(reconstruction_z_loss) |
|
|
1043 |
J_norm = self._get_Jacob(x_normalized, L) |
|
|
1044 |
self.add_loss((phi * J_norm)) |
|
|
1045 |
# gamma weight has been used when call _mmd_loss function. |
|
|
1046 |
self.add_loss(mmd_loss) |
|
|
1047 |
|
|
|
1048 |
if not pre_train: |
|
|
1049 |
pi = self.pilayer(pi_cov) if self.pilayer is not None else None |
|
|
1050 |
log_p_z = self.latent_space(z, pi, inference=False) |
|
|
1051 |
|
|
|
1052 |
# - E_q[log p(z)] |
|
|
1053 |
self.add_loss(- log_p_z) |
|
|
1054 |
|
|
|
1055 |
# - Eq[log q(z|x)] |
|
|
1056 |
E_qzx = - tf.reduce_mean( |
|
|
1057 |
0.5 * self.dim_latent * |
|
|
1058 |
(tf.math.log(tf.constant(2 * np.pi, tf.keras.backend.floatx())) + 1.0) + |
|
|
1059 |
0.5 * tf.reduce_sum(z_log_var, axis=-1) |
|
|
1060 |
) |
|
|
1061 |
self.add_loss(E_qzx) |
|
|
1062 |
return self.losses |
|
|
1063 |
|
|
|
1064 |
@tf.function |
|
|
1065 |
def _get_reconstruction_loss(self, x, z_in, scale_factor, L): |
|
|
1066 |
if self.data_type=='Gaussian': |
|
|
1067 |
# Gaussian Log-Likelihood Loss function |
|
|
1068 |
nu_z, tau = self.decoder(z_in) |
|
|
1069 |
neg_E_Gaus = 0.5 * tf.math.log(tf.clip_by_value(tau, 1e-12, 1e30)) + 0.5 * tf.math.square(x - nu_z) / tau |
|
|
1070 |
neg_E_Gaus = tf.reduce_mean(tf.reduce_sum(neg_E_Gaus, axis=-1)) |
|
|
1071 |
|
|
|
1072 |
return neg_E_Gaus |
|
|
1073 |
else: |
|
|
1074 |
if self.data_type == 'UMI': |
|
|
1075 |
x_hat, r = self.decoder(z_in) |
|
|
1076 |
else: |
|
|
1077 |
x_hat, r, phi = self.decoder(z_in) |
|
|
1078 |
|
|
|
1079 |
x_hat = x_hat*tf.expand_dims(scale_factor, -1) |
|
|
1080 |
|
|
|
1081 |
# Negative Log-Likelihood Loss function |
|
|
1082 |
|
|
|
1083 |
# Ref for NB & ZINB loss functions: |
|
|
1084 |
# https://github.com/gokceneraslan/neuralnet_countmodels/blob/master/Count%20models%20with%20neuralnets.ipynb |
|
|
1085 |
# Negative Binomial loss |
|
|
1086 |
|
|
|
1087 |
neg_E_nb = tf.math.lgamma(r) + tf.math.lgamma(x+1.0) \ |
|
|
1088 |
- tf.math.lgamma(x+r) + \ |
|
|
1089 |
(r+x) * tf.math.log(1.0 + (x_hat/r)) + \ |
|
|
1090 |
x * (tf.math.log(r) - tf.math.log(tf.clip_by_value(x_hat, 1e-12, 1e30))) |
|
|
1091 |
|
|
|
1092 |
if self.data_type == 'non-UMI': |
|
|
1093 |
# Zero-Inflated Negative Binomial loss |
|
|
1094 |
nb_case = neg_E_nb - tf.math.log(tf.clip_by_value(1.0-phi, 1e-12, 1e30)) |
|
|
1095 |
zero_case = - tf.math.log(tf.clip_by_value( |
|
|
1096 |
phi + (1.0-phi) * tf.pow(r * tf.math.reciprocal_no_nan(r + x_hat), r), |
|
|
1097 |
1e-12, 1e30)) |
|
|
1098 |
neg_E_nb = tf.where(tf.less(x, 1e-8), zero_case, nb_case) |
|
|
1099 |
|
|
|
1100 |
neg_E_nb = tf.reduce_mean(tf.reduce_sum(neg_E_nb, axis=-1)) |
|
|
1101 |
return neg_E_nb |
|
|
1102 |
|
|
|
1103 |
def _get_total_mmd_loss(self,conditions,z,gamma): |
|
|
1104 |
mmd_loss = 0.0 |
|
|
1105 |
conditions = tf.cast(conditions,tf.int32) |
|
|
1106 |
n_group = conditions.shape[1] |
|
|
1107 |
|
|
|
1108 |
for i in range(n_group): |
|
|
1109 |
sub_conditions = conditions[:, i] |
|
|
1110 |
# 0 means not participant in mmd |
|
|
1111 |
z_cond = z[sub_conditions != 0] |
|
|
1112 |
sub_conditions = sub_conditions[sub_conditions != 0] |
|
|
1113 |
n_sub_group = tf.unique(sub_conditions)[0].shape[0] |
|
|
1114 |
real_labels = K.reshape(sub_conditions, (-1,)).numpy() |
|
|
1115 |
unique_set = list(set(real_labels)) |
|
|
1116 |
reindex_dict = dict(zip(unique_set, range(n_sub_group))) |
|
|
1117 |
real_labels = [reindex_dict[x] for x in real_labels] |
|
|
1118 |
real_labels = tf.convert_to_tensor(real_labels,dtype=tf.int32) |
|
|
1119 |
|
|
|
1120 |
if (n_sub_group == 1) | (n_sub_group == 0): |
|
|
1121 |
_loss = 0 |
|
|
1122 |
else: |
|
|
1123 |
_loss = self._mmd_loss(real_labels=real_labels, y_pred=z_cond, gamma=gamma, |
|
|
1124 |
n_conditions=n_sub_group, |
|
|
1125 |
kernel_method='multi-scale-rbf', |
|
|
1126 |
computation_method="general") |
|
|
1127 |
mmd_loss = mmd_loss + _loss |
|
|
1128 |
return mmd_loss |
|
|
1129 |
|
|
|
1130 |
# each loop the inputed shape is changed. Can not use @tf.function |
|
|
1131 |
# tf graph requires static shape and tensor dtype |
|
|
1132 |
def _mmd_loss(self, real_labels, y_pred, gamma, n_conditions, kernel_method='multi-scale-rbf', |
|
|
1133 |
computation_method="general"): |
|
|
1134 |
conditions_mmd = tf.dynamic_partition(y_pred, real_labels, num_partitions=n_conditions) |
|
|
1135 |
loss = 0.0 |
|
|
1136 |
if computation_method.isdigit(): |
|
|
1137 |
boundary = int(computation_method) |
|
|
1138 |
## every pair of groups will calculate a distance |
|
|
1139 |
for i in range(boundary): |
|
|
1140 |
for j in range(boundary, n_conditions): |
|
|
1141 |
loss += _nan2zero(compute_mmd(conditions_mmd[i], conditions_mmd[j], kernel_method)) |
|
|
1142 |
else: |
|
|
1143 |
for i in range(len(conditions_mmd)): |
|
|
1144 |
for j in range(i): |
|
|
1145 |
loss += _nan2zero(compute_mmd(conditions_mmd[i], conditions_mmd[j], kernel_method)) |
|
|
1146 |
|
|
|
1147 |
# print("The loss is ", loss) |
|
|
1148 |
return gamma * loss |
|
|
1149 |
|
|
|
1150 |
@tf.function |
|
|
1151 |
def _get_Jacob(self, x, L): |
|
|
1152 |
with tf.GradientTape() as g: |
|
|
1153 |
g.watch(x) |
|
|
1154 |
z_mean, z_log_var, z = self.encoder(x, L) |
|
|
1155 |
# y_mean, y_log_var = self.decoder(z) |
|
|
1156 |
## just jacobian will cause shape (batch,16,batch,64) matrix |
|
|
1157 |
J = g.batch_jacobian(z, x) |
|
|
1158 |
J_norm = tf.norm(J) |
|
|
1159 |
# tf.print(J_norm) |
|
|
1160 |
|
|
|
1161 |
return J_norm |
|
|
1162 |
|
|
|
1163 |
def get_z(self, x_normalized, c_score): |
|
|
1164 |
'''Get \(q(Z_i|Y_i,X_i)\). |
|
|
1165 |
|
|
|
1166 |
Parameters |
|
|
1167 |
---------- |
|
|
1168 |
x_normalized : int |
|
|
1169 |
\([B, G]\) The preprocessed data. |
|
|
1170 |
c_score : np.array |
|
|
1171 |
\([B, s]\) The covariates \(X_i\), only used when `has_cov=True`. |
|
|
1172 |
|
|
|
1173 |
Returns |
|
|
1174 |
---------- |
|
|
1175 |
z_mean : np.array |
|
|
1176 |
\([B, d]\) The latent mean. |
|
|
1177 |
''' |
|
|
1178 |
x_normalized = x_normalized if (not self.has_cov or c_score is None) else tf.concat([x_normalized, c_score], -1) |
|
|
1179 |
z_mean, _, _ = self.encoder(x_normalized, 1, False) |
|
|
1180 |
return z_mean.numpy() |
|
|
1181 |
|
|
|
1182 |
def get_pc_x(self, test_dataset): |
|
|
1183 |
'''Get \(p(c_i|Y_i,X_i)\). |
|
|
1184 |
|
|
|
1185 |
Parameters |
|
|
1186 |
---------- |
|
|
1187 |
test_dataset : tf.Dataset |
|
|
1188 |
the dataset object. |
|
|
1189 |
|
|
|
1190 |
Returns |
|
|
1191 |
---------- |
|
|
1192 |
pi_norm : np.array |
|
|
1193 |
\([1, K]\) The estimated \(\\pi\). |
|
|
1194 |
p_c_x : np.array |
|
|
1195 |
\([N, ]\) The estimated \(p(c_i|Y_i,X_i)\). |
|
|
1196 |
''' |
|
|
1197 |
if self.latent_space is None: |
|
|
1198 |
raise ReferenceError('Have not initialized the latent space.') |
|
|
1199 |
|
|
|
1200 |
pi_norm = tf.nn.softmax(self.latent_space.pi).numpy() |
|
|
1201 |
p_c_x = [] |
|
|
1202 |
for x,c_score in test_dataset: |
|
|
1203 |
x = tf.concat([x, c_score], -1) if self.has_cov else x |
|
|
1204 |
_, _, z = self.encoder(x, 1, False) |
|
|
1205 |
_p_c_x = self.latent_space.get_posterior_c(z) |
|
|
1206 |
p_c_x.append(_p_c_x) |
|
|
1207 |
p_c_x = np.concatenate(p_c_x) |
|
|
1208 |
return pi_norm, p_c_x |
|
|
1209 |
|
|
|
1210 |
def inference(self, test_dataset, L=1): |
|
|
1211 |
'''Get \(p(c_i|Y_i,X_i)\). |
|
|
1212 |
|
|
|
1213 |
Parameters |
|
|
1214 |
---------- |
|
|
1215 |
test_dataset : tf.Dataset |
|
|
1216 |
The dataset object. |
|
|
1217 |
L : int |
|
|
1218 |
The number of MC samples. |
|
|
1219 |
|
|
|
1220 |
Returns |
|
|
1221 |
---------- |
|
|
1222 |
pi_norm : np.array |
|
|
1223 |
\([1, K]\) The estimated \(\\pi\). |
|
|
1224 |
mu : np.array |
|
|
1225 |
\([d, k]\) The estimated \(\\mu\). |
|
|
1226 |
p_c_x : np.array |
|
|
1227 |
\([N, ]\) The estimated \(p(c_i|Y_i,X_i)\). |
|
|
1228 |
w_tilde : np.array |
|
|
1229 |
\([N, k]\) The estimated \(E(\\tilde{w}_i|Y_i,X_i)\). |
|
|
1230 |
var_w_tilde : np.array |
|
|
1231 |
\([N, k]\) The estimated \(Var(\\tilde{w}_i|Y_i,X_i)\). |
|
|
1232 |
z_mean : np.array |
|
|
1233 |
\([N, d]\) The estimated latent mean. |
|
|
1234 |
''' |
|
|
1235 |
if self.latent_space is None: |
|
|
1236 |
raise ReferenceError('Have not initialized the latent space.') |
|
|
1237 |
|
|
|
1238 |
print('Computing posterior estimations over mini-batches.') |
|
|
1239 |
progbar = Progbar(test_dataset.cardinality().numpy()) |
|
|
1240 |
pi_norm = tf.nn.softmax(self.latent_space.pi).numpy() |
|
|
1241 |
mu = self.latent_space.mu.numpy() |
|
|
1242 |
z_mean = [] |
|
|
1243 |
p_c_x = [] |
|
|
1244 |
w_tilde = [] |
|
|
1245 |
var_w_tilde = [] |
|
|
1246 |
for step, (x,c_score, _, _) in enumerate(test_dataset): |
|
|
1247 |
x = tf.concat([x, c_score], -1) if self.has_cov else x |
|
|
1248 |
_z_mean, _, z = self.encoder(x, L, False) |
|
|
1249 |
res = self.latent_space(z, inference=True) |
|
|
1250 |
|
|
|
1251 |
z_mean.append(_z_mean.numpy()) |
|
|
1252 |
p_c_x.append(res['p_c_x']) |
|
|
1253 |
w_tilde.append(res['w_tilde']) |
|
|
1254 |
var_w_tilde.append(res['var_w_tilde']) |
|
|
1255 |
progbar.update(step+1) |
|
|
1256 |
|
|
|
1257 |
z_mean = np.concatenate(z_mean) |
|
|
1258 |
p_c_x = np.concatenate(p_c_x) |
|
|
1259 |
w_tilde = np.concatenate(w_tilde) |
|
|
1260 |
w_tilde /= np.sum(w_tilde, axis=1, keepdims=True) |
|
|
1261 |
var_w_tilde = np.concatenate(var_w_tilde) |
|
|
1262 |
return pi_norm, mu, p_c_x, w_tilde, var_w_tilde, z_mean</code></pre> |
|
|
1263 |
</details> |
|
|
1264 |
<h3>Ancestors</h3> |
|
|
1265 |
<ul class="hlist"> |
|
|
1266 |
<li>keras.src.engine.training.Model</li> |
|
|
1267 |
<li>keras.src.engine.base_layer.Layer</li> |
|
|
1268 |
<li>tensorflow.python.module.module.Module</li> |
|
|
1269 |
<li>tensorflow.python.trackable.autotrackable.AutoTrackable</li> |
|
|
1270 |
<li>tensorflow.python.trackable.base.Trackable</li> |
|
|
1271 |
<li>keras.src.utils.version_utils.LayerVersionSelector</li> |
|
|
1272 |
<li>keras.src.utils.version_utils.ModelVersionSelector</li> |
|
|
1273 |
</ul> |
|
|
1274 |
<h3>Methods</h3> |
|
|
1275 |
<dl> |
|
|
1276 |
<dt id="VITAE.model.VariationalAutoEncoder.init_latent_space"><code class="name flex"> |
|
|
1277 |
<span>def <span class="ident">init_latent_space</span></span>(<span>self, n_clusters, mu, log_pi=None)</span> |
|
|
1278 |
</code></dt> |
|
|
1279 |
<dd> |
|
|
1280 |
<div class="desc"><p>Initialze the latent space.</p> |
|
|
1281 |
<h2 id="parameters">Parameters</h2> |
|
|
1282 |
<dl> |
|
|
1283 |
<dt><strong><code>n_clusters</code></strong> : <code>int</code></dt> |
|
|
1284 |
<dd>The number of vertices in the latent space.</dd> |
|
|
1285 |
<dt><strong><code>mu</code></strong> : <code>np.array</code></dt> |
|
|
1286 |
<dd><span><span class="MathJax_Preview">[d, k]</span><script type="math/tex">[d, k]</script></span> The position matrix.</dd> |
|
|
1287 |
<dt><strong><code>log_pi</code></strong> : <code>np.array</code>, optional</dt> |
|
|
1288 |
<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> <span><span class="MathJax_Preview">\log\pi</span><script type="math/tex">\log\pi</script></span>.</dd> |
|
|
1289 |
</dl></div> |
|
|
1290 |
</dd> |
|
|
1291 |
<dt id="VITAE.model.VariationalAutoEncoder.create_pilayer"><code class="name flex"> |
|
|
1292 |
<span>def <span class="ident">create_pilayer</span></span>(<span>self)</span> |
|
|
1293 |
</code></dt> |
|
|
1294 |
<dd> |
|
|
1295 |
<div class="desc"></div> |
|
|
1296 |
</dd> |
|
|
1297 |
<dt id="VITAE.model.VariationalAutoEncoder.call"><code class="name flex"> |
|
|
1298 |
<span>def <span class="ident">call</span></span>(<span>self, x_normalized, c_score, x=None, scale_factor=1, pre_train=False, L=1, alpha=0.0, gamma=0.0, phi=1.0, conditions=None, pi_cov=None)</span> |
|
|
1299 |
</code></dt> |
|
|
1300 |
<dd> |
|
|
1301 |
<div class="desc"><p>Feed forward through encoder, LatentSpace layer and decoder.</p> |
|
|
1302 |
<h2 id="parameters">Parameters</h2> |
|
|
1303 |
<dl> |
|
|
1304 |
<dt><strong><code>x_normalized</code></strong> : <code>np.array</code></dt> |
|
|
1305 |
<dd><span><span class="MathJax_Preview">[B, G]</span><script type="math/tex">[B, G]</script></span> The preprocessed data.</dd> |
|
|
1306 |
<dt><strong><code>c_score</code></strong> : <code>np.array</code></dt> |
|
|
1307 |
<dd><span><span class="MathJax_Preview">[B, s]</span><script type="math/tex">[B, s]</script></span> The covariates <span><span class="MathJax_Preview">X_i</span><script type="math/tex">X_i</script></span>, only used when <code>has_cov=True</code>.</dd> |
|
|
1308 |
<dt><strong><code>x</code></strong> : <code>np.array</code>, optional</dt> |
|
|
1309 |
<dd><span><span class="MathJax_Preview">[B, G]</span><script type="math/tex">[B, G]</script></span> The original count data <span><span class="MathJax_Preview">Y_i</span><script type="math/tex">Y_i</script></span>, only used when data_type is not <code>'Gaussian'</code>.</dd> |
|
|
1310 |
<dt><strong><code>scale_factor</code></strong> : <code>np.array</code>, optional</dt> |
|
|
1311 |
<dd><span><span class="MathJax_Preview">[B, ]</span><script type="math/tex">[B, ]</script></span> The scale factors, only used when data_type is not <code>'Gaussian'</code>.</dd> |
|
|
1312 |
<dt><strong><code>pre_train</code></strong> : <code>boolean</code>, optional</dt> |
|
|
1313 |
<dd>Whether in the pre-training phare or not.</dd> |
|
|
1314 |
<dt><strong><code>L</code></strong> : <code>int</code>, optional</dt> |
|
|
1315 |
<dd>The number of MC samples.</dd> |
|
|
1316 |
<dt><strong><code>alpha</code></strong> : <code>float</code>, optional</dt> |
|
|
1317 |
<dd>The penalty parameter for covariates adjustment.</dd> |
|
|
1318 |
<dt><strong><code>gamma</code></strong> : <code>float</code>, optional</dt> |
|
|
1319 |
<dd>The weight of mmd loss</dd> |
|
|
1320 |
<dt><strong><code>phi</code></strong> : <code>float</code>, optional</dt> |
|
|
1321 |
<dd>The weight of Jacob norm of the encoder.</dd> |
|
|
1322 |
<dt><strong><code>conditions</code></strong> : <code>str</code> or <code>list</code>, optional</dt> |
|
|
1323 |
<dd>The conditions of different cells from the selected batch</dd> |
|
|
1324 |
</dl> |
|
|
1325 |
<h2 id="returns">Returns</h2> |
|
|
1326 |
<dl> |
|
|
1327 |
<dt><strong><code>losses</code></strong> : <code>float</code></dt> |
|
|
1328 |
<dd>the loss.</dd> |
|
|
1329 |
</dl></div> |
|
|
1330 |
</dd> |
|
|
1331 |
<dt id="VITAE.model.VariationalAutoEncoder.get_z"><code class="name flex"> |
|
|
1332 |
<span>def <span class="ident">get_z</span></span>(<span>self, x_normalized, c_score)</span> |
|
|
1333 |
</code></dt> |
|
|
1334 |
<dd> |
|
|
1335 |
<div class="desc"><p>Get <span><span class="MathJax_Preview">q(Z_i|Y_i,X_i)</span><script type="math/tex">q(Z_i|Y_i,X_i)</script></span>.</p> |
|
|
1336 |
<h2 id="parameters">Parameters</h2> |
|
|
1337 |
<dl> |
|
|
1338 |
<dt><strong><code>x_normalized</code></strong> : <code>int</code></dt> |
|
|
1339 |
<dd><span><span class="MathJax_Preview">[B, G]</span><script type="math/tex">[B, G]</script></span> The preprocessed data.</dd> |
|
|
1340 |
<dt><strong><code>c_score</code></strong> : <code>np.array</code></dt> |
|
|
1341 |
<dd><span><span class="MathJax_Preview">[B, s]</span><script type="math/tex">[B, s]</script></span> The covariates <span><span class="MathJax_Preview">X_i</span><script type="math/tex">X_i</script></span>, only used when <code>has_cov=True</code>.</dd> |
|
|
1342 |
</dl> |
|
|
1343 |
<h2 id="returns">Returns</h2> |
|
|
1344 |
<dl> |
|
|
1345 |
<dt><strong><code>z_mean</code></strong> : <code>np.array</code></dt> |
|
|
1346 |
<dd><span><span class="MathJax_Preview">[B, d]</span><script type="math/tex">[B, d]</script></span> The latent mean.</dd> |
|
|
1347 |
</dl></div> |
|
|
1348 |
</dd> |
|
|
1349 |
<dt id="VITAE.model.VariationalAutoEncoder.get_pc_x"><code class="name flex"> |
|
|
1350 |
<span>def <span class="ident">get_pc_x</span></span>(<span>self, test_dataset)</span> |
|
|
1351 |
</code></dt> |
|
|
1352 |
<dd> |
|
|
1353 |
<div class="desc"><p>Get <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</p> |
|
|
1354 |
<h2 id="parameters">Parameters</h2> |
|
|
1355 |
<dl> |
|
|
1356 |
<dt><strong><code>test_dataset</code></strong> : <code>tf.Dataset</code></dt> |
|
|
1357 |
<dd>the dataset object.</dd> |
|
|
1358 |
</dl> |
|
|
1359 |
<h2 id="returns">Returns</h2> |
|
|
1360 |
<dl> |
|
|
1361 |
<dt><strong><code>pi_norm</code></strong> : <code>np.array</code></dt> |
|
|
1362 |
<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> The estimated <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</dd> |
|
|
1363 |
<dt><strong><code>p_c_x</code></strong> : <code>np.array</code></dt> |
|
|
1364 |
<dd><span><span class="MathJax_Preview">[N, ]</span><script type="math/tex">[N, ]</script></span> The estimated <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</dd> |
|
|
1365 |
</dl></div> |
|
|
1366 |
</dd> |
|
|
1367 |
<dt id="VITAE.model.VariationalAutoEncoder.inference"><code class="name flex"> |
|
|
1368 |
<span>def <span class="ident">inference</span></span>(<span>self, test_dataset, L=1)</span> |
|
|
1369 |
</code></dt> |
|
|
1370 |
<dd> |
|
|
1371 |
<div class="desc"><p>Get <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</p> |
|
|
1372 |
<h2 id="parameters">Parameters</h2> |
|
|
1373 |
<dl> |
|
|
1374 |
<dt><strong><code>test_dataset</code></strong> : <code>tf.Dataset</code></dt> |
|
|
1375 |
<dd>The dataset object.</dd> |
|
|
1376 |
<dt><strong><code>L</code></strong> : <code>int</code></dt> |
|
|
1377 |
<dd>The number of MC samples.</dd> |
|
|
1378 |
</dl> |
|
|
1379 |
<h2 id="returns">Returns</h2> |
|
|
1380 |
<dl> |
|
|
1381 |
<dt><code>pi_norm |
|
|
1382 |
: np.array</code></dt> |
|
|
1383 |
<dd><span><span class="MathJax_Preview">[1, K]</span><script type="math/tex">[1, K]</script></span> The estimated <span><span class="MathJax_Preview">\pi</span><script type="math/tex">\pi</script></span>.</dd> |
|
|
1384 |
<dt><strong><code>mu</code></strong> : <code>np.array</code></dt> |
|
|
1385 |
<dd><span><span class="MathJax_Preview">[d, k]</span><script type="math/tex">[d, k]</script></span> The estimated <span><span class="MathJax_Preview">\mu</span><script type="math/tex">\mu</script></span>.</dd> |
|
|
1386 |
<dt><strong><code>p_c_x</code></strong> : <code>np.array</code></dt> |
|
|
1387 |
<dd><span><span class="MathJax_Preview">[N, ]</span><script type="math/tex">[N, ]</script></span> The estimated <span><span class="MathJax_Preview">p(c_i|Y_i,X_i)</span><script type="math/tex">p(c_i|Y_i,X_i)</script></span>.</dd> |
|
|
1388 |
<dt><strong><code>w_tilde</code></strong> : <code>np.array</code></dt> |
|
|
1389 |
<dd><span><span class="MathJax_Preview">[N, k]</span><script type="math/tex">[N, k]</script></span> The estimated <span><span class="MathJax_Preview">E(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">E(\tilde{w}_i|Y_i,X_i)</script></span>.</dd> |
|
|
1390 |
<dt><code>var_w_tilde |
|
|
1391 |
: np.array</code></dt> |
|
|
1392 |
<dd><span><span class="MathJax_Preview">[N, k]</span><script type="math/tex">[N, k]</script></span> The estimated <span><span class="MathJax_Preview">Var(\tilde{w}_i|Y_i,X_i)</span><script type="math/tex">Var(\tilde{w}_i|Y_i,X_i)</script></span>.</dd> |
|
|
1393 |
<dt><strong><code>z_mean</code></strong> : <code>np.array</code></dt> |
|
|
1394 |
<dd><span><span class="MathJax_Preview">[N, d]</span><script type="math/tex">[N, d]</script></span> The estimated latent mean.</dd> |
|
|
1395 |
</dl></div> |
|
|
1396 |
</dd> |
|
|
1397 |
</dl> |
|
|
1398 |
</dd> |
|
|
1399 |
</dl> |
|
|
1400 |
</section> |
|
|
1401 |
</article> |
|
|
1402 |
<nav id="sidebar"> |
|
|
1403 |
<div class="toc"> |
|
|
1404 |
<ul></ul> |
|
|
1405 |
</div> |
|
|
1406 |
<ul id="index"> |
|
|
1407 |
<li><h3>Super-module</h3> |
|
|
1408 |
<ul> |
|
|
1409 |
<li><code><a title="VITAE" href="index.html">VITAE</a></code></li> |
|
|
1410 |
</ul> |
|
|
1411 |
</li> |
|
|
1412 |
<li><h3><a href="#header-classes">Classes</a></h3> |
|
|
1413 |
<ul> |
|
|
1414 |
<li> |
|
|
1415 |
<h4><code><a title="VITAE.model.cdf_layer" href="#VITAE.model.cdf_layer">cdf_layer</a></code></h4> |
|
|
1416 |
<ul class=""> |
|
|
1417 |
<li><code><a title="VITAE.model.cdf_layer.call" href="#VITAE.model.cdf_layer.call">call</a></code></li> |
|
|
1418 |
<li><code><a title="VITAE.model.cdf_layer.func" href="#VITAE.model.cdf_layer.func">func</a></code></li> |
|
|
1419 |
</ul> |
|
|
1420 |
</li> |
|
|
1421 |
<li> |
|
|
1422 |
<h4><code><a title="VITAE.model.Sampling" href="#VITAE.model.Sampling">Sampling</a></code></h4> |
|
|
1423 |
<ul class=""> |
|
|
1424 |
<li><code><a title="VITAE.model.Sampling.call" href="#VITAE.model.Sampling.call">call</a></code></li> |
|
|
1425 |
</ul> |
|
|
1426 |
</li> |
|
|
1427 |
<li> |
|
|
1428 |
<h4><code><a title="VITAE.model.Encoder" href="#VITAE.model.Encoder">Encoder</a></code></h4> |
|
|
1429 |
<ul class=""> |
|
|
1430 |
<li><code><a title="VITAE.model.Encoder.call" href="#VITAE.model.Encoder.call">call</a></code></li> |
|
|
1431 |
</ul> |
|
|
1432 |
</li> |
|
|
1433 |
<li> |
|
|
1434 |
<h4><code><a title="VITAE.model.Decoder" href="#VITAE.model.Decoder">Decoder</a></code></h4> |
|
|
1435 |
<ul class=""> |
|
|
1436 |
<li><code><a title="VITAE.model.Decoder.call" href="#VITAE.model.Decoder.call">call</a></code></li> |
|
|
1437 |
</ul> |
|
|
1438 |
</li> |
|
|
1439 |
<li> |
|
|
1440 |
<h4><code><a title="VITAE.model.LatentSpace" href="#VITAE.model.LatentSpace">LatentSpace</a></code></h4> |
|
|
1441 |
<ul class=""> |
|
|
1442 |
<li><code><a title="VITAE.model.LatentSpace.initialize" href="#VITAE.model.LatentSpace.initialize">initialize</a></code></li> |
|
|
1443 |
<li><code><a title="VITAE.model.LatentSpace.normalize" href="#VITAE.model.LatentSpace.normalize">normalize</a></code></li> |
|
|
1444 |
<li><code><a title="VITAE.model.LatentSpace.get_pz" href="#VITAE.model.LatentSpace.get_pz">get_pz</a></code></li> |
|
|
1445 |
<li><code><a title="VITAE.model.LatentSpace.get_posterior_c" href="#VITAE.model.LatentSpace.get_posterior_c">get_posterior_c</a></code></li> |
|
|
1446 |
<li><code><a title="VITAE.model.LatentSpace.call" href="#VITAE.model.LatentSpace.call">call</a></code></li> |
|
|
1447 |
</ul> |
|
|
1448 |
</li> |
|
|
1449 |
<li> |
|
|
1450 |
<h4><code><a title="VITAE.model.VariationalAutoEncoder" href="#VITAE.model.VariationalAutoEncoder">VariationalAutoEncoder</a></code></h4> |
|
|
1451 |
<ul class="two-column"> |
|
|
1452 |
<li><code><a title="VITAE.model.VariationalAutoEncoder.init_latent_space" href="#VITAE.model.VariationalAutoEncoder.init_latent_space">init_latent_space</a></code></li> |
|
|
1453 |
<li><code><a title="VITAE.model.VariationalAutoEncoder.create_pilayer" href="#VITAE.model.VariationalAutoEncoder.create_pilayer">create_pilayer</a></code></li> |
|
|
1454 |
<li><code><a title="VITAE.model.VariationalAutoEncoder.call" href="#VITAE.model.VariationalAutoEncoder.call">call</a></code></li> |
|
|
1455 |
<li><code><a title="VITAE.model.VariationalAutoEncoder.get_z" href="#VITAE.model.VariationalAutoEncoder.get_z">get_z</a></code></li> |
|
|
1456 |
<li><code><a title="VITAE.model.VariationalAutoEncoder.get_pc_x" href="#VITAE.model.VariationalAutoEncoder.get_pc_x">get_pc_x</a></code></li> |
|
|
1457 |
<li><code><a title="VITAE.model.VariationalAutoEncoder.inference" href="#VITAE.model.VariationalAutoEncoder.inference">inference</a></code></li> |
|
|
1458 |
</ul> |
|
|
1459 |
</li> |
|
|
1460 |
</ul> |
|
|
1461 |
</li> |
|
|
1462 |
</ul> |
|
|
1463 |
</nav> |
|
|
1464 |
</main> |
|
|
1465 |
<footer id="footer"> |
|
|
1466 |
<p>Generated by <a href="https://pdoc3.github.io/pdoc" title="pdoc: Python API documentation generator"><cite>pdoc</cite> 0.11.1</a>.</p> |
|
|
1467 |
</footer> |
|
|
1468 |
</body> |
|
|
1469 |
</html> |