Diff of /docs/index.html [000000] .. [2c6b19]

Switch to side-by-side view

--- a
+++ b/docs/index.html
@@ -0,0 +1,1851 @@
+<!doctype html>
+<html lang="en">
+<head>
+<meta charset="utf-8">
+<meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1">
+<meta name="generator" content="pdoc3 0.11.1">
+<title>VITAE API documentation</title>
+<meta name="description" content="">
+<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/sanitize.min.css" integrity="sha512-y1dtMcuvtTMJc1yPgEqF0ZjQbhnc/bFhyvIyVNb9Zk5mIGtqVaAB1Ttl28su8AvFMOY0EwRbAe+HCLqj6W7/KA==" crossorigin>
+<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/typography.min.css" integrity="sha512-Y1DYSb995BAfxobCkKepB1BqJJTPrOp3zPL74AWFugHHmmdcvO+C48WLrUOlhGMc0QG7AE3f7gmvvcrmX2fDoA==" crossorigin>
+<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/default.min.css" crossorigin>
+<style>:root{--highlight-color:#fe9}.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:1.5em;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:2em 0 .50em 0}h3{font-size:1.4em;margin:1.6em 0 .7em 0}h4{margin:0;font-size:105%}h1:target,h2:target,h3:target,h4:target,h5:target,h6:target{background:var(--highlight-color);padding:.2em 0}a{color:#058;text-decoration:none;transition:color .2s ease-in-out}a:visited{color:#503}a:hover{color:#b62}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900;font-weight:bold}pre code{font-size:.8em;line-height:1.4em;padding:1em;display:block}code{background:#f3f3f3;font-family:"DejaVu Sans Mono",monospace;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}dt:target .name{background:var(--highlight-color)}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}td{padding:0 .5em}.admonition{padding:.1em 1em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style>
+<style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul ul{padding-left:1em}.toc > ul > li{margin-top:.5em}}</style>
+<style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style>
+<script type="text/x-mathjax-config">MathJax.Hub.Config({ tex2jax: { inlineMath: [ ['$','$'], ["\\(","\\)"] ], processEscapes: true } });</script>
+<script async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS_CHTML" integrity="sha256-kZafAc6mZvK3W3v1pHOcUix30OHQN6pU/NO2oFkqZVw=" crossorigin></script>
+<script defer src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js" integrity="sha512-D9gUyxqja7hBtkWpPWGt9wfbfaMGVt9gnyCvYa+jojwwPHLCzUm5i8rpk7vD7wNee9bA35eYIjobYPaQuKS1MQ==" crossorigin></script>
+<script>window.addEventListener('DOMContentLoaded', () => {
+hljs.configure({languages: ['bash', 'css', 'diff', 'graphql', 'ini', 'javascript', 'json', 'plaintext', 'python', 'python-repl', 'rust', 'shell', 'sql', 'typescript', 'xml', 'yaml']});
+hljs.highlightAll();
+})</script>
+</head>
+<body>
+<main>
+<article id="content">
+<header>
+<h1 class="title">Package <code>VITAE</code></h1>
+</header>
+<section id="section-intro">
+</section>
+<section>
+<h2 class="section-title" id="header-submodules">Sub-modules</h2>
+<dl>
+<dt><code class="name"><a title="VITAE.inference" href="inference.html">VITAE.inference</a></code></dt>
+<dd>
+<div class="desc"></div>
+</dd>
+<dt><code class="name"><a title="VITAE.metric" href="metric.html">VITAE.metric</a></code></dt>
+<dd>
+<div class="desc"></div>
+</dd>
+<dt><code class="name"><a title="VITAE.model" href="model.html">VITAE.model</a></code></dt>
+<dd>
+<div class="desc"></div>
+</dd>
+<dt><code class="name"><a title="VITAE.train" href="train.html">VITAE.train</a></code></dt>
+<dd>
+<div class="desc"></div>
+</dd>
+<dt><code class="name"><a title="VITAE.utils" href="utils.html">VITAE.utils</a></code></dt>
+<dd>
+<div class="desc"></div>
+</dd>
+</dl>
+</section>
+<section>
+</section>
+<section>
+</section>
+<section>
+<h2 class="section-title" id="header-classes">Classes</h2>
+<dl>
+<dt id="VITAE.VITAE"><code class="flex name class">
+<span>class <span class="ident">VITAE</span></span>
+<span>(</span><span>adata: anndata._core.anndata.AnnData, covariates=None, pi_covariates=None, model_type: str = 'Gaussian', npc: int = 64, adata_layer_counts=None, copy_adata: bool = False, hidden_layers=[32], latent_space_dim: int = 16, conditions=None)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Variational Inference for Trajectory by AutoEncoder.</p>
+<p>Get input data for model. Data need to be first processed using scancy and stored as an AnnData object
+The 'UMI' or 'non-UMI' model need the original count matrix, so the count matrix need to be saved in
+adata.layers in order to use these models.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>adata</code></strong> :&ensp;<code>sc.AnnData</code></dt>
+<dd>The scanpy AnnData object. adata should already contain adata.var.highly_variable</dd>
+<dt><strong><code>covariates</code></strong> :&ensp;<code>list</code>, optional</dt>
+<dd>A list of names of covariate vectors that are stored in adata.obs</dd>
+<dt><strong><code>pi_covariates</code></strong> :&ensp;<code>list</code>, optional</dt>
+<dd>A list of names of covariate vectors used as input for pilayer</dd>
+<dt><strong><code>model_type</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd>'UMI', 'non-UMI' and 'Gaussian', default is 'Gaussian'.</dd>
+<dt><strong><code>npc</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The number of PCs to use when model_type is 'Gaussian'. The default is 64.</dd>
+<dt><strong><code>adata_layer_counts</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd>the key name of adata.layers that stores the count data if model_type is
+'UMI' or 'non-UMI'</dd>
+<dt><strong><code>copy_adata</code></strong> :&ensp;<code>bool</code>, optional<code>. Set to True if we don't want VITAE to modify the original adata. If set to True, self.adata will be an independent copy</code> of <code>the original adata. </code></dt>
+<dd>&nbsp;</dd>
+<dt><strong><code>hidden_layers</code></strong> :&ensp;<code>list</code>, optional</dt>
+<dd>The list of dimensions of layers of autoencoder between latent space and original space. Default is to have only one hidden layer with 32 nodes</dd>
+<dt><strong><code>latent_space_dim</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The dimension of latent space.</dd>
+<dt><strong><code>gamme</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The weight of the MMD loss</dd>
+<dt><strong><code>conditions</code></strong> :&ensp;<code>str</code> or <code>list</code>, optional</dt>
+<dd>The conditions of different cells</dd>
+</dl>
+<h2 id="returns">Returns</h2>
+<p>None.</p></div>
+<details class="source">
+<summary>
+<span>Expand source code</span>
+</summary>
+<pre><code class="python">class VITAE():
+    &#34;&#34;&#34;
+    Variational Inference for Trajectory by AutoEncoder.
+    &#34;&#34;&#34;
+    def __init__(self, adata: sc.AnnData,
+               covariates = None, pi_covariates = None,
+               model_type: str = &#39;Gaussian&#39;,
+               npc: int = 64,
+               adata_layer_counts = None,
+               copy_adata: bool = False,
+               hidden_layers = [32],
+               latent_space_dim: int = 16,
+               conditions = None):
+        &#39;&#39;&#39;
+        Get input data for model. Data need to be first processed using scancy and stored as an AnnData object
+         The &#39;UMI&#39; or &#39;non-UMI&#39; model need the original count matrix, so the count matrix need to be saved in
+         adata.layers in order to use these models.
+
+
+        Parameters
+        ----------
+        adata : sc.AnnData
+            The scanpy AnnData object. adata should already contain adata.var.highly_variable
+        covariates : list, optional
+            A list of names of covariate vectors that are stored in adata.obs
+        pi_covariates: list, optional
+            A list of names of covariate vectors used as input for pilayer
+        model_type : str, optional
+            &#39;UMI&#39;, &#39;non-UMI&#39; and &#39;Gaussian&#39;, default is &#39;Gaussian&#39;.
+        npc : int, optional
+            The number of PCs to use when model_type is &#39;Gaussian&#39;. The default is 64.
+        adata_layer_counts: str, optional
+            the key name of adata.layers that stores the count data if model_type is
+            &#39;UMI&#39; or &#39;non-UMI&#39;
+        copy_adata: bool, optional. Set to True if we don&#39;t want VITAE to modify the original adata. If set to True, self.adata will be an independent copy of the original adata. 
+        hidden_layers : list, optional
+            The list of dimensions of layers of autoencoder between latent space and original space. Default is to have only one hidden layer with 32 nodes
+        latent_space_dim : int, optional
+            The dimension of latent space.
+        gamme : float, optional
+            The weight of the MMD loss
+        conditions : str or list, optional
+            The conditions of different cells
+
+
+        Returns
+        -------
+        None.
+
+        &#39;&#39;&#39;
+        self.dict_method_scname = {
+            &#39;PCA&#39; : &#39;X_pca&#39;,
+            &#39;UMAP&#39; : &#39;X_umap&#39;,
+            &#39;TSNE&#39; : &#39;X_tsne&#39;,
+            &#39;diffmap&#39; : &#39;X_diffmap&#39;,
+            &#39;draw_graph&#39; : &#39;X_draw_graph_fa&#39;
+        }
+
+        if model_type != &#39;Gaussian&#39;:
+            if adata_layer_counts is None:
+                raise ValueError(&#34;need to provide the name in adata.layers that stores the raw count data&#34;)
+            if &#39;highly_variable&#39; not in adata.var:
+                raise ValueError(&#34;need to first select highly variable genes using scanpy&#34;)
+
+        self.model_type = model_type
+
+        if copy_adata:
+            self.adata = adata.copy()
+        else:
+            self.adata = adata
+
+        if covariates is not None:
+            if isinstance(covariates, str):
+                covariates = [covariates]
+            covariates = np.array(covariates)
+            id_cat = (adata.obs[covariates].dtypes == &#39;category&#39;)
+            # add OneHotEncoder &amp; StandardScaler as class variable if needed
+            if np.sum(id_cat)&gt;0:
+                covariates_cat = OneHotEncoder(drop=&#39;if_binary&#39;, handle_unknown=&#39;ignore&#39;
+                    ).fit_transform(adata.obs[covariates[id_cat]]).toarray()
+            else:
+                covariates_cat = np.array([]).reshape(adata.shape[0],0)
+
+            # temporarily disable StandardScaler
+            if np.sum(~id_cat)&gt;0:
+                #covariates_con = StandardScaler().fit_transform(adata.obs[covariates[~id_cat]])
+                covariates_con = adata.obs[covariates[~id_cat]]
+            else:
+                covariates_con = np.array([]).reshape(adata.shape[0],0)
+
+            self.covariates = np.c_[covariates_cat, covariates_con].astype(tf.keras.backend.floatx())
+        else:
+            self.covariates = None
+
+        if conditions is not None:
+            ## observations with np.nan will not participant in calculating mmd_loss
+            if isinstance(conditions, str):
+                conditions = [conditions]
+            conditions = np.array(conditions)
+            if np.any(adata.obs[conditions].dtypes != &#39;category&#39;):
+                raise ValueError(&#34;Conditions should all be categorical.&#34;)
+
+            self.conditions = OrdinalEncoder(dtype=int, encoded_missing_value=-1).fit_transform(adata.obs[conditions]) + int(1)
+        else:
+            self.conditions = None
+
+        if pi_covariates is not None:
+            self.pi_cov = adata.obs[pi_covariates].to_numpy()
+            if self.pi_cov.ndim == 1:
+                self.pi_cov = self.pi_cov.reshape(-1, 1)
+                self.pi_cov = self.pi_cov.astype(tf.keras.backend.floatx())
+        else:
+            self.pi_cov = np.zeros((adata.shape[0],1), dtype=tf.keras.backend.floatx())
+            
+        self.model_type = model_type
+        self._adata = sc.AnnData(X = self.adata.X, var = self.adata.var)
+        self._adata.obs = self.adata.obs
+        self._adata.uns = self.adata.uns
+
+
+        if model_type == &#39;Gaussian&#39;:
+            sc.tl.pca(adata, n_comps = npc)
+            self.X_input = self.X_output = adata.obsm[&#39;X_pca&#39;]
+            self.scale_factor = np.ones(self.X_output.shape[0])
+        else:
+            print(f&#34;{adata.var.highly_variable.sum()} highly variable genes selected as input&#34;) 
+            self.X_input = adata.X[:, adata.var.highly_variable]
+            self.X_output = adata.layers[adata_layer_counts][ :, adata.var.highly_variable]
+            self.scale_factor = np.sum(self.X_output, axis=1, keepdims=True)/1e4
+
+        self.dimensions = hidden_layers
+        self.dim_latent = latent_space_dim
+
+        self.vae = model.VariationalAutoEncoder(
+            self.X_output.shape[1], self.dimensions,
+            self.dim_latent, self.model_type,
+            False if self.covariates is None else True,
+            )
+
+        if hasattr(self, &#39;inferer&#39;):
+            delattr(self, &#39;inferer&#39;)
+        
+
+    def pre_train(self, test_size = 0.1, random_state: int = 0,
+            learning_rate: float = 1e-3, batch_size: int = 256, L: int = 1, alpha: float = 0.10, gamma: float = 0,
+            phi : float = 1,num_epoch: int = 200, num_step_per_epoch: Optional[int] = None,
+            early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, 
+            early_stopping_relative: bool = True, verbose: bool = False,path_to_weights: Optional[str] = None):
+        &#39;&#39;&#39;Pretrain the model with specified learning rate.
+
+        Parameters
+        ----------
+        test_size : float or int, optional
+            The proportion or size of the test set.
+        random_state : int, optional
+            The random state for data splitting.
+        learning_rate : float, optional
+            The initial learning rate for the Adam optimizer.
+        batch_size : int, optional 
+            The batch size for pre-training.  Default is 256. Set to 32 if number of cells is small (less than 1000)
+        L : int, optional 
+            The number of MC samples.
+        alpha : float, optional
+            The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.
+        gamma : float, optional
+            The weight of the mmd loss if used.
+        phi : float, optional
+            The weight of Jocob norm of the encoder.
+        num_epoch : int, optional 
+            The maximum number of epochs.
+        num_step_per_epoch : int, optional 
+            The number of step per epoch, it will be inferred from number of cells and batch size if it is None.            
+        early_stopping_patience : int, optional 
+            The maximum number of epochs if there is no improvement.
+        early_stopping_tolerance : float, optional 
+            The minimum change of loss to be considered as an improvement.
+        early_stopping_relative : bool, optional
+            Whether monitor the relative change of loss as stopping criteria or not.
+        path_to_weights : str, optional 
+            The path of weight file to be saved; not saving weight if None.
+        conditions : str or list, optional
+            The conditions of different cells
+        &#39;&#39;&#39;
+
+        id_train, id_test = train_test_split(
+                                np.arange(self.X_input.shape[0]), 
+                                test_size=test_size, 
+                                random_state=random_state)
+        if num_step_per_epoch is None:
+            num_step_per_epoch = len(id_train)//batch_size+1
+        self.train_dataset = train.warp_dataset(self.X_input[id_train].astype(tf.keras.backend.floatx()), 
+                                                None if self.covariates is None else self.covariates[id_train].astype(tf.keras.backend.floatx()),
+                                                batch_size, 
+                                                self.X_output[id_train].astype(tf.keras.backend.floatx()), 
+                                                self.scale_factor[id_train].astype(tf.keras.backend.floatx()),
+                                                conditions = None if self.conditions is None else self.conditions[id_train].astype(tf.keras.backend.floatx()))
+        self.test_dataset = train.warp_dataset(self.X_input[id_test], 
+                                                None if self.covariates is None else self.covariates[id_test].astype(tf.keras.backend.floatx()),
+                                                batch_size, 
+                                                self.X_output[id_test].astype(tf.keras.backend.floatx()), 
+                                                self.scale_factor[id_test].astype(tf.keras.backend.floatx()),
+                                                conditions = None if self.conditions is None else self.conditions[id_test].astype(tf.keras.backend.floatx()))
+
+        self.vae = train.pre_train(
+            self.train_dataset,
+            self.test_dataset,
+            self.vae,
+            learning_rate,                        
+            L, alpha, gamma, phi,
+            num_epoch,
+            num_step_per_epoch,
+            early_stopping_patience,
+            early_stopping_tolerance,
+            early_stopping_relative,
+            verbose)
+        
+        self.update_z()
+
+        if path_to_weights is not None:
+            self.save_model(path_to_weights)
+            
+
+    def update_z(self):
+        self.z = self.get_latent_z()        
+        self._adata_z = sc.AnnData(self.z)
+        sc.pp.neighbors(self._adata_z)
+
+            
+    def get_latent_z(self):
+        &#39;&#39;&#39; get the posterier mean of current latent space z (encoder output)
+
+        Returns
+        ----------
+        z : np.array
+            \([N,d]\) The latent means.
+        &#39;&#39;&#39; 
+        c = None if self.covariates is None else self.covariates
+        return self.vae.get_z(self.X_input, c)
+            
+    
+    def visualize_latent(self, method: str = &#34;UMAP&#34;, 
+                         color = None, **kwargs):
+        &#39;&#39;&#39;
+        visualize the current latent space z using the scanpy visualization tools
+
+        Parameters
+        ----------
+        method : str, optional
+            Visualization method to use. The default is &#34;draw_graph&#34; (the FA plot). Possible choices include &#34;PCA&#34;, &#34;UMAP&#34;, 
+            &#34;diffmap&#34;, &#34;TSNE&#34; and &#34;draw_graph&#34;
+        color : TYPE, optional
+            Keys for annotations of observations/cells or variables/genes, e.g., &#39;ann1&#39; or [&#39;ann1&#39;, &#39;ann2&#39;].
+            The default is None. Same as scanpy.
+        **kwargs :  
+            Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX).   
+
+        Returns
+        -------
+        None.
+
+        &#39;&#39;&#39;
+          
+        if method not in [&#39;PCA&#39;, &#39;UMAP&#39;, &#39;TSNE&#39;, &#39;diffmap&#39;, &#39;draw_graph&#39;]:
+            raise ValueError(&#34;visualization method should be one of &#39;PCA&#39;, &#39;UMAP&#39;, &#39;TSNE&#39;, &#39;diffmap&#39; and &#39;draw_graph&#39;&#34;)
+        
+        temp = list(self._adata_z.obsm.keys())
+        if method == &#39;PCA&#39; and not &#39;X_pca&#39; in temp:
+            print(&#34;Calculate PCs ...&#34;)
+            sc.tl.pca(self._adata_z)
+        elif method == &#39;UMAP&#39; and not &#39;X_umap&#39; in temp:  
+            print(&#34;Calculate UMAP ...&#34;)
+            sc.tl.umap(self._adata_z)
+        elif method == &#39;TSNE&#39; and not &#39;X_tsne&#39; in temp:
+            print(&#34;Calculate TSNE ...&#34;)
+            sc.tl.tsne(self._adata_z)
+        elif method == &#39;diffmap&#39; and not &#39;X_diffmap&#39; in temp:
+            print(&#34;Calculate diffusion map ...&#34;)
+            sc.tl.diffmap(self._adata_z)
+        elif method == &#39;draw_graph&#39; and not &#39;X_draw_graph_fa&#39; in temp:
+            print(&#34;Calculate FA ...&#34;)
+            sc.tl.draw_graph(self._adata_z)
+            
+
+        self._adata.obs = self.adata.obs.copy()
+        self._adata.obsp = self._adata_z.obsp
+#        self._adata.uns = self._adata_z.uns
+        self._adata.obsm = self._adata_z.obsm
+    
+        if method == &#39;PCA&#39;:
+            axes = sc.pl.pca(self._adata, color = color, **kwargs)
+        elif method == &#39;UMAP&#39;:            
+            axes = sc.pl.umap(self._adata, color = color, **kwargs)
+        elif method == &#39;TSNE&#39;:
+            axes = sc.pl.tsne(self._adata, color = color, **kwargs)
+        elif method == &#39;diffmap&#39;:
+            axes = sc.pl.diffmap(self._adata, color = color, **kwargs)
+        elif method == &#39;draw_graph&#39;:
+            axes = sc.pl.draw_graph(self._adata, color = color, **kwargs)
+        return axes
+
+
+    def init_latent_space(self, cluster_label = None, log_pi = None, res: float = 1.0, 
+                          ratio_prune= None, dist = None, dist_thres = 0.5, topk=0, pilayer = False):
+        &#39;&#39;&#39;Initialize the latent space.
+
+        Parameters
+        ----------
+        cluster_label : str, optional
+            The name of vector of labels that can be found in self.adata.obs. 
+            Default is None, which will perform leiden clustering on the pretrained z to get clusters
+        mu : np.array, optional
+            \([d,k]\) The value of initial \(\\mu\).
+        log_pi : np.array, optional
+            \([1,K]\) The value of initial \(\\log(\\pi)\).
+        res: 
+            The resolution of leiden clustering, which is a parameter value controlling the coarseness of the clustering. 
+            Higher values lead to more clusters. Deafult is 1.
+        ratio_prune : float, optional
+            The ratio of edges to be removed before estimating.
+        topk : int, optional
+            The number of top k neighbors to keep for each cluster.
+        &#39;&#39;&#39;   
+    
+        
+        if cluster_label is None:
+            print(&#34;Perform leiden clustering on the latent space z ...&#34;)
+            g = get_igraph(self.z)
+            cluster_labels = leidenalg_igraph(g, res = res)
+            cluster_labels = cluster_labels.astype(str) 
+            uni_cluster_labels = np.unique(cluster_labels)
+        else:
+            if isinstance(cluster_label,str):
+                cluster_labels = self.adata.obs[cluster_label].to_numpy()
+                uni_cluster_labels = np.array(self.adata.obs[cluster_label].cat.categories)
+            else:
+                ## if cluster_label is a list
+                cluster_labels = cluster_label
+                uni_cluster_labels = np.unique(cluster_labels)
+
+        n_clusters = len(uni_cluster_labels)
+
+        if not hasattr(self, &#39;z&#39;):
+            self.update_z()        
+        z = self.z
+        mu = np.zeros((z.shape[1], n_clusters))
+        for i,l in enumerate(uni_cluster_labels):
+            mu[:,i] = np.mean(z[cluster_labels==l], axis=0)
+       
+        if dist is None:
+            ### update cluster centers if some cluster centers are too close
+            clustering = AgglomerativeClustering(
+                n_clusters=None,
+                distance_threshold=dist_thres,
+                linkage=&#39;complete&#39;
+                ).fit(mu.T/np.sqrt(mu.shape[0]))
+            n_clusters_new = clustering.n_clusters_
+            if n_clusters_new &lt; n_clusters:
+                print(&#34;Merge clusters for cluster centers that are too close ...&#34;)
+                n_clusters = n_clusters_new
+                for i in range(n_clusters):    
+                    temp = uni_cluster_labels[clustering.labels_ == i]
+                    idx = np.isin(cluster_labels, temp)
+                    cluster_labels[idx] = &#39;,&#39;.join(temp)
+                    if np.sum(clustering.labels_==i)&gt;1:
+                        print(&#39;Merge %s&#39;% &#39;,&#39;.join(temp))
+                uni_cluster_labels = np.unique(cluster_labels)
+                mu = np.zeros((z.shape[1], n_clusters))
+                for i,l in enumerate(uni_cluster_labels):
+                    mu[:,i] = np.mean(z[cluster_labels==l], axis=0)
+            
+        self.adata.obs[&#39;vitae_init_clustering&#39;] = cluster_labels
+        self.adata.obs[&#39;vitae_init_clustering&#39;] = self.adata.obs[&#39;vitae_init_clustering&#39;].astype(&#39;category&#39;)
+        print(&#34;Initial clustering labels saved as &#39;vitae_init_clustering&#39; in self.adata.obs.&#34;)
+   
+        if (log_pi is None) and (cluster_labels is not None) and (n_clusters&gt;3):                         
+            n_states = int((n_clusters+1)*n_clusters/2)
+            
+            if dist is None:
+                dist = _comp_dist(z, cluster_labels, mu.T)
+
+            C = np.triu(np.ones(n_clusters))
+            C[C&gt;0] = np.arange(n_states)
+            C = C + C.T - np.diag(np.diag(C))
+            C = C.astype(int)
+
+            log_pi = np.zeros((1,n_states))            
+
+            ## pruning to throw away edges for far-away clusters if there are too many clusters
+            if ratio_prune is not None:
+                log_pi[0, C[np.triu(dist)&gt;np.quantile(dist[np.triu_indices(n_clusters, 1)], 1-ratio_prune)]] = - np.inf
+            else:
+                log_pi[0, C[np.triu(dist)&gt;np.quantile(dist[np.triu_indices(n_clusters, 1)], 5/n_clusters) * 3]] = - np.inf
+
+            ## also keep the top k neighbor of clusters
+            topk = max(0, min(topk, n_clusters-1)) + 1
+            topk_indices = np.argsort(dist,axis=1)[:,:topk]
+            for i in range(n_clusters):
+                log_pi[0, C[i, topk_indices[i]]] = 0
+
+        self.n_states = n_clusters
+        self.labels = cluster_labels
+        
+        labels_map = pd.DataFrame.from_dict(
+            {i:label for i,label in enumerate(uni_cluster_labels)}, 
+            orient=&#39;index&#39;, columns=[&#39;label_names&#39;], dtype=str
+            )
+        
+        self.labels_map = labels_map
+        self.vae.init_latent_space(self.n_states, mu, log_pi)
+        self.inferer = Inferer(self.n_states)
+        self.mu = self.vae.latent_space.mu.numpy()
+        self.pi = np.triu(np.ones(self.n_states))
+        self.pi[self.pi &gt; 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0]
+
+        if pilayer:
+            self.vae.create_pilayer()
+
+
+    def update_latent_space(self, dist_thres: float=0.5):
+        pi = self.pi[np.triu_indices(self.n_states)]
+        mu = self.mu    
+        clustering = AgglomerativeClustering(
+            n_clusters=None,
+            distance_threshold=dist_thres,
+            linkage=&#39;complete&#39;
+            ).fit(mu.T/np.sqrt(mu.shape[0]))
+        n_clusters = clustering.n_clusters_   
+
+        if n_clusters&lt;self.n_states:      
+            print(&#34;Merge clusters for cluster centers that are too close ...&#34;)
+            mu_new = np.empty((self.dim_latent, n_clusters))
+            C = np.zeros((self.n_states, self.n_states))
+            C[np.triu_indices(self.n_states, 0)] = pi
+            C = np.triu(C, 1) + C.T
+            C_new = np.zeros((n_clusters, n_clusters))
+            
+            uni_cluster_labels = self.labels_map[&#39;label_names&#39;].to_numpy()
+            returned_order = {}
+            cluster_labels = self.labels
+            for i in range(n_clusters):
+                temp = uni_cluster_labels[clustering.labels_ == i]
+                idx = np.isin(cluster_labels, temp)
+                cluster_labels[idx] = &#39;,&#39;.join(temp)
+                returned_order[i] = &#39;,&#39;.join(temp)
+                if np.sum(clustering.labels_==i)&gt;1:
+                    print(&#39;Merge %s&#39;% &#39;,&#39;.join(temp))
+            uni_cluster_labels = np.unique(cluster_labels) 
+            for i,l in enumerate(uni_cluster_labels):  ## reorder the merged clusters based on the cluster names
+                k = np.where(returned_order == l)
+                mu_new[:, i] = np.mean(mu[:,clustering.labels_==k], axis=-1)
+                # sum of the aggregated pi&#39;s
+                C_new[i, i] = np.sum(np.triu(C[clustering.labels_==k,:][:,clustering.labels_==k]))
+                for j in range(i+1, n_clusters):
+                    k1 = np.where(returned_order == uni_cluster_labels[j])
+                    C_new[i, j] = np.sum(C[clustering.labels_== k, :][:, clustering.labels_==k1])
+
+#            labels_map_new = {}
+#            for i in range(n_clusters):                       
+#                # update label map: int-&gt;str
+#                labels_map_new[i] = self.labels_map.loc[clustering.labels_==i, &#39;label_names&#39;].str.cat(sep=&#39;,&#39;)
+#                if np.sum(clustering.labels_==i)&gt;1:
+#                    print(&#39;Merge %s&#39;%labels_map_new[i])
+#                # mean of the aggregated cluster means
+#                mu_new[:, i] = np.mean(mu[:,clustering.labels_==i], axis=-1)
+#                # sum of the aggregated pi&#39;s
+#                C_new[i, i] = np.sum(np.triu(C[clustering.labels_==i,:][:,clustering.labels_==i]))
+#                for j in range(i+1, n_clusters):
+#                    C_new[i, j] = np.sum(C[clustering.labels_== i, :][:, clustering.labels_==j])
+            C_new = np.triu(C_new,1) + C_new.T
+
+            pi_new = C_new[np.triu_indices(n_clusters)]
+            log_pi_new = np.log(pi_new, out=np.ones_like(pi_new)*(-np.inf), where=(pi_new!=0)).reshape((1,-1))
+            self.n_states = n_clusters
+            self.labels_map = pd.DataFrame.from_dict(
+                {i:label for i,label in enumerate(uni_cluster_labels)},
+                orient=&#39;index&#39;, columns=[&#39;label_names&#39;], dtype=str
+                )
+            self.labels = cluster_labels
+#            self.labels_map = pd.DataFrame.from_dict(
+#                labels_map_new, orient=&#39;index&#39;, columns=[&#39;label_names&#39;], dtype=str
+#            )
+            self.vae.init_latent_space(self.n_states, mu_new, log_pi_new)
+            self.inferer = Inferer(self.n_states)
+            self.mu = self.vae.latent_space.mu.numpy()
+            self.pi = np.triu(np.ones(self.n_states))
+            self.pi[self.pi &gt; 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0]
+
+
+
+    def train(self, stratify = False, test_size = 0.1, random_state: int = 0,
+            learning_rate: float = 1e-3, batch_size: int = 256,
+            L: int = 1, alpha: float = 0.10, beta: float = 1, gamma: float = 0, phi: float = 1,
+            num_epoch: int = 200, num_step_per_epoch: Optional[int] =  None,
+            early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, 
+            early_stopping_relative: bool = True, early_stopping_warmup: int = 0,
+            path_to_weights: Optional[str] = None,
+            verbose: bool = False, **kwargs):
+        &#39;&#39;&#39;Train the model.
+
+        Parameters
+        ----------
+        stratify : np.array, None, or False
+            If an array is provided, or `stratify=None` and `self.labels` is available, then they will be used to perform stratified shuffle splitting. Otherwise, general shuffle splitting is used. Set to `False` if `self.labels` is not intended for stratified shuffle splitting.
+        test_size : float or int, optional
+            The proportion or size of the test set.
+        random_state : int, optional
+            The random state for data splitting.
+        learning_rate : float, optional  
+            The initial learning rate for the Adam optimizer.
+        batch_size : int, optional  
+            The batch size for training. Default is 256. Set to 32 if number of cells is small (less than 1000)
+        L : int, optional  
+            The number of MC samples.
+        alpha : float, optional  
+            The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.
+        beta : float, optional  
+            The value of beta in beta-VAE.
+        gamma : float, optional
+            The weight of mmd_loss.
+        phi : float, optional
+            The weight of Jacob norm of encoder.
+        num_epoch : int, optional  
+            The number of epoch.
+        num_step_per_epoch : int, optional 
+            The number of step per epoch, it will be inferred from number of cells and batch size if it is None.
+        early_stopping_patience : int, optional 
+            The maximum number of epochs if there is no improvement.
+        early_stopping_tolerance : float, optional 
+            The minimum change of loss to be considered as an improvement.
+        early_stopping_relative : bool, optional
+            Whether monitor the relative change of loss or not.            
+        early_stopping_warmup : int, optional 
+            The number of warmup epochs.            
+        path_to_weights : str, optional 
+            The path of weight file to be saved; not saving weight if None.
+        **kwargs :  
+            Extra key-value arguments for dimension reduction algorithms.        
+        &#39;&#39;&#39;
+        if gamma == 0 or self.conditions is None:
+            conditions = np.array([np.nan] * self.adata.shape[0])
+        else:
+            conditions = self.conditions
+
+        if stratify is None:
+            stratify = self.labels
+        elif stratify is False:
+            stratify = None    
+        id_train, id_test = train_test_split(
+                                np.arange(self.X_input.shape[0]), 
+                                test_size=test_size, 
+                                stratify=stratify, 
+                                random_state=random_state)
+        if num_step_per_epoch is None:
+            num_step_per_epoch = len(id_train)//batch_size+1
+        c = None if self.covariates is None else self.covariates.astype(tf.keras.backend.floatx())
+        self.train_dataset = train.warp_dataset(self.X_input[id_train].astype(tf.keras.backend.floatx()),
+                                                None if c is None else c[id_train],
+                                                batch_size, 
+                                                self.X_output[id_train].astype(tf.keras.backend.floatx()), 
+                                                self.scale_factor[id_train].astype(tf.keras.backend.floatx()),
+                                                conditions = conditions[id_train],
+                                                pi_cov = self.pi_cov[id_train])
+        self.test_dataset = train.warp_dataset(self.X_input[id_test].astype(tf.keras.backend.floatx()),
+                                                None if c is None else c[id_test],
+                                                batch_size, 
+                                                self.X_output[id_test].astype(tf.keras.backend.floatx()), 
+                                                self.scale_factor[id_test].astype(tf.keras.backend.floatx()),
+                                                conditions = conditions[id_test],
+                                                pi_cov = self.pi_cov[id_test])
+                                   
+        self.vae = train.train(
+            self.train_dataset,
+            self.test_dataset,
+            self.vae,
+            learning_rate,
+            L,
+            alpha,
+            beta,
+            gamma,
+            phi,
+            num_epoch,
+            num_step_per_epoch,
+            early_stopping_patience,
+            early_stopping_tolerance,
+            early_stopping_relative,
+            early_stopping_warmup,  
+            verbose,
+            **kwargs            
+            )
+        
+        self.update_z()
+        self.mu = self.vae.latent_space.mu.numpy()
+        self.pi = np.triu(np.ones(self.n_states))
+        self.pi[self.pi &gt; 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0]
+            
+        if path_to_weights is not None:
+            self.save_model(path_to_weights)
+    
+
+    def output_pi(self, pi_cov):
+        &#34;&#34;&#34;return a matrix n_states by n_states and a mask for plotting, which can be used to cover the lower triangular(except the diagnoals) of a heatmap&#34;&#34;&#34;
+        p = self.vae.pilayer
+        pi_cov = tf.expand_dims(tf.constant([pi_cov], dtype=tf.float32), 0)
+        pi_val = tf.nn.softmax(p(pi_cov)).numpy()[0]
+        # Create heatmap matrix
+        n = self.vae.n_states
+        matrix = np.zeros((n, n))
+        matrix[np.triu_indices(n)] = pi_val
+        mask = np.tril(np.ones_like(matrix), k=-1)
+        return matrix, mask
+
+
+    def return_pilayer_weights(self):
+        &#34;&#34;&#34;return parameters of pilayer, which has dimension dim(pi_cov) + 1 by n_categories, the last row is biases&#34;&#34;&#34;
+        return np.vstack((model.vae.pilayer.weights[0].numpy(), model.vae.pilayer.weights[1].numpy().reshape(1, -1)))
+
+
+    def posterior_estimation(self, batch_size: int = 32, L: int = 50, **kwargs):
+        &#39;&#39;&#39;Initialize trajectory inference by computing the posterior estimations.        
+
+        Parameters
+        ----------
+        batch_size : int, optional
+            The batch size when doing inference.
+        L : int, optional
+            The number of MC samples when doing inference.
+        **kwargs :  
+            Extra key-value arguments for dimension reduction algorithms.              
+        &#39;&#39;&#39;
+        c = None if self.covariates is None else self.covariates.astype(tf.keras.backend.floatx())
+        self.test_dataset = train.warp_dataset(self.X_input.astype(tf.keras.backend.floatx()), 
+                                               c,
+                                               batch_size)
+        _, _, self.pc_x,\
+            self.cell_position_posterior,self.cell_position_variance,_ = self.vae.inference(self.test_dataset, L=L)
+            
+        uni_cluster_labels = self.labels_map[&#39;label_names&#39;].to_numpy()
+        self.adata.obs[&#39;vitae_new_clustering&#39;] = uni_cluster_labels[np.argmax(self.cell_position_posterior, 1)]
+        self.adata.obs[&#39;vitae_new_clustering&#39;] = self.adata.obs[&#39;vitae_new_clustering&#39;].astype(&#39;category&#39;)
+        print(&#34;New clustering labels saved as &#39;vitae_new_clustering&#39; in self.adata.obs.&#34;)
+        return None
+
+
+    def infer_backbone(self, method: str = &#39;modified_map&#39;, thres = 0.5,
+            no_loop: bool = True, cutoff: float = 0,
+            visualize: bool = True, color = &#39;vitae_new_clustering&#39;,path_to_fig = None,**kwargs):
+        &#39;&#39;&#39; Compute edge scores.
+
+        Parameters
+        ----------
+        method : string, optional
+            &#39;mean&#39;, &#39;modified_mean&#39;, &#39;map&#39;, or &#39;modified_map&#39;.
+        thres : float, optional
+            The threshold used for filtering edges \(e_{ij}\) that \((n_{i}+n_{j}+e_{ij})/N&lt;thres\), only applied to mean method.
+        no_loop : boolean, optional
+            Whether loops are allowed to exist in the graph. If no_loop is true, will prune the graph to contain only the
+            maximum spanning true
+        cutoff : string, optional
+            The score threshold for filtering edges with scores less than cutoff.
+        visualize: boolean
+            whether plot the current trajectory backbone (undirected graph)
+
+        Returns
+        ----------
+        G : nx.Graph
+            The weighted graph with weight on each edge indicating its score of existence.
+        &#39;&#39;&#39;
+        # build_graph, return graph
+        self.backbone = self.inferer.build_graphs(self.cell_position_posterior, self.pc_x,
+                method, thres, no_loop, cutoff)
+        self.cell_position_projected = self.inferer.modify_wtilde(self.cell_position_posterior, 
+                np.array(list(self.backbone.edges)))
+        
+        uni_cluster_labels = self.labels_map[&#39;label_names&#39;].to_numpy()
+        temp_dict = {i:label for i,label in enumerate(uni_cluster_labels)}
+        nx.relabel_nodes(self.backbone, temp_dict)
+       
+        self.adata.obs[&#39;vitae_new_clustering&#39;] = uni_cluster_labels[np.argmax(self.cell_position_projected, 1)]
+        self.adata.obs[&#39;vitae_new_clustering&#39;] = self.adata.obs[&#39;vitae_new_clustering&#39;].astype(&#39;category&#39;)
+        print(&#34;&#39;vitae_new_clustering&#39; updated based on the projected cell positions.&#34;)
+
+        self.uncertainty = np.sum((self.cell_position_projected - self.cell_position_posterior)**2, axis=-1) \
+            + np.sum(self.cell_position_variance, axis=-1)
+        self.adata.obs[&#39;projection_uncertainty&#39;] = self.uncertainty
+        print(&#34;Cell projection uncertainties stored as &#39;projection_uncertainty&#39; in self.adata.obs&#34;)
+        if visualize:
+            self._adata.obs = self.adata.obs.copy()
+            self.ax = self.plot_backbone(directed = False,color = color, **kwargs)
+            if path_to_fig is not None:
+                self.ax.figure.savefig(path_to_fig)
+            self.ax.figure.show()
+        return None
+
+
+    def select_root(self, days, method: str = &#39;proportion&#39;):
+        &#39;&#39;&#39;Order the vertices/states based on cells&#39; collection time information to select the root state.      
+
+        Parameters
+        ----------
+        day : np.array 
+            The day information for selected cells used to determine the root vertex.
+            The dtype should be &#39;int&#39; or &#39;float&#39;.
+        method : str, optional
+            &#39;sum&#39; or &#39;mean&#39;. 
+            For &#39;proportion&#39;, the root is the one with maximal proportion of cells from the earliest day.
+            For &#39;mean&#39;, the root is the one with earliest mean time among cells associated with it.
+
+        Returns
+        ----------
+        root : int 
+            The root vertex in the inferred trajectory based on given day information.
+        &#39;&#39;&#39;
+        ## TODO: change return description
+        if days is not None and len(days)!=self.X_input.shape[0]:
+            raise ValueError(&#34;The length of day information ({}) is not &#34;
+                &#34;consistent with the number of selected cells ({})!&#34;.format(
+                    len(days), self.X_input.shape[0]))
+        if not hasattr(self, &#39;cell_position_projected&#39;):
+            raise ValueError(&#34;Need to call &#39;infer_backbone&#39; first!&#34;)
+
+        collection_time = np.dot(days, self.cell_position_projected)/np.sum(self.cell_position_projected, axis = 0)
+        earliest_prop = np.dot(days==np.min(days), self.cell_position_projected)/np.sum(self.cell_position_projected, axis = 0)
+        
+        root_info = self.labels_map.copy()
+        root_info[&#39;mean_collection_time&#39;] = collection_time
+        root_info[&#39;earliest_time_prop&#39;] = earliest_prop
+        root_info.sort_values(&#39;mean_collection_time&#39;, inplace=True)
+        return root_info
+
+
+    def plot_backbone(self, directed: bool = False, 
+                      method: str = &#39;UMAP&#39;, color = &#39;vitae_new_clustering&#39;, **kwargs):
+        &#39;&#39;&#39;Plot the current trajectory backbone (undirected graph).
+
+        Parameters
+        ----------
+        directed : boolean, optional
+            Whether the backbone is directed or not.
+        method : str, optional
+            The dimension reduction method to use. The default is &#34;UMAP&#34;.
+        color : str, optional
+            The key for annotations of observations/cells or variables/genes, e.g., &#39;ann1&#39; or [&#39;ann1&#39;, &#39;ann2&#39;].
+            The default is &#39;vitae_new_clustering&#39;.
+        **kwargs :
+            Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX).
+        &#39;&#39;&#39;
+        if not isinstance(color,str):
+            raise ValueError(&#39;The color argument should be of type str!&#39;)
+        ax = self.visualize_latent(method = method, color=color, show=False, **kwargs)
+        dict_label_num = {j:i for i,j in self.labels_map[&#39;label_names&#39;].to_dict().items()}
+        uni_cluster_labels = self.adata.obs[&#39;vitae_init_clustering&#39;].cat.categories
+        cluster_labels = self.adata.obs[&#39;vitae_new_clustering&#39;].to_numpy()
+        embed_z = self._adata.obsm[self.dict_method_scname[method]]
+        embed_mu = np.zeros((len(uni_cluster_labels), 2))
+        for l in uni_cluster_labels:
+            embed_mu[dict_label_num[l],:] = np.mean(embed_z[cluster_labels==l], axis=0)
+
+        if directed:
+            graph = self.directed_backbone
+        else:
+            graph = self.backbone
+        edges = list(graph.edges)
+        edge_scores = np.array([d[&#39;weight&#39;] for (u,v,d) in graph.edges(data=True)])
+        if max(edge_scores) - min(edge_scores) == 0:
+            edge_scores = edge_scores/max(edge_scores)
+        else:
+            edge_scores = (edge_scores - min(edge_scores))/(max(edge_scores) - min(edge_scores))*3
+
+        value_range = np.maximum(np.diff(ax.get_xlim())[0], np.diff(ax.get_ylim())[0])
+        y_range = np.min(embed_z[:,1]), np.max(embed_z[:,1], axis=0)
+        for i in range(len(edges)):
+            points = embed_z[np.sum(self.cell_position_projected[:, edges[i]]&gt;0, axis=-1)==2,:]
+            points = points[points[:,0].argsort()]
+            try:
+                x_smooth, y_smooth = _get_smooth_curve(
+                    points,
+                    embed_mu[edges[i], :],
+                    y_range
+                    )
+            except:
+                x_smooth, y_smooth = embed_mu[edges[i], 0], embed_mu[edges[i], 1]
+            ax.plot(x_smooth, y_smooth,
+                &#39;-&#39;,
+                linewidth= 1 + edge_scores[i],
+                color=&#34;black&#34;,
+                alpha=0.8,
+                path_effects=[pe.Stroke(linewidth=1+edge_scores[i]+1.5,
+                                        foreground=&#39;white&#39;), pe.Normal()],
+                zorder=1
+                )
+
+            if directed:
+                delta_x = embed_mu[edges[i][1], 0] - x_smooth[-2]
+                delta_y = embed_mu[edges[i][1], 1] - y_smooth[-2]
+                length = np.sqrt(delta_x**2 + delta_y**2) / 50 * value_range
+                ax.arrow(
+                        embed_mu[edges[i][1], 0]-delta_x/length,
+                        embed_mu[edges[i][1], 1]-delta_y/length,
+                        delta_x/length,
+                        delta_y/length,
+                        color=&#39;black&#39;, alpha=1.0,
+                        shape=&#39;full&#39;, lw=0, length_includes_head=True,
+                        head_width=np.maximum(0.01*(1 + edge_scores[i]), 0.03) * value_range,
+                        zorder=2) 
+        
+        colors = self._adata.uns[&#39;vitae_new_clustering_colors&#39;]
+            
+        for i,l in enumerate(uni_cluster_labels):
+            ax.scatter(*embed_mu[dict_label_num[l]:dict_label_num[l]+1,:].T, 
+                       c=[colors[i]], edgecolors=&#39;white&#39;, # linewidths=10,  norm=norm,
+                       s=250, marker=&#39;*&#39;, label=l)
+
+        plt.setp(ax, xticks=[], yticks=[])
+        box = ax.get_position()
+        ax.set_position([box.x0, box.y0 + box.height * 0.1,
+                            box.width, box.height * 0.9])
+        if directed:
+            ax.legend(loc=&#39;upper center&#39;, bbox_to_anchor=(0.5, -0.05),
+                fancybox=True, shadow=True, ncol=5)
+
+        return ax
+
+
+    def plot_center(self, color = &#34;vitae_new_clustering&#34;, plot_legend = True, legend_add_index = True,
+                    method: str = &#39;UMAP&#39;,ncol = 2,font_size = &#34;medium&#34;,
+                    add_egde = False, add_direct = False,**kwargs):
+        &#39;&#39;&#39;Plot the center of each cluster in the latent space.
+
+        Parameters
+        ----------
+        color : str, optional
+            The color of the center of each cluster. Default is &#34;vitae_new_clustering&#34;.
+        plot_legend : bool, optional
+            Whether to plot the legend. Default is True.
+        legend_add_index : bool, optional
+            Whether to add the index of each cluster in the legend. Default is True.
+        method : str, optional
+            The dimension reduction method used for visualization. Default is &#39;UMAP&#39;.
+        ncol : int, optional
+            The number of columns in the legend. Default is 2.
+        font_size : str, optional
+            The font size of the legend. Default is &#34;medium&#34;.
+        add_egde : bool, optional
+            Whether to add the edges between the centers of clusters. Default is False.
+        add_direct : bool, optional
+            Whether to add the direction of the edges. Default is False.
+        &#39;&#39;&#39;
+        if color not in [&#34;vitae_new_clustering&#34;,&#34;vitae_init_clustering&#34;]:
+            raise ValueError(&#34;Can only plot center of vitae_new_clustering or vitae_init_clustering&#34;)
+        dict_label_num = {j: i for i, j in self.labels_map[&#39;label_names&#39;].to_dict().items()}
+        if legend_add_index:
+            self._adata.obs[&#34;index_&#34;+color] = self._adata.obs[color].map(lambda x: dict_label_num[x])
+            ax = self.visualize_latent(method=method, color=&#34;index_&#34; + color, show=False, legend_loc=&#34;on data&#34;,
+                                        legend_fontsize=font_size,**kwargs)
+            colors = self._adata.uns[&#34;index_&#34; + color + &#39;_colors&#39;]
+        else:
+            ax = self.visualize_latent(method=method, color = color, show=False,**kwargs)
+            colors = self._adata.uns[color + &#39;_colors&#39;]
+        uni_cluster_labels = self.adata.obs[color].cat.categories
+        cluster_labels = self.adata.obs[color].to_numpy()
+        embed_z = self._adata.obsm[self.dict_method_scname[method]]
+        embed_mu = np.zeros((len(uni_cluster_labels), 2))
+        for l in uni_cluster_labels:
+            embed_mu[dict_label_num[l], :] = np.mean(embed_z[cluster_labels == l], axis=0)
+
+        leg = (self.labels_map.index.astype(str) + &#34; : &#34; + self.labels_map.label_names).values
+        for i, l in enumerate(uni_cluster_labels):
+            ax.scatter(*embed_mu[dict_label_num[l]:dict_label_num[l] + 1, :].T,
+                       c=[colors[i]], edgecolors=&#39;white&#39;, # linewidths=3,
+                       s=250, marker=&#39;*&#39;, label=leg[i])
+        if plot_legend:
+            ax.legend(loc=&#39;center left&#39;, bbox_to_anchor=(1, 0.5), ncol=ncol, markerscale=0.8, frameon=False)
+        plt.setp(ax, xticks=[], yticks=[])
+        box = ax.get_position()
+        ax.set_position([box.x0, box.y0 + box.height * 0.1,
+                         box.width, box.height * 0.9])
+        if add_egde:
+            if add_direct:
+                graph = self.directed_backbone
+            else:
+                graph = self.backbone
+            edges = list(graph.edges)
+            edge_scores = np.array([d[&#39;weight&#39;] for (u, v, d) in graph.edges(data=True)])
+            if max(edge_scores) - min(edge_scores) == 0:
+                edge_scores = edge_scores / max(edge_scores)
+            else:
+                edge_scores = (edge_scores - min(edge_scores)) / (max(edge_scores) - min(edge_scores)) * 3
+
+            value_range = np.maximum(np.diff(ax.get_xlim())[0], np.diff(ax.get_ylim())[0])
+            y_range = np.min(embed_z[:, 1]), np.max(embed_z[:, 1], axis=0)
+            for i in range(len(edges)):
+                points = embed_z[np.sum(self.cell_position_projected[:, edges[i]] &gt; 0, axis=-1) == 2, :]
+                points = points[points[:, 0].argsort()]
+                try:
+                    x_smooth, y_smooth = _get_smooth_curve(
+                        points,
+                        embed_mu[edges[i], :],
+                        y_range
+                    )
+                except:
+                    x_smooth, y_smooth = embed_mu[edges[i], 0], embed_mu[edges[i], 1]
+                ax.plot(x_smooth, y_smooth,
+                        &#39;-&#39;,
+                        linewidth=1 + edge_scores[i],
+                        color=&#34;black&#34;,
+                        alpha=0.8,
+                        path_effects=[pe.Stroke(linewidth=1 + edge_scores[i] + 1.5,
+                                                foreground=&#39;white&#39;), pe.Normal()],
+                        zorder=1
+                        )
+
+                if add_direct:
+                    delta_x = embed_mu[edges[i][1], 0] - x_smooth[-2]
+                    delta_y = embed_mu[edges[i][1], 1] - y_smooth[-2]
+                    length = np.sqrt(delta_x ** 2 + delta_y ** 2) / 50 * value_range
+                    ax.arrow(
+                        embed_mu[edges[i][1], 0] - delta_x / length,
+                        embed_mu[edges[i][1], 1] - delta_y / length,
+                        delta_x / length,
+                        delta_y / length,
+                        color=&#39;black&#39;, alpha=1.0,
+                        shape=&#39;full&#39;, lw=0, length_includes_head=True,
+                        head_width=np.maximum(0.01 * (1 + edge_scores[i]), 0.03) * value_range,
+                        zorder=2)
+        self.ax = ax
+        self.ax.figure.show()
+        return None
+
+
+    def infer_trajectory(self, root: Union[int,str], digraph = None, color = &#34;pseudotime&#34;,
+                         visualize: bool = True, path_to_fig = None,  **kwargs):
+        &#39;&#39;&#39;Infer the trajectory.
+
+        Parameters
+        ----------
+        root : int or string
+            The root of the inferred trajectory. Can provide either an int (vertex index) or string (label name)
+        digraph : nx.DiGraph, optional
+            The directed graph to be used for trajectory inference. If None, the minimum spanning tree of the estimated trajectory backbone will be used.
+        cutoff : string, optional
+            The threshold for filtering edges with scores less than cutoff.
+        visualize: boolean
+            Whether plot the current trajectory backbone (directed graph)
+        path_to_fig : string, optional  
+            The path to save figure, or don&#39;t save if it is None.
+        **kwargs : dict, optional
+            Other keywords arguments for plotting.
+        &#39;&#39;&#39;
+        if isinstance(root,str):
+            if root not in self.labels_map.values:
+                raise ValueError(&#34;Root {} is not in the label names!&#34;.format(root))
+            root = self.labels_map[self.labels_map[&#39;label_names&#39;]==root].index[0]
+
+        if digraph is None:
+            connected_comps = nx.node_connected_component(self.backbone, root)
+            subG = self.backbone.subgraph(connected_comps)
+
+            ## generate directed backbone which contains no loops
+            DG = nx.DiGraph(nx.to_directed(self.backbone))
+            temp = DG.subgraph(connected_comps)
+            DG.remove_edges_from(temp.edges - nx.dfs_edges(DG, root))
+            self.directed_backbone = DG
+        else:
+            if not nx.is_directed_acyclic_graph(digraph):
+                raise ValueError(&#34;The graph &#39;digraph&#39; should be a directed acyclic graph.&#34;)
+            if set(digraph.nodes) != set(self.backbone.nodes):
+                raise ValueError(&#34;The nodes in &#39;digraph&#39; do not match the nodes in &#39;self.backbone&#39;.&#34;)
+            self.directed_backbone = digraph
+
+            connected_comps = nx.node_connected_component(digraph, root)
+            subG = self.backbone.subgraph(connected_comps)
+
+
+        if len(subG.edges)&gt;0:
+            milestone_net = self.inferer.build_milestone_net(subG, root)
+            if self.inferer.no_loop is False and milestone_net.shape[0]&lt;len(self.backbone.edges):
+                warnings.warn(&#34;The directed graph shown is a minimum spanning tree of the estimated trajectory backbone to avoid arbitrary assignment of the directions.&#34;)
+            self.pseudotime = self.inferer.comp_pseudotime(milestone_net, root, self.cell_position_projected)
+        else:
+            warnings.warn(&#34;There are no connected states for starting from the giving root.&#34;)
+            self.pseudotime = -np.ones(self._adata.shape[0])
+
+        self.adata.obs[&#39;pseudotime&#39;] = self.pseudotime
+        print(&#34;Cell projection uncertainties stored as &#39;pseudotime&#39; in self.adata.obs&#34;)
+
+        if visualize:
+            self._adata.obs[&#39;pseudotime&#39;] = self.pseudotime
+            self.ax = self.plot_backbone(directed = True, color = color, **kwargs)
+            if path_to_fig is not None:
+                self.ax.figure.savefig(path_to_fig)
+            self.ax.figure.show()
+
+        return None
+
+
+
+    def differential_expression_test(self, alpha: float = 0.05, cell_subset = None, order: int = 1):
+        &#39;&#39;&#39;Differentially gene expression test. All (selected and unselected) genes will be tested 
+        Only cells in `selected_cell_subset` will be used, which is useful when one need to
+        test differentially expressed genes on a branch of the inferred trajectory.
+
+        Parameters
+        ----------
+        alpha : float, optional
+            The cutoff of p-values.
+        cell_subset : np.array, optional
+            The subset of cells to be used for testing. If None, all cells will be used.
+        order : int, optional
+            The maxium order we used for pseudotime in regression.
+
+        Returns
+        ----------
+        res_df : pandas.DataFrame
+            The test results of expressed genes with two columns,
+            the estimated coefficients and the adjusted p-values.
+        &#39;&#39;&#39;
+        if not hasattr(self, &#39;pseudotime&#39;):
+            raise ReferenceError(&#34;Pseudotime does not exist! Please run &#39;infer_trajectory&#39; first.&#34;)
+        if cell_subset is None:
+            cell_subset = np.arange(self.X_input.shape[0])
+            print(&#34;All cells are selected.&#34;)
+        if order &lt; 1:
+            raise  ValueError(&#34;Maximal order of pseudotime in regression must be at least 1.&#34;)
+
+        # Prepare X and Y for regression expression ~ rank(PDT) + covariates
+        Y = self.adata.X[cell_subset,:]
+#        std_Y = np.std(Y, ddof=1, axis=0, keepdims=True)
+#        Y = np.divide(Y-np.mean(Y, axis=0, keepdims=True), std_Y, out=np.empty_like(Y)*np.nan, where=std_Y!=0)
+        X = stats.rankdata(self.pseudotime[cell_subset])        
+        if order &gt; 1:
+            for _order in range(2, order+1):
+                X = np.c_[X, X**_order]
+        X = ((X-np.mean(X,axis=0, keepdims=True))/np.std(X, ddof=1, axis=0, keepdims=True))
+        X = np.c_[np.ones((X.shape[0],1)), X]
+        if self.covariates is not None:
+            X = np.c_[X, self.covariates[cell_subset, :]]
+
+        res_df = DE_test(Y, X, self.adata.var_names, i_test = np.array(list(range(1,order+1))), alpha = alpha)
+        return res_df[res_df.pvalue_adjusted_1 != 0]
+
+
+ 
+
+    def evaluate(self, milestone_net, begin_node_true, grouping = None,
+                thres: float = 0.5, no_loop: bool = True, cutoff: Optional[float] = None,
+                method: str = &#39;mean&#39;, path: Optional[str] = None):
+        &#39;&#39;&#39; Evaluate the model.
+
+        Parameters
+        ----------
+        milestone_net : pd.DataFrame
+            The true milestone network. For real data, milestone_net will be a DataFrame of the graph of nodes.
+            Eg.
+
+            from|to
+            ---|---
+            cluster 1 | cluster 1
+            cluster 1 | cluster 2
+
+            For synthetic data, milestone_net will be a DataFrame of the (projected)
+            positions of cells. The indexes are the orders of cells in the dataset.
+            Eg.
+
+            from|to|w
+            ---|---|---
+            cluster 1 | cluster 1 | 1
+            cluster 1 | cluster 2 | 0.1
+        begin_node_true : str or int
+            The true begin node of the milestone.
+        grouping : np.array, optional
+            \([N,]\) The labels. For real data, grouping must be provided.
+
+        Returns
+        ----------
+        res : pd.DataFrame
+            The evaluation result.
+        &#39;&#39;&#39;
+        if not hasattr(self, &#39;labels_map&#39;):
+            raise ValueError(&#34;No given labels for training.&#34;)
+
+        &#39;&#39;&#39;
+        # Evaluate for the whole dataset will ignore selected_cell_subset.
+        if len(self.selected_cell_subset)!=len(self.cell_names):
+            warnings.warn(&#34;Evaluate for the whole dataset.&#34;)
+        &#39;&#39;&#39;
+        
+        # If the begin_node_true, need to encode it by self.le.
+        # this dict is for milestone net cause their labels are not merged
+        # all keys of label_map_dict are str
+        label_map_dict = dict()
+        for i in range(self.labels_map.shape[0]):
+            label_mapped = self.labels_map.loc[i]
+            ## merged cluster index is connected by comma
+            for each in label_mapped.values[0].split(&#34;,&#34;):
+                label_map_dict[each] = i
+        if isinstance(begin_node_true, str):
+            begin_node_true = label_map_dict[begin_node_true]
+            
+        # For generated data, grouping information is already in milestone_net
+        if &#39;w&#39; in milestone_net.columns:
+            grouping = None
+            
+        # If milestone_net is provided, transform them to be numeric.
+        if milestone_net is not None:
+            milestone_net[&#39;from&#39;] = [label_map_dict[x] for x in milestone_net[&#34;from&#34;]]
+            milestone_net[&#39;to&#39;] = [label_map_dict[x] for x in milestone_net[&#34;to&#34;]]
+
+        # this dict is for potentially merged clusters.
+        label_map_dict_for_merged_cluster = dict(zip(self.labels_map[&#34;label_names&#34;],self.labels_map.index))
+        mapped_labels = np.array([label_map_dict_for_merged_cluster[x] for x in self.labels])
+        begin_node_pred = int(np.argmin(np.mean((
+            self.z[mapped_labels==begin_node_true,:,np.newaxis] -
+            self.mu[np.newaxis,:,:])**2, axis=(0,1))))
+
+        if cutoff is None:
+            cutoff = 0.01
+
+        G = self.backbone
+        w = self.cell_position_projected
+        pseudotime = self.pseudotime
+
+        # 1. Topology
+        G_pred = nx.Graph()
+        G_pred.add_nodes_from(G.nodes)
+        G_pred.add_edges_from(G.edges)
+        nx.set_node_attributes(G_pred, False, &#39;is_init&#39;)
+        G_pred.nodes[begin_node_pred][&#39;is_init&#39;] = True
+
+        G_true = nx.Graph()
+        G_true.add_nodes_from(G.nodes)
+        # if &#39;grouping&#39; is not provided, assume &#39;milestone_net&#39; contains proportions
+        if grouping is None:
+            G_true.add_edges_from(list(
+                milestone_net[~pd.isna(milestone_net[&#39;w&#39;])].groupby([&#39;from&#39;, &#39;to&#39;]).count().index))
+        # otherwise, &#39;milestone_net&#39; indicates edges
+        else:
+            if milestone_net is not None:             
+                G_true.add_edges_from(list(
+                    milestone_net.groupby([&#39;from&#39;, &#39;to&#39;]).count().index))
+            grouping = [label_map_dict[x] for x in grouping]
+            grouping = np.array(grouping)
+        G_true.remove_edges_from(nx.selfloop_edges(G_true))
+        nx.set_node_attributes(G_true, False, &#39;is_init&#39;)
+        G_true.nodes[begin_node_true][&#39;is_init&#39;] = True
+        res = topology(G_true, G_pred)
+            
+        # 2. Milestones assignment
+        if grouping is None:
+            milestones_true = milestone_net[&#39;from&#39;].values.copy()
+            milestones_true[(milestone_net[&#39;from&#39;]!=milestone_net[&#39;to&#39;])
+                           &amp;(milestone_net[&#39;w&#39;]&lt;0.5)] = milestone_net[(milestone_net[&#39;from&#39;]!=milestone_net[&#39;to&#39;])
+                                                                      &amp;(milestone_net[&#39;w&#39;]&lt;0.5)][&#39;to&#39;].values
+        else:
+            milestones_true = grouping
+        milestones_true = milestones_true
+        milestones_pred = np.argmax(w, axis=1)
+        res[&#39;ARI&#39;] = (adjusted_rand_score(milestones_true, milestones_pred) + 1)/2
+        
+        if grouping is None:
+            n_samples = len(milestone_net)
+            prop = np.zeros((n_samples,n_samples))
+            prop[np.arange(n_samples), milestone_net[&#39;to&#39;]] = 1-milestone_net[&#39;w&#39;]
+            prop[np.arange(n_samples), milestone_net[&#39;from&#39;]] = np.where(np.isnan(milestone_net[&#39;w&#39;]), 1, milestone_net[&#39;w&#39;])
+            res[&#39;GRI&#39;] = get_GRI(prop, w)
+        else:
+            res[&#39;GRI&#39;] = get_GRI(grouping, w)
+        
+        # 3. Correlation between geodesic distances / Pseudotime
+        if no_loop:
+            if grouping is None:
+                pseudotime_true = milestone_net[&#39;from&#39;].values + 1 - milestone_net[&#39;w&#39;].values
+                pseudotime_true[np.isnan(pseudotime_true)] = milestone_net[pd.isna(milestone_net[&#39;w&#39;])][&#39;from&#39;].values            
+            else:
+                pseudotime_true = - np.ones(len(grouping))
+                nx.set_edge_attributes(G_true, values = 1, name = &#39;weight&#39;)
+                connected_comps = nx.node_connected_component(G_true, begin_node_true)
+                subG = G_true.subgraph(connected_comps)
+                milestone_net_true = self.inferer.build_milestone_net(subG, begin_node_true)
+                if len(milestone_net_true)&gt;0:
+                    pseudotime_true[grouping==int(milestone_net_true[0,0])] = 0
+                    for i in range(len(milestone_net_true)):
+                        pseudotime_true[grouping==int(milestone_net_true[i,1])] = milestone_net_true[i,-1]
+            pseudotime_true = pseudotime_true[pseudotime&gt;-1]
+            pseudotime_pred = pseudotime[pseudotime&gt;-1]
+            res[&#39;PDT score&#39;] = (np.corrcoef(pseudotime_true,pseudotime_pred)[0,1]+1)/2
+        else:
+            res[&#39;PDT score&#39;] = np.nan
+            
+        # 4. Shape
+        # score_cos_theta = 0
+        # for (_from,_to) in G.edges:
+        #     _z = self.z[(w[:,_from]&gt;0) &amp; (w[:,_to]&gt;0),:]
+        #     v_1 = _z - self.mu[:,_from]
+        #     v_2 = _z - self.mu[:,_to]
+        #     cos_theta = np.sum(v_1*v_2, -1)/(np.linalg.norm(v_1,axis=-1)*np.linalg.norm(v_2,axis=-1)+1e-12)
+
+        #     score_cos_theta += np.sum((1-cos_theta)/2)
+
+        # res[&#39;score_cos_theta&#39;] = score_cos_theta/(np.sum(np.sum(w&gt;0, axis=-1)==2)+1e-12)
+        return res
+
+
+    def save_model(self, path_to_file: str = &#39;model.checkpoint&#39;,save_adata: bool = False):
+        &#39;&#39;&#39;Saving model weights.
+
+        Parameters
+        ----------
+        path_to_file : str, optional
+            The path to weight files of pre-trained or trained model
+        save_adata : boolean, optional
+            Whether to save adata or not.
+        &#39;&#39;&#39;
+        self.vae.save_weights(path_to_file)
+        if hasattr(self, &#39;labels&#39;) and self.labels is not None:
+            with open(path_to_file + &#39;.label&#39;, &#39;wb&#39;) as f:
+                np.save(f, self.labels)
+        with open(path_to_file + &#39;.config&#39;, &#39;wb&#39;) as f:
+            self.dim_origin = self.X_input.shape[1]
+            np.save(f, np.array([
+                self.dim_origin, self.dimensions, self.dim_latent,
+                self.model_type, 0 if self.covariates is None else self.covariates.shape[1]], dtype=object))
+        if hasattr(self, &#39;inferer&#39;) and hasattr(self, &#39;uncertainty&#39;):
+            with open(path_to_file + &#39;.inference&#39;, &#39;wb&#39;) as f:
+                np.save(f, np.array([
+                    self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty,
+                    self.z,self.cell_position_variance], dtype=object))
+        if save_adata:
+            self.adata.write(path_to_file + &#39;.adata.h5ad&#39;)
+
+
+    def load_model(self, path_to_file: str = &#39;model.checkpoint&#39;, load_labels: bool = False, load_adata: bool = False):
+        &#39;&#39;&#39;Load model weights.
+
+        Parameters
+        ----------
+        path_to_file : str, optional
+            The path to weight files of pre trained or trained model
+        load_labels : boolean, optional
+            Whether to load clustering labels or not.
+            If load_labels is True, then the LatentSpace layer will be initialized basd on the model.
+            If load_labels is False, then the LatentSpace layer will not be initialized.
+        load_adata : boolean, optional
+            Whether to load adata or not.
+        &#39;&#39;&#39;
+        if not os.path.exists(path_to_file + &#39;.config&#39;):
+            raise AssertionError(&#39;Config file not exist!&#39;)
+        if load_labels and not os.path.exists(path_to_file + &#39;.label&#39;):
+            raise AssertionError(&#39;Label file not exist!&#39;)
+
+        with open(path_to_file + &#39;.config&#39;, &#39;rb&#39;) as f:
+            [self.dim_origin, self.dimensions,
+             self.dim_latent, self.model_type, cov_dim] = np.load(f, allow_pickle=True)
+        self.vae = model.VariationalAutoEncoder(
+            self.dim_origin, self.dimensions,
+            self.dim_latent, self.model_type, False if cov_dim == 0 else True
+        )
+
+        if load_labels:
+            with open(path_to_file + &#39;.label&#39;, &#39;rb&#39;) as f:
+                cluster_labels = np.load(f, allow_pickle=True)
+            self.init_latent_space(cluster_labels, dist_thres=0)
+            if os.path.exists(path_to_file + &#39;.inference&#39;):
+                with open(path_to_file + &#39;.inference&#39;, &#39;rb&#39;) as f:
+                    arr = np.load(f, allow_pickle=True)
+                    if len(arr) == 8:
+                        [self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty,
+                         self.D_JS, self.z,self.cell_position_variance] = arr
+                    else:
+                        [self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty,
+                         self.z,self.cell_position_variance] = arr
+                self._adata_z = sc.AnnData(self.z)
+                sc.pp.neighbors(self._adata_z)
+        ## initialize the weight of encoder and decoder
+        self.vae.encoder(np.zeros((1, self.dim_origin + cov_dim)))
+        self.vae.decoder(np.expand_dims(np.zeros((1,self.dim_latent + cov_dim)),1))
+
+        self.vae.load_weights(path_to_file)
+        self.update_z()
+        if load_adata:
+            if not os.path.exists(path_to_file + &#39;.adata.h5ad&#39;):
+                raise AssertionError(&#39;AnnData file not exist!&#39;)
+            self.adata = sc.read_h5ad(path_to_file + &#39;.adata.h5ad&#39;)
+            self._adata.obs = self.adata.obs.copy()</code></pre>
+</details>
+<h3>Methods</h3>
+<dl>
+<dt id="VITAE.VITAE.pre_train"><code class="name flex">
+<span>def <span class="ident">pre_train</span></span>(<span>self, test_size=0.1, random_state: int = 0, learning_rate: float = 0.001, batch_size: int = 256, L: int = 1, alpha: float = 0.1, gamma: float = 0, phi: float = 1, num_epoch: int = 200, num_step_per_epoch: Optional[int] = None, early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, early_stopping_relative: bool = True, verbose: bool = False, path_to_weights: Optional[str] = None)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Pretrain the model with specified learning rate.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>test_size</code></strong> :&ensp;<code>float</code> or <code>int</code>, optional</dt>
+<dd>The proportion or size of the test set.</dd>
+<dt><strong><code>random_state</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The random state for data splitting.</dd>
+<dt><strong><code>learning_rate</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The initial learning rate for the Adam optimizer.</dd>
+<dt><strong><code>batch_size</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The batch size for pre-training.
+Default is 256. Set to 32 if number of cells is small (less than 1000)</dd>
+<dt><strong><code>L</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The number of MC samples.</dd>
+<dt><strong><code>alpha</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.</dd>
+<dt><strong><code>gamma</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The weight of the mmd loss if used.</dd>
+<dt><strong><code>phi</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The weight of Jocob norm of the encoder.</dd>
+<dt><strong><code>num_epoch</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The maximum number of epochs.</dd>
+<dt><strong><code>num_step_per_epoch</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The number of step per epoch, it will be inferred from number of cells and batch size if it is None.</dd>
+<dt><strong><code>early_stopping_patience</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The maximum number of epochs if there is no improvement.</dd>
+<dt><strong><code>early_stopping_tolerance</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The minimum change of loss to be considered as an improvement.</dd>
+<dt><strong><code>early_stopping_relative</code></strong> :&ensp;<code>bool</code>, optional</dt>
+<dd>Whether monitor the relative change of loss as stopping criteria or not.</dd>
+<dt><strong><code>path_to_weights</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd>The path of weight file to be saved; not saving weight if None.</dd>
+<dt><strong><code>conditions</code></strong> :&ensp;<code>str</code> or <code>list</code>, optional</dt>
+<dd>The conditions of different cells</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.VITAE.update_z"><code class="name flex">
+<span>def <span class="ident">update_z</span></span>(<span>self)</span>
+</code></dt>
+<dd>
+<div class="desc"></div>
+</dd>
+<dt id="VITAE.VITAE.get_latent_z"><code class="name flex">
+<span>def <span class="ident">get_latent_z</span></span>(<span>self)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>get the posterier mean of current latent space z (encoder output)</p>
+<h2 id="returns">Returns</h2>
+<dl>
+<dt><strong><code>z</code></strong> :&ensp;<code>np.array</code></dt>
+<dd><span><span class="MathJax_Preview">[N,d]</span><script type="math/tex">[N,d]</script></span> The latent means.</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.VITAE.visualize_latent"><code class="name flex">
+<span>def <span class="ident">visualize_latent</span></span>(<span>self, method: str = 'UMAP', color=None, **kwargs)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>visualize the current latent space z using the scanpy visualization tools</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>method</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd>Visualization method to use. The default is "draw_graph" (the FA plot). Possible choices include "PCA", "UMAP",
+"diffmap", "TSNE" and "draw_graph"</dd>
+<dt><strong><code>color</code></strong> :&ensp;<code>TYPE</code>, optional</dt>
+<dd>Keys for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2'].
+The default is None. Same as scanpy.</dd>
+<dt><strong><code>**kwargs</code></strong> :&ensp;<code> </code></dt>
+<dd>Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX).</dd>
+</dl>
+<h2 id="returns">Returns</h2>
+<p>None.</p></div>
+</dd>
+<dt id="VITAE.VITAE.init_latent_space"><code class="name flex">
+<span>def <span class="ident">init_latent_space</span></span>(<span>self, cluster_label=None, log_pi=None, res: float = 1.0, ratio_prune=None, dist=None, dist_thres=0.5, topk=0, pilayer=False)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Initialize the latent space.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>cluster_label</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd>The name of vector of labels that can be found in self.adata.obs.
+Default is None, which will perform leiden clustering on the pretrained z to get clusters</dd>
+<dt><strong><code>mu</code></strong> :&ensp;<code>np.array</code>, optional</dt>
+<dd><span><span class="MathJax_Preview">[d,k]</span><script type="math/tex">[d,k]</script></span> The value of initial <span><span class="MathJax_Preview">\mu</span><script type="math/tex">\mu</script></span>.</dd>
+<dt><strong><code>log_pi</code></strong> :&ensp;<code>np.array</code>, optional</dt>
+<dd><span><span class="MathJax_Preview">[1,K]</span><script type="math/tex">[1,K]</script></span> The value of initial <span><span class="MathJax_Preview">\log(\pi)</span><script type="math/tex">\log(\pi)</script></span>.</dd>
+<dt><strong><code>res</code></strong></dt>
+<dd>The resolution of leiden clustering, which is a parameter value controlling the coarseness of the clustering.
+Higher values lead to more clusters. Deafult is 1.</dd>
+<dt><strong><code>ratio_prune</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The ratio of edges to be removed before estimating.</dd>
+<dt><strong><code>topk</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The number of top k neighbors to keep for each cluster.</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.VITAE.update_latent_space"><code class="name flex">
+<span>def <span class="ident">update_latent_space</span></span>(<span>self, dist_thres: float = 0.5)</span>
+</code></dt>
+<dd>
+<div class="desc"></div>
+</dd>
+<dt id="VITAE.VITAE.train"><code class="name flex">
+<span>def <span class="ident">train</span></span>(<span>self, stratify=False, test_size=0.1, random_state: int = 0, learning_rate: float = 0.001, batch_size: int = 256, L: int = 1, alpha: float = 0.1, beta: float = 1, gamma: float = 0, phi: float = 1, num_epoch: int = 200, num_step_per_epoch: Optional[int] = None, early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, early_stopping_relative: bool = True, early_stopping_warmup: int = 0, path_to_weights: Optional[str] = None, verbose: bool = False, **kwargs)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Train the model.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>stratify</code></strong> :&ensp;<code>np.array, None,</code> or <code>False</code></dt>
+<dd>If an array is provided, or <code>stratify=None</code> and <code>self.labels</code> is available, then they will be used to perform stratified shuffle splitting. Otherwise, general shuffle splitting is used. Set to <code>False</code> if <code>self.labels</code> is not intended for stratified shuffle splitting.</dd>
+<dt><strong><code>test_size</code></strong> :&ensp;<code>float</code> or <code>int</code>, optional</dt>
+<dd>The proportion or size of the test set.</dd>
+<dt><strong><code>random_state</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The random state for data splitting.</dd>
+<dt><strong><code>learning_rate</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The initial learning rate for the Adam optimizer.</dd>
+<dt><strong><code>batch_size</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The batch size for training. Default is 256. Set to 32 if number of cells is small (less than 1000)</dd>
+<dt><strong><code>L</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The number of MC samples.</dd>
+<dt><strong><code>alpha</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.</dd>
+<dt><strong><code>beta</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The value of beta in beta-VAE.</dd>
+<dt><strong><code>gamma</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The weight of mmd_loss.</dd>
+<dt><strong><code>phi</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The weight of Jacob norm of encoder.</dd>
+<dt><strong><code>num_epoch</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The number of epoch.</dd>
+<dt><strong><code>num_step_per_epoch</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The number of step per epoch, it will be inferred from number of cells and batch size if it is None.</dd>
+<dt><strong><code>early_stopping_patience</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The maximum number of epochs if there is no improvement.</dd>
+<dt><strong><code>early_stopping_tolerance</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The minimum change of loss to be considered as an improvement.</dd>
+<dt><strong><code>early_stopping_relative</code></strong> :&ensp;<code>bool</code>, optional</dt>
+<dd>Whether monitor the relative change of loss or not.</dd>
+<dt><strong><code>early_stopping_warmup</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The number of warmup epochs.</dd>
+<dt><strong><code>path_to_weights</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd>The path of weight file to be saved; not saving weight if None.</dd>
+<dt><strong><code>**kwargs</code></strong> :&ensp;<code> </code></dt>
+<dd>Extra key-value arguments for dimension reduction algorithms.</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.VITAE.output_pi"><code class="name flex">
+<span>def <span class="ident">output_pi</span></span>(<span>self, pi_cov)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>return a matrix n_states by n_states and a mask for plotting, which can be used to cover the lower triangular(except the diagnoals) of a heatmap</p></div>
+</dd>
+<dt id="VITAE.VITAE.return_pilayer_weights"><code class="name flex">
+<span>def <span class="ident">return_pilayer_weights</span></span>(<span>self)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>return parameters of pilayer, which has dimension dim(pi_cov) + 1 by n_categories, the last row is biases</p></div>
+</dd>
+<dt id="VITAE.VITAE.posterior_estimation"><code class="name flex">
+<span>def <span class="ident">posterior_estimation</span></span>(<span>self, batch_size: int = 32, L: int = 50, **kwargs)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Initialize trajectory inference by computing the posterior estimations.
+</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>batch_size</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The batch size when doing inference.</dd>
+<dt><strong><code>L</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The number of MC samples when doing inference.</dd>
+<dt><strong><code>**kwargs</code></strong> :&ensp;<code> </code></dt>
+<dd>Extra key-value arguments for dimension reduction algorithms.</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.VITAE.infer_backbone"><code class="name flex">
+<span>def <span class="ident">infer_backbone</span></span>(<span>self, method: str = 'modified_map', thres=0.5, no_loop: bool = True, cutoff: float = 0, visualize: bool = True, color='vitae_new_clustering', path_to_fig=None, **kwargs)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Compute edge scores.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>method</code></strong> :&ensp;<code>string</code>, optional</dt>
+<dd>'mean', 'modified_mean', 'map', or 'modified_map'.</dd>
+<dt><strong><code>thres</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The threshold used for filtering edges <span><span class="MathJax_Preview">e_{ij}</span><script type="math/tex">e_{ij}</script></span> that <span><span class="MathJax_Preview">(n_{i}+n_{j}+e_{ij})/N&lt;thres</span><script type="math/tex">(n_{i}+n_{j}+e_{ij})/N<thres</script></span>, only applied to mean method.</dd>
+<dt><strong><code>no_loop</code></strong> :&ensp;<code>boolean</code>, optional</dt>
+<dd>Whether loops are allowed to exist in the graph. If no_loop is true, will prune the graph to contain only the
+maximum spanning true</dd>
+<dt><strong><code>cutoff</code></strong> :&ensp;<code>string</code>, optional</dt>
+<dd>The score threshold for filtering edges with scores less than cutoff.</dd>
+<dt><strong><code>visualize</code></strong> :&ensp;<code>boolean</code></dt>
+<dd>whether plot the current trajectory backbone (undirected graph)</dd>
+</dl>
+<h2 id="returns">Returns</h2>
+<dl>
+<dt><strong><code>G</code></strong> :&ensp;<code>nx.Graph</code></dt>
+<dd>The weighted graph with weight on each edge indicating its score of existence.</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.VITAE.select_root"><code class="name flex">
+<span>def <span class="ident">select_root</span></span>(<span>self, days, method: str = 'proportion')</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Order the vertices/states based on cells' collection time information to select the root state.
+</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>day</code></strong> :&ensp;<code>np.array </code></dt>
+<dd>The day information for selected cells used to determine the root vertex.
+The dtype should be 'int' or 'float'.</dd>
+<dt><strong><code>method</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd>'sum' or 'mean'.
+For 'proportion', the root is the one with maximal proportion of cells from the earliest day.
+For 'mean', the root is the one with earliest mean time among cells associated with it.</dd>
+</dl>
+<h2 id="returns">Returns</h2>
+<dl>
+<dt><strong><code>root</code></strong> :&ensp;<code>int </code></dt>
+<dd>The root vertex in the inferred trajectory based on given day information.</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.VITAE.plot_backbone"><code class="name flex">
+<span>def <span class="ident">plot_backbone</span></span>(<span>self, directed: bool = False, method: str = 'UMAP', color='vitae_new_clustering', **kwargs)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Plot the current trajectory backbone (undirected graph).</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>directed</code></strong> :&ensp;<code>boolean</code>, optional</dt>
+<dd>Whether the backbone is directed or not.</dd>
+<dt><strong><code>method</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd>The dimension reduction method to use. The default is "UMAP".</dd>
+<dt><strong><code>color</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd>The key for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2'].
+The default is 'vitae_new_clustering'.</dd>
+</dl>
+<p>**kwargs :
+Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX).</p></div>
+</dd>
+<dt id="VITAE.VITAE.plot_center"><code class="name flex">
+<span>def <span class="ident">plot_center</span></span>(<span>self, color='vitae_new_clustering', plot_legend=True, legend_add_index=True, method: str = 'UMAP', ncol=2, font_size='medium', add_egde=False, add_direct=False, **kwargs)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Plot the center of each cluster in the latent space.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>color</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd>The color of the center of each cluster. Default is "vitae_new_clustering".</dd>
+<dt><strong><code>plot_legend</code></strong> :&ensp;<code>bool</code>, optional</dt>
+<dd>Whether to plot the legend. Default is True.</dd>
+<dt><strong><code>legend_add_index</code></strong> :&ensp;<code>bool</code>, optional</dt>
+<dd>Whether to add the index of each cluster in the legend. Default is True.</dd>
+<dt><strong><code>method</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd>The dimension reduction method used for visualization. Default is 'UMAP'.</dd>
+<dt><strong><code>ncol</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The number of columns in the legend. Default is 2.</dd>
+<dt><strong><code>font_size</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd>The font size of the legend. Default is "medium".</dd>
+<dt><strong><code>add_egde</code></strong> :&ensp;<code>bool</code>, optional</dt>
+<dd>Whether to add the edges between the centers of clusters. Default is False.</dd>
+<dt><strong><code>add_direct</code></strong> :&ensp;<code>bool</code>, optional</dt>
+<dd>Whether to add the direction of the edges. Default is False.</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.VITAE.infer_trajectory"><code class="name flex">
+<span>def <span class="ident">infer_trajectory</span></span>(<span>self, root: Union[int, str], digraph=None, color='pseudotime', visualize: bool = True, path_to_fig=None, **kwargs)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Infer the trajectory.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>root</code></strong> :&ensp;<code>int</code> or <code>string</code></dt>
+<dd>The root of the inferred trajectory. Can provide either an int (vertex index) or string (label name)</dd>
+<dt><strong><code>digraph</code></strong> :&ensp;<code>nx.DiGraph</code>, optional</dt>
+<dd>The directed graph to be used for trajectory inference. If None, the minimum spanning tree of the estimated trajectory backbone will be used.</dd>
+<dt><strong><code>cutoff</code></strong> :&ensp;<code>string</code>, optional</dt>
+<dd>The threshold for filtering edges with scores less than cutoff.</dd>
+<dt><strong><code>visualize</code></strong> :&ensp;<code>boolean</code></dt>
+<dd>Whether plot the current trajectory backbone (directed graph)</dd>
+<dt><strong><code>path_to_fig</code></strong> :&ensp;<code>string</code>, optional</dt>
+<dd>The path to save figure, or don't save if it is None.</dd>
+<dt><strong><code>**kwargs</code></strong> :&ensp;<code>dict</code>, optional</dt>
+<dd>Other keywords arguments for plotting.</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.VITAE.differential_expression_test"><code class="name flex">
+<span>def <span class="ident">differential_expression_test</span></span>(<span>self, alpha: float = 0.05, cell_subset=None, order: int = 1)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Differentially gene expression test. All (selected and unselected) genes will be tested
+Only cells in <code>selected_cell_subset</code> will be used, which is useful when one need to
+test differentially expressed genes on a branch of the inferred trajectory.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>alpha</code></strong> :&ensp;<code>float</code>, optional</dt>
+<dd>The cutoff of p-values.</dd>
+<dt><strong><code>cell_subset</code></strong> :&ensp;<code>np.array</code>, optional</dt>
+<dd>The subset of cells to be used for testing. If None, all cells will be used.</dd>
+<dt><strong><code>order</code></strong> :&ensp;<code>int</code>, optional</dt>
+<dd>The maxium order we used for pseudotime in regression.</dd>
+</dl>
+<h2 id="returns">Returns</h2>
+<dl>
+<dt><strong><code>res_df</code></strong> :&ensp;<code>pandas.DataFrame</code></dt>
+<dd>The test results of expressed genes with two columns,
+the estimated coefficients and the adjusted p-values.</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.VITAE.evaluate"><code class="name flex">
+<span>def <span class="ident">evaluate</span></span>(<span>self, milestone_net, begin_node_true, grouping=None, thres: float = 0.5, no_loop: bool = True, cutoff: Optional[float] = None, method: str = 'mean', path: Optional[str] = None)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Evaluate the model.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>milestone_net</code></strong> :&ensp;<code>pd.DataFrame</code></dt>
+<dd>
+<p>The true milestone network. For real data, milestone_net will be a DataFrame of the graph of nodes.
+Eg.</p>
+<table>
+<thead>
+<tr>
+<th>from</th>
+<th>to</th>
+</tr>
+</thead>
+<tbody>
+<tr>
+<td>cluster 1</td>
+<td>cluster 1</td>
+</tr>
+<tr>
+<td>cluster 1</td>
+<td>cluster 2</td>
+</tr>
+</tbody>
+</table>
+<p>For synthetic data, milestone_net will be a DataFrame of the (projected)
+positions of cells. The indexes are the orders of cells in the dataset.
+Eg.</p>
+<table>
+<thead>
+<tr>
+<th>from</th>
+<th>to</th>
+<th>w</th>
+</tr>
+</thead>
+<tbody>
+<tr>
+<td>cluster 1</td>
+<td>cluster 1</td>
+<td>1</td>
+</tr>
+<tr>
+<td>cluster 1</td>
+<td>cluster 2</td>
+<td>0.1</td>
+</tr>
+</tbody>
+</table>
+</dd>
+<dt><strong><code>begin_node_true</code></strong> :&ensp;<code>str</code> or <code>int</code></dt>
+<dd>The true begin node of the milestone.</dd>
+<dt><strong><code>grouping</code></strong> :&ensp;<code>np.array</code>, optional</dt>
+<dd><span><span class="MathJax_Preview">[N,]</span><script type="math/tex">[N,]</script></span> The labels. For real data, grouping must be provided.</dd>
+</dl>
+<h2 id="returns">Returns</h2>
+<dl>
+<dt><strong><code>res</code></strong> :&ensp;<code>pd.DataFrame</code></dt>
+<dd>The evaluation result.</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.VITAE.save_model"><code class="name flex">
+<span>def <span class="ident">save_model</span></span>(<span>self, path_to_file: str = 'model.checkpoint', save_adata: bool = False)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Saving model weights.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>path_to_file</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd>The path to weight files of pre-trained or trained model</dd>
+<dt><strong><code>save_adata</code></strong> :&ensp;<code>boolean</code>, optional</dt>
+<dd>Whether to save adata or not.</dd>
+</dl></div>
+</dd>
+<dt id="VITAE.VITAE.load_model"><code class="name flex">
+<span>def <span class="ident">load_model</span></span>(<span>self, path_to_file: str = 'model.checkpoint', load_labels: bool = False, load_adata: bool = False)</span>
+</code></dt>
+<dd>
+<div class="desc"><p>Load model weights.</p>
+<h2 id="parameters">Parameters</h2>
+<dl>
+<dt><strong><code>path_to_file</code></strong> :&ensp;<code>str</code>, optional</dt>
+<dd>The path to weight files of pre trained or trained model</dd>
+<dt><strong><code>load_labels</code></strong> :&ensp;<code>boolean</code>, optional</dt>
+<dd>Whether to load clustering labels or not.
+If load_labels is True, then the LatentSpace layer will be initialized basd on the model.
+If load_labels is False, then the LatentSpace layer will not be initialized.</dd>
+<dt><strong><code>load_adata</code></strong> :&ensp;<code>boolean</code>, optional</dt>
+<dd>Whether to load adata or not.</dd>
+</dl></div>
+</dd>
+</dl>
+</dd>
+</dl>
+</section>
+</article>
+<nav id="sidebar">
+<div class="toc">
+<ul></ul>
+</div>
+<ul id="index">
+<li><h3><a href="#header-submodules">Sub-modules</a></h3>
+<ul>
+<li><code><a title="VITAE.inference" href="inference.html">VITAE.inference</a></code></li>
+<li><code><a title="VITAE.metric" href="metric.html">VITAE.metric</a></code></li>
+<li><code><a title="VITAE.model" href="model.html">VITAE.model</a></code></li>
+<li><code><a title="VITAE.train" href="train.html">VITAE.train</a></code></li>
+<li><code><a title="VITAE.utils" href="utils.html">VITAE.utils</a></code></li>
+</ul>
+</li>
+<li><h3><a href="#header-classes">Classes</a></h3>
+<ul>
+<li>
+<h4><code><a title="VITAE.VITAE" href="#VITAE.VITAE">VITAE</a></code></h4>
+<ul class="">
+<li><code><a title="VITAE.VITAE.pre_train" href="#VITAE.VITAE.pre_train">pre_train</a></code></li>
+<li><code><a title="VITAE.VITAE.update_z" href="#VITAE.VITAE.update_z">update_z</a></code></li>
+<li><code><a title="VITAE.VITAE.get_latent_z" href="#VITAE.VITAE.get_latent_z">get_latent_z</a></code></li>
+<li><code><a title="VITAE.VITAE.visualize_latent" href="#VITAE.VITAE.visualize_latent">visualize_latent</a></code></li>
+<li><code><a title="VITAE.VITAE.init_latent_space" href="#VITAE.VITAE.init_latent_space">init_latent_space</a></code></li>
+<li><code><a title="VITAE.VITAE.update_latent_space" href="#VITAE.VITAE.update_latent_space">update_latent_space</a></code></li>
+<li><code><a title="VITAE.VITAE.train" href="#VITAE.VITAE.train">train</a></code></li>
+<li><code><a title="VITAE.VITAE.output_pi" href="#VITAE.VITAE.output_pi">output_pi</a></code></li>
+<li><code><a title="VITAE.VITAE.return_pilayer_weights" href="#VITAE.VITAE.return_pilayer_weights">return_pilayer_weights</a></code></li>
+<li><code><a title="VITAE.VITAE.posterior_estimation" href="#VITAE.VITAE.posterior_estimation">posterior_estimation</a></code></li>
+<li><code><a title="VITAE.VITAE.infer_backbone" href="#VITAE.VITAE.infer_backbone">infer_backbone</a></code></li>
+<li><code><a title="VITAE.VITAE.select_root" href="#VITAE.VITAE.select_root">select_root</a></code></li>
+<li><code><a title="VITAE.VITAE.plot_backbone" href="#VITAE.VITAE.plot_backbone">plot_backbone</a></code></li>
+<li><code><a title="VITAE.VITAE.plot_center" href="#VITAE.VITAE.plot_center">plot_center</a></code></li>
+<li><code><a title="VITAE.VITAE.infer_trajectory" href="#VITAE.VITAE.infer_trajectory">infer_trajectory</a></code></li>
+<li><code><a title="VITAE.VITAE.differential_expression_test" href="#VITAE.VITAE.differential_expression_test">differential_expression_test</a></code></li>
+<li><code><a title="VITAE.VITAE.evaluate" href="#VITAE.VITAE.evaluate">evaluate</a></code></li>
+<li><code><a title="VITAE.VITAE.save_model" href="#VITAE.VITAE.save_model">save_model</a></code></li>
+<li><code><a title="VITAE.VITAE.load_model" href="#VITAE.VITAE.load_model">load_model</a></code></li>
+</ul>
+</li>
+</ul>
+</li>
+</ul>
+</nav>
+</main>
+<footer id="footer">
+<p>Generated by <a href="https://pdoc3.github.io/pdoc" title="pdoc: Python API documentation generator"><cite>pdoc</cite> 0.11.1</a>.</p>
+</footer>
+</body>
+</html>