Diff of /docs/index.html [000000] .. [2c6b19]

Switch to unified view

a b/docs/index.html
1
<!doctype html>
2
<html lang="en">
3
<head>
4
<meta charset="utf-8">
5
<meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1">
6
<meta name="generator" content="pdoc3 0.11.1">
7
<title>VITAE API documentation</title>
8
<meta name="description" content="">
9
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/sanitize.min.css" integrity="sha512-y1dtMcuvtTMJc1yPgEqF0ZjQbhnc/bFhyvIyVNb9Zk5mIGtqVaAB1Ttl28su8AvFMOY0EwRbAe+HCLqj6W7/KA==" crossorigin>
10
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/typography.min.css" integrity="sha512-Y1DYSb995BAfxobCkKepB1BqJJTPrOp3zPL74AWFugHHmmdcvO+C48WLrUOlhGMc0QG7AE3f7gmvvcrmX2fDoA==" crossorigin>
11
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/default.min.css" crossorigin>
12
<style>:root{--highlight-color:#fe9}.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:1.5em;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:2em 0 .50em 0}h3{font-size:1.4em;margin:1.6em 0 .7em 0}h4{margin:0;font-size:105%}h1:target,h2:target,h3:target,h4:target,h5:target,h6:target{background:var(--highlight-color);padding:.2em 0}a{color:#058;text-decoration:none;transition:color .2s ease-in-out}a:visited{color:#503}a:hover{color:#b62}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900;font-weight:bold}pre code{font-size:.8em;line-height:1.4em;padding:1em;display:block}code{background:#f3f3f3;font-family:"DejaVu Sans Mono",monospace;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}dt:target .name{background:var(--highlight-color)}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}td{padding:0 .5em}.admonition{padding:.1em 1em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style>
13
<style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul ul{padding-left:1em}.toc > ul > li{margin-top:.5em}}</style>
14
<style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style>
15
<script type="text/x-mathjax-config">MathJax.Hub.Config({ tex2jax: { inlineMath: [ ['$','$'], ["\\(","\\)"] ], processEscapes: true } });</script>
16
<script async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS_CHTML" integrity="sha256-kZafAc6mZvK3W3v1pHOcUix30OHQN6pU/NO2oFkqZVw=" crossorigin></script>
17
<script defer src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js" integrity="sha512-D9gUyxqja7hBtkWpPWGt9wfbfaMGVt9gnyCvYa+jojwwPHLCzUm5i8rpk7vD7wNee9bA35eYIjobYPaQuKS1MQ==" crossorigin></script>
18
<script>window.addEventListener('DOMContentLoaded', () => {
19
hljs.configure({languages: ['bash', 'css', 'diff', 'graphql', 'ini', 'javascript', 'json', 'plaintext', 'python', 'python-repl', 'rust', 'shell', 'sql', 'typescript', 'xml', 'yaml']});
20
hljs.highlightAll();
21
})</script>
22
</head>
23
<body>
24
<main>
25
<article id="content">
26
<header>
27
<h1 class="title">Package <code>VITAE</code></h1>
28
</header>
29
<section id="section-intro">
30
</section>
31
<section>
32
<h2 class="section-title" id="header-submodules">Sub-modules</h2>
33
<dl>
34
<dt><code class="name"><a title="VITAE.inference" href="inference.html">VITAE.inference</a></code></dt>
35
<dd>
36
<div class="desc"></div>
37
</dd>
38
<dt><code class="name"><a title="VITAE.metric" href="metric.html">VITAE.metric</a></code></dt>
39
<dd>
40
<div class="desc"></div>
41
</dd>
42
<dt><code class="name"><a title="VITAE.model" href="model.html">VITAE.model</a></code></dt>
43
<dd>
44
<div class="desc"></div>
45
</dd>
46
<dt><code class="name"><a title="VITAE.train" href="train.html">VITAE.train</a></code></dt>
47
<dd>
48
<div class="desc"></div>
49
</dd>
50
<dt><code class="name"><a title="VITAE.utils" href="utils.html">VITAE.utils</a></code></dt>
51
<dd>
52
<div class="desc"></div>
53
</dd>
54
</dl>
55
</section>
56
<section>
57
</section>
58
<section>
59
</section>
60
<section>
61
<h2 class="section-title" id="header-classes">Classes</h2>
62
<dl>
63
<dt id="VITAE.VITAE"><code class="flex name class">
64
<span>class <span class="ident">VITAE</span></span>
65
<span>(</span><span>adata: anndata._core.anndata.AnnData, covariates=None, pi_covariates=None, model_type: str = 'Gaussian', npc: int = 64, adata_layer_counts=None, copy_adata: bool = False, hidden_layers=[32], latent_space_dim: int = 16, conditions=None)</span>
66
</code></dt>
67
<dd>
68
<div class="desc"><p>Variational Inference for Trajectory by AutoEncoder.</p>
69
<p>Get input data for model. Data need to be first processed using scancy and stored as an AnnData object
70
The 'UMI' or 'non-UMI' model need the original count matrix, so the count matrix need to be saved in
71
adata.layers in order to use these models.</p>
72
<h2 id="parameters">Parameters</h2>
73
<dl>
74
<dt><strong><code>adata</code></strong> :&ensp;<code>sc.AnnData</code></dt>
75
<dd>The scanpy AnnData object. adata should already contain adata.var.highly_variable</dd>
76
<dt><strong><code>covariates</code></strong> :&ensp;<code>list</code>, optional</dt>
77
<dd>A list of names of covariate vectors that are stored in adata.obs</dd>
78
<dt><strong><code>pi_covariates</code></strong> :&ensp;<code>list</code>, optional</dt>
79
<dd>A list of names of covariate vectors used as input for pilayer</dd>
80
<dt><strong><code>model_type</code></strong> :&ensp;<code>str</code>, optional</dt>
81
<dd>'UMI', 'non-UMI' and 'Gaussian', default is 'Gaussian'.</dd>
82
<dt><strong><code>npc</code></strong> :&ensp;<code>int</code>, optional</dt>
83
<dd>The number of PCs to use when model_type is 'Gaussian'. The default is 64.</dd>
84
<dt><strong><code>adata_layer_counts</code></strong> :&ensp;<code>str</code>, optional</dt>
85
<dd>the key name of adata.layers that stores the count data if model_type is
86
'UMI' or 'non-UMI'</dd>
87
<dt><strong><code>copy_adata</code></strong> :&ensp;<code>bool</code>, optional<code>. Set to True if we don't want VITAE to modify the original adata. If set to True, self.adata will be an independent copy</code> of <code>the original adata. </code></dt>
88
<dd>&nbsp;</dd>
89
<dt><strong><code>hidden_layers</code></strong> :&ensp;<code>list</code>, optional</dt>
90
<dd>The list of dimensions of layers of autoencoder between latent space and original space. Default is to have only one hidden layer with 32 nodes</dd>
91
<dt><strong><code>latent_space_dim</code></strong> :&ensp;<code>int</code>, optional</dt>
92
<dd>The dimension of latent space.</dd>
93
<dt><strong><code>gamme</code></strong> :&ensp;<code>float</code>, optional</dt>
94
<dd>The weight of the MMD loss</dd>
95
<dt><strong><code>conditions</code></strong> :&ensp;<code>str</code> or <code>list</code>, optional</dt>
96
<dd>The conditions of different cells</dd>
97
</dl>
98
<h2 id="returns">Returns</h2>
99
<p>None.</p></div>
100
<details class="source">
101
<summary>
102
<span>Expand source code</span>
103
</summary>
104
<pre><code class="python">class VITAE():
105
    &#34;&#34;&#34;
106
    Variational Inference for Trajectory by AutoEncoder.
107
    &#34;&#34;&#34;
108
    def __init__(self, adata: sc.AnnData,
109
               covariates = None, pi_covariates = None,
110
               model_type: str = &#39;Gaussian&#39;,
111
               npc: int = 64,
112
               adata_layer_counts = None,
113
               copy_adata: bool = False,
114
               hidden_layers = [32],
115
               latent_space_dim: int = 16,
116
               conditions = None):
117
        &#39;&#39;&#39;
118
        Get input data for model. Data need to be first processed using scancy and stored as an AnnData object
119
         The &#39;UMI&#39; or &#39;non-UMI&#39; model need the original count matrix, so the count matrix need to be saved in
120
         adata.layers in order to use these models.
121
122
123
        Parameters
124
        ----------
125
        adata : sc.AnnData
126
            The scanpy AnnData object. adata should already contain adata.var.highly_variable
127
        covariates : list, optional
128
            A list of names of covariate vectors that are stored in adata.obs
129
        pi_covariates: list, optional
130
            A list of names of covariate vectors used as input for pilayer
131
        model_type : str, optional
132
            &#39;UMI&#39;, &#39;non-UMI&#39; and &#39;Gaussian&#39;, default is &#39;Gaussian&#39;.
133
        npc : int, optional
134
            The number of PCs to use when model_type is &#39;Gaussian&#39;. The default is 64.
135
        adata_layer_counts: str, optional
136
            the key name of adata.layers that stores the count data if model_type is
137
            &#39;UMI&#39; or &#39;non-UMI&#39;
138
        copy_adata: bool, optional. Set to True if we don&#39;t want VITAE to modify the original adata. If set to True, self.adata will be an independent copy of the original adata. 
139
        hidden_layers : list, optional
140
            The list of dimensions of layers of autoencoder between latent space and original space. Default is to have only one hidden layer with 32 nodes
141
        latent_space_dim : int, optional
142
            The dimension of latent space.
143
        gamme : float, optional
144
            The weight of the MMD loss
145
        conditions : str or list, optional
146
            The conditions of different cells
147
148
149
        Returns
150
        -------
151
        None.
152
153
        &#39;&#39;&#39;
154
        self.dict_method_scname = {
155
            &#39;PCA&#39; : &#39;X_pca&#39;,
156
            &#39;UMAP&#39; : &#39;X_umap&#39;,
157
            &#39;TSNE&#39; : &#39;X_tsne&#39;,
158
            &#39;diffmap&#39; : &#39;X_diffmap&#39;,
159
            &#39;draw_graph&#39; : &#39;X_draw_graph_fa&#39;
160
        }
161
162
        if model_type != &#39;Gaussian&#39;:
163
            if adata_layer_counts is None:
164
                raise ValueError(&#34;need to provide the name in adata.layers that stores the raw count data&#34;)
165
            if &#39;highly_variable&#39; not in adata.var:
166
                raise ValueError(&#34;need to first select highly variable genes using scanpy&#34;)
167
168
        self.model_type = model_type
169
170
        if copy_adata:
171
            self.adata = adata.copy()
172
        else:
173
            self.adata = adata
174
175
        if covariates is not None:
176
            if isinstance(covariates, str):
177
                covariates = [covariates]
178
            covariates = np.array(covariates)
179
            id_cat = (adata.obs[covariates].dtypes == &#39;category&#39;)
180
            # add OneHotEncoder &amp; StandardScaler as class variable if needed
181
            if np.sum(id_cat)&gt;0:
182
                covariates_cat = OneHotEncoder(drop=&#39;if_binary&#39;, handle_unknown=&#39;ignore&#39;
183
                    ).fit_transform(adata.obs[covariates[id_cat]]).toarray()
184
            else:
185
                covariates_cat = np.array([]).reshape(adata.shape[0],0)
186
187
            # temporarily disable StandardScaler
188
            if np.sum(~id_cat)&gt;0:
189
                #covariates_con = StandardScaler().fit_transform(adata.obs[covariates[~id_cat]])
190
                covariates_con = adata.obs[covariates[~id_cat]]
191
            else:
192
                covariates_con = np.array([]).reshape(adata.shape[0],0)
193
194
            self.covariates = np.c_[covariates_cat, covariates_con].astype(tf.keras.backend.floatx())
195
        else:
196
            self.covariates = None
197
198
        if conditions is not None:
199
            ## observations with np.nan will not participant in calculating mmd_loss
200
            if isinstance(conditions, str):
201
                conditions = [conditions]
202
            conditions = np.array(conditions)
203
            if np.any(adata.obs[conditions].dtypes != &#39;category&#39;):
204
                raise ValueError(&#34;Conditions should all be categorical.&#34;)
205
206
            self.conditions = OrdinalEncoder(dtype=int, encoded_missing_value=-1).fit_transform(adata.obs[conditions]) + int(1)
207
        else:
208
            self.conditions = None
209
210
        if pi_covariates is not None:
211
            self.pi_cov = adata.obs[pi_covariates].to_numpy()
212
            if self.pi_cov.ndim == 1:
213
                self.pi_cov = self.pi_cov.reshape(-1, 1)
214
                self.pi_cov = self.pi_cov.astype(tf.keras.backend.floatx())
215
        else:
216
            self.pi_cov = np.zeros((adata.shape[0],1), dtype=tf.keras.backend.floatx())
217
            
218
        self.model_type = model_type
219
        self._adata = sc.AnnData(X = self.adata.X, var = self.adata.var)
220
        self._adata.obs = self.adata.obs
221
        self._adata.uns = self.adata.uns
222
223
224
        if model_type == &#39;Gaussian&#39;:
225
            sc.tl.pca(adata, n_comps = npc)
226
            self.X_input = self.X_output = adata.obsm[&#39;X_pca&#39;]
227
            self.scale_factor = np.ones(self.X_output.shape[0])
228
        else:
229
            print(f&#34;{adata.var.highly_variable.sum()} highly variable genes selected as input&#34;) 
230
            self.X_input = adata.X[:, adata.var.highly_variable]
231
            self.X_output = adata.layers[adata_layer_counts][ :, adata.var.highly_variable]
232
            self.scale_factor = np.sum(self.X_output, axis=1, keepdims=True)/1e4
233
234
        self.dimensions = hidden_layers
235
        self.dim_latent = latent_space_dim
236
237
        self.vae = model.VariationalAutoEncoder(
238
            self.X_output.shape[1], self.dimensions,
239
            self.dim_latent, self.model_type,
240
            False if self.covariates is None else True,
241
            )
242
243
        if hasattr(self, &#39;inferer&#39;):
244
            delattr(self, &#39;inferer&#39;)
245
        
246
247
    def pre_train(self, test_size = 0.1, random_state: int = 0,
248
            learning_rate: float = 1e-3, batch_size: int = 256, L: int = 1, alpha: float = 0.10, gamma: float = 0,
249
            phi : float = 1,num_epoch: int = 200, num_step_per_epoch: Optional[int] = None,
250
            early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, 
251
            early_stopping_relative: bool = True, verbose: bool = False,path_to_weights: Optional[str] = None):
252
        &#39;&#39;&#39;Pretrain the model with specified learning rate.
253
254
        Parameters
255
        ----------
256
        test_size : float or int, optional
257
            The proportion or size of the test set.
258
        random_state : int, optional
259
            The random state for data splitting.
260
        learning_rate : float, optional
261
            The initial learning rate for the Adam optimizer.
262
        batch_size : int, optional 
263
            The batch size for pre-training.  Default is 256. Set to 32 if number of cells is small (less than 1000)
264
        L : int, optional 
265
            The number of MC samples.
266
        alpha : float, optional
267
            The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.
268
        gamma : float, optional
269
            The weight of the mmd loss if used.
270
        phi : float, optional
271
            The weight of Jocob norm of the encoder.
272
        num_epoch : int, optional 
273
            The maximum number of epochs.
274
        num_step_per_epoch : int, optional 
275
            The number of step per epoch, it will be inferred from number of cells and batch size if it is None.            
276
        early_stopping_patience : int, optional 
277
            The maximum number of epochs if there is no improvement.
278
        early_stopping_tolerance : float, optional 
279
            The minimum change of loss to be considered as an improvement.
280
        early_stopping_relative : bool, optional
281
            Whether monitor the relative change of loss as stopping criteria or not.
282
        path_to_weights : str, optional 
283
            The path of weight file to be saved; not saving weight if None.
284
        conditions : str or list, optional
285
            The conditions of different cells
286
        &#39;&#39;&#39;
287
288
        id_train, id_test = train_test_split(
289
                                np.arange(self.X_input.shape[0]), 
290
                                test_size=test_size, 
291
                                random_state=random_state)
292
        if num_step_per_epoch is None:
293
            num_step_per_epoch = len(id_train)//batch_size+1
294
        self.train_dataset = train.warp_dataset(self.X_input[id_train].astype(tf.keras.backend.floatx()), 
295
                                                None if self.covariates is None else self.covariates[id_train].astype(tf.keras.backend.floatx()),
296
                                                batch_size, 
297
                                                self.X_output[id_train].astype(tf.keras.backend.floatx()), 
298
                                                self.scale_factor[id_train].astype(tf.keras.backend.floatx()),
299
                                                conditions = None if self.conditions is None else self.conditions[id_train].astype(tf.keras.backend.floatx()))
300
        self.test_dataset = train.warp_dataset(self.X_input[id_test], 
301
                                                None if self.covariates is None else self.covariates[id_test].astype(tf.keras.backend.floatx()),
302
                                                batch_size, 
303
                                                self.X_output[id_test].astype(tf.keras.backend.floatx()), 
304
                                                self.scale_factor[id_test].astype(tf.keras.backend.floatx()),
305
                                                conditions = None if self.conditions is None else self.conditions[id_test].astype(tf.keras.backend.floatx()))
306
307
        self.vae = train.pre_train(
308
            self.train_dataset,
309
            self.test_dataset,
310
            self.vae,
311
            learning_rate,                        
312
            L, alpha, gamma, phi,
313
            num_epoch,
314
            num_step_per_epoch,
315
            early_stopping_patience,
316
            early_stopping_tolerance,
317
            early_stopping_relative,
318
            verbose)
319
        
320
        self.update_z()
321
322
        if path_to_weights is not None:
323
            self.save_model(path_to_weights)
324
            
325
326
    def update_z(self):
327
        self.z = self.get_latent_z()        
328
        self._adata_z = sc.AnnData(self.z)
329
        sc.pp.neighbors(self._adata_z)
330
331
            
332
    def get_latent_z(self):
333
        &#39;&#39;&#39; get the posterier mean of current latent space z (encoder output)
334
335
        Returns
336
        ----------
337
        z : np.array
338
            \([N,d]\) The latent means.
339
        &#39;&#39;&#39; 
340
        c = None if self.covariates is None else self.covariates
341
        return self.vae.get_z(self.X_input, c)
342
            
343
    
344
    def visualize_latent(self, method: str = &#34;UMAP&#34;, 
345
                         color = None, **kwargs):
346
        &#39;&#39;&#39;
347
        visualize the current latent space z using the scanpy visualization tools
348
349
        Parameters
350
        ----------
351
        method : str, optional
352
            Visualization method to use. The default is &#34;draw_graph&#34; (the FA plot). Possible choices include &#34;PCA&#34;, &#34;UMAP&#34;, 
353
            &#34;diffmap&#34;, &#34;TSNE&#34; and &#34;draw_graph&#34;
354
        color : TYPE, optional
355
            Keys for annotations of observations/cells or variables/genes, e.g., &#39;ann1&#39; or [&#39;ann1&#39;, &#39;ann2&#39;].
356
            The default is None. Same as scanpy.
357
        **kwargs :  
358
            Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX).   
359
360
        Returns
361
        -------
362
        None.
363
364
        &#39;&#39;&#39;
365
          
366
        if method not in [&#39;PCA&#39;, &#39;UMAP&#39;, &#39;TSNE&#39;, &#39;diffmap&#39;, &#39;draw_graph&#39;]:
367
            raise ValueError(&#34;visualization method should be one of &#39;PCA&#39;, &#39;UMAP&#39;, &#39;TSNE&#39;, &#39;diffmap&#39; and &#39;draw_graph&#39;&#34;)
368
        
369
        temp = list(self._adata_z.obsm.keys())
370
        if method == &#39;PCA&#39; and not &#39;X_pca&#39; in temp:
371
            print(&#34;Calculate PCs ...&#34;)
372
            sc.tl.pca(self._adata_z)
373
        elif method == &#39;UMAP&#39; and not &#39;X_umap&#39; in temp:  
374
            print(&#34;Calculate UMAP ...&#34;)
375
            sc.tl.umap(self._adata_z)
376
        elif method == &#39;TSNE&#39; and not &#39;X_tsne&#39; in temp:
377
            print(&#34;Calculate TSNE ...&#34;)
378
            sc.tl.tsne(self._adata_z)
379
        elif method == &#39;diffmap&#39; and not &#39;X_diffmap&#39; in temp:
380
            print(&#34;Calculate diffusion map ...&#34;)
381
            sc.tl.diffmap(self._adata_z)
382
        elif method == &#39;draw_graph&#39; and not &#39;X_draw_graph_fa&#39; in temp:
383
            print(&#34;Calculate FA ...&#34;)
384
            sc.tl.draw_graph(self._adata_z)
385
            
386
387
        self._adata.obs = self.adata.obs.copy()
388
        self._adata.obsp = self._adata_z.obsp
389
#        self._adata.uns = self._adata_z.uns
390
        self._adata.obsm = self._adata_z.obsm
391
    
392
        if method == &#39;PCA&#39;:
393
            axes = sc.pl.pca(self._adata, color = color, **kwargs)
394
        elif method == &#39;UMAP&#39;:            
395
            axes = sc.pl.umap(self._adata, color = color, **kwargs)
396
        elif method == &#39;TSNE&#39;:
397
            axes = sc.pl.tsne(self._adata, color = color, **kwargs)
398
        elif method == &#39;diffmap&#39;:
399
            axes = sc.pl.diffmap(self._adata, color = color, **kwargs)
400
        elif method == &#39;draw_graph&#39;:
401
            axes = sc.pl.draw_graph(self._adata, color = color, **kwargs)
402
        return axes
403
404
405
    def init_latent_space(self, cluster_label = None, log_pi = None, res: float = 1.0, 
406
                          ratio_prune= None, dist = None, dist_thres = 0.5, topk=0, pilayer = False):
407
        &#39;&#39;&#39;Initialize the latent space.
408
409
        Parameters
410
        ----------
411
        cluster_label : str, optional
412
            The name of vector of labels that can be found in self.adata.obs. 
413
            Default is None, which will perform leiden clustering on the pretrained z to get clusters
414
        mu : np.array, optional
415
            \([d,k]\) The value of initial \(\\mu\).
416
        log_pi : np.array, optional
417
            \([1,K]\) The value of initial \(\\log(\\pi)\).
418
        res: 
419
            The resolution of leiden clustering, which is a parameter value controlling the coarseness of the clustering. 
420
            Higher values lead to more clusters. Deafult is 1.
421
        ratio_prune : float, optional
422
            The ratio of edges to be removed before estimating.
423
        topk : int, optional
424
            The number of top k neighbors to keep for each cluster.
425
        &#39;&#39;&#39;   
426
    
427
        
428
        if cluster_label is None:
429
            print(&#34;Perform leiden clustering on the latent space z ...&#34;)
430
            g = get_igraph(self.z)
431
            cluster_labels = leidenalg_igraph(g, res = res)
432
            cluster_labels = cluster_labels.astype(str) 
433
            uni_cluster_labels = np.unique(cluster_labels)
434
        else:
435
            if isinstance(cluster_label,str):
436
                cluster_labels = self.adata.obs[cluster_label].to_numpy()
437
                uni_cluster_labels = np.array(self.adata.obs[cluster_label].cat.categories)
438
            else:
439
                ## if cluster_label is a list
440
                cluster_labels = cluster_label
441
                uni_cluster_labels = np.unique(cluster_labels)
442
443
        n_clusters = len(uni_cluster_labels)
444
445
        if not hasattr(self, &#39;z&#39;):
446
            self.update_z()        
447
        z = self.z
448
        mu = np.zeros((z.shape[1], n_clusters))
449
        for i,l in enumerate(uni_cluster_labels):
450
            mu[:,i] = np.mean(z[cluster_labels==l], axis=0)
451
       
452
        if dist is None:
453
            ### update cluster centers if some cluster centers are too close
454
            clustering = AgglomerativeClustering(
455
                n_clusters=None,
456
                distance_threshold=dist_thres,
457
                linkage=&#39;complete&#39;
458
                ).fit(mu.T/np.sqrt(mu.shape[0]))
459
            n_clusters_new = clustering.n_clusters_
460
            if n_clusters_new &lt; n_clusters:
461
                print(&#34;Merge clusters for cluster centers that are too close ...&#34;)
462
                n_clusters = n_clusters_new
463
                for i in range(n_clusters):    
464
                    temp = uni_cluster_labels[clustering.labels_ == i]
465
                    idx = np.isin(cluster_labels, temp)
466
                    cluster_labels[idx] = &#39;,&#39;.join(temp)
467
                    if np.sum(clustering.labels_==i)&gt;1:
468
                        print(&#39;Merge %s&#39;% &#39;,&#39;.join(temp))
469
                uni_cluster_labels = np.unique(cluster_labels)
470
                mu = np.zeros((z.shape[1], n_clusters))
471
                for i,l in enumerate(uni_cluster_labels):
472
                    mu[:,i] = np.mean(z[cluster_labels==l], axis=0)
473
            
474
        self.adata.obs[&#39;vitae_init_clustering&#39;] = cluster_labels
475
        self.adata.obs[&#39;vitae_init_clustering&#39;] = self.adata.obs[&#39;vitae_init_clustering&#39;].astype(&#39;category&#39;)
476
        print(&#34;Initial clustering labels saved as &#39;vitae_init_clustering&#39; in self.adata.obs.&#34;)
477
   
478
        if (log_pi is None) and (cluster_labels is not None) and (n_clusters&gt;3):                         
479
            n_states = int((n_clusters+1)*n_clusters/2)
480
            
481
            if dist is None:
482
                dist = _comp_dist(z, cluster_labels, mu.T)
483
484
            C = np.triu(np.ones(n_clusters))
485
            C[C&gt;0] = np.arange(n_states)
486
            C = C + C.T - np.diag(np.diag(C))
487
            C = C.astype(int)
488
489
            log_pi = np.zeros((1,n_states))            
490
491
            ## pruning to throw away edges for far-away clusters if there are too many clusters
492
            if ratio_prune is not None:
493
                log_pi[0, C[np.triu(dist)&gt;np.quantile(dist[np.triu_indices(n_clusters, 1)], 1-ratio_prune)]] = - np.inf
494
            else:
495
                log_pi[0, C[np.triu(dist)&gt;np.quantile(dist[np.triu_indices(n_clusters, 1)], 5/n_clusters) * 3]] = - np.inf
496
497
            ## also keep the top k neighbor of clusters
498
            topk = max(0, min(topk, n_clusters-1)) + 1
499
            topk_indices = np.argsort(dist,axis=1)[:,:topk]
500
            for i in range(n_clusters):
501
                log_pi[0, C[i, topk_indices[i]]] = 0
502
503
        self.n_states = n_clusters
504
        self.labels = cluster_labels
505
        
506
        labels_map = pd.DataFrame.from_dict(
507
            {i:label for i,label in enumerate(uni_cluster_labels)}, 
508
            orient=&#39;index&#39;, columns=[&#39;label_names&#39;], dtype=str
509
            )
510
        
511
        self.labels_map = labels_map
512
        self.vae.init_latent_space(self.n_states, mu, log_pi)
513
        self.inferer = Inferer(self.n_states)
514
        self.mu = self.vae.latent_space.mu.numpy()
515
        self.pi = np.triu(np.ones(self.n_states))
516
        self.pi[self.pi &gt; 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0]
517
518
        if pilayer:
519
            self.vae.create_pilayer()
520
521
522
    def update_latent_space(self, dist_thres: float=0.5):
523
        pi = self.pi[np.triu_indices(self.n_states)]
524
        mu = self.mu    
525
        clustering = AgglomerativeClustering(
526
            n_clusters=None,
527
            distance_threshold=dist_thres,
528
            linkage=&#39;complete&#39;
529
            ).fit(mu.T/np.sqrt(mu.shape[0]))
530
        n_clusters = clustering.n_clusters_   
531
532
        if n_clusters&lt;self.n_states:      
533
            print(&#34;Merge clusters for cluster centers that are too close ...&#34;)
534
            mu_new = np.empty((self.dim_latent, n_clusters))
535
            C = np.zeros((self.n_states, self.n_states))
536
            C[np.triu_indices(self.n_states, 0)] = pi
537
            C = np.triu(C, 1) + C.T
538
            C_new = np.zeros((n_clusters, n_clusters))
539
            
540
            uni_cluster_labels = self.labels_map[&#39;label_names&#39;].to_numpy()
541
            returned_order = {}
542
            cluster_labels = self.labels
543
            for i in range(n_clusters):
544
                temp = uni_cluster_labels[clustering.labels_ == i]
545
                idx = np.isin(cluster_labels, temp)
546
                cluster_labels[idx] = &#39;,&#39;.join(temp)
547
                returned_order[i] = &#39;,&#39;.join(temp)
548
                if np.sum(clustering.labels_==i)&gt;1:
549
                    print(&#39;Merge %s&#39;% &#39;,&#39;.join(temp))
550
            uni_cluster_labels = np.unique(cluster_labels) 
551
            for i,l in enumerate(uni_cluster_labels):  ## reorder the merged clusters based on the cluster names
552
                k = np.where(returned_order == l)
553
                mu_new[:, i] = np.mean(mu[:,clustering.labels_==k], axis=-1)
554
                # sum of the aggregated pi&#39;s
555
                C_new[i, i] = np.sum(np.triu(C[clustering.labels_==k,:][:,clustering.labels_==k]))
556
                for j in range(i+1, n_clusters):
557
                    k1 = np.where(returned_order == uni_cluster_labels[j])
558
                    C_new[i, j] = np.sum(C[clustering.labels_== k, :][:, clustering.labels_==k1])
559
560
#            labels_map_new = {}
561
#            for i in range(n_clusters):                       
562
#                # update label map: int-&gt;str
563
#                labels_map_new[i] = self.labels_map.loc[clustering.labels_==i, &#39;label_names&#39;].str.cat(sep=&#39;,&#39;)
564
#                if np.sum(clustering.labels_==i)&gt;1:
565
#                    print(&#39;Merge %s&#39;%labels_map_new[i])
566
#                # mean of the aggregated cluster means
567
#                mu_new[:, i] = np.mean(mu[:,clustering.labels_==i], axis=-1)
568
#                # sum of the aggregated pi&#39;s
569
#                C_new[i, i] = np.sum(np.triu(C[clustering.labels_==i,:][:,clustering.labels_==i]))
570
#                for j in range(i+1, n_clusters):
571
#                    C_new[i, j] = np.sum(C[clustering.labels_== i, :][:, clustering.labels_==j])
572
            C_new = np.triu(C_new,1) + C_new.T
573
574
            pi_new = C_new[np.triu_indices(n_clusters)]
575
            log_pi_new = np.log(pi_new, out=np.ones_like(pi_new)*(-np.inf), where=(pi_new!=0)).reshape((1,-1))
576
            self.n_states = n_clusters
577
            self.labels_map = pd.DataFrame.from_dict(
578
                {i:label for i,label in enumerate(uni_cluster_labels)},
579
                orient=&#39;index&#39;, columns=[&#39;label_names&#39;], dtype=str
580
                )
581
            self.labels = cluster_labels
582
#            self.labels_map = pd.DataFrame.from_dict(
583
#                labels_map_new, orient=&#39;index&#39;, columns=[&#39;label_names&#39;], dtype=str
584
#            )
585
            self.vae.init_latent_space(self.n_states, mu_new, log_pi_new)
586
            self.inferer = Inferer(self.n_states)
587
            self.mu = self.vae.latent_space.mu.numpy()
588
            self.pi = np.triu(np.ones(self.n_states))
589
            self.pi[self.pi &gt; 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0]
590
591
592
593
    def train(self, stratify = False, test_size = 0.1, random_state: int = 0,
594
            learning_rate: float = 1e-3, batch_size: int = 256,
595
            L: int = 1, alpha: float = 0.10, beta: float = 1, gamma: float = 0, phi: float = 1,
596
            num_epoch: int = 200, num_step_per_epoch: Optional[int] =  None,
597
            early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, 
598
            early_stopping_relative: bool = True, early_stopping_warmup: int = 0,
599
            path_to_weights: Optional[str] = None,
600
            verbose: bool = False, **kwargs):
601
        &#39;&#39;&#39;Train the model.
602
603
        Parameters
604
        ----------
605
        stratify : np.array, None, or False
606
            If an array is provided, or `stratify=None` and `self.labels` is available, then they will be used to perform stratified shuffle splitting. Otherwise, general shuffle splitting is used. Set to `False` if `self.labels` is not intended for stratified shuffle splitting.
607
        test_size : float or int, optional
608
            The proportion or size of the test set.
609
        random_state : int, optional
610
            The random state for data splitting.
611
        learning_rate : float, optional  
612
            The initial learning rate for the Adam optimizer.
613
        batch_size : int, optional  
614
            The batch size for training. Default is 256. Set to 32 if number of cells is small (less than 1000)
615
        L : int, optional  
616
            The number of MC samples.
617
        alpha : float, optional  
618
            The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.
619
        beta : float, optional  
620
            The value of beta in beta-VAE.
621
        gamma : float, optional
622
            The weight of mmd_loss.
623
        phi : float, optional
624
            The weight of Jacob norm of encoder.
625
        num_epoch : int, optional  
626
            The number of epoch.
627
        num_step_per_epoch : int, optional 
628
            The number of step per epoch, it will be inferred from number of cells and batch size if it is None.
629
        early_stopping_patience : int, optional 
630
            The maximum number of epochs if there is no improvement.
631
        early_stopping_tolerance : float, optional 
632
            The minimum change of loss to be considered as an improvement.
633
        early_stopping_relative : bool, optional
634
            Whether monitor the relative change of loss or not.            
635
        early_stopping_warmup : int, optional 
636
            The number of warmup epochs.            
637
        path_to_weights : str, optional 
638
            The path of weight file to be saved; not saving weight if None.
639
        **kwargs :  
640
            Extra key-value arguments for dimension reduction algorithms.        
641
        &#39;&#39;&#39;
642
        if gamma == 0 or self.conditions is None:
643
            conditions = np.array([np.nan] * self.adata.shape[0])
644
        else:
645
            conditions = self.conditions
646
647
        if stratify is None:
648
            stratify = self.labels
649
        elif stratify is False:
650
            stratify = None    
651
        id_train, id_test = train_test_split(
652
                                np.arange(self.X_input.shape[0]), 
653
                                test_size=test_size, 
654
                                stratify=stratify, 
655
                                random_state=random_state)
656
        if num_step_per_epoch is None:
657
            num_step_per_epoch = len(id_train)//batch_size+1
658
        c = None if self.covariates is None else self.covariates.astype(tf.keras.backend.floatx())
659
        self.train_dataset = train.warp_dataset(self.X_input[id_train].astype(tf.keras.backend.floatx()),
660
                                                None if c is None else c[id_train],
661
                                                batch_size, 
662
                                                self.X_output[id_train].astype(tf.keras.backend.floatx()), 
663
                                                self.scale_factor[id_train].astype(tf.keras.backend.floatx()),
664
                                                conditions = conditions[id_train],
665
                                                pi_cov = self.pi_cov[id_train])
666
        self.test_dataset = train.warp_dataset(self.X_input[id_test].astype(tf.keras.backend.floatx()),
667
                                                None if c is None else c[id_test],
668
                                                batch_size, 
669
                                                self.X_output[id_test].astype(tf.keras.backend.floatx()), 
670
                                                self.scale_factor[id_test].astype(tf.keras.backend.floatx()),
671
                                                conditions = conditions[id_test],
672
                                                pi_cov = self.pi_cov[id_test])
673
                                   
674
        self.vae = train.train(
675
            self.train_dataset,
676
            self.test_dataset,
677
            self.vae,
678
            learning_rate,
679
            L,
680
            alpha,
681
            beta,
682
            gamma,
683
            phi,
684
            num_epoch,
685
            num_step_per_epoch,
686
            early_stopping_patience,
687
            early_stopping_tolerance,
688
            early_stopping_relative,
689
            early_stopping_warmup,  
690
            verbose,
691
            **kwargs            
692
            )
693
        
694
        self.update_z()
695
        self.mu = self.vae.latent_space.mu.numpy()
696
        self.pi = np.triu(np.ones(self.n_states))
697
        self.pi[self.pi &gt; 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0]
698
            
699
        if path_to_weights is not None:
700
            self.save_model(path_to_weights)
701
    
702
703
    def output_pi(self, pi_cov):
704
        &#34;&#34;&#34;return a matrix n_states by n_states and a mask for plotting, which can be used to cover the lower triangular(except the diagnoals) of a heatmap&#34;&#34;&#34;
705
        p = self.vae.pilayer
706
        pi_cov = tf.expand_dims(tf.constant([pi_cov], dtype=tf.float32), 0)
707
        pi_val = tf.nn.softmax(p(pi_cov)).numpy()[0]
708
        # Create heatmap matrix
709
        n = self.vae.n_states
710
        matrix = np.zeros((n, n))
711
        matrix[np.triu_indices(n)] = pi_val
712
        mask = np.tril(np.ones_like(matrix), k=-1)
713
        return matrix, mask
714
715
716
    def return_pilayer_weights(self):
717
        &#34;&#34;&#34;return parameters of pilayer, which has dimension dim(pi_cov) + 1 by n_categories, the last row is biases&#34;&#34;&#34;
718
        return np.vstack((model.vae.pilayer.weights[0].numpy(), model.vae.pilayer.weights[1].numpy().reshape(1, -1)))
719
720
721
    def posterior_estimation(self, batch_size: int = 32, L: int = 50, **kwargs):
722
        &#39;&#39;&#39;Initialize trajectory inference by computing the posterior estimations.        
723
724
        Parameters
725
        ----------
726
        batch_size : int, optional
727
            The batch size when doing inference.
728
        L : int, optional
729
            The number of MC samples when doing inference.
730
        **kwargs :  
731
            Extra key-value arguments for dimension reduction algorithms.              
732
        &#39;&#39;&#39;
733
        c = None if self.covariates is None else self.covariates.astype(tf.keras.backend.floatx())
734
        self.test_dataset = train.warp_dataset(self.X_input.astype(tf.keras.backend.floatx()), 
735
                                               c,
736
                                               batch_size)
737
        _, _, self.pc_x,\
738
            self.cell_position_posterior,self.cell_position_variance,_ = self.vae.inference(self.test_dataset, L=L)
739
            
740
        uni_cluster_labels = self.labels_map[&#39;label_names&#39;].to_numpy()
741
        self.adata.obs[&#39;vitae_new_clustering&#39;] = uni_cluster_labels[np.argmax(self.cell_position_posterior, 1)]
742
        self.adata.obs[&#39;vitae_new_clustering&#39;] = self.adata.obs[&#39;vitae_new_clustering&#39;].astype(&#39;category&#39;)
743
        print(&#34;New clustering labels saved as &#39;vitae_new_clustering&#39; in self.adata.obs.&#34;)
744
        return None
745
746
747
    def infer_backbone(self, method: str = &#39;modified_map&#39;, thres = 0.5,
748
            no_loop: bool = True, cutoff: float = 0,
749
            visualize: bool = True, color = &#39;vitae_new_clustering&#39;,path_to_fig = None,**kwargs):
750
        &#39;&#39;&#39; Compute edge scores.
751
752
        Parameters
753
        ----------
754
        method : string, optional
755
            &#39;mean&#39;, &#39;modified_mean&#39;, &#39;map&#39;, or &#39;modified_map&#39;.
756
        thres : float, optional
757
            The threshold used for filtering edges \(e_{ij}\) that \((n_{i}+n_{j}+e_{ij})/N&lt;thres\), only applied to mean method.
758
        no_loop : boolean, optional
759
            Whether loops are allowed to exist in the graph. If no_loop is true, will prune the graph to contain only the
760
            maximum spanning true
761
        cutoff : string, optional
762
            The score threshold for filtering edges with scores less than cutoff.
763
        visualize: boolean
764
            whether plot the current trajectory backbone (undirected graph)
765
766
        Returns
767
        ----------
768
        G : nx.Graph
769
            The weighted graph with weight on each edge indicating its score of existence.
770
        &#39;&#39;&#39;
771
        # build_graph, return graph
772
        self.backbone = self.inferer.build_graphs(self.cell_position_posterior, self.pc_x,
773
                method, thres, no_loop, cutoff)
774
        self.cell_position_projected = self.inferer.modify_wtilde(self.cell_position_posterior, 
775
                np.array(list(self.backbone.edges)))
776
        
777
        uni_cluster_labels = self.labels_map[&#39;label_names&#39;].to_numpy()
778
        temp_dict = {i:label for i,label in enumerate(uni_cluster_labels)}
779
        nx.relabel_nodes(self.backbone, temp_dict)
780
       
781
        self.adata.obs[&#39;vitae_new_clustering&#39;] = uni_cluster_labels[np.argmax(self.cell_position_projected, 1)]
782
        self.adata.obs[&#39;vitae_new_clustering&#39;] = self.adata.obs[&#39;vitae_new_clustering&#39;].astype(&#39;category&#39;)
783
        print(&#34;&#39;vitae_new_clustering&#39; updated based on the projected cell positions.&#34;)
784
785
        self.uncertainty = np.sum((self.cell_position_projected - self.cell_position_posterior)**2, axis=-1) \
786
            + np.sum(self.cell_position_variance, axis=-1)
787
        self.adata.obs[&#39;projection_uncertainty&#39;] = self.uncertainty
788
        print(&#34;Cell projection uncertainties stored as &#39;projection_uncertainty&#39; in self.adata.obs&#34;)
789
        if visualize:
790
            self._adata.obs = self.adata.obs.copy()
791
            self.ax = self.plot_backbone(directed = False,color = color, **kwargs)
792
            if path_to_fig is not None:
793
                self.ax.figure.savefig(path_to_fig)
794
            self.ax.figure.show()
795
        return None
796
797
798
    def select_root(self, days, method: str = &#39;proportion&#39;):
799
        &#39;&#39;&#39;Order the vertices/states based on cells&#39; collection time information to select the root state.      
800
801
        Parameters
802
        ----------
803
        day : np.array 
804
            The day information for selected cells used to determine the root vertex.
805
            The dtype should be &#39;int&#39; or &#39;float&#39;.
806
        method : str, optional
807
            &#39;sum&#39; or &#39;mean&#39;. 
808
            For &#39;proportion&#39;, the root is the one with maximal proportion of cells from the earliest day.
809
            For &#39;mean&#39;, the root is the one with earliest mean time among cells associated with it.
810
811
        Returns
812
        ----------
813
        root : int 
814
            The root vertex in the inferred trajectory based on given day information.
815
        &#39;&#39;&#39;
816
        ## TODO: change return description
817
        if days is not None and len(days)!=self.X_input.shape[0]:
818
            raise ValueError(&#34;The length of day information ({}) is not &#34;
819
                &#34;consistent with the number of selected cells ({})!&#34;.format(
820
                    len(days), self.X_input.shape[0]))
821
        if not hasattr(self, &#39;cell_position_projected&#39;):
822
            raise ValueError(&#34;Need to call &#39;infer_backbone&#39; first!&#34;)
823
824
        collection_time = np.dot(days, self.cell_position_projected)/np.sum(self.cell_position_projected, axis = 0)
825
        earliest_prop = np.dot(days==np.min(days), self.cell_position_projected)/np.sum(self.cell_position_projected, axis = 0)
826
        
827
        root_info = self.labels_map.copy()
828
        root_info[&#39;mean_collection_time&#39;] = collection_time
829
        root_info[&#39;earliest_time_prop&#39;] = earliest_prop
830
        root_info.sort_values(&#39;mean_collection_time&#39;, inplace=True)
831
        return root_info
832
833
834
    def plot_backbone(self, directed: bool = False, 
835
                      method: str = &#39;UMAP&#39;, color = &#39;vitae_new_clustering&#39;, **kwargs):
836
        &#39;&#39;&#39;Plot the current trajectory backbone (undirected graph).
837
838
        Parameters
839
        ----------
840
        directed : boolean, optional
841
            Whether the backbone is directed or not.
842
        method : str, optional
843
            The dimension reduction method to use. The default is &#34;UMAP&#34;.
844
        color : str, optional
845
            The key for annotations of observations/cells or variables/genes, e.g., &#39;ann1&#39; or [&#39;ann1&#39;, &#39;ann2&#39;].
846
            The default is &#39;vitae_new_clustering&#39;.
847
        **kwargs :
848
            Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX).
849
        &#39;&#39;&#39;
850
        if not isinstance(color,str):
851
            raise ValueError(&#39;The color argument should be of type str!&#39;)
852
        ax = self.visualize_latent(method = method, color=color, show=False, **kwargs)
853
        dict_label_num = {j:i for i,j in self.labels_map[&#39;label_names&#39;].to_dict().items()}
854
        uni_cluster_labels = self.adata.obs[&#39;vitae_init_clustering&#39;].cat.categories
855
        cluster_labels = self.adata.obs[&#39;vitae_new_clustering&#39;].to_numpy()
856
        embed_z = self._adata.obsm[self.dict_method_scname[method]]
857
        embed_mu = np.zeros((len(uni_cluster_labels), 2))
858
        for l in uni_cluster_labels:
859
            embed_mu[dict_label_num[l],:] = np.mean(embed_z[cluster_labels==l], axis=0)
860
861
        if directed:
862
            graph = self.directed_backbone
863
        else:
864
            graph = self.backbone
865
        edges = list(graph.edges)
866
        edge_scores = np.array([d[&#39;weight&#39;] for (u,v,d) in graph.edges(data=True)])
867
        if max(edge_scores) - min(edge_scores) == 0:
868
            edge_scores = edge_scores/max(edge_scores)
869
        else:
870
            edge_scores = (edge_scores - min(edge_scores))/(max(edge_scores) - min(edge_scores))*3
871
872
        value_range = np.maximum(np.diff(ax.get_xlim())[0], np.diff(ax.get_ylim())[0])
873
        y_range = np.min(embed_z[:,1]), np.max(embed_z[:,1], axis=0)
874
        for i in range(len(edges)):
875
            points = embed_z[np.sum(self.cell_position_projected[:, edges[i]]&gt;0, axis=-1)==2,:]
876
            points = points[points[:,0].argsort()]
877
            try:
878
                x_smooth, y_smooth = _get_smooth_curve(
879
                    points,
880
                    embed_mu[edges[i], :],
881
                    y_range
882
                    )
883
            except:
884
                x_smooth, y_smooth = embed_mu[edges[i], 0], embed_mu[edges[i], 1]
885
            ax.plot(x_smooth, y_smooth,
886
                &#39;-&#39;,
887
                linewidth= 1 + edge_scores[i],
888
                color=&#34;black&#34;,
889
                alpha=0.8,
890
                path_effects=[pe.Stroke(linewidth=1+edge_scores[i]+1.5,
891
                                        foreground=&#39;white&#39;), pe.Normal()],
892
                zorder=1
893
                )
894
895
            if directed:
896
                delta_x = embed_mu[edges[i][1], 0] - x_smooth[-2]
897
                delta_y = embed_mu[edges[i][1], 1] - y_smooth[-2]
898
                length = np.sqrt(delta_x**2 + delta_y**2) / 50 * value_range
899
                ax.arrow(
900
                        embed_mu[edges[i][1], 0]-delta_x/length,
901
                        embed_mu[edges[i][1], 1]-delta_y/length,
902
                        delta_x/length,
903
                        delta_y/length,
904
                        color=&#39;black&#39;, alpha=1.0,
905
                        shape=&#39;full&#39;, lw=0, length_includes_head=True,
906
                        head_width=np.maximum(0.01*(1 + edge_scores[i]), 0.03) * value_range,
907
                        zorder=2) 
908
        
909
        colors = self._adata.uns[&#39;vitae_new_clustering_colors&#39;]
910
            
911
        for i,l in enumerate(uni_cluster_labels):
912
            ax.scatter(*embed_mu[dict_label_num[l]:dict_label_num[l]+1,:].T, 
913
                       c=[colors[i]], edgecolors=&#39;white&#39;, # linewidths=10,  norm=norm,
914
                       s=250, marker=&#39;*&#39;, label=l)
915
916
        plt.setp(ax, xticks=[], yticks=[])
917
        box = ax.get_position()
918
        ax.set_position([box.x0, box.y0 + box.height * 0.1,
919
                            box.width, box.height * 0.9])
920
        if directed:
921
            ax.legend(loc=&#39;upper center&#39;, bbox_to_anchor=(0.5, -0.05),
922
                fancybox=True, shadow=True, ncol=5)
923
924
        return ax
925
926
927
    def plot_center(self, color = &#34;vitae_new_clustering&#34;, plot_legend = True, legend_add_index = True,
928
                    method: str = &#39;UMAP&#39;,ncol = 2,font_size = &#34;medium&#34;,
929
                    add_egde = False, add_direct = False,**kwargs):
930
        &#39;&#39;&#39;Plot the center of each cluster in the latent space.
931
932
        Parameters
933
        ----------
934
        color : str, optional
935
            The color of the center of each cluster. Default is &#34;vitae_new_clustering&#34;.
936
        plot_legend : bool, optional
937
            Whether to plot the legend. Default is True.
938
        legend_add_index : bool, optional
939
            Whether to add the index of each cluster in the legend. Default is True.
940
        method : str, optional
941
            The dimension reduction method used for visualization. Default is &#39;UMAP&#39;.
942
        ncol : int, optional
943
            The number of columns in the legend. Default is 2.
944
        font_size : str, optional
945
            The font size of the legend. Default is &#34;medium&#34;.
946
        add_egde : bool, optional
947
            Whether to add the edges between the centers of clusters. Default is False.
948
        add_direct : bool, optional
949
            Whether to add the direction of the edges. Default is False.
950
        &#39;&#39;&#39;
951
        if color not in [&#34;vitae_new_clustering&#34;,&#34;vitae_init_clustering&#34;]:
952
            raise ValueError(&#34;Can only plot center of vitae_new_clustering or vitae_init_clustering&#34;)
953
        dict_label_num = {j: i for i, j in self.labels_map[&#39;label_names&#39;].to_dict().items()}
954
        if legend_add_index:
955
            self._adata.obs[&#34;index_&#34;+color] = self._adata.obs[color].map(lambda x: dict_label_num[x])
956
            ax = self.visualize_latent(method=method, color=&#34;index_&#34; + color, show=False, legend_loc=&#34;on data&#34;,
957
                                        legend_fontsize=font_size,**kwargs)
958
            colors = self._adata.uns[&#34;index_&#34; + color + &#39;_colors&#39;]
959
        else:
960
            ax = self.visualize_latent(method=method, color = color, show=False,**kwargs)
961
            colors = self._adata.uns[color + &#39;_colors&#39;]
962
        uni_cluster_labels = self.adata.obs[color].cat.categories
963
        cluster_labels = self.adata.obs[color].to_numpy()
964
        embed_z = self._adata.obsm[self.dict_method_scname[method]]
965
        embed_mu = np.zeros((len(uni_cluster_labels), 2))
966
        for l in uni_cluster_labels:
967
            embed_mu[dict_label_num[l], :] = np.mean(embed_z[cluster_labels == l], axis=0)
968
969
        leg = (self.labels_map.index.astype(str) + &#34; : &#34; + self.labels_map.label_names).values
970
        for i, l in enumerate(uni_cluster_labels):
971
            ax.scatter(*embed_mu[dict_label_num[l]:dict_label_num[l] + 1, :].T,
972
                       c=[colors[i]], edgecolors=&#39;white&#39;, # linewidths=3,
973
                       s=250, marker=&#39;*&#39;, label=leg[i])
974
        if plot_legend:
975
            ax.legend(loc=&#39;center left&#39;, bbox_to_anchor=(1, 0.5), ncol=ncol, markerscale=0.8, frameon=False)
976
        plt.setp(ax, xticks=[], yticks=[])
977
        box = ax.get_position()
978
        ax.set_position([box.x0, box.y0 + box.height * 0.1,
979
                         box.width, box.height * 0.9])
980
        if add_egde:
981
            if add_direct:
982
                graph = self.directed_backbone
983
            else:
984
                graph = self.backbone
985
            edges = list(graph.edges)
986
            edge_scores = np.array([d[&#39;weight&#39;] for (u, v, d) in graph.edges(data=True)])
987
            if max(edge_scores) - min(edge_scores) == 0:
988
                edge_scores = edge_scores / max(edge_scores)
989
            else:
990
                edge_scores = (edge_scores - min(edge_scores)) / (max(edge_scores) - min(edge_scores)) * 3
991
992
            value_range = np.maximum(np.diff(ax.get_xlim())[0], np.diff(ax.get_ylim())[0])
993
            y_range = np.min(embed_z[:, 1]), np.max(embed_z[:, 1], axis=0)
994
            for i in range(len(edges)):
995
                points = embed_z[np.sum(self.cell_position_projected[:, edges[i]] &gt; 0, axis=-1) == 2, :]
996
                points = points[points[:, 0].argsort()]
997
                try:
998
                    x_smooth, y_smooth = _get_smooth_curve(
999
                        points,
1000
                        embed_mu[edges[i], :],
1001
                        y_range
1002
                    )
1003
                except:
1004
                    x_smooth, y_smooth = embed_mu[edges[i], 0], embed_mu[edges[i], 1]
1005
                ax.plot(x_smooth, y_smooth,
1006
                        &#39;-&#39;,
1007
                        linewidth=1 + edge_scores[i],
1008
                        color=&#34;black&#34;,
1009
                        alpha=0.8,
1010
                        path_effects=[pe.Stroke(linewidth=1 + edge_scores[i] + 1.5,
1011
                                                foreground=&#39;white&#39;), pe.Normal()],
1012
                        zorder=1
1013
                        )
1014
1015
                if add_direct:
1016
                    delta_x = embed_mu[edges[i][1], 0] - x_smooth[-2]
1017
                    delta_y = embed_mu[edges[i][1], 1] - y_smooth[-2]
1018
                    length = np.sqrt(delta_x ** 2 + delta_y ** 2) / 50 * value_range
1019
                    ax.arrow(
1020
                        embed_mu[edges[i][1], 0] - delta_x / length,
1021
                        embed_mu[edges[i][1], 1] - delta_y / length,
1022
                        delta_x / length,
1023
                        delta_y / length,
1024
                        color=&#39;black&#39;, alpha=1.0,
1025
                        shape=&#39;full&#39;, lw=0, length_includes_head=True,
1026
                        head_width=np.maximum(0.01 * (1 + edge_scores[i]), 0.03) * value_range,
1027
                        zorder=2)
1028
        self.ax = ax
1029
        self.ax.figure.show()
1030
        return None
1031
1032
1033
    def infer_trajectory(self, root: Union[int,str], digraph = None, color = &#34;pseudotime&#34;,
1034
                         visualize: bool = True, path_to_fig = None,  **kwargs):
1035
        &#39;&#39;&#39;Infer the trajectory.
1036
1037
        Parameters
1038
        ----------
1039
        root : int or string
1040
            The root of the inferred trajectory. Can provide either an int (vertex index) or string (label name)
1041
        digraph : nx.DiGraph, optional
1042
            The directed graph to be used for trajectory inference. If None, the minimum spanning tree of the estimated trajectory backbone will be used.
1043
        cutoff : string, optional
1044
            The threshold for filtering edges with scores less than cutoff.
1045
        visualize: boolean
1046
            Whether plot the current trajectory backbone (directed graph)
1047
        path_to_fig : string, optional  
1048
            The path to save figure, or don&#39;t save if it is None.
1049
        **kwargs : dict, optional
1050
            Other keywords arguments for plotting.
1051
        &#39;&#39;&#39;
1052
        if isinstance(root,str):
1053
            if root not in self.labels_map.values:
1054
                raise ValueError(&#34;Root {} is not in the label names!&#34;.format(root))
1055
            root = self.labels_map[self.labels_map[&#39;label_names&#39;]==root].index[0]
1056
1057
        if digraph is None:
1058
            connected_comps = nx.node_connected_component(self.backbone, root)
1059
            subG = self.backbone.subgraph(connected_comps)
1060
1061
            ## generate directed backbone which contains no loops
1062
            DG = nx.DiGraph(nx.to_directed(self.backbone))
1063
            temp = DG.subgraph(connected_comps)
1064
            DG.remove_edges_from(temp.edges - nx.dfs_edges(DG, root))
1065
            self.directed_backbone = DG
1066
        else:
1067
            if not nx.is_directed_acyclic_graph(digraph):
1068
                raise ValueError(&#34;The graph &#39;digraph&#39; should be a directed acyclic graph.&#34;)
1069
            if set(digraph.nodes) != set(self.backbone.nodes):
1070
                raise ValueError(&#34;The nodes in &#39;digraph&#39; do not match the nodes in &#39;self.backbone&#39;.&#34;)
1071
            self.directed_backbone = digraph
1072
1073
            connected_comps = nx.node_connected_component(digraph, root)
1074
            subG = self.backbone.subgraph(connected_comps)
1075
1076
1077
        if len(subG.edges)&gt;0:
1078
            milestone_net = self.inferer.build_milestone_net(subG, root)
1079
            if self.inferer.no_loop is False and milestone_net.shape[0]&lt;len(self.backbone.edges):
1080
                warnings.warn(&#34;The directed graph shown is a minimum spanning tree of the estimated trajectory backbone to avoid arbitrary assignment of the directions.&#34;)
1081
            self.pseudotime = self.inferer.comp_pseudotime(milestone_net, root, self.cell_position_projected)
1082
        else:
1083
            warnings.warn(&#34;There are no connected states for starting from the giving root.&#34;)
1084
            self.pseudotime = -np.ones(self._adata.shape[0])
1085
1086
        self.adata.obs[&#39;pseudotime&#39;] = self.pseudotime
1087
        print(&#34;Cell projection uncertainties stored as &#39;pseudotime&#39; in self.adata.obs&#34;)
1088
1089
        if visualize:
1090
            self._adata.obs[&#39;pseudotime&#39;] = self.pseudotime
1091
            self.ax = self.plot_backbone(directed = True, color = color, **kwargs)
1092
            if path_to_fig is not None:
1093
                self.ax.figure.savefig(path_to_fig)
1094
            self.ax.figure.show()
1095
1096
        return None
1097
1098
1099
1100
    def differential_expression_test(self, alpha: float = 0.05, cell_subset = None, order: int = 1):
1101
        &#39;&#39;&#39;Differentially gene expression test. All (selected and unselected) genes will be tested 
1102
        Only cells in `selected_cell_subset` will be used, which is useful when one need to
1103
        test differentially expressed genes on a branch of the inferred trajectory.
1104
1105
        Parameters
1106
        ----------
1107
        alpha : float, optional
1108
            The cutoff of p-values.
1109
        cell_subset : np.array, optional
1110
            The subset of cells to be used for testing. If None, all cells will be used.
1111
        order : int, optional
1112
            The maxium order we used for pseudotime in regression.
1113
1114
        Returns
1115
        ----------
1116
        res_df : pandas.DataFrame
1117
            The test results of expressed genes with two columns,
1118
            the estimated coefficients and the adjusted p-values.
1119
        &#39;&#39;&#39;
1120
        if not hasattr(self, &#39;pseudotime&#39;):
1121
            raise ReferenceError(&#34;Pseudotime does not exist! Please run &#39;infer_trajectory&#39; first.&#34;)
1122
        if cell_subset is None:
1123
            cell_subset = np.arange(self.X_input.shape[0])
1124
            print(&#34;All cells are selected.&#34;)
1125
        if order &lt; 1:
1126
            raise  ValueError(&#34;Maximal order of pseudotime in regression must be at least 1.&#34;)
1127
1128
        # Prepare X and Y for regression expression ~ rank(PDT) + covariates
1129
        Y = self.adata.X[cell_subset,:]
1130
#        std_Y = np.std(Y, ddof=1, axis=0, keepdims=True)
1131
#        Y = np.divide(Y-np.mean(Y, axis=0, keepdims=True), std_Y, out=np.empty_like(Y)*np.nan, where=std_Y!=0)
1132
        X = stats.rankdata(self.pseudotime[cell_subset])        
1133
        if order &gt; 1:
1134
            for _order in range(2, order+1):
1135
                X = np.c_[X, X**_order]
1136
        X = ((X-np.mean(X,axis=0, keepdims=True))/np.std(X, ddof=1, axis=0, keepdims=True))
1137
        X = np.c_[np.ones((X.shape[0],1)), X]
1138
        if self.covariates is not None:
1139
            X = np.c_[X, self.covariates[cell_subset, :]]
1140
1141
        res_df = DE_test(Y, X, self.adata.var_names, i_test = np.array(list(range(1,order+1))), alpha = alpha)
1142
        return res_df[res_df.pvalue_adjusted_1 != 0]
1143
1144
1145
 
1146
1147
    def evaluate(self, milestone_net, begin_node_true, grouping = None,
1148
                thres: float = 0.5, no_loop: bool = True, cutoff: Optional[float] = None,
1149
                method: str = &#39;mean&#39;, path: Optional[str] = None):
1150
        &#39;&#39;&#39; Evaluate the model.
1151
1152
        Parameters
1153
        ----------
1154
        milestone_net : pd.DataFrame
1155
            The true milestone network. For real data, milestone_net will be a DataFrame of the graph of nodes.
1156
            Eg.
1157
1158
            from|to
1159
            ---|---
1160
            cluster 1 | cluster 1
1161
            cluster 1 | cluster 2
1162
1163
            For synthetic data, milestone_net will be a DataFrame of the (projected)
1164
            positions of cells. The indexes are the orders of cells in the dataset.
1165
            Eg.
1166
1167
            from|to|w
1168
            ---|---|---
1169
            cluster 1 | cluster 1 | 1
1170
            cluster 1 | cluster 2 | 0.1
1171
        begin_node_true : str or int
1172
            The true begin node of the milestone.
1173
        grouping : np.array, optional
1174
            \([N,]\) The labels. For real data, grouping must be provided.
1175
1176
        Returns
1177
        ----------
1178
        res : pd.DataFrame
1179
            The evaluation result.
1180
        &#39;&#39;&#39;
1181
        if not hasattr(self, &#39;labels_map&#39;):
1182
            raise ValueError(&#34;No given labels for training.&#34;)
1183
1184
        &#39;&#39;&#39;
1185
        # Evaluate for the whole dataset will ignore selected_cell_subset.
1186
        if len(self.selected_cell_subset)!=len(self.cell_names):
1187
            warnings.warn(&#34;Evaluate for the whole dataset.&#34;)
1188
        &#39;&#39;&#39;
1189
        
1190
        # If the begin_node_true, need to encode it by self.le.
1191
        # this dict is for milestone net cause their labels are not merged
1192
        # all keys of label_map_dict are str
1193
        label_map_dict = dict()
1194
        for i in range(self.labels_map.shape[0]):
1195
            label_mapped = self.labels_map.loc[i]
1196
            ## merged cluster index is connected by comma
1197
            for each in label_mapped.values[0].split(&#34;,&#34;):
1198
                label_map_dict[each] = i
1199
        if isinstance(begin_node_true, str):
1200
            begin_node_true = label_map_dict[begin_node_true]
1201
            
1202
        # For generated data, grouping information is already in milestone_net
1203
        if &#39;w&#39; in milestone_net.columns:
1204
            grouping = None
1205
            
1206
        # If milestone_net is provided, transform them to be numeric.
1207
        if milestone_net is not None:
1208
            milestone_net[&#39;from&#39;] = [label_map_dict[x] for x in milestone_net[&#34;from&#34;]]
1209
            milestone_net[&#39;to&#39;] = [label_map_dict[x] for x in milestone_net[&#34;to&#34;]]
1210
1211
        # this dict is for potentially merged clusters.
1212
        label_map_dict_for_merged_cluster = dict(zip(self.labels_map[&#34;label_names&#34;],self.labels_map.index))
1213
        mapped_labels = np.array([label_map_dict_for_merged_cluster[x] for x in self.labels])
1214
        begin_node_pred = int(np.argmin(np.mean((
1215
            self.z[mapped_labels==begin_node_true,:,np.newaxis] -
1216
            self.mu[np.newaxis,:,:])**2, axis=(0,1))))
1217
1218
        if cutoff is None:
1219
            cutoff = 0.01
1220
1221
        G = self.backbone
1222
        w = self.cell_position_projected
1223
        pseudotime = self.pseudotime
1224
1225
        # 1. Topology
1226
        G_pred = nx.Graph()
1227
        G_pred.add_nodes_from(G.nodes)
1228
        G_pred.add_edges_from(G.edges)
1229
        nx.set_node_attributes(G_pred, False, &#39;is_init&#39;)
1230
        G_pred.nodes[begin_node_pred][&#39;is_init&#39;] = True
1231
1232
        G_true = nx.Graph()
1233
        G_true.add_nodes_from(G.nodes)
1234
        # if &#39;grouping&#39; is not provided, assume &#39;milestone_net&#39; contains proportions
1235
        if grouping is None:
1236
            G_true.add_edges_from(list(
1237
                milestone_net[~pd.isna(milestone_net[&#39;w&#39;])].groupby([&#39;from&#39;, &#39;to&#39;]).count().index))
1238
        # otherwise, &#39;milestone_net&#39; indicates edges
1239
        else:
1240
            if milestone_net is not None:             
1241
                G_true.add_edges_from(list(
1242
                    milestone_net.groupby([&#39;from&#39;, &#39;to&#39;]).count().index))
1243
            grouping = [label_map_dict[x] for x in grouping]
1244
            grouping = np.array(grouping)
1245
        G_true.remove_edges_from(nx.selfloop_edges(G_true))
1246
        nx.set_node_attributes(G_true, False, &#39;is_init&#39;)
1247
        G_true.nodes[begin_node_true][&#39;is_init&#39;] = True
1248
        res = topology(G_true, G_pred)
1249
            
1250
        # 2. Milestones assignment
1251
        if grouping is None:
1252
            milestones_true = milestone_net[&#39;from&#39;].values.copy()
1253
            milestones_true[(milestone_net[&#39;from&#39;]!=milestone_net[&#39;to&#39;])
1254
                           &amp;(milestone_net[&#39;w&#39;]&lt;0.5)] = milestone_net[(milestone_net[&#39;from&#39;]!=milestone_net[&#39;to&#39;])
1255
                                                                      &amp;(milestone_net[&#39;w&#39;]&lt;0.5)][&#39;to&#39;].values
1256
        else:
1257
            milestones_true = grouping
1258
        milestones_true = milestones_true
1259
        milestones_pred = np.argmax(w, axis=1)
1260
        res[&#39;ARI&#39;] = (adjusted_rand_score(milestones_true, milestones_pred) + 1)/2
1261
        
1262
        if grouping is None:
1263
            n_samples = len(milestone_net)
1264
            prop = np.zeros((n_samples,n_samples))
1265
            prop[np.arange(n_samples), milestone_net[&#39;to&#39;]] = 1-milestone_net[&#39;w&#39;]
1266
            prop[np.arange(n_samples), milestone_net[&#39;from&#39;]] = np.where(np.isnan(milestone_net[&#39;w&#39;]), 1, milestone_net[&#39;w&#39;])
1267
            res[&#39;GRI&#39;] = get_GRI(prop, w)
1268
        else:
1269
            res[&#39;GRI&#39;] = get_GRI(grouping, w)
1270
        
1271
        # 3. Correlation between geodesic distances / Pseudotime
1272
        if no_loop:
1273
            if grouping is None:
1274
                pseudotime_true = milestone_net[&#39;from&#39;].values + 1 - milestone_net[&#39;w&#39;].values
1275
                pseudotime_true[np.isnan(pseudotime_true)] = milestone_net[pd.isna(milestone_net[&#39;w&#39;])][&#39;from&#39;].values            
1276
            else:
1277
                pseudotime_true = - np.ones(len(grouping))
1278
                nx.set_edge_attributes(G_true, values = 1, name = &#39;weight&#39;)
1279
                connected_comps = nx.node_connected_component(G_true, begin_node_true)
1280
                subG = G_true.subgraph(connected_comps)
1281
                milestone_net_true = self.inferer.build_milestone_net(subG, begin_node_true)
1282
                if len(milestone_net_true)&gt;0:
1283
                    pseudotime_true[grouping==int(milestone_net_true[0,0])] = 0
1284
                    for i in range(len(milestone_net_true)):
1285
                        pseudotime_true[grouping==int(milestone_net_true[i,1])] = milestone_net_true[i,-1]
1286
            pseudotime_true = pseudotime_true[pseudotime&gt;-1]
1287
            pseudotime_pred = pseudotime[pseudotime&gt;-1]
1288
            res[&#39;PDT score&#39;] = (np.corrcoef(pseudotime_true,pseudotime_pred)[0,1]+1)/2
1289
        else:
1290
            res[&#39;PDT score&#39;] = np.nan
1291
            
1292
        # 4. Shape
1293
        # score_cos_theta = 0
1294
        # for (_from,_to) in G.edges:
1295
        #     _z = self.z[(w[:,_from]&gt;0) &amp; (w[:,_to]&gt;0),:]
1296
        #     v_1 = _z - self.mu[:,_from]
1297
        #     v_2 = _z - self.mu[:,_to]
1298
        #     cos_theta = np.sum(v_1*v_2, -1)/(np.linalg.norm(v_1,axis=-1)*np.linalg.norm(v_2,axis=-1)+1e-12)
1299
1300
        #     score_cos_theta += np.sum((1-cos_theta)/2)
1301
1302
        # res[&#39;score_cos_theta&#39;] = score_cos_theta/(np.sum(np.sum(w&gt;0, axis=-1)==2)+1e-12)
1303
        return res
1304
1305
1306
    def save_model(self, path_to_file: str = &#39;model.checkpoint&#39;,save_adata: bool = False):
1307
        &#39;&#39;&#39;Saving model weights.
1308
1309
        Parameters
1310
        ----------
1311
        path_to_file : str, optional
1312
            The path to weight files of pre-trained or trained model
1313
        save_adata : boolean, optional
1314
            Whether to save adata or not.
1315
        &#39;&#39;&#39;
1316
        self.vae.save_weights(path_to_file)
1317
        if hasattr(self, &#39;labels&#39;) and self.labels is not None:
1318
            with open(path_to_file + &#39;.label&#39;, &#39;wb&#39;) as f:
1319
                np.save(f, self.labels)
1320
        with open(path_to_file + &#39;.config&#39;, &#39;wb&#39;) as f:
1321
            self.dim_origin = self.X_input.shape[1]
1322
            np.save(f, np.array([
1323
                self.dim_origin, self.dimensions, self.dim_latent,
1324
                self.model_type, 0 if self.covariates is None else self.covariates.shape[1]], dtype=object))
1325
        if hasattr(self, &#39;inferer&#39;) and hasattr(self, &#39;uncertainty&#39;):
1326
            with open(path_to_file + &#39;.inference&#39;, &#39;wb&#39;) as f:
1327
                np.save(f, np.array([
1328
                    self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty,
1329
                    self.z,self.cell_position_variance], dtype=object))
1330
        if save_adata:
1331
            self.adata.write(path_to_file + &#39;.adata.h5ad&#39;)
1332
1333
1334
    def load_model(self, path_to_file: str = &#39;model.checkpoint&#39;, load_labels: bool = False, load_adata: bool = False):
1335
        &#39;&#39;&#39;Load model weights.
1336
1337
        Parameters
1338
        ----------
1339
        path_to_file : str, optional
1340
            The path to weight files of pre trained or trained model
1341
        load_labels : boolean, optional
1342
            Whether to load clustering labels or not.
1343
            If load_labels is True, then the LatentSpace layer will be initialized basd on the model.
1344
            If load_labels is False, then the LatentSpace layer will not be initialized.
1345
        load_adata : boolean, optional
1346
            Whether to load adata or not.
1347
        &#39;&#39;&#39;
1348
        if not os.path.exists(path_to_file + &#39;.config&#39;):
1349
            raise AssertionError(&#39;Config file not exist!&#39;)
1350
        if load_labels and not os.path.exists(path_to_file + &#39;.label&#39;):
1351
            raise AssertionError(&#39;Label file not exist!&#39;)
1352
1353
        with open(path_to_file + &#39;.config&#39;, &#39;rb&#39;) as f:
1354
            [self.dim_origin, self.dimensions,
1355
             self.dim_latent, self.model_type, cov_dim] = np.load(f, allow_pickle=True)
1356
        self.vae = model.VariationalAutoEncoder(
1357
            self.dim_origin, self.dimensions,
1358
            self.dim_latent, self.model_type, False if cov_dim == 0 else True
1359
        )
1360
1361
        if load_labels:
1362
            with open(path_to_file + &#39;.label&#39;, &#39;rb&#39;) as f:
1363
                cluster_labels = np.load(f, allow_pickle=True)
1364
            self.init_latent_space(cluster_labels, dist_thres=0)
1365
            if os.path.exists(path_to_file + &#39;.inference&#39;):
1366
                with open(path_to_file + &#39;.inference&#39;, &#39;rb&#39;) as f:
1367
                    arr = np.load(f, allow_pickle=True)
1368
                    if len(arr) == 8:
1369
                        [self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty,
1370
                         self.D_JS, self.z,self.cell_position_variance] = arr
1371
                    else:
1372
                        [self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty,
1373
                         self.z,self.cell_position_variance] = arr
1374
                self._adata_z = sc.AnnData(self.z)
1375
                sc.pp.neighbors(self._adata_z)
1376
        ## initialize the weight of encoder and decoder
1377
        self.vae.encoder(np.zeros((1, self.dim_origin + cov_dim)))
1378
        self.vae.decoder(np.expand_dims(np.zeros((1,self.dim_latent + cov_dim)),1))
1379
1380
        self.vae.load_weights(path_to_file)
1381
        self.update_z()
1382
        if load_adata:
1383
            if not os.path.exists(path_to_file + &#39;.adata.h5ad&#39;):
1384
                raise AssertionError(&#39;AnnData file not exist!&#39;)
1385
            self.adata = sc.read_h5ad(path_to_file + &#39;.adata.h5ad&#39;)
1386
            self._adata.obs = self.adata.obs.copy()</code></pre>
1387
</details>
1388
<h3>Methods</h3>
1389
<dl>
1390
<dt id="VITAE.VITAE.pre_train"><code class="name flex">
1391
<span>def <span class="ident">pre_train</span></span>(<span>self, test_size=0.1, random_state: int = 0, learning_rate: float = 0.001, batch_size: int = 256, L: int = 1, alpha: float = 0.1, gamma: float = 0, phi: float = 1, num_epoch: int = 200, num_step_per_epoch: Optional[int] = None, early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, early_stopping_relative: bool = True, verbose: bool = False, path_to_weights: Optional[str] = None)</span>
1392
</code></dt>
1393
<dd>
1394
<div class="desc"><p>Pretrain the model with specified learning rate.</p>
1395
<h2 id="parameters">Parameters</h2>
1396
<dl>
1397
<dt><strong><code>test_size</code></strong> :&ensp;<code>float</code> or <code>int</code>, optional</dt>
1398
<dd>The proportion or size of the test set.</dd>
1399
<dt><strong><code>random_state</code></strong> :&ensp;<code>int</code>, optional</dt>
1400
<dd>The random state for data splitting.</dd>
1401
<dt><strong><code>learning_rate</code></strong> :&ensp;<code>float</code>, optional</dt>
1402
<dd>The initial learning rate for the Adam optimizer.</dd>
1403
<dt><strong><code>batch_size</code></strong> :&ensp;<code>int</code>, optional</dt>
1404
<dd>The batch size for pre-training.
1405
Default is 256. Set to 32 if number of cells is small (less than 1000)</dd>
1406
<dt><strong><code>L</code></strong> :&ensp;<code>int</code>, optional</dt>
1407
<dd>The number of MC samples.</dd>
1408
<dt><strong><code>alpha</code></strong> :&ensp;<code>float</code>, optional</dt>
1409
<dd>The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.</dd>
1410
<dt><strong><code>gamma</code></strong> :&ensp;<code>float</code>, optional</dt>
1411
<dd>The weight of the mmd loss if used.</dd>
1412
<dt><strong><code>phi</code></strong> :&ensp;<code>float</code>, optional</dt>
1413
<dd>The weight of Jocob norm of the encoder.</dd>
1414
<dt><strong><code>num_epoch</code></strong> :&ensp;<code>int</code>, optional</dt>
1415
<dd>The maximum number of epochs.</dd>
1416
<dt><strong><code>num_step_per_epoch</code></strong> :&ensp;<code>int</code>, optional</dt>
1417
<dd>The number of step per epoch, it will be inferred from number of cells and batch size if it is None.</dd>
1418
<dt><strong><code>early_stopping_patience</code></strong> :&ensp;<code>int</code>, optional</dt>
1419
<dd>The maximum number of epochs if there is no improvement.</dd>
1420
<dt><strong><code>early_stopping_tolerance</code></strong> :&ensp;<code>float</code>, optional</dt>
1421
<dd>The minimum change of loss to be considered as an improvement.</dd>
1422
<dt><strong><code>early_stopping_relative</code></strong> :&ensp;<code>bool</code>, optional</dt>
1423
<dd>Whether monitor the relative change of loss as stopping criteria or not.</dd>
1424
<dt><strong><code>path_to_weights</code></strong> :&ensp;<code>str</code>, optional</dt>
1425
<dd>The path of weight file to be saved; not saving weight if None.</dd>
1426
<dt><strong><code>conditions</code></strong> :&ensp;<code>str</code> or <code>list</code>, optional</dt>
1427
<dd>The conditions of different cells</dd>
1428
</dl></div>
1429
</dd>
1430
<dt id="VITAE.VITAE.update_z"><code class="name flex">
1431
<span>def <span class="ident">update_z</span></span>(<span>self)</span>
1432
</code></dt>
1433
<dd>
1434
<div class="desc"></div>
1435
</dd>
1436
<dt id="VITAE.VITAE.get_latent_z"><code class="name flex">
1437
<span>def <span class="ident">get_latent_z</span></span>(<span>self)</span>
1438
</code></dt>
1439
<dd>
1440
<div class="desc"><p>get the posterier mean of current latent space z (encoder output)</p>
1441
<h2 id="returns">Returns</h2>
1442
<dl>
1443
<dt><strong><code>z</code></strong> :&ensp;<code>np.array</code></dt>
1444
<dd><span><span class="MathJax_Preview">[N,d]</span><script type="math/tex">[N,d]</script></span> The latent means.</dd>
1445
</dl></div>
1446
</dd>
1447
<dt id="VITAE.VITAE.visualize_latent"><code class="name flex">
1448
<span>def <span class="ident">visualize_latent</span></span>(<span>self, method: str = 'UMAP', color=None, **kwargs)</span>
1449
</code></dt>
1450
<dd>
1451
<div class="desc"><p>visualize the current latent space z using the scanpy visualization tools</p>
1452
<h2 id="parameters">Parameters</h2>
1453
<dl>
1454
<dt><strong><code>method</code></strong> :&ensp;<code>str</code>, optional</dt>
1455
<dd>Visualization method to use. The default is "draw_graph" (the FA plot). Possible choices include "PCA", "UMAP",
1456
"diffmap", "TSNE" and "draw_graph"</dd>
1457
<dt><strong><code>color</code></strong> :&ensp;<code>TYPE</code>, optional</dt>
1458
<dd>Keys for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2'].
1459
The default is None. Same as scanpy.</dd>
1460
<dt><strong><code>**kwargs</code></strong> :&ensp;<code> </code></dt>
1461
<dd>Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX).</dd>
1462
</dl>
1463
<h2 id="returns">Returns</h2>
1464
<p>None.</p></div>
1465
</dd>
1466
<dt id="VITAE.VITAE.init_latent_space"><code class="name flex">
1467
<span>def <span class="ident">init_latent_space</span></span>(<span>self, cluster_label=None, log_pi=None, res: float = 1.0, ratio_prune=None, dist=None, dist_thres=0.5, topk=0, pilayer=False)</span>
1468
</code></dt>
1469
<dd>
1470
<div class="desc"><p>Initialize the latent space.</p>
1471
<h2 id="parameters">Parameters</h2>
1472
<dl>
1473
<dt><strong><code>cluster_label</code></strong> :&ensp;<code>str</code>, optional</dt>
1474
<dd>The name of vector of labels that can be found in self.adata.obs.
1475
Default is None, which will perform leiden clustering on the pretrained z to get clusters</dd>
1476
<dt><strong><code>mu</code></strong> :&ensp;<code>np.array</code>, optional</dt>
1477
<dd><span><span class="MathJax_Preview">[d,k]</span><script type="math/tex">[d,k]</script></span> The value of initial <span><span class="MathJax_Preview">\mu</span><script type="math/tex">\mu</script></span>.</dd>
1478
<dt><strong><code>log_pi</code></strong> :&ensp;<code>np.array</code>, optional</dt>
1479
<dd><span><span class="MathJax_Preview">[1,K]</span><script type="math/tex">[1,K]</script></span> The value of initial <span><span class="MathJax_Preview">\log(\pi)</span><script type="math/tex">\log(\pi)</script></span>.</dd>
1480
<dt><strong><code>res</code></strong></dt>
1481
<dd>The resolution of leiden clustering, which is a parameter value controlling the coarseness of the clustering.
1482
Higher values lead to more clusters. Deafult is 1.</dd>
1483
<dt><strong><code>ratio_prune</code></strong> :&ensp;<code>float</code>, optional</dt>
1484
<dd>The ratio of edges to be removed before estimating.</dd>
1485
<dt><strong><code>topk</code></strong> :&ensp;<code>int</code>, optional</dt>
1486
<dd>The number of top k neighbors to keep for each cluster.</dd>
1487
</dl></div>
1488
</dd>
1489
<dt id="VITAE.VITAE.update_latent_space"><code class="name flex">
1490
<span>def <span class="ident">update_latent_space</span></span>(<span>self, dist_thres: float = 0.5)</span>
1491
</code></dt>
1492
<dd>
1493
<div class="desc"></div>
1494
</dd>
1495
<dt id="VITAE.VITAE.train"><code class="name flex">
1496
<span>def <span class="ident">train</span></span>(<span>self, stratify=False, test_size=0.1, random_state: int = 0, learning_rate: float = 0.001, batch_size: int = 256, L: int = 1, alpha: float = 0.1, beta: float = 1, gamma: float = 0, phi: float = 1, num_epoch: int = 200, num_step_per_epoch: Optional[int] = None, early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, early_stopping_relative: bool = True, early_stopping_warmup: int = 0, path_to_weights: Optional[str] = None, verbose: bool = False, **kwargs)</span>
1497
</code></dt>
1498
<dd>
1499
<div class="desc"><p>Train the model.</p>
1500
<h2 id="parameters">Parameters</h2>
1501
<dl>
1502
<dt><strong><code>stratify</code></strong> :&ensp;<code>np.array, None,</code> or <code>False</code></dt>
1503
<dd>If an array is provided, or <code>stratify=None</code> and <code>self.labels</code> is available, then they will be used to perform stratified shuffle splitting. Otherwise, general shuffle splitting is used. Set to <code>False</code> if <code>self.labels</code> is not intended for stratified shuffle splitting.</dd>
1504
<dt><strong><code>test_size</code></strong> :&ensp;<code>float</code> or <code>int</code>, optional</dt>
1505
<dd>The proportion or size of the test set.</dd>
1506
<dt><strong><code>random_state</code></strong> :&ensp;<code>int</code>, optional</dt>
1507
<dd>The random state for data splitting.</dd>
1508
<dt><strong><code>learning_rate</code></strong> :&ensp;<code>float</code>, optional</dt>
1509
<dd>The initial learning rate for the Adam optimizer.</dd>
1510
<dt><strong><code>batch_size</code></strong> :&ensp;<code>int</code>, optional</dt>
1511
<dd>The batch size for training. Default is 256. Set to 32 if number of cells is small (less than 1000)</dd>
1512
<dt><strong><code>L</code></strong> :&ensp;<code>int</code>, optional</dt>
1513
<dd>The number of MC samples.</dd>
1514
<dt><strong><code>alpha</code></strong> :&ensp;<code>float</code>, optional</dt>
1515
<dd>The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.</dd>
1516
<dt><strong><code>beta</code></strong> :&ensp;<code>float</code>, optional</dt>
1517
<dd>The value of beta in beta-VAE.</dd>
1518
<dt><strong><code>gamma</code></strong> :&ensp;<code>float</code>, optional</dt>
1519
<dd>The weight of mmd_loss.</dd>
1520
<dt><strong><code>phi</code></strong> :&ensp;<code>float</code>, optional</dt>
1521
<dd>The weight of Jacob norm of encoder.</dd>
1522
<dt><strong><code>num_epoch</code></strong> :&ensp;<code>int</code>, optional</dt>
1523
<dd>The number of epoch.</dd>
1524
<dt><strong><code>num_step_per_epoch</code></strong> :&ensp;<code>int</code>, optional</dt>
1525
<dd>The number of step per epoch, it will be inferred from number of cells and batch size if it is None.</dd>
1526
<dt><strong><code>early_stopping_patience</code></strong> :&ensp;<code>int</code>, optional</dt>
1527
<dd>The maximum number of epochs if there is no improvement.</dd>
1528
<dt><strong><code>early_stopping_tolerance</code></strong> :&ensp;<code>float</code>, optional</dt>
1529
<dd>The minimum change of loss to be considered as an improvement.</dd>
1530
<dt><strong><code>early_stopping_relative</code></strong> :&ensp;<code>bool</code>, optional</dt>
1531
<dd>Whether monitor the relative change of loss or not.</dd>
1532
<dt><strong><code>early_stopping_warmup</code></strong> :&ensp;<code>int</code>, optional</dt>
1533
<dd>The number of warmup epochs.</dd>
1534
<dt><strong><code>path_to_weights</code></strong> :&ensp;<code>str</code>, optional</dt>
1535
<dd>The path of weight file to be saved; not saving weight if None.</dd>
1536
<dt><strong><code>**kwargs</code></strong> :&ensp;<code> </code></dt>
1537
<dd>Extra key-value arguments for dimension reduction algorithms.</dd>
1538
</dl></div>
1539
</dd>
1540
<dt id="VITAE.VITAE.output_pi"><code class="name flex">
1541
<span>def <span class="ident">output_pi</span></span>(<span>self, pi_cov)</span>
1542
</code></dt>
1543
<dd>
1544
<div class="desc"><p>return a matrix n_states by n_states and a mask for plotting, which can be used to cover the lower triangular(except the diagnoals) of a heatmap</p></div>
1545
</dd>
1546
<dt id="VITAE.VITAE.return_pilayer_weights"><code class="name flex">
1547
<span>def <span class="ident">return_pilayer_weights</span></span>(<span>self)</span>
1548
</code></dt>
1549
<dd>
1550
<div class="desc"><p>return parameters of pilayer, which has dimension dim(pi_cov) + 1 by n_categories, the last row is biases</p></div>
1551
</dd>
1552
<dt id="VITAE.VITAE.posterior_estimation"><code class="name flex">
1553
<span>def <span class="ident">posterior_estimation</span></span>(<span>self, batch_size: int = 32, L: int = 50, **kwargs)</span>
1554
</code></dt>
1555
<dd>
1556
<div class="desc"><p>Initialize trajectory inference by computing the posterior estimations.
1557
</p>
1558
<h2 id="parameters">Parameters</h2>
1559
<dl>
1560
<dt><strong><code>batch_size</code></strong> :&ensp;<code>int</code>, optional</dt>
1561
<dd>The batch size when doing inference.</dd>
1562
<dt><strong><code>L</code></strong> :&ensp;<code>int</code>, optional</dt>
1563
<dd>The number of MC samples when doing inference.</dd>
1564
<dt><strong><code>**kwargs</code></strong> :&ensp;<code> </code></dt>
1565
<dd>Extra key-value arguments for dimension reduction algorithms.</dd>
1566
</dl></div>
1567
</dd>
1568
<dt id="VITAE.VITAE.infer_backbone"><code class="name flex">
1569
<span>def <span class="ident">infer_backbone</span></span>(<span>self, method: str = 'modified_map', thres=0.5, no_loop: bool = True, cutoff: float = 0, visualize: bool = True, color='vitae_new_clustering', path_to_fig=None, **kwargs)</span>
1570
</code></dt>
1571
<dd>
1572
<div class="desc"><p>Compute edge scores.</p>
1573
<h2 id="parameters">Parameters</h2>
1574
<dl>
1575
<dt><strong><code>method</code></strong> :&ensp;<code>string</code>, optional</dt>
1576
<dd>'mean', 'modified_mean', 'map', or 'modified_map'.</dd>
1577
<dt><strong><code>thres</code></strong> :&ensp;<code>float</code>, optional</dt>
1578
<dd>The threshold used for filtering edges <span><span class="MathJax_Preview">e_{ij}</span><script type="math/tex">e_{ij}</script></span> that <span><span class="MathJax_Preview">(n_{i}+n_{j}+e_{ij})/N&lt;thres</span><script type="math/tex">(n_{i}+n_{j}+e_{ij})/N<thres</script></span>, only applied to mean method.</dd>
1579
<dt><strong><code>no_loop</code></strong> :&ensp;<code>boolean</code>, optional</dt>
1580
<dd>Whether loops are allowed to exist in the graph. If no_loop is true, will prune the graph to contain only the
1581
maximum spanning true</dd>
1582
<dt><strong><code>cutoff</code></strong> :&ensp;<code>string</code>, optional</dt>
1583
<dd>The score threshold for filtering edges with scores less than cutoff.</dd>
1584
<dt><strong><code>visualize</code></strong> :&ensp;<code>boolean</code></dt>
1585
<dd>whether plot the current trajectory backbone (undirected graph)</dd>
1586
</dl>
1587
<h2 id="returns">Returns</h2>
1588
<dl>
1589
<dt><strong><code>G</code></strong> :&ensp;<code>nx.Graph</code></dt>
1590
<dd>The weighted graph with weight on each edge indicating its score of existence.</dd>
1591
</dl></div>
1592
</dd>
1593
<dt id="VITAE.VITAE.select_root"><code class="name flex">
1594
<span>def <span class="ident">select_root</span></span>(<span>self, days, method: str = 'proportion')</span>
1595
</code></dt>
1596
<dd>
1597
<div class="desc"><p>Order the vertices/states based on cells' collection time information to select the root state.
1598
</p>
1599
<h2 id="parameters">Parameters</h2>
1600
<dl>
1601
<dt><strong><code>day</code></strong> :&ensp;<code>np.array </code></dt>
1602
<dd>The day information for selected cells used to determine the root vertex.
1603
The dtype should be 'int' or 'float'.</dd>
1604
<dt><strong><code>method</code></strong> :&ensp;<code>str</code>, optional</dt>
1605
<dd>'sum' or 'mean'.
1606
For 'proportion', the root is the one with maximal proportion of cells from the earliest day.
1607
For 'mean', the root is the one with earliest mean time among cells associated with it.</dd>
1608
</dl>
1609
<h2 id="returns">Returns</h2>
1610
<dl>
1611
<dt><strong><code>root</code></strong> :&ensp;<code>int </code></dt>
1612
<dd>The root vertex in the inferred trajectory based on given day information.</dd>
1613
</dl></div>
1614
</dd>
1615
<dt id="VITAE.VITAE.plot_backbone"><code class="name flex">
1616
<span>def <span class="ident">plot_backbone</span></span>(<span>self, directed: bool = False, method: str = 'UMAP', color='vitae_new_clustering', **kwargs)</span>
1617
</code></dt>
1618
<dd>
1619
<div class="desc"><p>Plot the current trajectory backbone (undirected graph).</p>
1620
<h2 id="parameters">Parameters</h2>
1621
<dl>
1622
<dt><strong><code>directed</code></strong> :&ensp;<code>boolean</code>, optional</dt>
1623
<dd>Whether the backbone is directed or not.</dd>
1624
<dt><strong><code>method</code></strong> :&ensp;<code>str</code>, optional</dt>
1625
<dd>The dimension reduction method to use. The default is "UMAP".</dd>
1626
<dt><strong><code>color</code></strong> :&ensp;<code>str</code>, optional</dt>
1627
<dd>The key for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2'].
1628
The default is 'vitae_new_clustering'.</dd>
1629
</dl>
1630
<p>**kwargs :
1631
Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX).</p></div>
1632
</dd>
1633
<dt id="VITAE.VITAE.plot_center"><code class="name flex">
1634
<span>def <span class="ident">plot_center</span></span>(<span>self, color='vitae_new_clustering', plot_legend=True, legend_add_index=True, method: str = 'UMAP', ncol=2, font_size='medium', add_egde=False, add_direct=False, **kwargs)</span>
1635
</code></dt>
1636
<dd>
1637
<div class="desc"><p>Plot the center of each cluster in the latent space.</p>
1638
<h2 id="parameters">Parameters</h2>
1639
<dl>
1640
<dt><strong><code>color</code></strong> :&ensp;<code>str</code>, optional</dt>
1641
<dd>The color of the center of each cluster. Default is "vitae_new_clustering".</dd>
1642
<dt><strong><code>plot_legend</code></strong> :&ensp;<code>bool</code>, optional</dt>
1643
<dd>Whether to plot the legend. Default is True.</dd>
1644
<dt><strong><code>legend_add_index</code></strong> :&ensp;<code>bool</code>, optional</dt>
1645
<dd>Whether to add the index of each cluster in the legend. Default is True.</dd>
1646
<dt><strong><code>method</code></strong> :&ensp;<code>str</code>, optional</dt>
1647
<dd>The dimension reduction method used for visualization. Default is 'UMAP'.</dd>
1648
<dt><strong><code>ncol</code></strong> :&ensp;<code>int</code>, optional</dt>
1649
<dd>The number of columns in the legend. Default is 2.</dd>
1650
<dt><strong><code>font_size</code></strong> :&ensp;<code>str</code>, optional</dt>
1651
<dd>The font size of the legend. Default is "medium".</dd>
1652
<dt><strong><code>add_egde</code></strong> :&ensp;<code>bool</code>, optional</dt>
1653
<dd>Whether to add the edges between the centers of clusters. Default is False.</dd>
1654
<dt><strong><code>add_direct</code></strong> :&ensp;<code>bool</code>, optional</dt>
1655
<dd>Whether to add the direction of the edges. Default is False.</dd>
1656
</dl></div>
1657
</dd>
1658
<dt id="VITAE.VITAE.infer_trajectory"><code class="name flex">
1659
<span>def <span class="ident">infer_trajectory</span></span>(<span>self, root: Union[int, str], digraph=None, color='pseudotime', visualize: bool = True, path_to_fig=None, **kwargs)</span>
1660
</code></dt>
1661
<dd>
1662
<div class="desc"><p>Infer the trajectory.</p>
1663
<h2 id="parameters">Parameters</h2>
1664
<dl>
1665
<dt><strong><code>root</code></strong> :&ensp;<code>int</code> or <code>string</code></dt>
1666
<dd>The root of the inferred trajectory. Can provide either an int (vertex index) or string (label name)</dd>
1667
<dt><strong><code>digraph</code></strong> :&ensp;<code>nx.DiGraph</code>, optional</dt>
1668
<dd>The directed graph to be used for trajectory inference. If None, the minimum spanning tree of the estimated trajectory backbone will be used.</dd>
1669
<dt><strong><code>cutoff</code></strong> :&ensp;<code>string</code>, optional</dt>
1670
<dd>The threshold for filtering edges with scores less than cutoff.</dd>
1671
<dt><strong><code>visualize</code></strong> :&ensp;<code>boolean</code></dt>
1672
<dd>Whether plot the current trajectory backbone (directed graph)</dd>
1673
<dt><strong><code>path_to_fig</code></strong> :&ensp;<code>string</code>, optional</dt>
1674
<dd>The path to save figure, or don't save if it is None.</dd>
1675
<dt><strong><code>**kwargs</code></strong> :&ensp;<code>dict</code>, optional</dt>
1676
<dd>Other keywords arguments for plotting.</dd>
1677
</dl></div>
1678
</dd>
1679
<dt id="VITAE.VITAE.differential_expression_test"><code class="name flex">
1680
<span>def <span class="ident">differential_expression_test</span></span>(<span>self, alpha: float = 0.05, cell_subset=None, order: int = 1)</span>
1681
</code></dt>
1682
<dd>
1683
<div class="desc"><p>Differentially gene expression test. All (selected and unselected) genes will be tested
1684
Only cells in <code>selected_cell_subset</code> will be used, which is useful when one need to
1685
test differentially expressed genes on a branch of the inferred trajectory.</p>
1686
<h2 id="parameters">Parameters</h2>
1687
<dl>
1688
<dt><strong><code>alpha</code></strong> :&ensp;<code>float</code>, optional</dt>
1689
<dd>The cutoff of p-values.</dd>
1690
<dt><strong><code>cell_subset</code></strong> :&ensp;<code>np.array</code>, optional</dt>
1691
<dd>The subset of cells to be used for testing. If None, all cells will be used.</dd>
1692
<dt><strong><code>order</code></strong> :&ensp;<code>int</code>, optional</dt>
1693
<dd>The maxium order we used for pseudotime in regression.</dd>
1694
</dl>
1695
<h2 id="returns">Returns</h2>
1696
<dl>
1697
<dt><strong><code>res_df</code></strong> :&ensp;<code>pandas.DataFrame</code></dt>
1698
<dd>The test results of expressed genes with two columns,
1699
the estimated coefficients and the adjusted p-values.</dd>
1700
</dl></div>
1701
</dd>
1702
<dt id="VITAE.VITAE.evaluate"><code class="name flex">
1703
<span>def <span class="ident">evaluate</span></span>(<span>self, milestone_net, begin_node_true, grouping=None, thres: float = 0.5, no_loop: bool = True, cutoff: Optional[float] = None, method: str = 'mean', path: Optional[str] = None)</span>
1704
</code></dt>
1705
<dd>
1706
<div class="desc"><p>Evaluate the model.</p>
1707
<h2 id="parameters">Parameters</h2>
1708
<dl>
1709
<dt><strong><code>milestone_net</code></strong> :&ensp;<code>pd.DataFrame</code></dt>
1710
<dd>
1711
<p>The true milestone network. For real data, milestone_net will be a DataFrame of the graph of nodes.
1712
Eg.</p>
1713
<table>
1714
<thead>
1715
<tr>
1716
<th>from</th>
1717
<th>to</th>
1718
</tr>
1719
</thead>
1720
<tbody>
1721
<tr>
1722
<td>cluster 1</td>
1723
<td>cluster 1</td>
1724
</tr>
1725
<tr>
1726
<td>cluster 1</td>
1727
<td>cluster 2</td>
1728
</tr>
1729
</tbody>
1730
</table>
1731
<p>For synthetic data, milestone_net will be a DataFrame of the (projected)
1732
positions of cells. The indexes are the orders of cells in the dataset.
1733
Eg.</p>
1734
<table>
1735
<thead>
1736
<tr>
1737
<th>from</th>
1738
<th>to</th>
1739
<th>w</th>
1740
</tr>
1741
</thead>
1742
<tbody>
1743
<tr>
1744
<td>cluster 1</td>
1745
<td>cluster 1</td>
1746
<td>1</td>
1747
</tr>
1748
<tr>
1749
<td>cluster 1</td>
1750
<td>cluster 2</td>
1751
<td>0.1</td>
1752
</tr>
1753
</tbody>
1754
</table>
1755
</dd>
1756
<dt><strong><code>begin_node_true</code></strong> :&ensp;<code>str</code> or <code>int</code></dt>
1757
<dd>The true begin node of the milestone.</dd>
1758
<dt><strong><code>grouping</code></strong> :&ensp;<code>np.array</code>, optional</dt>
1759
<dd><span><span class="MathJax_Preview">[N,]</span><script type="math/tex">[N,]</script></span> The labels. For real data, grouping must be provided.</dd>
1760
</dl>
1761
<h2 id="returns">Returns</h2>
1762
<dl>
1763
<dt><strong><code>res</code></strong> :&ensp;<code>pd.DataFrame</code></dt>
1764
<dd>The evaluation result.</dd>
1765
</dl></div>
1766
</dd>
1767
<dt id="VITAE.VITAE.save_model"><code class="name flex">
1768
<span>def <span class="ident">save_model</span></span>(<span>self, path_to_file: str = 'model.checkpoint', save_adata: bool = False)</span>
1769
</code></dt>
1770
<dd>
1771
<div class="desc"><p>Saving model weights.</p>
1772
<h2 id="parameters">Parameters</h2>
1773
<dl>
1774
<dt><strong><code>path_to_file</code></strong> :&ensp;<code>str</code>, optional</dt>
1775
<dd>The path to weight files of pre-trained or trained model</dd>
1776
<dt><strong><code>save_adata</code></strong> :&ensp;<code>boolean</code>, optional</dt>
1777
<dd>Whether to save adata or not.</dd>
1778
</dl></div>
1779
</dd>
1780
<dt id="VITAE.VITAE.load_model"><code class="name flex">
1781
<span>def <span class="ident">load_model</span></span>(<span>self, path_to_file: str = 'model.checkpoint', load_labels: bool = False, load_adata: bool = False)</span>
1782
</code></dt>
1783
<dd>
1784
<div class="desc"><p>Load model weights.</p>
1785
<h2 id="parameters">Parameters</h2>
1786
<dl>
1787
<dt><strong><code>path_to_file</code></strong> :&ensp;<code>str</code>, optional</dt>
1788
<dd>The path to weight files of pre trained or trained model</dd>
1789
<dt><strong><code>load_labels</code></strong> :&ensp;<code>boolean</code>, optional</dt>
1790
<dd>Whether to load clustering labels or not.
1791
If load_labels is True, then the LatentSpace layer will be initialized basd on the model.
1792
If load_labels is False, then the LatentSpace layer will not be initialized.</dd>
1793
<dt><strong><code>load_adata</code></strong> :&ensp;<code>boolean</code>, optional</dt>
1794
<dd>Whether to load adata or not.</dd>
1795
</dl></div>
1796
</dd>
1797
</dl>
1798
</dd>
1799
</dl>
1800
</section>
1801
</article>
1802
<nav id="sidebar">
1803
<div class="toc">
1804
<ul></ul>
1805
</div>
1806
<ul id="index">
1807
<li><h3><a href="#header-submodules">Sub-modules</a></h3>
1808
<ul>
1809
<li><code><a title="VITAE.inference" href="inference.html">VITAE.inference</a></code></li>
1810
<li><code><a title="VITAE.metric" href="metric.html">VITAE.metric</a></code></li>
1811
<li><code><a title="VITAE.model" href="model.html">VITAE.model</a></code></li>
1812
<li><code><a title="VITAE.train" href="train.html">VITAE.train</a></code></li>
1813
<li><code><a title="VITAE.utils" href="utils.html">VITAE.utils</a></code></li>
1814
</ul>
1815
</li>
1816
<li><h3><a href="#header-classes">Classes</a></h3>
1817
<ul>
1818
<li>
1819
<h4><code><a title="VITAE.VITAE" href="#VITAE.VITAE">VITAE</a></code></h4>
1820
<ul class="">
1821
<li><code><a title="VITAE.VITAE.pre_train" href="#VITAE.VITAE.pre_train">pre_train</a></code></li>
1822
<li><code><a title="VITAE.VITAE.update_z" href="#VITAE.VITAE.update_z">update_z</a></code></li>
1823
<li><code><a title="VITAE.VITAE.get_latent_z" href="#VITAE.VITAE.get_latent_z">get_latent_z</a></code></li>
1824
<li><code><a title="VITAE.VITAE.visualize_latent" href="#VITAE.VITAE.visualize_latent">visualize_latent</a></code></li>
1825
<li><code><a title="VITAE.VITAE.init_latent_space" href="#VITAE.VITAE.init_latent_space">init_latent_space</a></code></li>
1826
<li><code><a title="VITAE.VITAE.update_latent_space" href="#VITAE.VITAE.update_latent_space">update_latent_space</a></code></li>
1827
<li><code><a title="VITAE.VITAE.train" href="#VITAE.VITAE.train">train</a></code></li>
1828
<li><code><a title="VITAE.VITAE.output_pi" href="#VITAE.VITAE.output_pi">output_pi</a></code></li>
1829
<li><code><a title="VITAE.VITAE.return_pilayer_weights" href="#VITAE.VITAE.return_pilayer_weights">return_pilayer_weights</a></code></li>
1830
<li><code><a title="VITAE.VITAE.posterior_estimation" href="#VITAE.VITAE.posterior_estimation">posterior_estimation</a></code></li>
1831
<li><code><a title="VITAE.VITAE.infer_backbone" href="#VITAE.VITAE.infer_backbone">infer_backbone</a></code></li>
1832
<li><code><a title="VITAE.VITAE.select_root" href="#VITAE.VITAE.select_root">select_root</a></code></li>
1833
<li><code><a title="VITAE.VITAE.plot_backbone" href="#VITAE.VITAE.plot_backbone">plot_backbone</a></code></li>
1834
<li><code><a title="VITAE.VITAE.plot_center" href="#VITAE.VITAE.plot_center">plot_center</a></code></li>
1835
<li><code><a title="VITAE.VITAE.infer_trajectory" href="#VITAE.VITAE.infer_trajectory">infer_trajectory</a></code></li>
1836
<li><code><a title="VITAE.VITAE.differential_expression_test" href="#VITAE.VITAE.differential_expression_test">differential_expression_test</a></code></li>
1837
<li><code><a title="VITAE.VITAE.evaluate" href="#VITAE.VITAE.evaluate">evaluate</a></code></li>
1838
<li><code><a title="VITAE.VITAE.save_model" href="#VITAE.VITAE.save_model">save_model</a></code></li>
1839
<li><code><a title="VITAE.VITAE.load_model" href="#VITAE.VITAE.load_model">load_model</a></code></li>
1840
</ul>
1841
</li>
1842
</ul>
1843
</li>
1844
</ul>
1845
</nav>
1846
</main>
1847
<footer id="footer">
1848
<p>Generated by <a href="https://pdoc3.github.io/pdoc" title="pdoc: Python API documentation generator"><cite>pdoc</cite> 0.11.1</a>.</p>
1849
</footer>
1850
</body>
1851
</html>