<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1">
<meta name="generator" content="pdoc3 0.11.1">
<title>VITAE API documentation</title>
<meta name="description" content="">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/sanitize.min.css" integrity="sha512-y1dtMcuvtTMJc1yPgEqF0ZjQbhnc/bFhyvIyVNb9Zk5mIGtqVaAB1Ttl28su8AvFMOY0EwRbAe+HCLqj6W7/KA==" crossorigin>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/typography.min.css" integrity="sha512-Y1DYSb995BAfxobCkKepB1BqJJTPrOp3zPL74AWFugHHmmdcvO+C48WLrUOlhGMc0QG7AE3f7gmvvcrmX2fDoA==" crossorigin>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/default.min.css" crossorigin>
<style>:root{--highlight-color:#fe9}.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:1.5em;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:2em 0 .50em 0}h3{font-size:1.4em;margin:1.6em 0 .7em 0}h4{margin:0;font-size:105%}h1:target,h2:target,h3:target,h4:target,h5:target,h6:target{background:var(--highlight-color);padding:.2em 0}a{color:#058;text-decoration:none;transition:color .2s ease-in-out}a:visited{color:#503}a:hover{color:#b62}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900;font-weight:bold}pre code{font-size:.8em;line-height:1.4em;padding:1em;display:block}code{background:#f3f3f3;font-family:"DejaVu Sans Mono",monospace;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}dt:target .name{background:var(--highlight-color)}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}td{padding:0 .5em}.admonition{padding:.1em 1em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style>
<style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul ul{padding-left:1em}.toc > ul > li{margin-top:.5em}}</style>
<style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style>
<script type="text/x-mathjax-config">MathJax.Hub.Config({ tex2jax: { inlineMath: [ ['$','$'], ["\\(","\\)"] ], processEscapes: true } });</script>
<script async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS_CHTML" integrity="sha256-kZafAc6mZvK3W3v1pHOcUix30OHQN6pU/NO2oFkqZVw=" crossorigin></script>
<script defer src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js" integrity="sha512-D9gUyxqja7hBtkWpPWGt9wfbfaMGVt9gnyCvYa+jojwwPHLCzUm5i8rpk7vD7wNee9bA35eYIjobYPaQuKS1MQ==" crossorigin></script>
<script>window.addEventListener('DOMContentLoaded', () => {
hljs.configure({languages: ['bash', 'css', 'diff', 'graphql', 'ini', 'javascript', 'json', 'plaintext', 'python', 'python-repl', 'rust', 'shell', 'sql', 'typescript', 'xml', 'yaml']});
hljs.highlightAll();
})</script>
</head>
<body>
<main>
<article id="content">
<header>
<h1 class="title">Package <code>VITAE</code></h1>
</header>
<section id="section-intro">
</section>
<section>
<h2 class="section-title" id="header-submodules">Sub-modules</h2>
<dl>
<dt><code class="name"><a title="VITAE.inference" href="inference.html">VITAE.inference</a></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt><code class="name"><a title="VITAE.metric" href="metric.html">VITAE.metric</a></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt><code class="name"><a title="VITAE.model" href="model.html">VITAE.model</a></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt><code class="name"><a title="VITAE.train" href="train.html">VITAE.train</a></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt><code class="name"><a title="VITAE.utils" href="utils.html">VITAE.utils</a></code></dt>
<dd>
<div class="desc"></div>
</dd>
</dl>
</section>
<section>
</section>
<section>
</section>
<section>
<h2 class="section-title" id="header-classes">Classes</h2>
<dl>
<dt id="VITAE.VITAE"><code class="flex name class">
<span>class <span class="ident">VITAE</span></span>
<span>(</span><span>adata: anndata._core.anndata.AnnData, covariates=None, pi_covariates=None, model_type: str = 'Gaussian', npc: int = 64, adata_layer_counts=None, copy_adata: bool = False, hidden_layers=[32], latent_space_dim: int = 16, conditions=None)</span>
</code></dt>
<dd>
<div class="desc"><p>Variational Inference for Trajectory by AutoEncoder.</p>
<p>Get input data for model. Data need to be first processed using scancy and stored as an AnnData object
The 'UMI' or 'non-UMI' model need the original count matrix, so the count matrix need to be saved in
adata.layers in order to use these models.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>adata</code></strong> : <code>sc.AnnData</code></dt>
<dd>The scanpy AnnData object. adata should already contain adata.var.highly_variable</dd>
<dt><strong><code>covariates</code></strong> : <code>list</code>, optional</dt>
<dd>A list of names of covariate vectors that are stored in adata.obs</dd>
<dt><strong><code>pi_covariates</code></strong> : <code>list</code>, optional</dt>
<dd>A list of names of covariate vectors used as input for pilayer</dd>
<dt><strong><code>model_type</code></strong> : <code>str</code>, optional</dt>
<dd>'UMI', 'non-UMI' and 'Gaussian', default is 'Gaussian'.</dd>
<dt><strong><code>npc</code></strong> : <code>int</code>, optional</dt>
<dd>The number of PCs to use when model_type is 'Gaussian'. The default is 64.</dd>
<dt><strong><code>adata_layer_counts</code></strong> : <code>str</code>, optional</dt>
<dd>the key name of adata.layers that stores the count data if model_type is
'UMI' or 'non-UMI'</dd>
<dt><strong><code>copy_adata</code></strong> : <code>bool</code>, optional<code>. Set to True if we don't want VITAE to modify the original adata. If set to True, self.adata will be an independent copy</code> of <code>the original adata. </code></dt>
<dd> </dd>
<dt><strong><code>hidden_layers</code></strong> : <code>list</code>, optional</dt>
<dd>The list of dimensions of layers of autoencoder between latent space and original space. Default is to have only one hidden layer with 32 nodes</dd>
<dt><strong><code>latent_space_dim</code></strong> : <code>int</code>, optional</dt>
<dd>The dimension of latent space.</dd>
<dt><strong><code>gamme</code></strong> : <code>float</code>, optional</dt>
<dd>The weight of the MMD loss</dd>
<dt><strong><code>conditions</code></strong> : <code>str</code> or <code>list</code>, optional</dt>
<dd>The conditions of different cells</dd>
</dl>
<h2 id="returns">Returns</h2>
<p>None.</p></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">class VITAE():
"""
Variational Inference for Trajectory by AutoEncoder.
"""
def __init__(self, adata: sc.AnnData,
covariates = None, pi_covariates = None,
model_type: str = 'Gaussian',
npc: int = 64,
adata_layer_counts = None,
copy_adata: bool = False,
hidden_layers = [32],
latent_space_dim: int = 16,
conditions = None):
'''
Get input data for model. Data need to be first processed using scancy and stored as an AnnData object
The 'UMI' or 'non-UMI' model need the original count matrix, so the count matrix need to be saved in
adata.layers in order to use these models.
Parameters
----------
adata : sc.AnnData
The scanpy AnnData object. adata should already contain adata.var.highly_variable
covariates : list, optional
A list of names of covariate vectors that are stored in adata.obs
pi_covariates: list, optional
A list of names of covariate vectors used as input for pilayer
model_type : str, optional
'UMI', 'non-UMI' and 'Gaussian', default is 'Gaussian'.
npc : int, optional
The number of PCs to use when model_type is 'Gaussian'. The default is 64.
adata_layer_counts: str, optional
the key name of adata.layers that stores the count data if model_type is
'UMI' or 'non-UMI'
copy_adata: bool, optional. Set to True if we don't want VITAE to modify the original adata. If set to True, self.adata will be an independent copy of the original adata.
hidden_layers : list, optional
The list of dimensions of layers of autoencoder between latent space and original space. Default is to have only one hidden layer with 32 nodes
latent_space_dim : int, optional
The dimension of latent space.
gamme : float, optional
The weight of the MMD loss
conditions : str or list, optional
The conditions of different cells
Returns
-------
None.
'''
self.dict_method_scname = {
'PCA' : 'X_pca',
'UMAP' : 'X_umap',
'TSNE' : 'X_tsne',
'diffmap' : 'X_diffmap',
'draw_graph' : 'X_draw_graph_fa'
}
if model_type != 'Gaussian':
if adata_layer_counts is None:
raise ValueError("need to provide the name in adata.layers that stores the raw count data")
if 'highly_variable' not in adata.var:
raise ValueError("need to first select highly variable genes using scanpy")
self.model_type = model_type
if copy_adata:
self.adata = adata.copy()
else:
self.adata = adata
if covariates is not None:
if isinstance(covariates, str):
covariates = [covariates]
covariates = np.array(covariates)
id_cat = (adata.obs[covariates].dtypes == 'category')
# add OneHotEncoder & StandardScaler as class variable if needed
if np.sum(id_cat)>0:
covariates_cat = OneHotEncoder(drop='if_binary', handle_unknown='ignore'
).fit_transform(adata.obs[covariates[id_cat]]).toarray()
else:
covariates_cat = np.array([]).reshape(adata.shape[0],0)
# temporarily disable StandardScaler
if np.sum(~id_cat)>0:
#covariates_con = StandardScaler().fit_transform(adata.obs[covariates[~id_cat]])
covariates_con = adata.obs[covariates[~id_cat]]
else:
covariates_con = np.array([]).reshape(adata.shape[0],0)
self.covariates = np.c_[covariates_cat, covariates_con].astype(tf.keras.backend.floatx())
else:
self.covariates = None
if conditions is not None:
## observations with np.nan will not participant in calculating mmd_loss
if isinstance(conditions, str):
conditions = [conditions]
conditions = np.array(conditions)
if np.any(adata.obs[conditions].dtypes != 'category'):
raise ValueError("Conditions should all be categorical.")
self.conditions = OrdinalEncoder(dtype=int, encoded_missing_value=-1).fit_transform(adata.obs[conditions]) + int(1)
else:
self.conditions = None
if pi_covariates is not None:
self.pi_cov = adata.obs[pi_covariates].to_numpy()
if self.pi_cov.ndim == 1:
self.pi_cov = self.pi_cov.reshape(-1, 1)
self.pi_cov = self.pi_cov.astype(tf.keras.backend.floatx())
else:
self.pi_cov = np.zeros((adata.shape[0],1), dtype=tf.keras.backend.floatx())
self.model_type = model_type
self._adata = sc.AnnData(X = self.adata.X, var = self.adata.var)
self._adata.obs = self.adata.obs
self._adata.uns = self.adata.uns
if model_type == 'Gaussian':
sc.tl.pca(adata, n_comps = npc)
self.X_input = self.X_output = adata.obsm['X_pca']
self.scale_factor = np.ones(self.X_output.shape[0])
else:
print(f"{adata.var.highly_variable.sum()} highly variable genes selected as input")
self.X_input = adata.X[:, adata.var.highly_variable]
self.X_output = adata.layers[adata_layer_counts][ :, adata.var.highly_variable]
self.scale_factor = np.sum(self.X_output, axis=1, keepdims=True)/1e4
self.dimensions = hidden_layers
self.dim_latent = latent_space_dim
self.vae = model.VariationalAutoEncoder(
self.X_output.shape[1], self.dimensions,
self.dim_latent, self.model_type,
False if self.covariates is None else True,
)
if hasattr(self, 'inferer'):
delattr(self, 'inferer')
def pre_train(self, test_size = 0.1, random_state: int = 0,
learning_rate: float = 1e-3, batch_size: int = 256, L: int = 1, alpha: float = 0.10, gamma: float = 0,
phi : float = 1,num_epoch: int = 200, num_step_per_epoch: Optional[int] = None,
early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01,
early_stopping_relative: bool = True, verbose: bool = False,path_to_weights: Optional[str] = None):
'''Pretrain the model with specified learning rate.
Parameters
----------
test_size : float or int, optional
The proportion or size of the test set.
random_state : int, optional
The random state for data splitting.
learning_rate : float, optional
The initial learning rate for the Adam optimizer.
batch_size : int, optional
The batch size for pre-training. Default is 256. Set to 32 if number of cells is small (less than 1000)
L : int, optional
The number of MC samples.
alpha : float, optional
The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.
gamma : float, optional
The weight of the mmd loss if used.
phi : float, optional
The weight of Jocob norm of the encoder.
num_epoch : int, optional
The maximum number of epochs.
num_step_per_epoch : int, optional
The number of step per epoch, it will be inferred from number of cells and batch size if it is None.
early_stopping_patience : int, optional
The maximum number of epochs if there is no improvement.
early_stopping_tolerance : float, optional
The minimum change of loss to be considered as an improvement.
early_stopping_relative : bool, optional
Whether monitor the relative change of loss as stopping criteria or not.
path_to_weights : str, optional
The path of weight file to be saved; not saving weight if None.
conditions : str or list, optional
The conditions of different cells
'''
id_train, id_test = train_test_split(
np.arange(self.X_input.shape[0]),
test_size=test_size,
random_state=random_state)
if num_step_per_epoch is None:
num_step_per_epoch = len(id_train)//batch_size+1
self.train_dataset = train.warp_dataset(self.X_input[id_train].astype(tf.keras.backend.floatx()),
None if self.covariates is None else self.covariates[id_train].astype(tf.keras.backend.floatx()),
batch_size,
self.X_output[id_train].astype(tf.keras.backend.floatx()),
self.scale_factor[id_train].astype(tf.keras.backend.floatx()),
conditions = None if self.conditions is None else self.conditions[id_train].astype(tf.keras.backend.floatx()))
self.test_dataset = train.warp_dataset(self.X_input[id_test],
None if self.covariates is None else self.covariates[id_test].astype(tf.keras.backend.floatx()),
batch_size,
self.X_output[id_test].astype(tf.keras.backend.floatx()),
self.scale_factor[id_test].astype(tf.keras.backend.floatx()),
conditions = None if self.conditions is None else self.conditions[id_test].astype(tf.keras.backend.floatx()))
self.vae = train.pre_train(
self.train_dataset,
self.test_dataset,
self.vae,
learning_rate,
L, alpha, gamma, phi,
num_epoch,
num_step_per_epoch,
early_stopping_patience,
early_stopping_tolerance,
early_stopping_relative,
verbose)
self.update_z()
if path_to_weights is not None:
self.save_model(path_to_weights)
def update_z(self):
self.z = self.get_latent_z()
self._adata_z = sc.AnnData(self.z)
sc.pp.neighbors(self._adata_z)
def get_latent_z(self):
''' get the posterier mean of current latent space z (encoder output)
Returns
----------
z : np.array
\([N,d]\) The latent means.
'''
c = None if self.covariates is None else self.covariates
return self.vae.get_z(self.X_input, c)
def visualize_latent(self, method: str = "UMAP",
color = None, **kwargs):
'''
visualize the current latent space z using the scanpy visualization tools
Parameters
----------
method : str, optional
Visualization method to use. The default is "draw_graph" (the FA plot). Possible choices include "PCA", "UMAP",
"diffmap", "TSNE" and "draw_graph"
color : TYPE, optional
Keys for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2'].
The default is None. Same as scanpy.
**kwargs :
Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX).
Returns
-------
None.
'''
if method not in ['PCA', 'UMAP', 'TSNE', 'diffmap', 'draw_graph']:
raise ValueError("visualization method should be one of 'PCA', 'UMAP', 'TSNE', 'diffmap' and 'draw_graph'")
temp = list(self._adata_z.obsm.keys())
if method == 'PCA' and not 'X_pca' in temp:
print("Calculate PCs ...")
sc.tl.pca(self._adata_z)
elif method == 'UMAP' and not 'X_umap' in temp:
print("Calculate UMAP ...")
sc.tl.umap(self._adata_z)
elif method == 'TSNE' and not 'X_tsne' in temp:
print("Calculate TSNE ...")
sc.tl.tsne(self._adata_z)
elif method == 'diffmap' and not 'X_diffmap' in temp:
print("Calculate diffusion map ...")
sc.tl.diffmap(self._adata_z)
elif method == 'draw_graph' and not 'X_draw_graph_fa' in temp:
print("Calculate FA ...")
sc.tl.draw_graph(self._adata_z)
self._adata.obs = self.adata.obs.copy()
self._adata.obsp = self._adata_z.obsp
# self._adata.uns = self._adata_z.uns
self._adata.obsm = self._adata_z.obsm
if method == 'PCA':
axes = sc.pl.pca(self._adata, color = color, **kwargs)
elif method == 'UMAP':
axes = sc.pl.umap(self._adata, color = color, **kwargs)
elif method == 'TSNE':
axes = sc.pl.tsne(self._adata, color = color, **kwargs)
elif method == 'diffmap':
axes = sc.pl.diffmap(self._adata, color = color, **kwargs)
elif method == 'draw_graph':
axes = sc.pl.draw_graph(self._adata, color = color, **kwargs)
return axes
def init_latent_space(self, cluster_label = None, log_pi = None, res: float = 1.0,
ratio_prune= None, dist = None, dist_thres = 0.5, topk=0, pilayer = False):
'''Initialize the latent space.
Parameters
----------
cluster_label : str, optional
The name of vector of labels that can be found in self.adata.obs.
Default is None, which will perform leiden clustering on the pretrained z to get clusters
mu : np.array, optional
\([d,k]\) The value of initial \(\\mu\).
log_pi : np.array, optional
\([1,K]\) The value of initial \(\\log(\\pi)\).
res:
The resolution of leiden clustering, which is a parameter value controlling the coarseness of the clustering.
Higher values lead to more clusters. Deafult is 1.
ratio_prune : float, optional
The ratio of edges to be removed before estimating.
topk : int, optional
The number of top k neighbors to keep for each cluster.
'''
if cluster_label is None:
print("Perform leiden clustering on the latent space z ...")
g = get_igraph(self.z)
cluster_labels = leidenalg_igraph(g, res = res)
cluster_labels = cluster_labels.astype(str)
uni_cluster_labels = np.unique(cluster_labels)
else:
if isinstance(cluster_label,str):
cluster_labels = self.adata.obs[cluster_label].to_numpy()
uni_cluster_labels = np.array(self.adata.obs[cluster_label].cat.categories)
else:
## if cluster_label is a list
cluster_labels = cluster_label
uni_cluster_labels = np.unique(cluster_labels)
n_clusters = len(uni_cluster_labels)
if not hasattr(self, 'z'):
self.update_z()
z = self.z
mu = np.zeros((z.shape[1], n_clusters))
for i,l in enumerate(uni_cluster_labels):
mu[:,i] = np.mean(z[cluster_labels==l], axis=0)
if dist is None:
### update cluster centers if some cluster centers are too close
clustering = AgglomerativeClustering(
n_clusters=None,
distance_threshold=dist_thres,
linkage='complete'
).fit(mu.T/np.sqrt(mu.shape[0]))
n_clusters_new = clustering.n_clusters_
if n_clusters_new < n_clusters:
print("Merge clusters for cluster centers that are too close ...")
n_clusters = n_clusters_new
for i in range(n_clusters):
temp = uni_cluster_labels[clustering.labels_ == i]
idx = np.isin(cluster_labels, temp)
cluster_labels[idx] = ','.join(temp)
if np.sum(clustering.labels_==i)>1:
print('Merge %s'% ','.join(temp))
uni_cluster_labels = np.unique(cluster_labels)
mu = np.zeros((z.shape[1], n_clusters))
for i,l in enumerate(uni_cluster_labels):
mu[:,i] = np.mean(z[cluster_labels==l], axis=0)
self.adata.obs['vitae_init_clustering'] = cluster_labels
self.adata.obs['vitae_init_clustering'] = self.adata.obs['vitae_init_clustering'].astype('category')
print("Initial clustering labels saved as 'vitae_init_clustering' in self.adata.obs.")
if (log_pi is None) and (cluster_labels is not None) and (n_clusters>3):
n_states = int((n_clusters+1)*n_clusters/2)
if dist is None:
dist = _comp_dist(z, cluster_labels, mu.T)
C = np.triu(np.ones(n_clusters))
C[C>0] = np.arange(n_states)
C = C + C.T - np.diag(np.diag(C))
C = C.astype(int)
log_pi = np.zeros((1,n_states))
## pruning to throw away edges for far-away clusters if there are too many clusters
if ratio_prune is not None:
log_pi[0, C[np.triu(dist)>np.quantile(dist[np.triu_indices(n_clusters, 1)], 1-ratio_prune)]] = - np.inf
else:
log_pi[0, C[np.triu(dist)>np.quantile(dist[np.triu_indices(n_clusters, 1)], 5/n_clusters) * 3]] = - np.inf
## also keep the top k neighbor of clusters
topk = max(0, min(topk, n_clusters-1)) + 1
topk_indices = np.argsort(dist,axis=1)[:,:topk]
for i in range(n_clusters):
log_pi[0, C[i, topk_indices[i]]] = 0
self.n_states = n_clusters
self.labels = cluster_labels
labels_map = pd.DataFrame.from_dict(
{i:label for i,label in enumerate(uni_cluster_labels)},
orient='index', columns=['label_names'], dtype=str
)
self.labels_map = labels_map
self.vae.init_latent_space(self.n_states, mu, log_pi)
self.inferer = Inferer(self.n_states)
self.mu = self.vae.latent_space.mu.numpy()
self.pi = np.triu(np.ones(self.n_states))
self.pi[self.pi > 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0]
if pilayer:
self.vae.create_pilayer()
def update_latent_space(self, dist_thres: float=0.5):
pi = self.pi[np.triu_indices(self.n_states)]
mu = self.mu
clustering = AgglomerativeClustering(
n_clusters=None,
distance_threshold=dist_thres,
linkage='complete'
).fit(mu.T/np.sqrt(mu.shape[0]))
n_clusters = clustering.n_clusters_
if n_clusters<self.n_states:
print("Merge clusters for cluster centers that are too close ...")
mu_new = np.empty((self.dim_latent, n_clusters))
C = np.zeros((self.n_states, self.n_states))
C[np.triu_indices(self.n_states, 0)] = pi
C = np.triu(C, 1) + C.T
C_new = np.zeros((n_clusters, n_clusters))
uni_cluster_labels = self.labels_map['label_names'].to_numpy()
returned_order = {}
cluster_labels = self.labels
for i in range(n_clusters):
temp = uni_cluster_labels[clustering.labels_ == i]
idx = np.isin(cluster_labels, temp)
cluster_labels[idx] = ','.join(temp)
returned_order[i] = ','.join(temp)
if np.sum(clustering.labels_==i)>1:
print('Merge %s'% ','.join(temp))
uni_cluster_labels = np.unique(cluster_labels)
for i,l in enumerate(uni_cluster_labels): ## reorder the merged clusters based on the cluster names
k = np.where(returned_order == l)
mu_new[:, i] = np.mean(mu[:,clustering.labels_==k], axis=-1)
# sum of the aggregated pi's
C_new[i, i] = np.sum(np.triu(C[clustering.labels_==k,:][:,clustering.labels_==k]))
for j in range(i+1, n_clusters):
k1 = np.where(returned_order == uni_cluster_labels[j])
C_new[i, j] = np.sum(C[clustering.labels_== k, :][:, clustering.labels_==k1])
# labels_map_new = {}
# for i in range(n_clusters):
# # update label map: int->str
# labels_map_new[i] = self.labels_map.loc[clustering.labels_==i, 'label_names'].str.cat(sep=',')
# if np.sum(clustering.labels_==i)>1:
# print('Merge %s'%labels_map_new[i])
# # mean of the aggregated cluster means
# mu_new[:, i] = np.mean(mu[:,clustering.labels_==i], axis=-1)
# # sum of the aggregated pi's
# C_new[i, i] = np.sum(np.triu(C[clustering.labels_==i,:][:,clustering.labels_==i]))
# for j in range(i+1, n_clusters):
# C_new[i, j] = np.sum(C[clustering.labels_== i, :][:, clustering.labels_==j])
C_new = np.triu(C_new,1) + C_new.T
pi_new = C_new[np.triu_indices(n_clusters)]
log_pi_new = np.log(pi_new, out=np.ones_like(pi_new)*(-np.inf), where=(pi_new!=0)).reshape((1,-1))
self.n_states = n_clusters
self.labels_map = pd.DataFrame.from_dict(
{i:label for i,label in enumerate(uni_cluster_labels)},
orient='index', columns=['label_names'], dtype=str
)
self.labels = cluster_labels
# self.labels_map = pd.DataFrame.from_dict(
# labels_map_new, orient='index', columns=['label_names'], dtype=str
# )
self.vae.init_latent_space(self.n_states, mu_new, log_pi_new)
self.inferer = Inferer(self.n_states)
self.mu = self.vae.latent_space.mu.numpy()
self.pi = np.triu(np.ones(self.n_states))
self.pi[self.pi > 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0]
def train(self, stratify = False, test_size = 0.1, random_state: int = 0,
learning_rate: float = 1e-3, batch_size: int = 256,
L: int = 1, alpha: float = 0.10, beta: float = 1, gamma: float = 0, phi: float = 1,
num_epoch: int = 200, num_step_per_epoch: Optional[int] = None,
early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01,
early_stopping_relative: bool = True, early_stopping_warmup: int = 0,
path_to_weights: Optional[str] = None,
verbose: bool = False, **kwargs):
'''Train the model.
Parameters
----------
stratify : np.array, None, or False
If an array is provided, or `stratify=None` and `self.labels` is available, then they will be used to perform stratified shuffle splitting. Otherwise, general shuffle splitting is used. Set to `False` if `self.labels` is not intended for stratified shuffle splitting.
test_size : float or int, optional
The proportion or size of the test set.
random_state : int, optional
The random state for data splitting.
learning_rate : float, optional
The initial learning rate for the Adam optimizer.
batch_size : int, optional
The batch size for training. Default is 256. Set to 32 if number of cells is small (less than 1000)
L : int, optional
The number of MC samples.
alpha : float, optional
The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.
beta : float, optional
The value of beta in beta-VAE.
gamma : float, optional
The weight of mmd_loss.
phi : float, optional
The weight of Jacob norm of encoder.
num_epoch : int, optional
The number of epoch.
num_step_per_epoch : int, optional
The number of step per epoch, it will be inferred from number of cells and batch size if it is None.
early_stopping_patience : int, optional
The maximum number of epochs if there is no improvement.
early_stopping_tolerance : float, optional
The minimum change of loss to be considered as an improvement.
early_stopping_relative : bool, optional
Whether monitor the relative change of loss or not.
early_stopping_warmup : int, optional
The number of warmup epochs.
path_to_weights : str, optional
The path of weight file to be saved; not saving weight if None.
**kwargs :
Extra key-value arguments for dimension reduction algorithms.
'''
if gamma == 0 or self.conditions is None:
conditions = np.array([np.nan] * self.adata.shape[0])
else:
conditions = self.conditions
if stratify is None:
stratify = self.labels
elif stratify is False:
stratify = None
id_train, id_test = train_test_split(
np.arange(self.X_input.shape[0]),
test_size=test_size,
stratify=stratify,
random_state=random_state)
if num_step_per_epoch is None:
num_step_per_epoch = len(id_train)//batch_size+1
c = None if self.covariates is None else self.covariates.astype(tf.keras.backend.floatx())
self.train_dataset = train.warp_dataset(self.X_input[id_train].astype(tf.keras.backend.floatx()),
None if c is None else c[id_train],
batch_size,
self.X_output[id_train].astype(tf.keras.backend.floatx()),
self.scale_factor[id_train].astype(tf.keras.backend.floatx()),
conditions = conditions[id_train],
pi_cov = self.pi_cov[id_train])
self.test_dataset = train.warp_dataset(self.X_input[id_test].astype(tf.keras.backend.floatx()),
None if c is None else c[id_test],
batch_size,
self.X_output[id_test].astype(tf.keras.backend.floatx()),
self.scale_factor[id_test].astype(tf.keras.backend.floatx()),
conditions = conditions[id_test],
pi_cov = self.pi_cov[id_test])
self.vae = train.train(
self.train_dataset,
self.test_dataset,
self.vae,
learning_rate,
L,
alpha,
beta,
gamma,
phi,
num_epoch,
num_step_per_epoch,
early_stopping_patience,
early_stopping_tolerance,
early_stopping_relative,
early_stopping_warmup,
verbose,
**kwargs
)
self.update_z()
self.mu = self.vae.latent_space.mu.numpy()
self.pi = np.triu(np.ones(self.n_states))
self.pi[self.pi > 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0]
if path_to_weights is not None:
self.save_model(path_to_weights)
def output_pi(self, pi_cov):
"""return a matrix n_states by n_states and a mask for plotting, which can be used to cover the lower triangular(except the diagnoals) of a heatmap"""
p = self.vae.pilayer
pi_cov = tf.expand_dims(tf.constant([pi_cov], dtype=tf.float32), 0)
pi_val = tf.nn.softmax(p(pi_cov)).numpy()[0]
# Create heatmap matrix
n = self.vae.n_states
matrix = np.zeros((n, n))
matrix[np.triu_indices(n)] = pi_val
mask = np.tril(np.ones_like(matrix), k=-1)
return matrix, mask
def return_pilayer_weights(self):
"""return parameters of pilayer, which has dimension dim(pi_cov) + 1 by n_categories, the last row is biases"""
return np.vstack((model.vae.pilayer.weights[0].numpy(), model.vae.pilayer.weights[1].numpy().reshape(1, -1)))
def posterior_estimation(self, batch_size: int = 32, L: int = 50, **kwargs):
'''Initialize trajectory inference by computing the posterior estimations.
Parameters
----------
batch_size : int, optional
The batch size when doing inference.
L : int, optional
The number of MC samples when doing inference.
**kwargs :
Extra key-value arguments for dimension reduction algorithms.
'''
c = None if self.covariates is None else self.covariates.astype(tf.keras.backend.floatx())
self.test_dataset = train.warp_dataset(self.X_input.astype(tf.keras.backend.floatx()),
c,
batch_size)
_, _, self.pc_x,\
self.cell_position_posterior,self.cell_position_variance,_ = self.vae.inference(self.test_dataset, L=L)
uni_cluster_labels = self.labels_map['label_names'].to_numpy()
self.adata.obs['vitae_new_clustering'] = uni_cluster_labels[np.argmax(self.cell_position_posterior, 1)]
self.adata.obs['vitae_new_clustering'] = self.adata.obs['vitae_new_clustering'].astype('category')
print("New clustering labels saved as 'vitae_new_clustering' in self.adata.obs.")
return None
def infer_backbone(self, method: str = 'modified_map', thres = 0.5,
no_loop: bool = True, cutoff: float = 0,
visualize: bool = True, color = 'vitae_new_clustering',path_to_fig = None,**kwargs):
''' Compute edge scores.
Parameters
----------
method : string, optional
'mean', 'modified_mean', 'map', or 'modified_map'.
thres : float, optional
The threshold used for filtering edges \(e_{ij}\) that \((n_{i}+n_{j}+e_{ij})/N<thres\), only applied to mean method.
no_loop : boolean, optional
Whether loops are allowed to exist in the graph. If no_loop is true, will prune the graph to contain only the
maximum spanning true
cutoff : string, optional
The score threshold for filtering edges with scores less than cutoff.
visualize: boolean
whether plot the current trajectory backbone (undirected graph)
Returns
----------
G : nx.Graph
The weighted graph with weight on each edge indicating its score of existence.
'''
# build_graph, return graph
self.backbone = self.inferer.build_graphs(self.cell_position_posterior, self.pc_x,
method, thres, no_loop, cutoff)
self.cell_position_projected = self.inferer.modify_wtilde(self.cell_position_posterior,
np.array(list(self.backbone.edges)))
uni_cluster_labels = self.labels_map['label_names'].to_numpy()
temp_dict = {i:label for i,label in enumerate(uni_cluster_labels)}
nx.relabel_nodes(self.backbone, temp_dict)
self.adata.obs['vitae_new_clustering'] = uni_cluster_labels[np.argmax(self.cell_position_projected, 1)]
self.adata.obs['vitae_new_clustering'] = self.adata.obs['vitae_new_clustering'].astype('category')
print("'vitae_new_clustering' updated based on the projected cell positions.")
self.uncertainty = np.sum((self.cell_position_projected - self.cell_position_posterior)**2, axis=-1) \
+ np.sum(self.cell_position_variance, axis=-1)
self.adata.obs['projection_uncertainty'] = self.uncertainty
print("Cell projection uncertainties stored as 'projection_uncertainty' in self.adata.obs")
if visualize:
self._adata.obs = self.adata.obs.copy()
self.ax = self.plot_backbone(directed = False,color = color, **kwargs)
if path_to_fig is not None:
self.ax.figure.savefig(path_to_fig)
self.ax.figure.show()
return None
def select_root(self, days, method: str = 'proportion'):
'''Order the vertices/states based on cells' collection time information to select the root state.
Parameters
----------
day : np.array
The day information for selected cells used to determine the root vertex.
The dtype should be 'int' or 'float'.
method : str, optional
'sum' or 'mean'.
For 'proportion', the root is the one with maximal proportion of cells from the earliest day.
For 'mean', the root is the one with earliest mean time among cells associated with it.
Returns
----------
root : int
The root vertex in the inferred trajectory based on given day information.
'''
## TODO: change return description
if days is not None and len(days)!=self.X_input.shape[0]:
raise ValueError("The length of day information ({}) is not "
"consistent with the number of selected cells ({})!".format(
len(days), self.X_input.shape[0]))
if not hasattr(self, 'cell_position_projected'):
raise ValueError("Need to call 'infer_backbone' first!")
collection_time = np.dot(days, self.cell_position_projected)/np.sum(self.cell_position_projected, axis = 0)
earliest_prop = np.dot(days==np.min(days), self.cell_position_projected)/np.sum(self.cell_position_projected, axis = 0)
root_info = self.labels_map.copy()
root_info['mean_collection_time'] = collection_time
root_info['earliest_time_prop'] = earliest_prop
root_info.sort_values('mean_collection_time', inplace=True)
return root_info
def plot_backbone(self, directed: bool = False,
method: str = 'UMAP', color = 'vitae_new_clustering', **kwargs):
'''Plot the current trajectory backbone (undirected graph).
Parameters
----------
directed : boolean, optional
Whether the backbone is directed or not.
method : str, optional
The dimension reduction method to use. The default is "UMAP".
color : str, optional
The key for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2'].
The default is 'vitae_new_clustering'.
**kwargs :
Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX).
'''
if not isinstance(color,str):
raise ValueError('The color argument should be of type str!')
ax = self.visualize_latent(method = method, color=color, show=False, **kwargs)
dict_label_num = {j:i for i,j in self.labels_map['label_names'].to_dict().items()}
uni_cluster_labels = self.adata.obs['vitae_init_clustering'].cat.categories
cluster_labels = self.adata.obs['vitae_new_clustering'].to_numpy()
embed_z = self._adata.obsm[self.dict_method_scname[method]]
embed_mu = np.zeros((len(uni_cluster_labels), 2))
for l in uni_cluster_labels:
embed_mu[dict_label_num[l],:] = np.mean(embed_z[cluster_labels==l], axis=0)
if directed:
graph = self.directed_backbone
else:
graph = self.backbone
edges = list(graph.edges)
edge_scores = np.array([d['weight'] for (u,v,d) in graph.edges(data=True)])
if max(edge_scores) - min(edge_scores) == 0:
edge_scores = edge_scores/max(edge_scores)
else:
edge_scores = (edge_scores - min(edge_scores))/(max(edge_scores) - min(edge_scores))*3
value_range = np.maximum(np.diff(ax.get_xlim())[0], np.diff(ax.get_ylim())[0])
y_range = np.min(embed_z[:,1]), np.max(embed_z[:,1], axis=0)
for i in range(len(edges)):
points = embed_z[np.sum(self.cell_position_projected[:, edges[i]]>0, axis=-1)==2,:]
points = points[points[:,0].argsort()]
try:
x_smooth, y_smooth = _get_smooth_curve(
points,
embed_mu[edges[i], :],
y_range
)
except:
x_smooth, y_smooth = embed_mu[edges[i], 0], embed_mu[edges[i], 1]
ax.plot(x_smooth, y_smooth,
'-',
linewidth= 1 + edge_scores[i],
color="black",
alpha=0.8,
path_effects=[pe.Stroke(linewidth=1+edge_scores[i]+1.5,
foreground='white'), pe.Normal()],
zorder=1
)
if directed:
delta_x = embed_mu[edges[i][1], 0] - x_smooth[-2]
delta_y = embed_mu[edges[i][1], 1] - y_smooth[-2]
length = np.sqrt(delta_x**2 + delta_y**2) / 50 * value_range
ax.arrow(
embed_mu[edges[i][1], 0]-delta_x/length,
embed_mu[edges[i][1], 1]-delta_y/length,
delta_x/length,
delta_y/length,
color='black', alpha=1.0,
shape='full', lw=0, length_includes_head=True,
head_width=np.maximum(0.01*(1 + edge_scores[i]), 0.03) * value_range,
zorder=2)
colors = self._adata.uns['vitae_new_clustering_colors']
for i,l in enumerate(uni_cluster_labels):
ax.scatter(*embed_mu[dict_label_num[l]:dict_label_num[l]+1,:].T,
c=[colors[i]], edgecolors='white', # linewidths=10, norm=norm,
s=250, marker='*', label=l)
plt.setp(ax, xticks=[], yticks=[])
box = ax.get_position()
ax.set_position([box.x0, box.y0 + box.height * 0.1,
box.width, box.height * 0.9])
if directed:
ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
fancybox=True, shadow=True, ncol=5)
return ax
def plot_center(self, color = "vitae_new_clustering", plot_legend = True, legend_add_index = True,
method: str = 'UMAP',ncol = 2,font_size = "medium",
add_egde = False, add_direct = False,**kwargs):
'''Plot the center of each cluster in the latent space.
Parameters
----------
color : str, optional
The color of the center of each cluster. Default is "vitae_new_clustering".
plot_legend : bool, optional
Whether to plot the legend. Default is True.
legend_add_index : bool, optional
Whether to add the index of each cluster in the legend. Default is True.
method : str, optional
The dimension reduction method used for visualization. Default is 'UMAP'.
ncol : int, optional
The number of columns in the legend. Default is 2.
font_size : str, optional
The font size of the legend. Default is "medium".
add_egde : bool, optional
Whether to add the edges between the centers of clusters. Default is False.
add_direct : bool, optional
Whether to add the direction of the edges. Default is False.
'''
if color not in ["vitae_new_clustering","vitae_init_clustering"]:
raise ValueError("Can only plot center of vitae_new_clustering or vitae_init_clustering")
dict_label_num = {j: i for i, j in self.labels_map['label_names'].to_dict().items()}
if legend_add_index:
self._adata.obs["index_"+color] = self._adata.obs[color].map(lambda x: dict_label_num[x])
ax = self.visualize_latent(method=method, color="index_" + color, show=False, legend_loc="on data",
legend_fontsize=font_size,**kwargs)
colors = self._adata.uns["index_" + color + '_colors']
else:
ax = self.visualize_latent(method=method, color = color, show=False,**kwargs)
colors = self._adata.uns[color + '_colors']
uni_cluster_labels = self.adata.obs[color].cat.categories
cluster_labels = self.adata.obs[color].to_numpy()
embed_z = self._adata.obsm[self.dict_method_scname[method]]
embed_mu = np.zeros((len(uni_cluster_labels), 2))
for l in uni_cluster_labels:
embed_mu[dict_label_num[l], :] = np.mean(embed_z[cluster_labels == l], axis=0)
leg = (self.labels_map.index.astype(str) + " : " + self.labels_map.label_names).values
for i, l in enumerate(uni_cluster_labels):
ax.scatter(*embed_mu[dict_label_num[l]:dict_label_num[l] + 1, :].T,
c=[colors[i]], edgecolors='white', # linewidths=3,
s=250, marker='*', label=leg[i])
if plot_legend:
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=ncol, markerscale=0.8, frameon=False)
plt.setp(ax, xticks=[], yticks=[])
box = ax.get_position()
ax.set_position([box.x0, box.y0 + box.height * 0.1,
box.width, box.height * 0.9])
if add_egde:
if add_direct:
graph = self.directed_backbone
else:
graph = self.backbone
edges = list(graph.edges)
edge_scores = np.array([d['weight'] for (u, v, d) in graph.edges(data=True)])
if max(edge_scores) - min(edge_scores) == 0:
edge_scores = edge_scores / max(edge_scores)
else:
edge_scores = (edge_scores - min(edge_scores)) / (max(edge_scores) - min(edge_scores)) * 3
value_range = np.maximum(np.diff(ax.get_xlim())[0], np.diff(ax.get_ylim())[0])
y_range = np.min(embed_z[:, 1]), np.max(embed_z[:, 1], axis=0)
for i in range(len(edges)):
points = embed_z[np.sum(self.cell_position_projected[:, edges[i]] > 0, axis=-1) == 2, :]
points = points[points[:, 0].argsort()]
try:
x_smooth, y_smooth = _get_smooth_curve(
points,
embed_mu[edges[i], :],
y_range
)
except:
x_smooth, y_smooth = embed_mu[edges[i], 0], embed_mu[edges[i], 1]
ax.plot(x_smooth, y_smooth,
'-',
linewidth=1 + edge_scores[i],
color="black",
alpha=0.8,
path_effects=[pe.Stroke(linewidth=1 + edge_scores[i] + 1.5,
foreground='white'), pe.Normal()],
zorder=1
)
if add_direct:
delta_x = embed_mu[edges[i][1], 0] - x_smooth[-2]
delta_y = embed_mu[edges[i][1], 1] - y_smooth[-2]
length = np.sqrt(delta_x ** 2 + delta_y ** 2) / 50 * value_range
ax.arrow(
embed_mu[edges[i][1], 0] - delta_x / length,
embed_mu[edges[i][1], 1] - delta_y / length,
delta_x / length,
delta_y / length,
color='black', alpha=1.0,
shape='full', lw=0, length_includes_head=True,
head_width=np.maximum(0.01 * (1 + edge_scores[i]), 0.03) * value_range,
zorder=2)
self.ax = ax
self.ax.figure.show()
return None
def infer_trajectory(self, root: Union[int,str], digraph = None, color = "pseudotime",
visualize: bool = True, path_to_fig = None, **kwargs):
'''Infer the trajectory.
Parameters
----------
root : int or string
The root of the inferred trajectory. Can provide either an int (vertex index) or string (label name)
digraph : nx.DiGraph, optional
The directed graph to be used for trajectory inference. If None, the minimum spanning tree of the estimated trajectory backbone will be used.
cutoff : string, optional
The threshold for filtering edges with scores less than cutoff.
visualize: boolean
Whether plot the current trajectory backbone (directed graph)
path_to_fig : string, optional
The path to save figure, or don't save if it is None.
**kwargs : dict, optional
Other keywords arguments for plotting.
'''
if isinstance(root,str):
if root not in self.labels_map.values:
raise ValueError("Root {} is not in the label names!".format(root))
root = self.labels_map[self.labels_map['label_names']==root].index[0]
if digraph is None:
connected_comps = nx.node_connected_component(self.backbone, root)
subG = self.backbone.subgraph(connected_comps)
## generate directed backbone which contains no loops
DG = nx.DiGraph(nx.to_directed(self.backbone))
temp = DG.subgraph(connected_comps)
DG.remove_edges_from(temp.edges - nx.dfs_edges(DG, root))
self.directed_backbone = DG
else:
if not nx.is_directed_acyclic_graph(digraph):
raise ValueError("The graph 'digraph' should be a directed acyclic graph.")
if set(digraph.nodes) != set(self.backbone.nodes):
raise ValueError("The nodes in 'digraph' do not match the nodes in 'self.backbone'.")
self.directed_backbone = digraph
connected_comps = nx.node_connected_component(digraph, root)
subG = self.backbone.subgraph(connected_comps)
if len(subG.edges)>0:
milestone_net = self.inferer.build_milestone_net(subG, root)
if self.inferer.no_loop is False and milestone_net.shape[0]<len(self.backbone.edges):
warnings.warn("The directed graph shown is a minimum spanning tree of the estimated trajectory backbone to avoid arbitrary assignment of the directions.")
self.pseudotime = self.inferer.comp_pseudotime(milestone_net, root, self.cell_position_projected)
else:
warnings.warn("There are no connected states for starting from the giving root.")
self.pseudotime = -np.ones(self._adata.shape[0])
self.adata.obs['pseudotime'] = self.pseudotime
print("Cell projection uncertainties stored as 'pseudotime' in self.adata.obs")
if visualize:
self._adata.obs['pseudotime'] = self.pseudotime
self.ax = self.plot_backbone(directed = True, color = color, **kwargs)
if path_to_fig is not None:
self.ax.figure.savefig(path_to_fig)
self.ax.figure.show()
return None
def differential_expression_test(self, alpha: float = 0.05, cell_subset = None, order: int = 1):
'''Differentially gene expression test. All (selected and unselected) genes will be tested
Only cells in `selected_cell_subset` will be used, which is useful when one need to
test differentially expressed genes on a branch of the inferred trajectory.
Parameters
----------
alpha : float, optional
The cutoff of p-values.
cell_subset : np.array, optional
The subset of cells to be used for testing. If None, all cells will be used.
order : int, optional
The maxium order we used for pseudotime in regression.
Returns
----------
res_df : pandas.DataFrame
The test results of expressed genes with two columns,
the estimated coefficients and the adjusted p-values.
'''
if not hasattr(self, 'pseudotime'):
raise ReferenceError("Pseudotime does not exist! Please run 'infer_trajectory' first.")
if cell_subset is None:
cell_subset = np.arange(self.X_input.shape[0])
print("All cells are selected.")
if order < 1:
raise ValueError("Maximal order of pseudotime in regression must be at least 1.")
# Prepare X and Y for regression expression ~ rank(PDT) + covariates
Y = self.adata.X[cell_subset,:]
# std_Y = np.std(Y, ddof=1, axis=0, keepdims=True)
# Y = np.divide(Y-np.mean(Y, axis=0, keepdims=True), std_Y, out=np.empty_like(Y)*np.nan, where=std_Y!=0)
X = stats.rankdata(self.pseudotime[cell_subset])
if order > 1:
for _order in range(2, order+1):
X = np.c_[X, X**_order]
X = ((X-np.mean(X,axis=0, keepdims=True))/np.std(X, ddof=1, axis=0, keepdims=True))
X = np.c_[np.ones((X.shape[0],1)), X]
if self.covariates is not None:
X = np.c_[X, self.covariates[cell_subset, :]]
res_df = DE_test(Y, X, self.adata.var_names, i_test = np.array(list(range(1,order+1))), alpha = alpha)
return res_df[res_df.pvalue_adjusted_1 != 0]
def evaluate(self, milestone_net, begin_node_true, grouping = None,
thres: float = 0.5, no_loop: bool = True, cutoff: Optional[float] = None,
method: str = 'mean', path: Optional[str] = None):
''' Evaluate the model.
Parameters
----------
milestone_net : pd.DataFrame
The true milestone network. For real data, milestone_net will be a DataFrame of the graph of nodes.
Eg.
from|to
---|---
cluster 1 | cluster 1
cluster 1 | cluster 2
For synthetic data, milestone_net will be a DataFrame of the (projected)
positions of cells. The indexes are the orders of cells in the dataset.
Eg.
from|to|w
---|---|---
cluster 1 | cluster 1 | 1
cluster 1 | cluster 2 | 0.1
begin_node_true : str or int
The true begin node of the milestone.
grouping : np.array, optional
\([N,]\) The labels. For real data, grouping must be provided.
Returns
----------
res : pd.DataFrame
The evaluation result.
'''
if not hasattr(self, 'labels_map'):
raise ValueError("No given labels for training.")
'''
# Evaluate for the whole dataset will ignore selected_cell_subset.
if len(self.selected_cell_subset)!=len(self.cell_names):
warnings.warn("Evaluate for the whole dataset.")
'''
# If the begin_node_true, need to encode it by self.le.
# this dict is for milestone net cause their labels are not merged
# all keys of label_map_dict are str
label_map_dict = dict()
for i in range(self.labels_map.shape[0]):
label_mapped = self.labels_map.loc[i]
## merged cluster index is connected by comma
for each in label_mapped.values[0].split(","):
label_map_dict[each] = i
if isinstance(begin_node_true, str):
begin_node_true = label_map_dict[begin_node_true]
# For generated data, grouping information is already in milestone_net
if 'w' in milestone_net.columns:
grouping = None
# If milestone_net is provided, transform them to be numeric.
if milestone_net is not None:
milestone_net['from'] = [label_map_dict[x] for x in milestone_net["from"]]
milestone_net['to'] = [label_map_dict[x] for x in milestone_net["to"]]
# this dict is for potentially merged clusters.
label_map_dict_for_merged_cluster = dict(zip(self.labels_map["label_names"],self.labels_map.index))
mapped_labels = np.array([label_map_dict_for_merged_cluster[x] for x in self.labels])
begin_node_pred = int(np.argmin(np.mean((
self.z[mapped_labels==begin_node_true,:,np.newaxis] -
self.mu[np.newaxis,:,:])**2, axis=(0,1))))
if cutoff is None:
cutoff = 0.01
G = self.backbone
w = self.cell_position_projected
pseudotime = self.pseudotime
# 1. Topology
G_pred = nx.Graph()
G_pred.add_nodes_from(G.nodes)
G_pred.add_edges_from(G.edges)
nx.set_node_attributes(G_pred, False, 'is_init')
G_pred.nodes[begin_node_pred]['is_init'] = True
G_true = nx.Graph()
G_true.add_nodes_from(G.nodes)
# if 'grouping' is not provided, assume 'milestone_net' contains proportions
if grouping is None:
G_true.add_edges_from(list(
milestone_net[~pd.isna(milestone_net['w'])].groupby(['from', 'to']).count().index))
# otherwise, 'milestone_net' indicates edges
else:
if milestone_net is not None:
G_true.add_edges_from(list(
milestone_net.groupby(['from', 'to']).count().index))
grouping = [label_map_dict[x] for x in grouping]
grouping = np.array(grouping)
G_true.remove_edges_from(nx.selfloop_edges(G_true))
nx.set_node_attributes(G_true, False, 'is_init')
G_true.nodes[begin_node_true]['is_init'] = True
res = topology(G_true, G_pred)
# 2. Milestones assignment
if grouping is None:
milestones_true = milestone_net['from'].values.copy()
milestones_true[(milestone_net['from']!=milestone_net['to'])
&(milestone_net['w']<0.5)] = milestone_net[(milestone_net['from']!=milestone_net['to'])
&(milestone_net['w']<0.5)]['to'].values
else:
milestones_true = grouping
milestones_true = milestones_true
milestones_pred = np.argmax(w, axis=1)
res['ARI'] = (adjusted_rand_score(milestones_true, milestones_pred) + 1)/2
if grouping is None:
n_samples = len(milestone_net)
prop = np.zeros((n_samples,n_samples))
prop[np.arange(n_samples), milestone_net['to']] = 1-milestone_net['w']
prop[np.arange(n_samples), milestone_net['from']] = np.where(np.isnan(milestone_net['w']), 1, milestone_net['w'])
res['GRI'] = get_GRI(prop, w)
else:
res['GRI'] = get_GRI(grouping, w)
# 3. Correlation between geodesic distances / Pseudotime
if no_loop:
if grouping is None:
pseudotime_true = milestone_net['from'].values + 1 - milestone_net['w'].values
pseudotime_true[np.isnan(pseudotime_true)] = milestone_net[pd.isna(milestone_net['w'])]['from'].values
else:
pseudotime_true = - np.ones(len(grouping))
nx.set_edge_attributes(G_true, values = 1, name = 'weight')
connected_comps = nx.node_connected_component(G_true, begin_node_true)
subG = G_true.subgraph(connected_comps)
milestone_net_true = self.inferer.build_milestone_net(subG, begin_node_true)
if len(milestone_net_true)>0:
pseudotime_true[grouping==int(milestone_net_true[0,0])] = 0
for i in range(len(milestone_net_true)):
pseudotime_true[grouping==int(milestone_net_true[i,1])] = milestone_net_true[i,-1]
pseudotime_true = pseudotime_true[pseudotime>-1]
pseudotime_pred = pseudotime[pseudotime>-1]
res['PDT score'] = (np.corrcoef(pseudotime_true,pseudotime_pred)[0,1]+1)/2
else:
res['PDT score'] = np.nan
# 4. Shape
# score_cos_theta = 0
# for (_from,_to) in G.edges:
# _z = self.z[(w[:,_from]>0) & (w[:,_to]>0),:]
# v_1 = _z - self.mu[:,_from]
# v_2 = _z - self.mu[:,_to]
# cos_theta = np.sum(v_1*v_2, -1)/(np.linalg.norm(v_1,axis=-1)*np.linalg.norm(v_2,axis=-1)+1e-12)
# score_cos_theta += np.sum((1-cos_theta)/2)
# res['score_cos_theta'] = score_cos_theta/(np.sum(np.sum(w>0, axis=-1)==2)+1e-12)
return res
def save_model(self, path_to_file: str = 'model.checkpoint',save_adata: bool = False):
'''Saving model weights.
Parameters
----------
path_to_file : str, optional
The path to weight files of pre-trained or trained model
save_adata : boolean, optional
Whether to save adata or not.
'''
self.vae.save_weights(path_to_file)
if hasattr(self, 'labels') and self.labels is not None:
with open(path_to_file + '.label', 'wb') as f:
np.save(f, self.labels)
with open(path_to_file + '.config', 'wb') as f:
self.dim_origin = self.X_input.shape[1]
np.save(f, np.array([
self.dim_origin, self.dimensions, self.dim_latent,
self.model_type, 0 if self.covariates is None else self.covariates.shape[1]], dtype=object))
if hasattr(self, 'inferer') and hasattr(self, 'uncertainty'):
with open(path_to_file + '.inference', 'wb') as f:
np.save(f, np.array([
self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty,
self.z,self.cell_position_variance], dtype=object))
if save_adata:
self.adata.write(path_to_file + '.adata.h5ad')
def load_model(self, path_to_file: str = 'model.checkpoint', load_labels: bool = False, load_adata: bool = False):
'''Load model weights.
Parameters
----------
path_to_file : str, optional
The path to weight files of pre trained or trained model
load_labels : boolean, optional
Whether to load clustering labels or not.
If load_labels is True, then the LatentSpace layer will be initialized basd on the model.
If load_labels is False, then the LatentSpace layer will not be initialized.
load_adata : boolean, optional
Whether to load adata or not.
'''
if not os.path.exists(path_to_file + '.config'):
raise AssertionError('Config file not exist!')
if load_labels and not os.path.exists(path_to_file + '.label'):
raise AssertionError('Label file not exist!')
with open(path_to_file + '.config', 'rb') as f:
[self.dim_origin, self.dimensions,
self.dim_latent, self.model_type, cov_dim] = np.load(f, allow_pickle=True)
self.vae = model.VariationalAutoEncoder(
self.dim_origin, self.dimensions,
self.dim_latent, self.model_type, False if cov_dim == 0 else True
)
if load_labels:
with open(path_to_file + '.label', 'rb') as f:
cluster_labels = np.load(f, allow_pickle=True)
self.init_latent_space(cluster_labels, dist_thres=0)
if os.path.exists(path_to_file + '.inference'):
with open(path_to_file + '.inference', 'rb') as f:
arr = np.load(f, allow_pickle=True)
if len(arr) == 8:
[self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty,
self.D_JS, self.z,self.cell_position_variance] = arr
else:
[self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty,
self.z,self.cell_position_variance] = arr
self._adata_z = sc.AnnData(self.z)
sc.pp.neighbors(self._adata_z)
## initialize the weight of encoder and decoder
self.vae.encoder(np.zeros((1, self.dim_origin + cov_dim)))
self.vae.decoder(np.expand_dims(np.zeros((1,self.dim_latent + cov_dim)),1))
self.vae.load_weights(path_to_file)
self.update_z()
if load_adata:
if not os.path.exists(path_to_file + '.adata.h5ad'):
raise AssertionError('AnnData file not exist!')
self.adata = sc.read_h5ad(path_to_file + '.adata.h5ad')
self._adata.obs = self.adata.obs.copy()</code></pre>
</details>
<h3>Methods</h3>
<dl>
<dt id="VITAE.VITAE.pre_train"><code class="name flex">
<span>def <span class="ident">pre_train</span></span>(<span>self, test_size=0.1, random_state: int = 0, learning_rate: float = 0.001, batch_size: int = 256, L: int = 1, alpha: float = 0.1, gamma: float = 0, phi: float = 1, num_epoch: int = 200, num_step_per_epoch: Optional[int] = None, early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, early_stopping_relative: bool = True, verbose: bool = False, path_to_weights: Optional[str] = None)</span>
</code></dt>
<dd>
<div class="desc"><p>Pretrain the model with specified learning rate.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>test_size</code></strong> : <code>float</code> or <code>int</code>, optional</dt>
<dd>The proportion or size of the test set.</dd>
<dt><strong><code>random_state</code></strong> : <code>int</code>, optional</dt>
<dd>The random state for data splitting.</dd>
<dt><strong><code>learning_rate</code></strong> : <code>float</code>, optional</dt>
<dd>The initial learning rate for the Adam optimizer.</dd>
<dt><strong><code>batch_size</code></strong> : <code>int</code>, optional</dt>
<dd>The batch size for pre-training.
Default is 256. Set to 32 if number of cells is small (less than 1000)</dd>
<dt><strong><code>L</code></strong> : <code>int</code>, optional</dt>
<dd>The number of MC samples.</dd>
<dt><strong><code>alpha</code></strong> : <code>float</code>, optional</dt>
<dd>The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.</dd>
<dt><strong><code>gamma</code></strong> : <code>float</code>, optional</dt>
<dd>The weight of the mmd loss if used.</dd>
<dt><strong><code>phi</code></strong> : <code>float</code>, optional</dt>
<dd>The weight of Jocob norm of the encoder.</dd>
<dt><strong><code>num_epoch</code></strong> : <code>int</code>, optional</dt>
<dd>The maximum number of epochs.</dd>
<dt><strong><code>num_step_per_epoch</code></strong> : <code>int</code>, optional</dt>
<dd>The number of step per epoch, it will be inferred from number of cells and batch size if it is None.</dd>
<dt><strong><code>early_stopping_patience</code></strong> : <code>int</code>, optional</dt>
<dd>The maximum number of epochs if there is no improvement.</dd>
<dt><strong><code>early_stopping_tolerance</code></strong> : <code>float</code>, optional</dt>
<dd>The minimum change of loss to be considered as an improvement.</dd>
<dt><strong><code>early_stopping_relative</code></strong> : <code>bool</code>, optional</dt>
<dd>Whether monitor the relative change of loss as stopping criteria or not.</dd>
<dt><strong><code>path_to_weights</code></strong> : <code>str</code>, optional</dt>
<dd>The path of weight file to be saved; not saving weight if None.</dd>
<dt><strong><code>conditions</code></strong> : <code>str</code> or <code>list</code>, optional</dt>
<dd>The conditions of different cells</dd>
</dl></div>
</dd>
<dt id="VITAE.VITAE.update_z"><code class="name flex">
<span>def <span class="ident">update_z</span></span>(<span>self)</span>
</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="VITAE.VITAE.get_latent_z"><code class="name flex">
<span>def <span class="ident">get_latent_z</span></span>(<span>self)</span>
</code></dt>
<dd>
<div class="desc"><p>get the posterier mean of current latent space z (encoder output)</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>z</code></strong> : <code>np.array</code></dt>
<dd><span><span class="MathJax_Preview">[N,d]</span><script type="math/tex">[N,d]</script></span> The latent means.</dd>
</dl></div>
</dd>
<dt id="VITAE.VITAE.visualize_latent"><code class="name flex">
<span>def <span class="ident">visualize_latent</span></span>(<span>self, method: str = 'UMAP', color=None, **kwargs)</span>
</code></dt>
<dd>
<div class="desc"><p>visualize the current latent space z using the scanpy visualization tools</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>method</code></strong> : <code>str</code>, optional</dt>
<dd>Visualization method to use. The default is "draw_graph" (the FA plot). Possible choices include "PCA", "UMAP",
"diffmap", "TSNE" and "draw_graph"</dd>
<dt><strong><code>color</code></strong> : <code>TYPE</code>, optional</dt>
<dd>Keys for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2'].
The default is None. Same as scanpy.</dd>
<dt><strong><code>**kwargs</code></strong> : <code> </code></dt>
<dd>Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX).</dd>
</dl>
<h2 id="returns">Returns</h2>
<p>None.</p></div>
</dd>
<dt id="VITAE.VITAE.init_latent_space"><code class="name flex">
<span>def <span class="ident">init_latent_space</span></span>(<span>self, cluster_label=None, log_pi=None, res: float = 1.0, ratio_prune=None, dist=None, dist_thres=0.5, topk=0, pilayer=False)</span>
</code></dt>
<dd>
<div class="desc"><p>Initialize the latent space.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>cluster_label</code></strong> : <code>str</code>, optional</dt>
<dd>The name of vector of labels that can be found in self.adata.obs.
Default is None, which will perform leiden clustering on the pretrained z to get clusters</dd>
<dt><strong><code>mu</code></strong> : <code>np.array</code>, optional</dt>
<dd><span><span class="MathJax_Preview">[d,k]</span><script type="math/tex">[d,k]</script></span> The value of initial <span><span class="MathJax_Preview">\mu</span><script type="math/tex">\mu</script></span>.</dd>
<dt><strong><code>log_pi</code></strong> : <code>np.array</code>, optional</dt>
<dd><span><span class="MathJax_Preview">[1,K]</span><script type="math/tex">[1,K]</script></span> The value of initial <span><span class="MathJax_Preview">\log(\pi)</span><script type="math/tex">\log(\pi)</script></span>.</dd>
<dt><strong><code>res</code></strong></dt>
<dd>The resolution of leiden clustering, which is a parameter value controlling the coarseness of the clustering.
Higher values lead to more clusters. Deafult is 1.</dd>
<dt><strong><code>ratio_prune</code></strong> : <code>float</code>, optional</dt>
<dd>The ratio of edges to be removed before estimating.</dd>
<dt><strong><code>topk</code></strong> : <code>int</code>, optional</dt>
<dd>The number of top k neighbors to keep for each cluster.</dd>
</dl></div>
</dd>
<dt id="VITAE.VITAE.update_latent_space"><code class="name flex">
<span>def <span class="ident">update_latent_space</span></span>(<span>self, dist_thres: float = 0.5)</span>
</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="VITAE.VITAE.train"><code class="name flex">
<span>def <span class="ident">train</span></span>(<span>self, stratify=False, test_size=0.1, random_state: int = 0, learning_rate: float = 0.001, batch_size: int = 256, L: int = 1, alpha: float = 0.1, beta: float = 1, gamma: float = 0, phi: float = 1, num_epoch: int = 200, num_step_per_epoch: Optional[int] = None, early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, early_stopping_relative: bool = True, early_stopping_warmup: int = 0, path_to_weights: Optional[str] = None, verbose: bool = False, **kwargs)</span>
</code></dt>
<dd>
<div class="desc"><p>Train the model.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>stratify</code></strong> : <code>np.array, None,</code> or <code>False</code></dt>
<dd>If an array is provided, or <code>stratify=None</code> and <code>self.labels</code> is available, then they will be used to perform stratified shuffle splitting. Otherwise, general shuffle splitting is used. Set to <code>False</code> if <code>self.labels</code> is not intended for stratified shuffle splitting.</dd>
<dt><strong><code>test_size</code></strong> : <code>float</code> or <code>int</code>, optional</dt>
<dd>The proportion or size of the test set.</dd>
<dt><strong><code>random_state</code></strong> : <code>int</code>, optional</dt>
<dd>The random state for data splitting.</dd>
<dt><strong><code>learning_rate</code></strong> : <code>float</code>, optional</dt>
<dd>The initial learning rate for the Adam optimizer.</dd>
<dt><strong><code>batch_size</code></strong> : <code>int</code>, optional</dt>
<dd>The batch size for training. Default is 256. Set to 32 if number of cells is small (less than 1000)</dd>
<dt><strong><code>L</code></strong> : <code>int</code>, optional</dt>
<dd>The number of MC samples.</dd>
<dt><strong><code>alpha</code></strong> : <code>float</code>, optional</dt>
<dd>The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.</dd>
<dt><strong><code>beta</code></strong> : <code>float</code>, optional</dt>
<dd>The value of beta in beta-VAE.</dd>
<dt><strong><code>gamma</code></strong> : <code>float</code>, optional</dt>
<dd>The weight of mmd_loss.</dd>
<dt><strong><code>phi</code></strong> : <code>float</code>, optional</dt>
<dd>The weight of Jacob norm of encoder.</dd>
<dt><strong><code>num_epoch</code></strong> : <code>int</code>, optional</dt>
<dd>The number of epoch.</dd>
<dt><strong><code>num_step_per_epoch</code></strong> : <code>int</code>, optional</dt>
<dd>The number of step per epoch, it will be inferred from number of cells and batch size if it is None.</dd>
<dt><strong><code>early_stopping_patience</code></strong> : <code>int</code>, optional</dt>
<dd>The maximum number of epochs if there is no improvement.</dd>
<dt><strong><code>early_stopping_tolerance</code></strong> : <code>float</code>, optional</dt>
<dd>The minimum change of loss to be considered as an improvement.</dd>
<dt><strong><code>early_stopping_relative</code></strong> : <code>bool</code>, optional</dt>
<dd>Whether monitor the relative change of loss or not.</dd>
<dt><strong><code>early_stopping_warmup</code></strong> : <code>int</code>, optional</dt>
<dd>The number of warmup epochs.</dd>
<dt><strong><code>path_to_weights</code></strong> : <code>str</code>, optional</dt>
<dd>The path of weight file to be saved; not saving weight if None.</dd>
<dt><strong><code>**kwargs</code></strong> : <code> </code></dt>
<dd>Extra key-value arguments for dimension reduction algorithms.</dd>
</dl></div>
</dd>
<dt id="VITAE.VITAE.output_pi"><code class="name flex">
<span>def <span class="ident">output_pi</span></span>(<span>self, pi_cov)</span>
</code></dt>
<dd>
<div class="desc"><p>return a matrix n_states by n_states and a mask for plotting, which can be used to cover the lower triangular(except the diagnoals) of a heatmap</p></div>
</dd>
<dt id="VITAE.VITAE.return_pilayer_weights"><code class="name flex">
<span>def <span class="ident">return_pilayer_weights</span></span>(<span>self)</span>
</code></dt>
<dd>
<div class="desc"><p>return parameters of pilayer, which has dimension dim(pi_cov) + 1 by n_categories, the last row is biases</p></div>
</dd>
<dt id="VITAE.VITAE.posterior_estimation"><code class="name flex">
<span>def <span class="ident">posterior_estimation</span></span>(<span>self, batch_size: int = 32, L: int = 50, **kwargs)</span>
</code></dt>
<dd>
<div class="desc"><p>Initialize trajectory inference by computing the posterior estimations.
</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>batch_size</code></strong> : <code>int</code>, optional</dt>
<dd>The batch size when doing inference.</dd>
<dt><strong><code>L</code></strong> : <code>int</code>, optional</dt>
<dd>The number of MC samples when doing inference.</dd>
<dt><strong><code>**kwargs</code></strong> : <code> </code></dt>
<dd>Extra key-value arguments for dimension reduction algorithms.</dd>
</dl></div>
</dd>
<dt id="VITAE.VITAE.infer_backbone"><code class="name flex">
<span>def <span class="ident">infer_backbone</span></span>(<span>self, method: str = 'modified_map', thres=0.5, no_loop: bool = True, cutoff: float = 0, visualize: bool = True, color='vitae_new_clustering', path_to_fig=None, **kwargs)</span>
</code></dt>
<dd>
<div class="desc"><p>Compute edge scores.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>method</code></strong> : <code>string</code>, optional</dt>
<dd>'mean', 'modified_mean', 'map', or 'modified_map'.</dd>
<dt><strong><code>thres</code></strong> : <code>float</code>, optional</dt>
<dd>The threshold used for filtering edges <span><span class="MathJax_Preview">e_{ij}</span><script type="math/tex">e_{ij}</script></span> that <span><span class="MathJax_Preview">(n_{i}+n_{j}+e_{ij})/N<thres</span><script type="math/tex">(n_{i}+n_{j}+e_{ij})/N<thres</script></span>, only applied to mean method.</dd>
<dt><strong><code>no_loop</code></strong> : <code>boolean</code>, optional</dt>
<dd>Whether loops are allowed to exist in the graph. If no_loop is true, will prune the graph to contain only the
maximum spanning true</dd>
<dt><strong><code>cutoff</code></strong> : <code>string</code>, optional</dt>
<dd>The score threshold for filtering edges with scores less than cutoff.</dd>
<dt><strong><code>visualize</code></strong> : <code>boolean</code></dt>
<dd>whether plot the current trajectory backbone (undirected graph)</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>G</code></strong> : <code>nx.Graph</code></dt>
<dd>The weighted graph with weight on each edge indicating its score of existence.</dd>
</dl></div>
</dd>
<dt id="VITAE.VITAE.select_root"><code class="name flex">
<span>def <span class="ident">select_root</span></span>(<span>self, days, method: str = 'proportion')</span>
</code></dt>
<dd>
<div class="desc"><p>Order the vertices/states based on cells' collection time information to select the root state.
</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>day</code></strong> : <code>np.array </code></dt>
<dd>The day information for selected cells used to determine the root vertex.
The dtype should be 'int' or 'float'.</dd>
<dt><strong><code>method</code></strong> : <code>str</code>, optional</dt>
<dd>'sum' or 'mean'.
For 'proportion', the root is the one with maximal proportion of cells from the earliest day.
For 'mean', the root is the one with earliest mean time among cells associated with it.</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>root</code></strong> : <code>int </code></dt>
<dd>The root vertex in the inferred trajectory based on given day information.</dd>
</dl></div>
</dd>
<dt id="VITAE.VITAE.plot_backbone"><code class="name flex">
<span>def <span class="ident">plot_backbone</span></span>(<span>self, directed: bool = False, method: str = 'UMAP', color='vitae_new_clustering', **kwargs)</span>
</code></dt>
<dd>
<div class="desc"><p>Plot the current trajectory backbone (undirected graph).</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>directed</code></strong> : <code>boolean</code>, optional</dt>
<dd>Whether the backbone is directed or not.</dd>
<dt><strong><code>method</code></strong> : <code>str</code>, optional</dt>
<dd>The dimension reduction method to use. The default is "UMAP".</dd>
<dt><strong><code>color</code></strong> : <code>str</code>, optional</dt>
<dd>The key for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2'].
The default is 'vitae_new_clustering'.</dd>
</dl>
<p>**kwargs :
Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX).</p></div>
</dd>
<dt id="VITAE.VITAE.plot_center"><code class="name flex">
<span>def <span class="ident">plot_center</span></span>(<span>self, color='vitae_new_clustering', plot_legend=True, legend_add_index=True, method: str = 'UMAP', ncol=2, font_size='medium', add_egde=False, add_direct=False, **kwargs)</span>
</code></dt>
<dd>
<div class="desc"><p>Plot the center of each cluster in the latent space.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>color</code></strong> : <code>str</code>, optional</dt>
<dd>The color of the center of each cluster. Default is "vitae_new_clustering".</dd>
<dt><strong><code>plot_legend</code></strong> : <code>bool</code>, optional</dt>
<dd>Whether to plot the legend. Default is True.</dd>
<dt><strong><code>legend_add_index</code></strong> : <code>bool</code>, optional</dt>
<dd>Whether to add the index of each cluster in the legend. Default is True.</dd>
<dt><strong><code>method</code></strong> : <code>str</code>, optional</dt>
<dd>The dimension reduction method used for visualization. Default is 'UMAP'.</dd>
<dt><strong><code>ncol</code></strong> : <code>int</code>, optional</dt>
<dd>The number of columns in the legend. Default is 2.</dd>
<dt><strong><code>font_size</code></strong> : <code>str</code>, optional</dt>
<dd>The font size of the legend. Default is "medium".</dd>
<dt><strong><code>add_egde</code></strong> : <code>bool</code>, optional</dt>
<dd>Whether to add the edges between the centers of clusters. Default is False.</dd>
<dt><strong><code>add_direct</code></strong> : <code>bool</code>, optional</dt>
<dd>Whether to add the direction of the edges. Default is False.</dd>
</dl></div>
</dd>
<dt id="VITAE.VITAE.infer_trajectory"><code class="name flex">
<span>def <span class="ident">infer_trajectory</span></span>(<span>self, root: Union[int, str], digraph=None, color='pseudotime', visualize: bool = True, path_to_fig=None, **kwargs)</span>
</code></dt>
<dd>
<div class="desc"><p>Infer the trajectory.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>root</code></strong> : <code>int</code> or <code>string</code></dt>
<dd>The root of the inferred trajectory. Can provide either an int (vertex index) or string (label name)</dd>
<dt><strong><code>digraph</code></strong> : <code>nx.DiGraph</code>, optional</dt>
<dd>The directed graph to be used for trajectory inference. If None, the minimum spanning tree of the estimated trajectory backbone will be used.</dd>
<dt><strong><code>cutoff</code></strong> : <code>string</code>, optional</dt>
<dd>The threshold for filtering edges with scores less than cutoff.</dd>
<dt><strong><code>visualize</code></strong> : <code>boolean</code></dt>
<dd>Whether plot the current trajectory backbone (directed graph)</dd>
<dt><strong><code>path_to_fig</code></strong> : <code>string</code>, optional</dt>
<dd>The path to save figure, or don't save if it is None.</dd>
<dt><strong><code>**kwargs</code></strong> : <code>dict</code>, optional</dt>
<dd>Other keywords arguments for plotting.</dd>
</dl></div>
</dd>
<dt id="VITAE.VITAE.differential_expression_test"><code class="name flex">
<span>def <span class="ident">differential_expression_test</span></span>(<span>self, alpha: float = 0.05, cell_subset=None, order: int = 1)</span>
</code></dt>
<dd>
<div class="desc"><p>Differentially gene expression test. All (selected and unselected) genes will be tested
Only cells in <code>selected_cell_subset</code> will be used, which is useful when one need to
test differentially expressed genes on a branch of the inferred trajectory.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>alpha</code></strong> : <code>float</code>, optional</dt>
<dd>The cutoff of p-values.</dd>
<dt><strong><code>cell_subset</code></strong> : <code>np.array</code>, optional</dt>
<dd>The subset of cells to be used for testing. If None, all cells will be used.</dd>
<dt><strong><code>order</code></strong> : <code>int</code>, optional</dt>
<dd>The maxium order we used for pseudotime in regression.</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>res_df</code></strong> : <code>pandas.DataFrame</code></dt>
<dd>The test results of expressed genes with two columns,
the estimated coefficients and the adjusted p-values.</dd>
</dl></div>
</dd>
<dt id="VITAE.VITAE.evaluate"><code class="name flex">
<span>def <span class="ident">evaluate</span></span>(<span>self, milestone_net, begin_node_true, grouping=None, thres: float = 0.5, no_loop: bool = True, cutoff: Optional[float] = None, method: str = 'mean', path: Optional[str] = None)</span>
</code></dt>
<dd>
<div class="desc"><p>Evaluate the model.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>milestone_net</code></strong> : <code>pd.DataFrame</code></dt>
<dd>
<p>The true milestone network. For real data, milestone_net will be a DataFrame of the graph of nodes.
Eg.</p>
<table>
<thead>
<tr>
<th>from</th>
<th>to</th>
</tr>
</thead>
<tbody>
<tr>
<td>cluster 1</td>
<td>cluster 1</td>
</tr>
<tr>
<td>cluster 1</td>
<td>cluster 2</td>
</tr>
</tbody>
</table>
<p>For synthetic data, milestone_net will be a DataFrame of the (projected)
positions of cells. The indexes are the orders of cells in the dataset.
Eg.</p>
<table>
<thead>
<tr>
<th>from</th>
<th>to</th>
<th>w</th>
</tr>
</thead>
<tbody>
<tr>
<td>cluster 1</td>
<td>cluster 1</td>
<td>1</td>
</tr>
<tr>
<td>cluster 1</td>
<td>cluster 2</td>
<td>0.1</td>
</tr>
</tbody>
</table>
</dd>
<dt><strong><code>begin_node_true</code></strong> : <code>str</code> or <code>int</code></dt>
<dd>The true begin node of the milestone.</dd>
<dt><strong><code>grouping</code></strong> : <code>np.array</code>, optional</dt>
<dd><span><span class="MathJax_Preview">[N,]</span><script type="math/tex">[N,]</script></span> The labels. For real data, grouping must be provided.</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>res</code></strong> : <code>pd.DataFrame</code></dt>
<dd>The evaluation result.</dd>
</dl></div>
</dd>
<dt id="VITAE.VITAE.save_model"><code class="name flex">
<span>def <span class="ident">save_model</span></span>(<span>self, path_to_file: str = 'model.checkpoint', save_adata: bool = False)</span>
</code></dt>
<dd>
<div class="desc"><p>Saving model weights.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>path_to_file</code></strong> : <code>str</code>, optional</dt>
<dd>The path to weight files of pre-trained or trained model</dd>
<dt><strong><code>save_adata</code></strong> : <code>boolean</code>, optional</dt>
<dd>Whether to save adata or not.</dd>
</dl></div>
</dd>
<dt id="VITAE.VITAE.load_model"><code class="name flex">
<span>def <span class="ident">load_model</span></span>(<span>self, path_to_file: str = 'model.checkpoint', load_labels: bool = False, load_adata: bool = False)</span>
</code></dt>
<dd>
<div class="desc"><p>Load model weights.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>path_to_file</code></strong> : <code>str</code>, optional</dt>
<dd>The path to weight files of pre trained or trained model</dd>
<dt><strong><code>load_labels</code></strong> : <code>boolean</code>, optional</dt>
<dd>Whether to load clustering labels or not.
If load_labels is True, then the LatentSpace layer will be initialized basd on the model.
If load_labels is False, then the LatentSpace layer will not be initialized.</dd>
<dt><strong><code>load_adata</code></strong> : <code>boolean</code>, optional</dt>
<dd>Whether to load adata or not.</dd>
</dl></div>
</dd>
</dl>
</dd>
</dl>
</section>
</article>
<nav id="sidebar">
<div class="toc">
<ul></ul>
</div>
<ul id="index">
<li><h3><a href="#header-submodules">Sub-modules</a></h3>
<ul>
<li><code><a title="VITAE.inference" href="inference.html">VITAE.inference</a></code></li>
<li><code><a title="VITAE.metric" href="metric.html">VITAE.metric</a></code></li>
<li><code><a title="VITAE.model" href="model.html">VITAE.model</a></code></li>
<li><code><a title="VITAE.train" href="train.html">VITAE.train</a></code></li>
<li><code><a title="VITAE.utils" href="utils.html">VITAE.utils</a></code></li>
</ul>
</li>
<li><h3><a href="#header-classes">Classes</a></h3>
<ul>
<li>
<h4><code><a title="VITAE.VITAE" href="#VITAE.VITAE">VITAE</a></code></h4>
<ul class="">
<li><code><a title="VITAE.VITAE.pre_train" href="#VITAE.VITAE.pre_train">pre_train</a></code></li>
<li><code><a title="VITAE.VITAE.update_z" href="#VITAE.VITAE.update_z">update_z</a></code></li>
<li><code><a title="VITAE.VITAE.get_latent_z" href="#VITAE.VITAE.get_latent_z">get_latent_z</a></code></li>
<li><code><a title="VITAE.VITAE.visualize_latent" href="#VITAE.VITAE.visualize_latent">visualize_latent</a></code></li>
<li><code><a title="VITAE.VITAE.init_latent_space" href="#VITAE.VITAE.init_latent_space">init_latent_space</a></code></li>
<li><code><a title="VITAE.VITAE.update_latent_space" href="#VITAE.VITAE.update_latent_space">update_latent_space</a></code></li>
<li><code><a title="VITAE.VITAE.train" href="#VITAE.VITAE.train">train</a></code></li>
<li><code><a title="VITAE.VITAE.output_pi" href="#VITAE.VITAE.output_pi">output_pi</a></code></li>
<li><code><a title="VITAE.VITAE.return_pilayer_weights" href="#VITAE.VITAE.return_pilayer_weights">return_pilayer_weights</a></code></li>
<li><code><a title="VITAE.VITAE.posterior_estimation" href="#VITAE.VITAE.posterior_estimation">posterior_estimation</a></code></li>
<li><code><a title="VITAE.VITAE.infer_backbone" href="#VITAE.VITAE.infer_backbone">infer_backbone</a></code></li>
<li><code><a title="VITAE.VITAE.select_root" href="#VITAE.VITAE.select_root">select_root</a></code></li>
<li><code><a title="VITAE.VITAE.plot_backbone" href="#VITAE.VITAE.plot_backbone">plot_backbone</a></code></li>
<li><code><a title="VITAE.VITAE.plot_center" href="#VITAE.VITAE.plot_center">plot_center</a></code></li>
<li><code><a title="VITAE.VITAE.infer_trajectory" href="#VITAE.VITAE.infer_trajectory">infer_trajectory</a></code></li>
<li><code><a title="VITAE.VITAE.differential_expression_test" href="#VITAE.VITAE.differential_expression_test">differential_expression_test</a></code></li>
<li><code><a title="VITAE.VITAE.evaluate" href="#VITAE.VITAE.evaluate">evaluate</a></code></li>
<li><code><a title="VITAE.VITAE.save_model" href="#VITAE.VITAE.save_model">save_model</a></code></li>
<li><code><a title="VITAE.VITAE.load_model" href="#VITAE.VITAE.load_model">load_model</a></code></li>
</ul>
</li>
</ul>
</li>
</ul>
</nav>
</main>
<footer id="footer">
<p>Generated by <a href="https://pdoc3.github.io/pdoc" title="pdoc: Python API documentation generator"><cite>pdoc</cite> 0.11.1</a>.</p>
</footer>
</body>
</html>