--- a +++ b/docs/index.html @@ -0,0 +1,1851 @@ +<!doctype html> +<html lang="en"> +<head> +<meta charset="utf-8"> +<meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1"> +<meta name="generator" content="pdoc3 0.11.1"> +<title>VITAE API documentation</title> +<meta name="description" content=""> +<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/sanitize.min.css" integrity="sha512-y1dtMcuvtTMJc1yPgEqF0ZjQbhnc/bFhyvIyVNb9Zk5mIGtqVaAB1Ttl28su8AvFMOY0EwRbAe+HCLqj6W7/KA==" crossorigin> +<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/typography.min.css" integrity="sha512-Y1DYSb995BAfxobCkKepB1BqJJTPrOp3zPL74AWFugHHmmdcvO+C48WLrUOlhGMc0QG7AE3f7gmvvcrmX2fDoA==" crossorigin> +<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/default.min.css" crossorigin> +<style>:root{--highlight-color:#fe9}.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:1.5em;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:2em 0 .50em 0}h3{font-size:1.4em;margin:1.6em 0 .7em 0}h4{margin:0;font-size:105%}h1:target,h2:target,h3:target,h4:target,h5:target,h6:target{background:var(--highlight-color);padding:.2em 0}a{color:#058;text-decoration:none;transition:color .2s ease-in-out}a:visited{color:#503}a:hover{color:#b62}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900;font-weight:bold}pre code{font-size:.8em;line-height:1.4em;padding:1em;display:block}code{background:#f3f3f3;font-family:"DejaVu Sans Mono",monospace;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}dt:target .name{background:var(--highlight-color)}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}td{padding:0 .5em}.admonition{padding:.1em 1em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style> +<style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul ul{padding-left:1em}.toc > ul > li{margin-top:.5em}}</style> +<style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style> +<script type="text/x-mathjax-config">MathJax.Hub.Config({ tex2jax: { inlineMath: [ ['$','$'], ["\\(","\\)"] ], processEscapes: true } });</script> +<script async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS_CHTML" integrity="sha256-kZafAc6mZvK3W3v1pHOcUix30OHQN6pU/NO2oFkqZVw=" crossorigin></script> +<script defer src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js" integrity="sha512-D9gUyxqja7hBtkWpPWGt9wfbfaMGVt9gnyCvYa+jojwwPHLCzUm5i8rpk7vD7wNee9bA35eYIjobYPaQuKS1MQ==" crossorigin></script> +<script>window.addEventListener('DOMContentLoaded', () => { +hljs.configure({languages: ['bash', 'css', 'diff', 'graphql', 'ini', 'javascript', 'json', 'plaintext', 'python', 'python-repl', 'rust', 'shell', 'sql', 'typescript', 'xml', 'yaml']}); +hljs.highlightAll(); +})</script> +</head> +<body> +<main> +<article id="content"> +<header> +<h1 class="title">Package <code>VITAE</code></h1> +</header> +<section id="section-intro"> +</section> +<section> +<h2 class="section-title" id="header-submodules">Sub-modules</h2> +<dl> +<dt><code class="name"><a title="VITAE.inference" href="inference.html">VITAE.inference</a></code></dt> +<dd> +<div class="desc"></div> +</dd> +<dt><code class="name"><a title="VITAE.metric" href="metric.html">VITAE.metric</a></code></dt> +<dd> +<div class="desc"></div> +</dd> +<dt><code class="name"><a title="VITAE.model" href="model.html">VITAE.model</a></code></dt> +<dd> +<div class="desc"></div> +</dd> +<dt><code class="name"><a title="VITAE.train" href="train.html">VITAE.train</a></code></dt> +<dd> +<div class="desc"></div> +</dd> +<dt><code class="name"><a title="VITAE.utils" href="utils.html">VITAE.utils</a></code></dt> +<dd> +<div class="desc"></div> +</dd> +</dl> +</section> +<section> +</section> +<section> +</section> +<section> +<h2 class="section-title" id="header-classes">Classes</h2> +<dl> +<dt id="VITAE.VITAE"><code class="flex name class"> +<span>class <span class="ident">VITAE</span></span> +<span>(</span><span>adata: anndata._core.anndata.AnnData, covariates=None, pi_covariates=None, model_type: str = 'Gaussian', npc: int = 64, adata_layer_counts=None, copy_adata: bool = False, hidden_layers=[32], latent_space_dim: int = 16, conditions=None)</span> +</code></dt> +<dd> +<div class="desc"><p>Variational Inference for Trajectory by AutoEncoder.</p> +<p>Get input data for model. Data need to be first processed using scancy and stored as an AnnData object +The 'UMI' or 'non-UMI' model need the original count matrix, so the count matrix need to be saved in +adata.layers in order to use these models.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>adata</code></strong> : <code>sc.AnnData</code></dt> +<dd>The scanpy AnnData object. adata should already contain adata.var.highly_variable</dd> +<dt><strong><code>covariates</code></strong> : <code>list</code>, optional</dt> +<dd>A list of names of covariate vectors that are stored in adata.obs</dd> +<dt><strong><code>pi_covariates</code></strong> : <code>list</code>, optional</dt> +<dd>A list of names of covariate vectors used as input for pilayer</dd> +<dt><strong><code>model_type</code></strong> : <code>str</code>, optional</dt> +<dd>'UMI', 'non-UMI' and 'Gaussian', default is 'Gaussian'.</dd> +<dt><strong><code>npc</code></strong> : <code>int</code>, optional</dt> +<dd>The number of PCs to use when model_type is 'Gaussian'. The default is 64.</dd> +<dt><strong><code>adata_layer_counts</code></strong> : <code>str</code>, optional</dt> +<dd>the key name of adata.layers that stores the count data if model_type is +'UMI' or 'non-UMI'</dd> +<dt><strong><code>copy_adata</code></strong> : <code>bool</code>, optional<code>. Set to True if we don't want VITAE to modify the original adata. If set to True, self.adata will be an independent copy</code> of <code>the original adata. </code></dt> +<dd> </dd> +<dt><strong><code>hidden_layers</code></strong> : <code>list</code>, optional</dt> +<dd>The list of dimensions of layers of autoencoder between latent space and original space. Default is to have only one hidden layer with 32 nodes</dd> +<dt><strong><code>latent_space_dim</code></strong> : <code>int</code>, optional</dt> +<dd>The dimension of latent space.</dd> +<dt><strong><code>gamme</code></strong> : <code>float</code>, optional</dt> +<dd>The weight of the MMD loss</dd> +<dt><strong><code>conditions</code></strong> : <code>str</code> or <code>list</code>, optional</dt> +<dd>The conditions of different cells</dd> +</dl> +<h2 id="returns">Returns</h2> +<p>None.</p></div> +<details class="source"> +<summary> +<span>Expand source code</span> +</summary> +<pre><code class="python">class VITAE(): + """ + Variational Inference for Trajectory by AutoEncoder. + """ + def __init__(self, adata: sc.AnnData, + covariates = None, pi_covariates = None, + model_type: str = 'Gaussian', + npc: int = 64, + adata_layer_counts = None, + copy_adata: bool = False, + hidden_layers = [32], + latent_space_dim: int = 16, + conditions = None): + ''' + Get input data for model. Data need to be first processed using scancy and stored as an AnnData object + The 'UMI' or 'non-UMI' model need the original count matrix, so the count matrix need to be saved in + adata.layers in order to use these models. + + + Parameters + ---------- + adata : sc.AnnData + The scanpy AnnData object. adata should already contain adata.var.highly_variable + covariates : list, optional + A list of names of covariate vectors that are stored in adata.obs + pi_covariates: list, optional + A list of names of covariate vectors used as input for pilayer + model_type : str, optional + 'UMI', 'non-UMI' and 'Gaussian', default is 'Gaussian'. + npc : int, optional + The number of PCs to use when model_type is 'Gaussian'. The default is 64. + adata_layer_counts: str, optional + the key name of adata.layers that stores the count data if model_type is + 'UMI' or 'non-UMI' + copy_adata: bool, optional. Set to True if we don't want VITAE to modify the original adata. If set to True, self.adata will be an independent copy of the original adata. + hidden_layers : list, optional + The list of dimensions of layers of autoencoder between latent space and original space. Default is to have only one hidden layer with 32 nodes + latent_space_dim : int, optional + The dimension of latent space. + gamme : float, optional + The weight of the MMD loss + conditions : str or list, optional + The conditions of different cells + + + Returns + ------- + None. + + ''' + self.dict_method_scname = { + 'PCA' : 'X_pca', + 'UMAP' : 'X_umap', + 'TSNE' : 'X_tsne', + 'diffmap' : 'X_diffmap', + 'draw_graph' : 'X_draw_graph_fa' + } + + if model_type != 'Gaussian': + if adata_layer_counts is None: + raise ValueError("need to provide the name in adata.layers that stores the raw count data") + if 'highly_variable' not in adata.var: + raise ValueError("need to first select highly variable genes using scanpy") + + self.model_type = model_type + + if copy_adata: + self.adata = adata.copy() + else: + self.adata = adata + + if covariates is not None: + if isinstance(covariates, str): + covariates = [covariates] + covariates = np.array(covariates) + id_cat = (adata.obs[covariates].dtypes == 'category') + # add OneHotEncoder & StandardScaler as class variable if needed + if np.sum(id_cat)>0: + covariates_cat = OneHotEncoder(drop='if_binary', handle_unknown='ignore' + ).fit_transform(adata.obs[covariates[id_cat]]).toarray() + else: + covariates_cat = np.array([]).reshape(adata.shape[0],0) + + # temporarily disable StandardScaler + if np.sum(~id_cat)>0: + #covariates_con = StandardScaler().fit_transform(adata.obs[covariates[~id_cat]]) + covariates_con = adata.obs[covariates[~id_cat]] + else: + covariates_con = np.array([]).reshape(adata.shape[0],0) + + self.covariates = np.c_[covariates_cat, covariates_con].astype(tf.keras.backend.floatx()) + else: + self.covariates = None + + if conditions is not None: + ## observations with np.nan will not participant in calculating mmd_loss + if isinstance(conditions, str): + conditions = [conditions] + conditions = np.array(conditions) + if np.any(adata.obs[conditions].dtypes != 'category'): + raise ValueError("Conditions should all be categorical.") + + self.conditions = OrdinalEncoder(dtype=int, encoded_missing_value=-1).fit_transform(adata.obs[conditions]) + int(1) + else: + self.conditions = None + + if pi_covariates is not None: + self.pi_cov = adata.obs[pi_covariates].to_numpy() + if self.pi_cov.ndim == 1: + self.pi_cov = self.pi_cov.reshape(-1, 1) + self.pi_cov = self.pi_cov.astype(tf.keras.backend.floatx()) + else: + self.pi_cov = np.zeros((adata.shape[0],1), dtype=tf.keras.backend.floatx()) + + self.model_type = model_type + self._adata = sc.AnnData(X = self.adata.X, var = self.adata.var) + self._adata.obs = self.adata.obs + self._adata.uns = self.adata.uns + + + if model_type == 'Gaussian': + sc.tl.pca(adata, n_comps = npc) + self.X_input = self.X_output = adata.obsm['X_pca'] + self.scale_factor = np.ones(self.X_output.shape[0]) + else: + print(f"{adata.var.highly_variable.sum()} highly variable genes selected as input") + self.X_input = adata.X[:, adata.var.highly_variable] + self.X_output = adata.layers[adata_layer_counts][ :, adata.var.highly_variable] + self.scale_factor = np.sum(self.X_output, axis=1, keepdims=True)/1e4 + + self.dimensions = hidden_layers + self.dim_latent = latent_space_dim + + self.vae = model.VariationalAutoEncoder( + self.X_output.shape[1], self.dimensions, + self.dim_latent, self.model_type, + False if self.covariates is None else True, + ) + + if hasattr(self, 'inferer'): + delattr(self, 'inferer') + + + def pre_train(self, test_size = 0.1, random_state: int = 0, + learning_rate: float = 1e-3, batch_size: int = 256, L: int = 1, alpha: float = 0.10, gamma: float = 0, + phi : float = 1,num_epoch: int = 200, num_step_per_epoch: Optional[int] = None, + early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, + early_stopping_relative: bool = True, verbose: bool = False,path_to_weights: Optional[str] = None): + '''Pretrain the model with specified learning rate. + + Parameters + ---------- + test_size : float or int, optional + The proportion or size of the test set. + random_state : int, optional + The random state for data splitting. + learning_rate : float, optional + The initial learning rate for the Adam optimizer. + batch_size : int, optional + The batch size for pre-training. Default is 256. Set to 32 if number of cells is small (less than 1000) + L : int, optional + The number of MC samples. + alpha : float, optional + The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates. + gamma : float, optional + The weight of the mmd loss if used. + phi : float, optional + The weight of Jocob norm of the encoder. + num_epoch : int, optional + The maximum number of epochs. + num_step_per_epoch : int, optional + The number of step per epoch, it will be inferred from number of cells and batch size if it is None. + early_stopping_patience : int, optional + The maximum number of epochs if there is no improvement. + early_stopping_tolerance : float, optional + The minimum change of loss to be considered as an improvement. + early_stopping_relative : bool, optional + Whether monitor the relative change of loss as stopping criteria or not. + path_to_weights : str, optional + The path of weight file to be saved; not saving weight if None. + conditions : str or list, optional + The conditions of different cells + ''' + + id_train, id_test = train_test_split( + np.arange(self.X_input.shape[0]), + test_size=test_size, + random_state=random_state) + if num_step_per_epoch is None: + num_step_per_epoch = len(id_train)//batch_size+1 + self.train_dataset = train.warp_dataset(self.X_input[id_train].astype(tf.keras.backend.floatx()), + None if self.covariates is None else self.covariates[id_train].astype(tf.keras.backend.floatx()), + batch_size, + self.X_output[id_train].astype(tf.keras.backend.floatx()), + self.scale_factor[id_train].astype(tf.keras.backend.floatx()), + conditions = None if self.conditions is None else self.conditions[id_train].astype(tf.keras.backend.floatx())) + self.test_dataset = train.warp_dataset(self.X_input[id_test], + None if self.covariates is None else self.covariates[id_test].astype(tf.keras.backend.floatx()), + batch_size, + self.X_output[id_test].astype(tf.keras.backend.floatx()), + self.scale_factor[id_test].astype(tf.keras.backend.floatx()), + conditions = None if self.conditions is None else self.conditions[id_test].astype(tf.keras.backend.floatx())) + + self.vae = train.pre_train( + self.train_dataset, + self.test_dataset, + self.vae, + learning_rate, + L, alpha, gamma, phi, + num_epoch, + num_step_per_epoch, + early_stopping_patience, + early_stopping_tolerance, + early_stopping_relative, + verbose) + + self.update_z() + + if path_to_weights is not None: + self.save_model(path_to_weights) + + + def update_z(self): + self.z = self.get_latent_z() + self._adata_z = sc.AnnData(self.z) + sc.pp.neighbors(self._adata_z) + + + def get_latent_z(self): + ''' get the posterier mean of current latent space z (encoder output) + + Returns + ---------- + z : np.array + \([N,d]\) The latent means. + ''' + c = None if self.covariates is None else self.covariates + return self.vae.get_z(self.X_input, c) + + + def visualize_latent(self, method: str = "UMAP", + color = None, **kwargs): + ''' + visualize the current latent space z using the scanpy visualization tools + + Parameters + ---------- + method : str, optional + Visualization method to use. The default is "draw_graph" (the FA plot). Possible choices include "PCA", "UMAP", + "diffmap", "TSNE" and "draw_graph" + color : TYPE, optional + Keys for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2']. + The default is None. Same as scanpy. + **kwargs : + Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX). + + Returns + ------- + None. + + ''' + + if method not in ['PCA', 'UMAP', 'TSNE', 'diffmap', 'draw_graph']: + raise ValueError("visualization method should be one of 'PCA', 'UMAP', 'TSNE', 'diffmap' and 'draw_graph'") + + temp = list(self._adata_z.obsm.keys()) + if method == 'PCA' and not 'X_pca' in temp: + print("Calculate PCs ...") + sc.tl.pca(self._adata_z) + elif method == 'UMAP' and not 'X_umap' in temp: + print("Calculate UMAP ...") + sc.tl.umap(self._adata_z) + elif method == 'TSNE' and not 'X_tsne' in temp: + print("Calculate TSNE ...") + sc.tl.tsne(self._adata_z) + elif method == 'diffmap' and not 'X_diffmap' in temp: + print("Calculate diffusion map ...") + sc.tl.diffmap(self._adata_z) + elif method == 'draw_graph' and not 'X_draw_graph_fa' in temp: + print("Calculate FA ...") + sc.tl.draw_graph(self._adata_z) + + + self._adata.obs = self.adata.obs.copy() + self._adata.obsp = self._adata_z.obsp +# self._adata.uns = self._adata_z.uns + self._adata.obsm = self._adata_z.obsm + + if method == 'PCA': + axes = sc.pl.pca(self._adata, color = color, **kwargs) + elif method == 'UMAP': + axes = sc.pl.umap(self._adata, color = color, **kwargs) + elif method == 'TSNE': + axes = sc.pl.tsne(self._adata, color = color, **kwargs) + elif method == 'diffmap': + axes = sc.pl.diffmap(self._adata, color = color, **kwargs) + elif method == 'draw_graph': + axes = sc.pl.draw_graph(self._adata, color = color, **kwargs) + return axes + + + def init_latent_space(self, cluster_label = None, log_pi = None, res: float = 1.0, + ratio_prune= None, dist = None, dist_thres = 0.5, topk=0, pilayer = False): + '''Initialize the latent space. + + Parameters + ---------- + cluster_label : str, optional + The name of vector of labels that can be found in self.adata.obs. + Default is None, which will perform leiden clustering on the pretrained z to get clusters + mu : np.array, optional + \([d,k]\) The value of initial \(\\mu\). + log_pi : np.array, optional + \([1,K]\) The value of initial \(\\log(\\pi)\). + res: + The resolution of leiden clustering, which is a parameter value controlling the coarseness of the clustering. + Higher values lead to more clusters. Deafult is 1. + ratio_prune : float, optional + The ratio of edges to be removed before estimating. + topk : int, optional + The number of top k neighbors to keep for each cluster. + ''' + + + if cluster_label is None: + print("Perform leiden clustering on the latent space z ...") + g = get_igraph(self.z) + cluster_labels = leidenalg_igraph(g, res = res) + cluster_labels = cluster_labels.astype(str) + uni_cluster_labels = np.unique(cluster_labels) + else: + if isinstance(cluster_label,str): + cluster_labels = self.adata.obs[cluster_label].to_numpy() + uni_cluster_labels = np.array(self.adata.obs[cluster_label].cat.categories) + else: + ## if cluster_label is a list + cluster_labels = cluster_label + uni_cluster_labels = np.unique(cluster_labels) + + n_clusters = len(uni_cluster_labels) + + if not hasattr(self, 'z'): + self.update_z() + z = self.z + mu = np.zeros((z.shape[1], n_clusters)) + for i,l in enumerate(uni_cluster_labels): + mu[:,i] = np.mean(z[cluster_labels==l], axis=0) + + if dist is None: + ### update cluster centers if some cluster centers are too close + clustering = AgglomerativeClustering( + n_clusters=None, + distance_threshold=dist_thres, + linkage='complete' + ).fit(mu.T/np.sqrt(mu.shape[0])) + n_clusters_new = clustering.n_clusters_ + if n_clusters_new < n_clusters: + print("Merge clusters for cluster centers that are too close ...") + n_clusters = n_clusters_new + for i in range(n_clusters): + temp = uni_cluster_labels[clustering.labels_ == i] + idx = np.isin(cluster_labels, temp) + cluster_labels[idx] = ','.join(temp) + if np.sum(clustering.labels_==i)>1: + print('Merge %s'% ','.join(temp)) + uni_cluster_labels = np.unique(cluster_labels) + mu = np.zeros((z.shape[1], n_clusters)) + for i,l in enumerate(uni_cluster_labels): + mu[:,i] = np.mean(z[cluster_labels==l], axis=0) + + self.adata.obs['vitae_init_clustering'] = cluster_labels + self.adata.obs['vitae_init_clustering'] = self.adata.obs['vitae_init_clustering'].astype('category') + print("Initial clustering labels saved as 'vitae_init_clustering' in self.adata.obs.") + + if (log_pi is None) and (cluster_labels is not None) and (n_clusters>3): + n_states = int((n_clusters+1)*n_clusters/2) + + if dist is None: + dist = _comp_dist(z, cluster_labels, mu.T) + + C = np.triu(np.ones(n_clusters)) + C[C>0] = np.arange(n_states) + C = C + C.T - np.diag(np.diag(C)) + C = C.astype(int) + + log_pi = np.zeros((1,n_states)) + + ## pruning to throw away edges for far-away clusters if there are too many clusters + if ratio_prune is not None: + log_pi[0, C[np.triu(dist)>np.quantile(dist[np.triu_indices(n_clusters, 1)], 1-ratio_prune)]] = - np.inf + else: + log_pi[0, C[np.triu(dist)>np.quantile(dist[np.triu_indices(n_clusters, 1)], 5/n_clusters) * 3]] = - np.inf + + ## also keep the top k neighbor of clusters + topk = max(0, min(topk, n_clusters-1)) + 1 + topk_indices = np.argsort(dist,axis=1)[:,:topk] + for i in range(n_clusters): + log_pi[0, C[i, topk_indices[i]]] = 0 + + self.n_states = n_clusters + self.labels = cluster_labels + + labels_map = pd.DataFrame.from_dict( + {i:label for i,label in enumerate(uni_cluster_labels)}, + orient='index', columns=['label_names'], dtype=str + ) + + self.labels_map = labels_map + self.vae.init_latent_space(self.n_states, mu, log_pi) + self.inferer = Inferer(self.n_states) + self.mu = self.vae.latent_space.mu.numpy() + self.pi = np.triu(np.ones(self.n_states)) + self.pi[self.pi > 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0] + + if pilayer: + self.vae.create_pilayer() + + + def update_latent_space(self, dist_thres: float=0.5): + pi = self.pi[np.triu_indices(self.n_states)] + mu = self.mu + clustering = AgglomerativeClustering( + n_clusters=None, + distance_threshold=dist_thres, + linkage='complete' + ).fit(mu.T/np.sqrt(mu.shape[0])) + n_clusters = clustering.n_clusters_ + + if n_clusters<self.n_states: + print("Merge clusters for cluster centers that are too close ...") + mu_new = np.empty((self.dim_latent, n_clusters)) + C = np.zeros((self.n_states, self.n_states)) + C[np.triu_indices(self.n_states, 0)] = pi + C = np.triu(C, 1) + C.T + C_new = np.zeros((n_clusters, n_clusters)) + + uni_cluster_labels = self.labels_map['label_names'].to_numpy() + returned_order = {} + cluster_labels = self.labels + for i in range(n_clusters): + temp = uni_cluster_labels[clustering.labels_ == i] + idx = np.isin(cluster_labels, temp) + cluster_labels[idx] = ','.join(temp) + returned_order[i] = ','.join(temp) + if np.sum(clustering.labels_==i)>1: + print('Merge %s'% ','.join(temp)) + uni_cluster_labels = np.unique(cluster_labels) + for i,l in enumerate(uni_cluster_labels): ## reorder the merged clusters based on the cluster names + k = np.where(returned_order == l) + mu_new[:, i] = np.mean(mu[:,clustering.labels_==k], axis=-1) + # sum of the aggregated pi's + C_new[i, i] = np.sum(np.triu(C[clustering.labels_==k,:][:,clustering.labels_==k])) + for j in range(i+1, n_clusters): + k1 = np.where(returned_order == uni_cluster_labels[j]) + C_new[i, j] = np.sum(C[clustering.labels_== k, :][:, clustering.labels_==k1]) + +# labels_map_new = {} +# for i in range(n_clusters): +# # update label map: int->str +# labels_map_new[i] = self.labels_map.loc[clustering.labels_==i, 'label_names'].str.cat(sep=',') +# if np.sum(clustering.labels_==i)>1: +# print('Merge %s'%labels_map_new[i]) +# # mean of the aggregated cluster means +# mu_new[:, i] = np.mean(mu[:,clustering.labels_==i], axis=-1) +# # sum of the aggregated pi's +# C_new[i, i] = np.sum(np.triu(C[clustering.labels_==i,:][:,clustering.labels_==i])) +# for j in range(i+1, n_clusters): +# C_new[i, j] = np.sum(C[clustering.labels_== i, :][:, clustering.labels_==j]) + C_new = np.triu(C_new,1) + C_new.T + + pi_new = C_new[np.triu_indices(n_clusters)] + log_pi_new = np.log(pi_new, out=np.ones_like(pi_new)*(-np.inf), where=(pi_new!=0)).reshape((1,-1)) + self.n_states = n_clusters + self.labels_map = pd.DataFrame.from_dict( + {i:label for i,label in enumerate(uni_cluster_labels)}, + orient='index', columns=['label_names'], dtype=str + ) + self.labels = cluster_labels +# self.labels_map = pd.DataFrame.from_dict( +# labels_map_new, orient='index', columns=['label_names'], dtype=str +# ) + self.vae.init_latent_space(self.n_states, mu_new, log_pi_new) + self.inferer = Inferer(self.n_states) + self.mu = self.vae.latent_space.mu.numpy() + self.pi = np.triu(np.ones(self.n_states)) + self.pi[self.pi > 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0] + + + + def train(self, stratify = False, test_size = 0.1, random_state: int = 0, + learning_rate: float = 1e-3, batch_size: int = 256, + L: int = 1, alpha: float = 0.10, beta: float = 1, gamma: float = 0, phi: float = 1, + num_epoch: int = 200, num_step_per_epoch: Optional[int] = None, + early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, + early_stopping_relative: bool = True, early_stopping_warmup: int = 0, + path_to_weights: Optional[str] = None, + verbose: bool = False, **kwargs): + '''Train the model. + + Parameters + ---------- + stratify : np.array, None, or False + If an array is provided, or `stratify=None` and `self.labels` is available, then they will be used to perform stratified shuffle splitting. Otherwise, general shuffle splitting is used. Set to `False` if `self.labels` is not intended for stratified shuffle splitting. + test_size : float or int, optional + The proportion or size of the test set. + random_state : int, optional + The random state for data splitting. + learning_rate : float, optional + The initial learning rate for the Adam optimizer. + batch_size : int, optional + The batch size for training. Default is 256. Set to 32 if number of cells is small (less than 1000) + L : int, optional + The number of MC samples. + alpha : float, optional + The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates. + beta : float, optional + The value of beta in beta-VAE. + gamma : float, optional + The weight of mmd_loss. + phi : float, optional + The weight of Jacob norm of encoder. + num_epoch : int, optional + The number of epoch. + num_step_per_epoch : int, optional + The number of step per epoch, it will be inferred from number of cells and batch size if it is None. + early_stopping_patience : int, optional + The maximum number of epochs if there is no improvement. + early_stopping_tolerance : float, optional + The minimum change of loss to be considered as an improvement. + early_stopping_relative : bool, optional + Whether monitor the relative change of loss or not. + early_stopping_warmup : int, optional + The number of warmup epochs. + path_to_weights : str, optional + The path of weight file to be saved; not saving weight if None. + **kwargs : + Extra key-value arguments for dimension reduction algorithms. + ''' + if gamma == 0 or self.conditions is None: + conditions = np.array([np.nan] * self.adata.shape[0]) + else: + conditions = self.conditions + + if stratify is None: + stratify = self.labels + elif stratify is False: + stratify = None + id_train, id_test = train_test_split( + np.arange(self.X_input.shape[0]), + test_size=test_size, + stratify=stratify, + random_state=random_state) + if num_step_per_epoch is None: + num_step_per_epoch = len(id_train)//batch_size+1 + c = None if self.covariates is None else self.covariates.astype(tf.keras.backend.floatx()) + self.train_dataset = train.warp_dataset(self.X_input[id_train].astype(tf.keras.backend.floatx()), + None if c is None else c[id_train], + batch_size, + self.X_output[id_train].astype(tf.keras.backend.floatx()), + self.scale_factor[id_train].astype(tf.keras.backend.floatx()), + conditions = conditions[id_train], + pi_cov = self.pi_cov[id_train]) + self.test_dataset = train.warp_dataset(self.X_input[id_test].astype(tf.keras.backend.floatx()), + None if c is None else c[id_test], + batch_size, + self.X_output[id_test].astype(tf.keras.backend.floatx()), + self.scale_factor[id_test].astype(tf.keras.backend.floatx()), + conditions = conditions[id_test], + pi_cov = self.pi_cov[id_test]) + + self.vae = train.train( + self.train_dataset, + self.test_dataset, + self.vae, + learning_rate, + L, + alpha, + beta, + gamma, + phi, + num_epoch, + num_step_per_epoch, + early_stopping_patience, + early_stopping_tolerance, + early_stopping_relative, + early_stopping_warmup, + verbose, + **kwargs + ) + + self.update_z() + self.mu = self.vae.latent_space.mu.numpy() + self.pi = np.triu(np.ones(self.n_states)) + self.pi[self.pi > 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0] + + if path_to_weights is not None: + self.save_model(path_to_weights) + + + def output_pi(self, pi_cov): + """return a matrix n_states by n_states and a mask for plotting, which can be used to cover the lower triangular(except the diagnoals) of a heatmap""" + p = self.vae.pilayer + pi_cov = tf.expand_dims(tf.constant([pi_cov], dtype=tf.float32), 0) + pi_val = tf.nn.softmax(p(pi_cov)).numpy()[0] + # Create heatmap matrix + n = self.vae.n_states + matrix = np.zeros((n, n)) + matrix[np.triu_indices(n)] = pi_val + mask = np.tril(np.ones_like(matrix), k=-1) + return matrix, mask + + + def return_pilayer_weights(self): + """return parameters of pilayer, which has dimension dim(pi_cov) + 1 by n_categories, the last row is biases""" + return np.vstack((model.vae.pilayer.weights[0].numpy(), model.vae.pilayer.weights[1].numpy().reshape(1, -1))) + + + def posterior_estimation(self, batch_size: int = 32, L: int = 50, **kwargs): + '''Initialize trajectory inference by computing the posterior estimations. + + Parameters + ---------- + batch_size : int, optional + The batch size when doing inference. + L : int, optional + The number of MC samples when doing inference. + **kwargs : + Extra key-value arguments for dimension reduction algorithms. + ''' + c = None if self.covariates is None else self.covariates.astype(tf.keras.backend.floatx()) + self.test_dataset = train.warp_dataset(self.X_input.astype(tf.keras.backend.floatx()), + c, + batch_size) + _, _, self.pc_x,\ + self.cell_position_posterior,self.cell_position_variance,_ = self.vae.inference(self.test_dataset, L=L) + + uni_cluster_labels = self.labels_map['label_names'].to_numpy() + self.adata.obs['vitae_new_clustering'] = uni_cluster_labels[np.argmax(self.cell_position_posterior, 1)] + self.adata.obs['vitae_new_clustering'] = self.adata.obs['vitae_new_clustering'].astype('category') + print("New clustering labels saved as 'vitae_new_clustering' in self.adata.obs.") + return None + + + def infer_backbone(self, method: str = 'modified_map', thres = 0.5, + no_loop: bool = True, cutoff: float = 0, + visualize: bool = True, color = 'vitae_new_clustering',path_to_fig = None,**kwargs): + ''' Compute edge scores. + + Parameters + ---------- + method : string, optional + 'mean', 'modified_mean', 'map', or 'modified_map'. + thres : float, optional + The threshold used for filtering edges \(e_{ij}\) that \((n_{i}+n_{j}+e_{ij})/N<thres\), only applied to mean method. + no_loop : boolean, optional + Whether loops are allowed to exist in the graph. If no_loop is true, will prune the graph to contain only the + maximum spanning true + cutoff : string, optional + The score threshold for filtering edges with scores less than cutoff. + visualize: boolean + whether plot the current trajectory backbone (undirected graph) + + Returns + ---------- + G : nx.Graph + The weighted graph with weight on each edge indicating its score of existence. + ''' + # build_graph, return graph + self.backbone = self.inferer.build_graphs(self.cell_position_posterior, self.pc_x, + method, thres, no_loop, cutoff) + self.cell_position_projected = self.inferer.modify_wtilde(self.cell_position_posterior, + np.array(list(self.backbone.edges))) + + uni_cluster_labels = self.labels_map['label_names'].to_numpy() + temp_dict = {i:label for i,label in enumerate(uni_cluster_labels)} + nx.relabel_nodes(self.backbone, temp_dict) + + self.adata.obs['vitae_new_clustering'] = uni_cluster_labels[np.argmax(self.cell_position_projected, 1)] + self.adata.obs['vitae_new_clustering'] = self.adata.obs['vitae_new_clustering'].astype('category') + print("'vitae_new_clustering' updated based on the projected cell positions.") + + self.uncertainty = np.sum((self.cell_position_projected - self.cell_position_posterior)**2, axis=-1) \ + + np.sum(self.cell_position_variance, axis=-1) + self.adata.obs['projection_uncertainty'] = self.uncertainty + print("Cell projection uncertainties stored as 'projection_uncertainty' in self.adata.obs") + if visualize: + self._adata.obs = self.adata.obs.copy() + self.ax = self.plot_backbone(directed = False,color = color, **kwargs) + if path_to_fig is not None: + self.ax.figure.savefig(path_to_fig) + self.ax.figure.show() + return None + + + def select_root(self, days, method: str = 'proportion'): + '''Order the vertices/states based on cells' collection time information to select the root state. + + Parameters + ---------- + day : np.array + The day information for selected cells used to determine the root vertex. + The dtype should be 'int' or 'float'. + method : str, optional + 'sum' or 'mean'. + For 'proportion', the root is the one with maximal proportion of cells from the earliest day. + For 'mean', the root is the one with earliest mean time among cells associated with it. + + Returns + ---------- + root : int + The root vertex in the inferred trajectory based on given day information. + ''' + ## TODO: change return description + if days is not None and len(days)!=self.X_input.shape[0]: + raise ValueError("The length of day information ({}) is not " + "consistent with the number of selected cells ({})!".format( + len(days), self.X_input.shape[0])) + if not hasattr(self, 'cell_position_projected'): + raise ValueError("Need to call 'infer_backbone' first!") + + collection_time = np.dot(days, self.cell_position_projected)/np.sum(self.cell_position_projected, axis = 0) + earliest_prop = np.dot(days==np.min(days), self.cell_position_projected)/np.sum(self.cell_position_projected, axis = 0) + + root_info = self.labels_map.copy() + root_info['mean_collection_time'] = collection_time + root_info['earliest_time_prop'] = earliest_prop + root_info.sort_values('mean_collection_time', inplace=True) + return root_info + + + def plot_backbone(self, directed: bool = False, + method: str = 'UMAP', color = 'vitae_new_clustering', **kwargs): + '''Plot the current trajectory backbone (undirected graph). + + Parameters + ---------- + directed : boolean, optional + Whether the backbone is directed or not. + method : str, optional + The dimension reduction method to use. The default is "UMAP". + color : str, optional + The key for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2']. + The default is 'vitae_new_clustering'. + **kwargs : + Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX). + ''' + if not isinstance(color,str): + raise ValueError('The color argument should be of type str!') + ax = self.visualize_latent(method = method, color=color, show=False, **kwargs) + dict_label_num = {j:i for i,j in self.labels_map['label_names'].to_dict().items()} + uni_cluster_labels = self.adata.obs['vitae_init_clustering'].cat.categories + cluster_labels = self.adata.obs['vitae_new_clustering'].to_numpy() + embed_z = self._adata.obsm[self.dict_method_scname[method]] + embed_mu = np.zeros((len(uni_cluster_labels), 2)) + for l in uni_cluster_labels: + embed_mu[dict_label_num[l],:] = np.mean(embed_z[cluster_labels==l], axis=0) + + if directed: + graph = self.directed_backbone + else: + graph = self.backbone + edges = list(graph.edges) + edge_scores = np.array([d['weight'] for (u,v,d) in graph.edges(data=True)]) + if max(edge_scores) - min(edge_scores) == 0: + edge_scores = edge_scores/max(edge_scores) + else: + edge_scores = (edge_scores - min(edge_scores))/(max(edge_scores) - min(edge_scores))*3 + + value_range = np.maximum(np.diff(ax.get_xlim())[0], np.diff(ax.get_ylim())[0]) + y_range = np.min(embed_z[:,1]), np.max(embed_z[:,1], axis=0) + for i in range(len(edges)): + points = embed_z[np.sum(self.cell_position_projected[:, edges[i]]>0, axis=-1)==2,:] + points = points[points[:,0].argsort()] + try: + x_smooth, y_smooth = _get_smooth_curve( + points, + embed_mu[edges[i], :], + y_range + ) + except: + x_smooth, y_smooth = embed_mu[edges[i], 0], embed_mu[edges[i], 1] + ax.plot(x_smooth, y_smooth, + '-', + linewidth= 1 + edge_scores[i], + color="black", + alpha=0.8, + path_effects=[pe.Stroke(linewidth=1+edge_scores[i]+1.5, + foreground='white'), pe.Normal()], + zorder=1 + ) + + if directed: + delta_x = embed_mu[edges[i][1], 0] - x_smooth[-2] + delta_y = embed_mu[edges[i][1], 1] - y_smooth[-2] + length = np.sqrt(delta_x**2 + delta_y**2) / 50 * value_range + ax.arrow( + embed_mu[edges[i][1], 0]-delta_x/length, + embed_mu[edges[i][1], 1]-delta_y/length, + delta_x/length, + delta_y/length, + color='black', alpha=1.0, + shape='full', lw=0, length_includes_head=True, + head_width=np.maximum(0.01*(1 + edge_scores[i]), 0.03) * value_range, + zorder=2) + + colors = self._adata.uns['vitae_new_clustering_colors'] + + for i,l in enumerate(uni_cluster_labels): + ax.scatter(*embed_mu[dict_label_num[l]:dict_label_num[l]+1,:].T, + c=[colors[i]], edgecolors='white', # linewidths=10, norm=norm, + s=250, marker='*', label=l) + + plt.setp(ax, xticks=[], yticks=[]) + box = ax.get_position() + ax.set_position([box.x0, box.y0 + box.height * 0.1, + box.width, box.height * 0.9]) + if directed: + ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), + fancybox=True, shadow=True, ncol=5) + + return ax + + + def plot_center(self, color = "vitae_new_clustering", plot_legend = True, legend_add_index = True, + method: str = 'UMAP',ncol = 2,font_size = "medium", + add_egde = False, add_direct = False,**kwargs): + '''Plot the center of each cluster in the latent space. + + Parameters + ---------- + color : str, optional + The color of the center of each cluster. Default is "vitae_new_clustering". + plot_legend : bool, optional + Whether to plot the legend. Default is True. + legend_add_index : bool, optional + Whether to add the index of each cluster in the legend. Default is True. + method : str, optional + The dimension reduction method used for visualization. Default is 'UMAP'. + ncol : int, optional + The number of columns in the legend. Default is 2. + font_size : str, optional + The font size of the legend. Default is "medium". + add_egde : bool, optional + Whether to add the edges between the centers of clusters. Default is False. + add_direct : bool, optional + Whether to add the direction of the edges. Default is False. + ''' + if color not in ["vitae_new_clustering","vitae_init_clustering"]: + raise ValueError("Can only plot center of vitae_new_clustering or vitae_init_clustering") + dict_label_num = {j: i for i, j in self.labels_map['label_names'].to_dict().items()} + if legend_add_index: + self._adata.obs["index_"+color] = self._adata.obs[color].map(lambda x: dict_label_num[x]) + ax = self.visualize_latent(method=method, color="index_" + color, show=False, legend_loc="on data", + legend_fontsize=font_size,**kwargs) + colors = self._adata.uns["index_" + color + '_colors'] + else: + ax = self.visualize_latent(method=method, color = color, show=False,**kwargs) + colors = self._adata.uns[color + '_colors'] + uni_cluster_labels = self.adata.obs[color].cat.categories + cluster_labels = self.adata.obs[color].to_numpy() + embed_z = self._adata.obsm[self.dict_method_scname[method]] + embed_mu = np.zeros((len(uni_cluster_labels), 2)) + for l in uni_cluster_labels: + embed_mu[dict_label_num[l], :] = np.mean(embed_z[cluster_labels == l], axis=0) + + leg = (self.labels_map.index.astype(str) + " : " + self.labels_map.label_names).values + for i, l in enumerate(uni_cluster_labels): + ax.scatter(*embed_mu[dict_label_num[l]:dict_label_num[l] + 1, :].T, + c=[colors[i]], edgecolors='white', # linewidths=3, + s=250, marker='*', label=leg[i]) + if plot_legend: + ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=ncol, markerscale=0.8, frameon=False) + plt.setp(ax, xticks=[], yticks=[]) + box = ax.get_position() + ax.set_position([box.x0, box.y0 + box.height * 0.1, + box.width, box.height * 0.9]) + if add_egde: + if add_direct: + graph = self.directed_backbone + else: + graph = self.backbone + edges = list(graph.edges) + edge_scores = np.array([d['weight'] for (u, v, d) in graph.edges(data=True)]) + if max(edge_scores) - min(edge_scores) == 0: + edge_scores = edge_scores / max(edge_scores) + else: + edge_scores = (edge_scores - min(edge_scores)) / (max(edge_scores) - min(edge_scores)) * 3 + + value_range = np.maximum(np.diff(ax.get_xlim())[0], np.diff(ax.get_ylim())[0]) + y_range = np.min(embed_z[:, 1]), np.max(embed_z[:, 1], axis=0) + for i in range(len(edges)): + points = embed_z[np.sum(self.cell_position_projected[:, edges[i]] > 0, axis=-1) == 2, :] + points = points[points[:, 0].argsort()] + try: + x_smooth, y_smooth = _get_smooth_curve( + points, + embed_mu[edges[i], :], + y_range + ) + except: + x_smooth, y_smooth = embed_mu[edges[i], 0], embed_mu[edges[i], 1] + ax.plot(x_smooth, y_smooth, + '-', + linewidth=1 + edge_scores[i], + color="black", + alpha=0.8, + path_effects=[pe.Stroke(linewidth=1 + edge_scores[i] + 1.5, + foreground='white'), pe.Normal()], + zorder=1 + ) + + if add_direct: + delta_x = embed_mu[edges[i][1], 0] - x_smooth[-2] + delta_y = embed_mu[edges[i][1], 1] - y_smooth[-2] + length = np.sqrt(delta_x ** 2 + delta_y ** 2) / 50 * value_range + ax.arrow( + embed_mu[edges[i][1], 0] - delta_x / length, + embed_mu[edges[i][1], 1] - delta_y / length, + delta_x / length, + delta_y / length, + color='black', alpha=1.0, + shape='full', lw=0, length_includes_head=True, + head_width=np.maximum(0.01 * (1 + edge_scores[i]), 0.03) * value_range, + zorder=2) + self.ax = ax + self.ax.figure.show() + return None + + + def infer_trajectory(self, root: Union[int,str], digraph = None, color = "pseudotime", + visualize: bool = True, path_to_fig = None, **kwargs): + '''Infer the trajectory. + + Parameters + ---------- + root : int or string + The root of the inferred trajectory. Can provide either an int (vertex index) or string (label name) + digraph : nx.DiGraph, optional + The directed graph to be used for trajectory inference. If None, the minimum spanning tree of the estimated trajectory backbone will be used. + cutoff : string, optional + The threshold for filtering edges with scores less than cutoff. + visualize: boolean + Whether plot the current trajectory backbone (directed graph) + path_to_fig : string, optional + The path to save figure, or don't save if it is None. + **kwargs : dict, optional + Other keywords arguments for plotting. + ''' + if isinstance(root,str): + if root not in self.labels_map.values: + raise ValueError("Root {} is not in the label names!".format(root)) + root = self.labels_map[self.labels_map['label_names']==root].index[0] + + if digraph is None: + connected_comps = nx.node_connected_component(self.backbone, root) + subG = self.backbone.subgraph(connected_comps) + + ## generate directed backbone which contains no loops + DG = nx.DiGraph(nx.to_directed(self.backbone)) + temp = DG.subgraph(connected_comps) + DG.remove_edges_from(temp.edges - nx.dfs_edges(DG, root)) + self.directed_backbone = DG + else: + if not nx.is_directed_acyclic_graph(digraph): + raise ValueError("The graph 'digraph' should be a directed acyclic graph.") + if set(digraph.nodes) != set(self.backbone.nodes): + raise ValueError("The nodes in 'digraph' do not match the nodes in 'self.backbone'.") + self.directed_backbone = digraph + + connected_comps = nx.node_connected_component(digraph, root) + subG = self.backbone.subgraph(connected_comps) + + + if len(subG.edges)>0: + milestone_net = self.inferer.build_milestone_net(subG, root) + if self.inferer.no_loop is False and milestone_net.shape[0]<len(self.backbone.edges): + warnings.warn("The directed graph shown is a minimum spanning tree of the estimated trajectory backbone to avoid arbitrary assignment of the directions.") + self.pseudotime = self.inferer.comp_pseudotime(milestone_net, root, self.cell_position_projected) + else: + warnings.warn("There are no connected states for starting from the giving root.") + self.pseudotime = -np.ones(self._adata.shape[0]) + + self.adata.obs['pseudotime'] = self.pseudotime + print("Cell projection uncertainties stored as 'pseudotime' in self.adata.obs") + + if visualize: + self._adata.obs['pseudotime'] = self.pseudotime + self.ax = self.plot_backbone(directed = True, color = color, **kwargs) + if path_to_fig is not None: + self.ax.figure.savefig(path_to_fig) + self.ax.figure.show() + + return None + + + + def differential_expression_test(self, alpha: float = 0.05, cell_subset = None, order: int = 1): + '''Differentially gene expression test. All (selected and unselected) genes will be tested + Only cells in `selected_cell_subset` will be used, which is useful when one need to + test differentially expressed genes on a branch of the inferred trajectory. + + Parameters + ---------- + alpha : float, optional + The cutoff of p-values. + cell_subset : np.array, optional + The subset of cells to be used for testing. If None, all cells will be used. + order : int, optional + The maxium order we used for pseudotime in regression. + + Returns + ---------- + res_df : pandas.DataFrame + The test results of expressed genes with two columns, + the estimated coefficients and the adjusted p-values. + ''' + if not hasattr(self, 'pseudotime'): + raise ReferenceError("Pseudotime does not exist! Please run 'infer_trajectory' first.") + if cell_subset is None: + cell_subset = np.arange(self.X_input.shape[0]) + print("All cells are selected.") + if order < 1: + raise ValueError("Maximal order of pseudotime in regression must be at least 1.") + + # Prepare X and Y for regression expression ~ rank(PDT) + covariates + Y = self.adata.X[cell_subset,:] +# std_Y = np.std(Y, ddof=1, axis=0, keepdims=True) +# Y = np.divide(Y-np.mean(Y, axis=0, keepdims=True), std_Y, out=np.empty_like(Y)*np.nan, where=std_Y!=0) + X = stats.rankdata(self.pseudotime[cell_subset]) + if order > 1: + for _order in range(2, order+1): + X = np.c_[X, X**_order] + X = ((X-np.mean(X,axis=0, keepdims=True))/np.std(X, ddof=1, axis=0, keepdims=True)) + X = np.c_[np.ones((X.shape[0],1)), X] + if self.covariates is not None: + X = np.c_[X, self.covariates[cell_subset, :]] + + res_df = DE_test(Y, X, self.adata.var_names, i_test = np.array(list(range(1,order+1))), alpha = alpha) + return res_df[res_df.pvalue_adjusted_1 != 0] + + + + + def evaluate(self, milestone_net, begin_node_true, grouping = None, + thres: float = 0.5, no_loop: bool = True, cutoff: Optional[float] = None, + method: str = 'mean', path: Optional[str] = None): + ''' Evaluate the model. + + Parameters + ---------- + milestone_net : pd.DataFrame + The true milestone network. For real data, milestone_net will be a DataFrame of the graph of nodes. + Eg. + + from|to + ---|--- + cluster 1 | cluster 1 + cluster 1 | cluster 2 + + For synthetic data, milestone_net will be a DataFrame of the (projected) + positions of cells. The indexes are the orders of cells in the dataset. + Eg. + + from|to|w + ---|---|--- + cluster 1 | cluster 1 | 1 + cluster 1 | cluster 2 | 0.1 + begin_node_true : str or int + The true begin node of the milestone. + grouping : np.array, optional + \([N,]\) The labels. For real data, grouping must be provided. + + Returns + ---------- + res : pd.DataFrame + The evaluation result. + ''' + if not hasattr(self, 'labels_map'): + raise ValueError("No given labels for training.") + + ''' + # Evaluate for the whole dataset will ignore selected_cell_subset. + if len(self.selected_cell_subset)!=len(self.cell_names): + warnings.warn("Evaluate for the whole dataset.") + ''' + + # If the begin_node_true, need to encode it by self.le. + # this dict is for milestone net cause their labels are not merged + # all keys of label_map_dict are str + label_map_dict = dict() + for i in range(self.labels_map.shape[0]): + label_mapped = self.labels_map.loc[i] + ## merged cluster index is connected by comma + for each in label_mapped.values[0].split(","): + label_map_dict[each] = i + if isinstance(begin_node_true, str): + begin_node_true = label_map_dict[begin_node_true] + + # For generated data, grouping information is already in milestone_net + if 'w' in milestone_net.columns: + grouping = None + + # If milestone_net is provided, transform them to be numeric. + if milestone_net is not None: + milestone_net['from'] = [label_map_dict[x] for x in milestone_net["from"]] + milestone_net['to'] = [label_map_dict[x] for x in milestone_net["to"]] + + # this dict is for potentially merged clusters. + label_map_dict_for_merged_cluster = dict(zip(self.labels_map["label_names"],self.labels_map.index)) + mapped_labels = np.array([label_map_dict_for_merged_cluster[x] for x in self.labels]) + begin_node_pred = int(np.argmin(np.mean(( + self.z[mapped_labels==begin_node_true,:,np.newaxis] - + self.mu[np.newaxis,:,:])**2, axis=(0,1)))) + + if cutoff is None: + cutoff = 0.01 + + G = self.backbone + w = self.cell_position_projected + pseudotime = self.pseudotime + + # 1. Topology + G_pred = nx.Graph() + G_pred.add_nodes_from(G.nodes) + G_pred.add_edges_from(G.edges) + nx.set_node_attributes(G_pred, False, 'is_init') + G_pred.nodes[begin_node_pred]['is_init'] = True + + G_true = nx.Graph() + G_true.add_nodes_from(G.nodes) + # if 'grouping' is not provided, assume 'milestone_net' contains proportions + if grouping is None: + G_true.add_edges_from(list( + milestone_net[~pd.isna(milestone_net['w'])].groupby(['from', 'to']).count().index)) + # otherwise, 'milestone_net' indicates edges + else: + if milestone_net is not None: + G_true.add_edges_from(list( + milestone_net.groupby(['from', 'to']).count().index)) + grouping = [label_map_dict[x] for x in grouping] + grouping = np.array(grouping) + G_true.remove_edges_from(nx.selfloop_edges(G_true)) + nx.set_node_attributes(G_true, False, 'is_init') + G_true.nodes[begin_node_true]['is_init'] = True + res = topology(G_true, G_pred) + + # 2. Milestones assignment + if grouping is None: + milestones_true = milestone_net['from'].values.copy() + milestones_true[(milestone_net['from']!=milestone_net['to']) + &(milestone_net['w']<0.5)] = milestone_net[(milestone_net['from']!=milestone_net['to']) + &(milestone_net['w']<0.5)]['to'].values + else: + milestones_true = grouping + milestones_true = milestones_true + milestones_pred = np.argmax(w, axis=1) + res['ARI'] = (adjusted_rand_score(milestones_true, milestones_pred) + 1)/2 + + if grouping is None: + n_samples = len(milestone_net) + prop = np.zeros((n_samples,n_samples)) + prop[np.arange(n_samples), milestone_net['to']] = 1-milestone_net['w'] + prop[np.arange(n_samples), milestone_net['from']] = np.where(np.isnan(milestone_net['w']), 1, milestone_net['w']) + res['GRI'] = get_GRI(prop, w) + else: + res['GRI'] = get_GRI(grouping, w) + + # 3. Correlation between geodesic distances / Pseudotime + if no_loop: + if grouping is None: + pseudotime_true = milestone_net['from'].values + 1 - milestone_net['w'].values + pseudotime_true[np.isnan(pseudotime_true)] = milestone_net[pd.isna(milestone_net['w'])]['from'].values + else: + pseudotime_true = - np.ones(len(grouping)) + nx.set_edge_attributes(G_true, values = 1, name = 'weight') + connected_comps = nx.node_connected_component(G_true, begin_node_true) + subG = G_true.subgraph(connected_comps) + milestone_net_true = self.inferer.build_milestone_net(subG, begin_node_true) + if len(milestone_net_true)>0: + pseudotime_true[grouping==int(milestone_net_true[0,0])] = 0 + for i in range(len(milestone_net_true)): + pseudotime_true[grouping==int(milestone_net_true[i,1])] = milestone_net_true[i,-1] + pseudotime_true = pseudotime_true[pseudotime>-1] + pseudotime_pred = pseudotime[pseudotime>-1] + res['PDT score'] = (np.corrcoef(pseudotime_true,pseudotime_pred)[0,1]+1)/2 + else: + res['PDT score'] = np.nan + + # 4. Shape + # score_cos_theta = 0 + # for (_from,_to) in G.edges: + # _z = self.z[(w[:,_from]>0) & (w[:,_to]>0),:] + # v_1 = _z - self.mu[:,_from] + # v_2 = _z - self.mu[:,_to] + # cos_theta = np.sum(v_1*v_2, -1)/(np.linalg.norm(v_1,axis=-1)*np.linalg.norm(v_2,axis=-1)+1e-12) + + # score_cos_theta += np.sum((1-cos_theta)/2) + + # res['score_cos_theta'] = score_cos_theta/(np.sum(np.sum(w>0, axis=-1)==2)+1e-12) + return res + + + def save_model(self, path_to_file: str = 'model.checkpoint',save_adata: bool = False): + '''Saving model weights. + + Parameters + ---------- + path_to_file : str, optional + The path to weight files of pre-trained or trained model + save_adata : boolean, optional + Whether to save adata or not. + ''' + self.vae.save_weights(path_to_file) + if hasattr(self, 'labels') and self.labels is not None: + with open(path_to_file + '.label', 'wb') as f: + np.save(f, self.labels) + with open(path_to_file + '.config', 'wb') as f: + self.dim_origin = self.X_input.shape[1] + np.save(f, np.array([ + self.dim_origin, self.dimensions, self.dim_latent, + self.model_type, 0 if self.covariates is None else self.covariates.shape[1]], dtype=object)) + if hasattr(self, 'inferer') and hasattr(self, 'uncertainty'): + with open(path_to_file + '.inference', 'wb') as f: + np.save(f, np.array([ + self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty, + self.z,self.cell_position_variance], dtype=object)) + if save_adata: + self.adata.write(path_to_file + '.adata.h5ad') + + + def load_model(self, path_to_file: str = 'model.checkpoint', load_labels: bool = False, load_adata: bool = False): + '''Load model weights. + + Parameters + ---------- + path_to_file : str, optional + The path to weight files of pre trained or trained model + load_labels : boolean, optional + Whether to load clustering labels or not. + If load_labels is True, then the LatentSpace layer will be initialized basd on the model. + If load_labels is False, then the LatentSpace layer will not be initialized. + load_adata : boolean, optional + Whether to load adata or not. + ''' + if not os.path.exists(path_to_file + '.config'): + raise AssertionError('Config file not exist!') + if load_labels and not os.path.exists(path_to_file + '.label'): + raise AssertionError('Label file not exist!') + + with open(path_to_file + '.config', 'rb') as f: + [self.dim_origin, self.dimensions, + self.dim_latent, self.model_type, cov_dim] = np.load(f, allow_pickle=True) + self.vae = model.VariationalAutoEncoder( + self.dim_origin, self.dimensions, + self.dim_latent, self.model_type, False if cov_dim == 0 else True + ) + + if load_labels: + with open(path_to_file + '.label', 'rb') as f: + cluster_labels = np.load(f, allow_pickle=True) + self.init_latent_space(cluster_labels, dist_thres=0) + if os.path.exists(path_to_file + '.inference'): + with open(path_to_file + '.inference', 'rb') as f: + arr = np.load(f, allow_pickle=True) + if len(arr) == 8: + [self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty, + self.D_JS, self.z,self.cell_position_variance] = arr + else: + [self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty, + self.z,self.cell_position_variance] = arr + self._adata_z = sc.AnnData(self.z) + sc.pp.neighbors(self._adata_z) + ## initialize the weight of encoder and decoder + self.vae.encoder(np.zeros((1, self.dim_origin + cov_dim))) + self.vae.decoder(np.expand_dims(np.zeros((1,self.dim_latent + cov_dim)),1)) + + self.vae.load_weights(path_to_file) + self.update_z() + if load_adata: + if not os.path.exists(path_to_file + '.adata.h5ad'): + raise AssertionError('AnnData file not exist!') + self.adata = sc.read_h5ad(path_to_file + '.adata.h5ad') + self._adata.obs = self.adata.obs.copy()</code></pre> +</details> +<h3>Methods</h3> +<dl> +<dt id="VITAE.VITAE.pre_train"><code class="name flex"> +<span>def <span class="ident">pre_train</span></span>(<span>self, test_size=0.1, random_state: int = 0, learning_rate: float = 0.001, batch_size: int = 256, L: int = 1, alpha: float = 0.1, gamma: float = 0, phi: float = 1, num_epoch: int = 200, num_step_per_epoch: Optional[int] = None, early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, early_stopping_relative: bool = True, verbose: bool = False, path_to_weights: Optional[str] = None)</span> +</code></dt> +<dd> +<div class="desc"><p>Pretrain the model with specified learning rate.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>test_size</code></strong> : <code>float</code> or <code>int</code>, optional</dt> +<dd>The proportion or size of the test set.</dd> +<dt><strong><code>random_state</code></strong> : <code>int</code>, optional</dt> +<dd>The random state for data splitting.</dd> +<dt><strong><code>learning_rate</code></strong> : <code>float</code>, optional</dt> +<dd>The initial learning rate for the Adam optimizer.</dd> +<dt><strong><code>batch_size</code></strong> : <code>int</code>, optional</dt> +<dd>The batch size for pre-training. +Default is 256. Set to 32 if number of cells is small (less than 1000)</dd> +<dt><strong><code>L</code></strong> : <code>int</code>, optional</dt> +<dd>The number of MC samples.</dd> +<dt><strong><code>alpha</code></strong> : <code>float</code>, optional</dt> +<dd>The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.</dd> +<dt><strong><code>gamma</code></strong> : <code>float</code>, optional</dt> +<dd>The weight of the mmd loss if used.</dd> +<dt><strong><code>phi</code></strong> : <code>float</code>, optional</dt> +<dd>The weight of Jocob norm of the encoder.</dd> +<dt><strong><code>num_epoch</code></strong> : <code>int</code>, optional</dt> +<dd>The maximum number of epochs.</dd> +<dt><strong><code>num_step_per_epoch</code></strong> : <code>int</code>, optional</dt> +<dd>The number of step per epoch, it will be inferred from number of cells and batch size if it is None.</dd> +<dt><strong><code>early_stopping_patience</code></strong> : <code>int</code>, optional</dt> +<dd>The maximum number of epochs if there is no improvement.</dd> +<dt><strong><code>early_stopping_tolerance</code></strong> : <code>float</code>, optional</dt> +<dd>The minimum change of loss to be considered as an improvement.</dd> +<dt><strong><code>early_stopping_relative</code></strong> : <code>bool</code>, optional</dt> +<dd>Whether monitor the relative change of loss as stopping criteria or not.</dd> +<dt><strong><code>path_to_weights</code></strong> : <code>str</code>, optional</dt> +<dd>The path of weight file to be saved; not saving weight if None.</dd> +<dt><strong><code>conditions</code></strong> : <code>str</code> or <code>list</code>, optional</dt> +<dd>The conditions of different cells</dd> +</dl></div> +</dd> +<dt id="VITAE.VITAE.update_z"><code class="name flex"> +<span>def <span class="ident">update_z</span></span>(<span>self)</span> +</code></dt> +<dd> +<div class="desc"></div> +</dd> +<dt id="VITAE.VITAE.get_latent_z"><code class="name flex"> +<span>def <span class="ident">get_latent_z</span></span>(<span>self)</span> +</code></dt> +<dd> +<div class="desc"><p>get the posterier mean of current latent space z (encoder output)</p> +<h2 id="returns">Returns</h2> +<dl> +<dt><strong><code>z</code></strong> : <code>np.array</code></dt> +<dd><span><span class="MathJax_Preview">[N,d]</span><script type="math/tex">[N,d]</script></span> The latent means.</dd> +</dl></div> +</dd> +<dt id="VITAE.VITAE.visualize_latent"><code class="name flex"> +<span>def <span class="ident">visualize_latent</span></span>(<span>self, method: str = 'UMAP', color=None, **kwargs)</span> +</code></dt> +<dd> +<div class="desc"><p>visualize the current latent space z using the scanpy visualization tools</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>method</code></strong> : <code>str</code>, optional</dt> +<dd>Visualization method to use. The default is "draw_graph" (the FA plot). Possible choices include "PCA", "UMAP", +"diffmap", "TSNE" and "draw_graph"</dd> +<dt><strong><code>color</code></strong> : <code>TYPE</code>, optional</dt> +<dd>Keys for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2']. +The default is None. Same as scanpy.</dd> +<dt><strong><code>**kwargs</code></strong> : <code> </code></dt> +<dd>Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX).</dd> +</dl> +<h2 id="returns">Returns</h2> +<p>None.</p></div> +</dd> +<dt id="VITAE.VITAE.init_latent_space"><code class="name flex"> +<span>def <span class="ident">init_latent_space</span></span>(<span>self, cluster_label=None, log_pi=None, res: float = 1.0, ratio_prune=None, dist=None, dist_thres=0.5, topk=0, pilayer=False)</span> +</code></dt> +<dd> +<div class="desc"><p>Initialize the latent space.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>cluster_label</code></strong> : <code>str</code>, optional</dt> +<dd>The name of vector of labels that can be found in self.adata.obs. +Default is None, which will perform leiden clustering on the pretrained z to get clusters</dd> +<dt><strong><code>mu</code></strong> : <code>np.array</code>, optional</dt> +<dd><span><span class="MathJax_Preview">[d,k]</span><script type="math/tex">[d,k]</script></span> The value of initial <span><span class="MathJax_Preview">\mu</span><script type="math/tex">\mu</script></span>.</dd> +<dt><strong><code>log_pi</code></strong> : <code>np.array</code>, optional</dt> +<dd><span><span class="MathJax_Preview">[1,K]</span><script type="math/tex">[1,K]</script></span> The value of initial <span><span class="MathJax_Preview">\log(\pi)</span><script type="math/tex">\log(\pi)</script></span>.</dd> +<dt><strong><code>res</code></strong></dt> +<dd>The resolution of leiden clustering, which is a parameter value controlling the coarseness of the clustering. +Higher values lead to more clusters. Deafult is 1.</dd> +<dt><strong><code>ratio_prune</code></strong> : <code>float</code>, optional</dt> +<dd>The ratio of edges to be removed before estimating.</dd> +<dt><strong><code>topk</code></strong> : <code>int</code>, optional</dt> +<dd>The number of top k neighbors to keep for each cluster.</dd> +</dl></div> +</dd> +<dt id="VITAE.VITAE.update_latent_space"><code class="name flex"> +<span>def <span class="ident">update_latent_space</span></span>(<span>self, dist_thres: float = 0.5)</span> +</code></dt> +<dd> +<div class="desc"></div> +</dd> +<dt id="VITAE.VITAE.train"><code class="name flex"> +<span>def <span class="ident">train</span></span>(<span>self, stratify=False, test_size=0.1, random_state: int = 0, learning_rate: float = 0.001, batch_size: int = 256, L: int = 1, alpha: float = 0.1, beta: float = 1, gamma: float = 0, phi: float = 1, num_epoch: int = 200, num_step_per_epoch: Optional[int] = None, early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, early_stopping_relative: bool = True, early_stopping_warmup: int = 0, path_to_weights: Optional[str] = None, verbose: bool = False, **kwargs)</span> +</code></dt> +<dd> +<div class="desc"><p>Train the model.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>stratify</code></strong> : <code>np.array, None,</code> or <code>False</code></dt> +<dd>If an array is provided, or <code>stratify=None</code> and <code>self.labels</code> is available, then they will be used to perform stratified shuffle splitting. Otherwise, general shuffle splitting is used. Set to <code>False</code> if <code>self.labels</code> is not intended for stratified shuffle splitting.</dd> +<dt><strong><code>test_size</code></strong> : <code>float</code> or <code>int</code>, optional</dt> +<dd>The proportion or size of the test set.</dd> +<dt><strong><code>random_state</code></strong> : <code>int</code>, optional</dt> +<dd>The random state for data splitting.</dd> +<dt><strong><code>learning_rate</code></strong> : <code>float</code>, optional</dt> +<dd>The initial learning rate for the Adam optimizer.</dd> +<dt><strong><code>batch_size</code></strong> : <code>int</code>, optional</dt> +<dd>The batch size for training. Default is 256. Set to 32 if number of cells is small (less than 1000)</dd> +<dt><strong><code>L</code></strong> : <code>int</code>, optional</dt> +<dd>The number of MC samples.</dd> +<dt><strong><code>alpha</code></strong> : <code>float</code>, optional</dt> +<dd>The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.</dd> +<dt><strong><code>beta</code></strong> : <code>float</code>, optional</dt> +<dd>The value of beta in beta-VAE.</dd> +<dt><strong><code>gamma</code></strong> : <code>float</code>, optional</dt> +<dd>The weight of mmd_loss.</dd> +<dt><strong><code>phi</code></strong> : <code>float</code>, optional</dt> +<dd>The weight of Jacob norm of encoder.</dd> +<dt><strong><code>num_epoch</code></strong> : <code>int</code>, optional</dt> +<dd>The number of epoch.</dd> +<dt><strong><code>num_step_per_epoch</code></strong> : <code>int</code>, optional</dt> +<dd>The number of step per epoch, it will be inferred from number of cells and batch size if it is None.</dd> +<dt><strong><code>early_stopping_patience</code></strong> : <code>int</code>, optional</dt> +<dd>The maximum number of epochs if there is no improvement.</dd> +<dt><strong><code>early_stopping_tolerance</code></strong> : <code>float</code>, optional</dt> +<dd>The minimum change of loss to be considered as an improvement.</dd> +<dt><strong><code>early_stopping_relative</code></strong> : <code>bool</code>, optional</dt> +<dd>Whether monitor the relative change of loss or not.</dd> +<dt><strong><code>early_stopping_warmup</code></strong> : <code>int</code>, optional</dt> +<dd>The number of warmup epochs.</dd> +<dt><strong><code>path_to_weights</code></strong> : <code>str</code>, optional</dt> +<dd>The path of weight file to be saved; not saving weight if None.</dd> +<dt><strong><code>**kwargs</code></strong> : <code> </code></dt> +<dd>Extra key-value arguments for dimension reduction algorithms.</dd> +</dl></div> +</dd> +<dt id="VITAE.VITAE.output_pi"><code class="name flex"> +<span>def <span class="ident">output_pi</span></span>(<span>self, pi_cov)</span> +</code></dt> +<dd> +<div class="desc"><p>return a matrix n_states by n_states and a mask for plotting, which can be used to cover the lower triangular(except the diagnoals) of a heatmap</p></div> +</dd> +<dt id="VITAE.VITAE.return_pilayer_weights"><code class="name flex"> +<span>def <span class="ident">return_pilayer_weights</span></span>(<span>self)</span> +</code></dt> +<dd> +<div class="desc"><p>return parameters of pilayer, which has dimension dim(pi_cov) + 1 by n_categories, the last row is biases</p></div> +</dd> +<dt id="VITAE.VITAE.posterior_estimation"><code class="name flex"> +<span>def <span class="ident">posterior_estimation</span></span>(<span>self, batch_size: int = 32, L: int = 50, **kwargs)</span> +</code></dt> +<dd> +<div class="desc"><p>Initialize trajectory inference by computing the posterior estimations. +</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>batch_size</code></strong> : <code>int</code>, optional</dt> +<dd>The batch size when doing inference.</dd> +<dt><strong><code>L</code></strong> : <code>int</code>, optional</dt> +<dd>The number of MC samples when doing inference.</dd> +<dt><strong><code>**kwargs</code></strong> : <code> </code></dt> +<dd>Extra key-value arguments for dimension reduction algorithms.</dd> +</dl></div> +</dd> +<dt id="VITAE.VITAE.infer_backbone"><code class="name flex"> +<span>def <span class="ident">infer_backbone</span></span>(<span>self, method: str = 'modified_map', thres=0.5, no_loop: bool = True, cutoff: float = 0, visualize: bool = True, color='vitae_new_clustering', path_to_fig=None, **kwargs)</span> +</code></dt> +<dd> +<div class="desc"><p>Compute edge scores.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>method</code></strong> : <code>string</code>, optional</dt> +<dd>'mean', 'modified_mean', 'map', or 'modified_map'.</dd> +<dt><strong><code>thres</code></strong> : <code>float</code>, optional</dt> +<dd>The threshold used for filtering edges <span><span class="MathJax_Preview">e_{ij}</span><script type="math/tex">e_{ij}</script></span> that <span><span class="MathJax_Preview">(n_{i}+n_{j}+e_{ij})/N<thres</span><script type="math/tex">(n_{i}+n_{j}+e_{ij})/N<thres</script></span>, only applied to mean method.</dd> +<dt><strong><code>no_loop</code></strong> : <code>boolean</code>, optional</dt> +<dd>Whether loops are allowed to exist in the graph. If no_loop is true, will prune the graph to contain only the +maximum spanning true</dd> +<dt><strong><code>cutoff</code></strong> : <code>string</code>, optional</dt> +<dd>The score threshold for filtering edges with scores less than cutoff.</dd> +<dt><strong><code>visualize</code></strong> : <code>boolean</code></dt> +<dd>whether plot the current trajectory backbone (undirected graph)</dd> +</dl> +<h2 id="returns">Returns</h2> +<dl> +<dt><strong><code>G</code></strong> : <code>nx.Graph</code></dt> +<dd>The weighted graph with weight on each edge indicating its score of existence.</dd> +</dl></div> +</dd> +<dt id="VITAE.VITAE.select_root"><code class="name flex"> +<span>def <span class="ident">select_root</span></span>(<span>self, days, method: str = 'proportion')</span> +</code></dt> +<dd> +<div class="desc"><p>Order the vertices/states based on cells' collection time information to select the root state. +</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>day</code></strong> : <code>np.array </code></dt> +<dd>The day information for selected cells used to determine the root vertex. +The dtype should be 'int' or 'float'.</dd> +<dt><strong><code>method</code></strong> : <code>str</code>, optional</dt> +<dd>'sum' or 'mean'. +For 'proportion', the root is the one with maximal proportion of cells from the earliest day. +For 'mean', the root is the one with earliest mean time among cells associated with it.</dd> +</dl> +<h2 id="returns">Returns</h2> +<dl> +<dt><strong><code>root</code></strong> : <code>int </code></dt> +<dd>The root vertex in the inferred trajectory based on given day information.</dd> +</dl></div> +</dd> +<dt id="VITAE.VITAE.plot_backbone"><code class="name flex"> +<span>def <span class="ident">plot_backbone</span></span>(<span>self, directed: bool = False, method: str = 'UMAP', color='vitae_new_clustering', **kwargs)</span> +</code></dt> +<dd> +<div class="desc"><p>Plot the current trajectory backbone (undirected graph).</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>directed</code></strong> : <code>boolean</code>, optional</dt> +<dd>Whether the backbone is directed or not.</dd> +<dt><strong><code>method</code></strong> : <code>str</code>, optional</dt> +<dd>The dimension reduction method to use. The default is "UMAP".</dd> +<dt><strong><code>color</code></strong> : <code>str</code>, optional</dt> +<dd>The key for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2']. +The default is 'vitae_new_clustering'.</dd> +</dl> +<p>**kwargs : +Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX).</p></div> +</dd> +<dt id="VITAE.VITAE.plot_center"><code class="name flex"> +<span>def <span class="ident">plot_center</span></span>(<span>self, color='vitae_new_clustering', plot_legend=True, legend_add_index=True, method: str = 'UMAP', ncol=2, font_size='medium', add_egde=False, add_direct=False, **kwargs)</span> +</code></dt> +<dd> +<div class="desc"><p>Plot the center of each cluster in the latent space.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>color</code></strong> : <code>str</code>, optional</dt> +<dd>The color of the center of each cluster. Default is "vitae_new_clustering".</dd> +<dt><strong><code>plot_legend</code></strong> : <code>bool</code>, optional</dt> +<dd>Whether to plot the legend. Default is True.</dd> +<dt><strong><code>legend_add_index</code></strong> : <code>bool</code>, optional</dt> +<dd>Whether to add the index of each cluster in the legend. Default is True.</dd> +<dt><strong><code>method</code></strong> : <code>str</code>, optional</dt> +<dd>The dimension reduction method used for visualization. Default is 'UMAP'.</dd> +<dt><strong><code>ncol</code></strong> : <code>int</code>, optional</dt> +<dd>The number of columns in the legend. Default is 2.</dd> +<dt><strong><code>font_size</code></strong> : <code>str</code>, optional</dt> +<dd>The font size of the legend. Default is "medium".</dd> +<dt><strong><code>add_egde</code></strong> : <code>bool</code>, optional</dt> +<dd>Whether to add the edges between the centers of clusters. Default is False.</dd> +<dt><strong><code>add_direct</code></strong> : <code>bool</code>, optional</dt> +<dd>Whether to add the direction of the edges. Default is False.</dd> +</dl></div> +</dd> +<dt id="VITAE.VITAE.infer_trajectory"><code class="name flex"> +<span>def <span class="ident">infer_trajectory</span></span>(<span>self, root: Union[int, str], digraph=None, color='pseudotime', visualize: bool = True, path_to_fig=None, **kwargs)</span> +</code></dt> +<dd> +<div class="desc"><p>Infer the trajectory.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>root</code></strong> : <code>int</code> or <code>string</code></dt> +<dd>The root of the inferred trajectory. Can provide either an int (vertex index) or string (label name)</dd> +<dt><strong><code>digraph</code></strong> : <code>nx.DiGraph</code>, optional</dt> +<dd>The directed graph to be used for trajectory inference. If None, the minimum spanning tree of the estimated trajectory backbone will be used.</dd> +<dt><strong><code>cutoff</code></strong> : <code>string</code>, optional</dt> +<dd>The threshold for filtering edges with scores less than cutoff.</dd> +<dt><strong><code>visualize</code></strong> : <code>boolean</code></dt> +<dd>Whether plot the current trajectory backbone (directed graph)</dd> +<dt><strong><code>path_to_fig</code></strong> : <code>string</code>, optional</dt> +<dd>The path to save figure, or don't save if it is None.</dd> +<dt><strong><code>**kwargs</code></strong> : <code>dict</code>, optional</dt> +<dd>Other keywords arguments for plotting.</dd> +</dl></div> +</dd> +<dt id="VITAE.VITAE.differential_expression_test"><code class="name flex"> +<span>def <span class="ident">differential_expression_test</span></span>(<span>self, alpha: float = 0.05, cell_subset=None, order: int = 1)</span> +</code></dt> +<dd> +<div class="desc"><p>Differentially gene expression test. All (selected and unselected) genes will be tested +Only cells in <code>selected_cell_subset</code> will be used, which is useful when one need to +test differentially expressed genes on a branch of the inferred trajectory.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>alpha</code></strong> : <code>float</code>, optional</dt> +<dd>The cutoff of p-values.</dd> +<dt><strong><code>cell_subset</code></strong> : <code>np.array</code>, optional</dt> +<dd>The subset of cells to be used for testing. If None, all cells will be used.</dd> +<dt><strong><code>order</code></strong> : <code>int</code>, optional</dt> +<dd>The maxium order we used for pseudotime in regression.</dd> +</dl> +<h2 id="returns">Returns</h2> +<dl> +<dt><strong><code>res_df</code></strong> : <code>pandas.DataFrame</code></dt> +<dd>The test results of expressed genes with two columns, +the estimated coefficients and the adjusted p-values.</dd> +</dl></div> +</dd> +<dt id="VITAE.VITAE.evaluate"><code class="name flex"> +<span>def <span class="ident">evaluate</span></span>(<span>self, milestone_net, begin_node_true, grouping=None, thres: float = 0.5, no_loop: bool = True, cutoff: Optional[float] = None, method: str = 'mean', path: Optional[str] = None)</span> +</code></dt> +<dd> +<div class="desc"><p>Evaluate the model.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>milestone_net</code></strong> : <code>pd.DataFrame</code></dt> +<dd> +<p>The true milestone network. For real data, milestone_net will be a DataFrame of the graph of nodes. +Eg.</p> +<table> +<thead> +<tr> +<th>from</th> +<th>to</th> +</tr> +</thead> +<tbody> +<tr> +<td>cluster 1</td> +<td>cluster 1</td> +</tr> +<tr> +<td>cluster 1</td> +<td>cluster 2</td> +</tr> +</tbody> +</table> +<p>For synthetic data, milestone_net will be a DataFrame of the (projected) +positions of cells. The indexes are the orders of cells in the dataset. +Eg.</p> +<table> +<thead> +<tr> +<th>from</th> +<th>to</th> +<th>w</th> +</tr> +</thead> +<tbody> +<tr> +<td>cluster 1</td> +<td>cluster 1</td> +<td>1</td> +</tr> +<tr> +<td>cluster 1</td> +<td>cluster 2</td> +<td>0.1</td> +</tr> +</tbody> +</table> +</dd> +<dt><strong><code>begin_node_true</code></strong> : <code>str</code> or <code>int</code></dt> +<dd>The true begin node of the milestone.</dd> +<dt><strong><code>grouping</code></strong> : <code>np.array</code>, optional</dt> +<dd><span><span class="MathJax_Preview">[N,]</span><script type="math/tex">[N,]</script></span> The labels. For real data, grouping must be provided.</dd> +</dl> +<h2 id="returns">Returns</h2> +<dl> +<dt><strong><code>res</code></strong> : <code>pd.DataFrame</code></dt> +<dd>The evaluation result.</dd> +</dl></div> +</dd> +<dt id="VITAE.VITAE.save_model"><code class="name flex"> +<span>def <span class="ident">save_model</span></span>(<span>self, path_to_file: str = 'model.checkpoint', save_adata: bool = False)</span> +</code></dt> +<dd> +<div class="desc"><p>Saving model weights.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>path_to_file</code></strong> : <code>str</code>, optional</dt> +<dd>The path to weight files of pre-trained or trained model</dd> +<dt><strong><code>save_adata</code></strong> : <code>boolean</code>, optional</dt> +<dd>Whether to save adata or not.</dd> +</dl></div> +</dd> +<dt id="VITAE.VITAE.load_model"><code class="name flex"> +<span>def <span class="ident">load_model</span></span>(<span>self, path_to_file: str = 'model.checkpoint', load_labels: bool = False, load_adata: bool = False)</span> +</code></dt> +<dd> +<div class="desc"><p>Load model weights.</p> +<h2 id="parameters">Parameters</h2> +<dl> +<dt><strong><code>path_to_file</code></strong> : <code>str</code>, optional</dt> +<dd>The path to weight files of pre trained or trained model</dd> +<dt><strong><code>load_labels</code></strong> : <code>boolean</code>, optional</dt> +<dd>Whether to load clustering labels or not. +If load_labels is True, then the LatentSpace layer will be initialized basd on the model. +If load_labels is False, then the LatentSpace layer will not be initialized.</dd> +<dt><strong><code>load_adata</code></strong> : <code>boolean</code>, optional</dt> +<dd>Whether to load adata or not.</dd> +</dl></div> +</dd> +</dl> +</dd> +</dl> +</section> +</article> +<nav id="sidebar"> +<div class="toc"> +<ul></ul> +</div> +<ul id="index"> +<li><h3><a href="#header-submodules">Sub-modules</a></h3> +<ul> +<li><code><a title="VITAE.inference" href="inference.html">VITAE.inference</a></code></li> +<li><code><a title="VITAE.metric" href="metric.html">VITAE.metric</a></code></li> +<li><code><a title="VITAE.model" href="model.html">VITAE.model</a></code></li> +<li><code><a title="VITAE.train" href="train.html">VITAE.train</a></code></li> +<li><code><a title="VITAE.utils" href="utils.html">VITAE.utils</a></code></li> +</ul> +</li> +<li><h3><a href="#header-classes">Classes</a></h3> +<ul> +<li> +<h4><code><a title="VITAE.VITAE" href="#VITAE.VITAE">VITAE</a></code></h4> +<ul class=""> +<li><code><a title="VITAE.VITAE.pre_train" href="#VITAE.VITAE.pre_train">pre_train</a></code></li> +<li><code><a title="VITAE.VITAE.update_z" href="#VITAE.VITAE.update_z">update_z</a></code></li> +<li><code><a title="VITAE.VITAE.get_latent_z" href="#VITAE.VITAE.get_latent_z">get_latent_z</a></code></li> +<li><code><a title="VITAE.VITAE.visualize_latent" href="#VITAE.VITAE.visualize_latent">visualize_latent</a></code></li> +<li><code><a title="VITAE.VITAE.init_latent_space" href="#VITAE.VITAE.init_latent_space">init_latent_space</a></code></li> +<li><code><a title="VITAE.VITAE.update_latent_space" href="#VITAE.VITAE.update_latent_space">update_latent_space</a></code></li> +<li><code><a title="VITAE.VITAE.train" href="#VITAE.VITAE.train">train</a></code></li> +<li><code><a title="VITAE.VITAE.output_pi" href="#VITAE.VITAE.output_pi">output_pi</a></code></li> +<li><code><a title="VITAE.VITAE.return_pilayer_weights" href="#VITAE.VITAE.return_pilayer_weights">return_pilayer_weights</a></code></li> +<li><code><a title="VITAE.VITAE.posterior_estimation" href="#VITAE.VITAE.posterior_estimation">posterior_estimation</a></code></li> +<li><code><a title="VITAE.VITAE.infer_backbone" href="#VITAE.VITAE.infer_backbone">infer_backbone</a></code></li> +<li><code><a title="VITAE.VITAE.select_root" href="#VITAE.VITAE.select_root">select_root</a></code></li> +<li><code><a title="VITAE.VITAE.plot_backbone" href="#VITAE.VITAE.plot_backbone">plot_backbone</a></code></li> +<li><code><a title="VITAE.VITAE.plot_center" href="#VITAE.VITAE.plot_center">plot_center</a></code></li> +<li><code><a title="VITAE.VITAE.infer_trajectory" href="#VITAE.VITAE.infer_trajectory">infer_trajectory</a></code></li> +<li><code><a title="VITAE.VITAE.differential_expression_test" href="#VITAE.VITAE.differential_expression_test">differential_expression_test</a></code></li> +<li><code><a title="VITAE.VITAE.evaluate" href="#VITAE.VITAE.evaluate">evaluate</a></code></li> +<li><code><a title="VITAE.VITAE.save_model" href="#VITAE.VITAE.save_model">save_model</a></code></li> +<li><code><a title="VITAE.VITAE.load_model" href="#VITAE.VITAE.load_model">load_model</a></code></li> +</ul> +</li> +</ul> +</li> +</ul> +</nav> +</main> +<footer id="footer"> +<p>Generated by <a href="https://pdoc3.github.io/pdoc" title="pdoc: Python API documentation generator"><cite>pdoc</cite> 0.11.1</a>.</p> +</footer> +</body> +</html>