|
a |
|
b/docs/index.html |
|
|
1 |
<!doctype html> |
|
|
2 |
<html lang="en"> |
|
|
3 |
<head> |
|
|
4 |
<meta charset="utf-8"> |
|
|
5 |
<meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1"> |
|
|
6 |
<meta name="generator" content="pdoc3 0.11.1"> |
|
|
7 |
<title>VITAE API documentation</title> |
|
|
8 |
<meta name="description" content=""> |
|
|
9 |
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/sanitize.min.css" integrity="sha512-y1dtMcuvtTMJc1yPgEqF0ZjQbhnc/bFhyvIyVNb9Zk5mIGtqVaAB1Ttl28su8AvFMOY0EwRbAe+HCLqj6W7/KA==" crossorigin> |
|
|
10 |
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/typography.min.css" integrity="sha512-Y1DYSb995BAfxobCkKepB1BqJJTPrOp3zPL74AWFugHHmmdcvO+C48WLrUOlhGMc0QG7AE3f7gmvvcrmX2fDoA==" crossorigin> |
|
|
11 |
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/default.min.css" crossorigin> |
|
|
12 |
<style>:root{--highlight-color:#fe9}.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:1.5em;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:2em 0 .50em 0}h3{font-size:1.4em;margin:1.6em 0 .7em 0}h4{margin:0;font-size:105%}h1:target,h2:target,h3:target,h4:target,h5:target,h6:target{background:var(--highlight-color);padding:.2em 0}a{color:#058;text-decoration:none;transition:color .2s ease-in-out}a:visited{color:#503}a:hover{color:#b62}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900;font-weight:bold}pre code{font-size:.8em;line-height:1.4em;padding:1em;display:block}code{background:#f3f3f3;font-family:"DejaVu Sans Mono",monospace;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}dt:target .name{background:var(--highlight-color)}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}td{padding:0 .5em}.admonition{padding:.1em 1em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style> |
|
|
13 |
<style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul ul{padding-left:1em}.toc > ul > li{margin-top:.5em}}</style> |
|
|
14 |
<style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style> |
|
|
15 |
<script type="text/x-mathjax-config">MathJax.Hub.Config({ tex2jax: { inlineMath: [ ['$','$'], ["\\(","\\)"] ], processEscapes: true } });</script> |
|
|
16 |
<script async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS_CHTML" integrity="sha256-kZafAc6mZvK3W3v1pHOcUix30OHQN6pU/NO2oFkqZVw=" crossorigin></script> |
|
|
17 |
<script defer src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js" integrity="sha512-D9gUyxqja7hBtkWpPWGt9wfbfaMGVt9gnyCvYa+jojwwPHLCzUm5i8rpk7vD7wNee9bA35eYIjobYPaQuKS1MQ==" crossorigin></script> |
|
|
18 |
<script>window.addEventListener('DOMContentLoaded', () => { |
|
|
19 |
hljs.configure({languages: ['bash', 'css', 'diff', 'graphql', 'ini', 'javascript', 'json', 'plaintext', 'python', 'python-repl', 'rust', 'shell', 'sql', 'typescript', 'xml', 'yaml']}); |
|
|
20 |
hljs.highlightAll(); |
|
|
21 |
})</script> |
|
|
22 |
</head> |
|
|
23 |
<body> |
|
|
24 |
<main> |
|
|
25 |
<article id="content"> |
|
|
26 |
<header> |
|
|
27 |
<h1 class="title">Package <code>VITAE</code></h1> |
|
|
28 |
</header> |
|
|
29 |
<section id="section-intro"> |
|
|
30 |
</section> |
|
|
31 |
<section> |
|
|
32 |
<h2 class="section-title" id="header-submodules">Sub-modules</h2> |
|
|
33 |
<dl> |
|
|
34 |
<dt><code class="name"><a title="VITAE.inference" href="inference.html">VITAE.inference</a></code></dt> |
|
|
35 |
<dd> |
|
|
36 |
<div class="desc"></div> |
|
|
37 |
</dd> |
|
|
38 |
<dt><code class="name"><a title="VITAE.metric" href="metric.html">VITAE.metric</a></code></dt> |
|
|
39 |
<dd> |
|
|
40 |
<div class="desc"></div> |
|
|
41 |
</dd> |
|
|
42 |
<dt><code class="name"><a title="VITAE.model" href="model.html">VITAE.model</a></code></dt> |
|
|
43 |
<dd> |
|
|
44 |
<div class="desc"></div> |
|
|
45 |
</dd> |
|
|
46 |
<dt><code class="name"><a title="VITAE.train" href="train.html">VITAE.train</a></code></dt> |
|
|
47 |
<dd> |
|
|
48 |
<div class="desc"></div> |
|
|
49 |
</dd> |
|
|
50 |
<dt><code class="name"><a title="VITAE.utils" href="utils.html">VITAE.utils</a></code></dt> |
|
|
51 |
<dd> |
|
|
52 |
<div class="desc"></div> |
|
|
53 |
</dd> |
|
|
54 |
</dl> |
|
|
55 |
</section> |
|
|
56 |
<section> |
|
|
57 |
</section> |
|
|
58 |
<section> |
|
|
59 |
</section> |
|
|
60 |
<section> |
|
|
61 |
<h2 class="section-title" id="header-classes">Classes</h2> |
|
|
62 |
<dl> |
|
|
63 |
<dt id="VITAE.VITAE"><code class="flex name class"> |
|
|
64 |
<span>class <span class="ident">VITAE</span></span> |
|
|
65 |
<span>(</span><span>adata: anndata._core.anndata.AnnData, covariates=None, pi_covariates=None, model_type: str = 'Gaussian', npc: int = 64, adata_layer_counts=None, copy_adata: bool = False, hidden_layers=[32], latent_space_dim: int = 16, conditions=None)</span> |
|
|
66 |
</code></dt> |
|
|
67 |
<dd> |
|
|
68 |
<div class="desc"><p>Variational Inference for Trajectory by AutoEncoder.</p> |
|
|
69 |
<p>Get input data for model. Data need to be first processed using scancy and stored as an AnnData object |
|
|
70 |
The 'UMI' or 'non-UMI' model need the original count matrix, so the count matrix need to be saved in |
|
|
71 |
adata.layers in order to use these models.</p> |
|
|
72 |
<h2 id="parameters">Parameters</h2> |
|
|
73 |
<dl> |
|
|
74 |
<dt><strong><code>adata</code></strong> : <code>sc.AnnData</code></dt> |
|
|
75 |
<dd>The scanpy AnnData object. adata should already contain adata.var.highly_variable</dd> |
|
|
76 |
<dt><strong><code>covariates</code></strong> : <code>list</code>, optional</dt> |
|
|
77 |
<dd>A list of names of covariate vectors that are stored in adata.obs</dd> |
|
|
78 |
<dt><strong><code>pi_covariates</code></strong> : <code>list</code>, optional</dt> |
|
|
79 |
<dd>A list of names of covariate vectors used as input for pilayer</dd> |
|
|
80 |
<dt><strong><code>model_type</code></strong> : <code>str</code>, optional</dt> |
|
|
81 |
<dd>'UMI', 'non-UMI' and 'Gaussian', default is 'Gaussian'.</dd> |
|
|
82 |
<dt><strong><code>npc</code></strong> : <code>int</code>, optional</dt> |
|
|
83 |
<dd>The number of PCs to use when model_type is 'Gaussian'. The default is 64.</dd> |
|
|
84 |
<dt><strong><code>adata_layer_counts</code></strong> : <code>str</code>, optional</dt> |
|
|
85 |
<dd>the key name of adata.layers that stores the count data if model_type is |
|
|
86 |
'UMI' or 'non-UMI'</dd> |
|
|
87 |
<dt><strong><code>copy_adata</code></strong> : <code>bool</code>, optional<code>. Set to True if we don't want VITAE to modify the original adata. If set to True, self.adata will be an independent copy</code> of <code>the original adata. </code></dt> |
|
|
88 |
<dd> </dd> |
|
|
89 |
<dt><strong><code>hidden_layers</code></strong> : <code>list</code>, optional</dt> |
|
|
90 |
<dd>The list of dimensions of layers of autoencoder between latent space and original space. Default is to have only one hidden layer with 32 nodes</dd> |
|
|
91 |
<dt><strong><code>latent_space_dim</code></strong> : <code>int</code>, optional</dt> |
|
|
92 |
<dd>The dimension of latent space.</dd> |
|
|
93 |
<dt><strong><code>gamme</code></strong> : <code>float</code>, optional</dt> |
|
|
94 |
<dd>The weight of the MMD loss</dd> |
|
|
95 |
<dt><strong><code>conditions</code></strong> : <code>str</code> or <code>list</code>, optional</dt> |
|
|
96 |
<dd>The conditions of different cells</dd> |
|
|
97 |
</dl> |
|
|
98 |
<h2 id="returns">Returns</h2> |
|
|
99 |
<p>None.</p></div> |
|
|
100 |
<details class="source"> |
|
|
101 |
<summary> |
|
|
102 |
<span>Expand source code</span> |
|
|
103 |
</summary> |
|
|
104 |
<pre><code class="python">class VITAE(): |
|
|
105 |
""" |
|
|
106 |
Variational Inference for Trajectory by AutoEncoder. |
|
|
107 |
""" |
|
|
108 |
def __init__(self, adata: sc.AnnData, |
|
|
109 |
covariates = None, pi_covariates = None, |
|
|
110 |
model_type: str = 'Gaussian', |
|
|
111 |
npc: int = 64, |
|
|
112 |
adata_layer_counts = None, |
|
|
113 |
copy_adata: bool = False, |
|
|
114 |
hidden_layers = [32], |
|
|
115 |
latent_space_dim: int = 16, |
|
|
116 |
conditions = None): |
|
|
117 |
''' |
|
|
118 |
Get input data for model. Data need to be first processed using scancy and stored as an AnnData object |
|
|
119 |
The 'UMI' or 'non-UMI' model need the original count matrix, so the count matrix need to be saved in |
|
|
120 |
adata.layers in order to use these models. |
|
|
121 |
|
|
|
122 |
|
|
|
123 |
Parameters |
|
|
124 |
---------- |
|
|
125 |
adata : sc.AnnData |
|
|
126 |
The scanpy AnnData object. adata should already contain adata.var.highly_variable |
|
|
127 |
covariates : list, optional |
|
|
128 |
A list of names of covariate vectors that are stored in adata.obs |
|
|
129 |
pi_covariates: list, optional |
|
|
130 |
A list of names of covariate vectors used as input for pilayer |
|
|
131 |
model_type : str, optional |
|
|
132 |
'UMI', 'non-UMI' and 'Gaussian', default is 'Gaussian'. |
|
|
133 |
npc : int, optional |
|
|
134 |
The number of PCs to use when model_type is 'Gaussian'. The default is 64. |
|
|
135 |
adata_layer_counts: str, optional |
|
|
136 |
the key name of adata.layers that stores the count data if model_type is |
|
|
137 |
'UMI' or 'non-UMI' |
|
|
138 |
copy_adata: bool, optional. Set to True if we don't want VITAE to modify the original adata. If set to True, self.adata will be an independent copy of the original adata. |
|
|
139 |
hidden_layers : list, optional |
|
|
140 |
The list of dimensions of layers of autoencoder between latent space and original space. Default is to have only one hidden layer with 32 nodes |
|
|
141 |
latent_space_dim : int, optional |
|
|
142 |
The dimension of latent space. |
|
|
143 |
gamme : float, optional |
|
|
144 |
The weight of the MMD loss |
|
|
145 |
conditions : str or list, optional |
|
|
146 |
The conditions of different cells |
|
|
147 |
|
|
|
148 |
|
|
|
149 |
Returns |
|
|
150 |
------- |
|
|
151 |
None. |
|
|
152 |
|
|
|
153 |
''' |
|
|
154 |
self.dict_method_scname = { |
|
|
155 |
'PCA' : 'X_pca', |
|
|
156 |
'UMAP' : 'X_umap', |
|
|
157 |
'TSNE' : 'X_tsne', |
|
|
158 |
'diffmap' : 'X_diffmap', |
|
|
159 |
'draw_graph' : 'X_draw_graph_fa' |
|
|
160 |
} |
|
|
161 |
|
|
|
162 |
if model_type != 'Gaussian': |
|
|
163 |
if adata_layer_counts is None: |
|
|
164 |
raise ValueError("need to provide the name in adata.layers that stores the raw count data") |
|
|
165 |
if 'highly_variable' not in adata.var: |
|
|
166 |
raise ValueError("need to first select highly variable genes using scanpy") |
|
|
167 |
|
|
|
168 |
self.model_type = model_type |
|
|
169 |
|
|
|
170 |
if copy_adata: |
|
|
171 |
self.adata = adata.copy() |
|
|
172 |
else: |
|
|
173 |
self.adata = adata |
|
|
174 |
|
|
|
175 |
if covariates is not None: |
|
|
176 |
if isinstance(covariates, str): |
|
|
177 |
covariates = [covariates] |
|
|
178 |
covariates = np.array(covariates) |
|
|
179 |
id_cat = (adata.obs[covariates].dtypes == 'category') |
|
|
180 |
# add OneHotEncoder & StandardScaler as class variable if needed |
|
|
181 |
if np.sum(id_cat)>0: |
|
|
182 |
covariates_cat = OneHotEncoder(drop='if_binary', handle_unknown='ignore' |
|
|
183 |
).fit_transform(adata.obs[covariates[id_cat]]).toarray() |
|
|
184 |
else: |
|
|
185 |
covariates_cat = np.array([]).reshape(adata.shape[0],0) |
|
|
186 |
|
|
|
187 |
# temporarily disable StandardScaler |
|
|
188 |
if np.sum(~id_cat)>0: |
|
|
189 |
#covariates_con = StandardScaler().fit_transform(adata.obs[covariates[~id_cat]]) |
|
|
190 |
covariates_con = adata.obs[covariates[~id_cat]] |
|
|
191 |
else: |
|
|
192 |
covariates_con = np.array([]).reshape(adata.shape[0],0) |
|
|
193 |
|
|
|
194 |
self.covariates = np.c_[covariates_cat, covariates_con].astype(tf.keras.backend.floatx()) |
|
|
195 |
else: |
|
|
196 |
self.covariates = None |
|
|
197 |
|
|
|
198 |
if conditions is not None: |
|
|
199 |
## observations with np.nan will not participant in calculating mmd_loss |
|
|
200 |
if isinstance(conditions, str): |
|
|
201 |
conditions = [conditions] |
|
|
202 |
conditions = np.array(conditions) |
|
|
203 |
if np.any(adata.obs[conditions].dtypes != 'category'): |
|
|
204 |
raise ValueError("Conditions should all be categorical.") |
|
|
205 |
|
|
|
206 |
self.conditions = OrdinalEncoder(dtype=int, encoded_missing_value=-1).fit_transform(adata.obs[conditions]) + int(1) |
|
|
207 |
else: |
|
|
208 |
self.conditions = None |
|
|
209 |
|
|
|
210 |
if pi_covariates is not None: |
|
|
211 |
self.pi_cov = adata.obs[pi_covariates].to_numpy() |
|
|
212 |
if self.pi_cov.ndim == 1: |
|
|
213 |
self.pi_cov = self.pi_cov.reshape(-1, 1) |
|
|
214 |
self.pi_cov = self.pi_cov.astype(tf.keras.backend.floatx()) |
|
|
215 |
else: |
|
|
216 |
self.pi_cov = np.zeros((adata.shape[0],1), dtype=tf.keras.backend.floatx()) |
|
|
217 |
|
|
|
218 |
self.model_type = model_type |
|
|
219 |
self._adata = sc.AnnData(X = self.adata.X, var = self.adata.var) |
|
|
220 |
self._adata.obs = self.adata.obs |
|
|
221 |
self._adata.uns = self.adata.uns |
|
|
222 |
|
|
|
223 |
|
|
|
224 |
if model_type == 'Gaussian': |
|
|
225 |
sc.tl.pca(adata, n_comps = npc) |
|
|
226 |
self.X_input = self.X_output = adata.obsm['X_pca'] |
|
|
227 |
self.scale_factor = np.ones(self.X_output.shape[0]) |
|
|
228 |
else: |
|
|
229 |
print(f"{adata.var.highly_variable.sum()} highly variable genes selected as input") |
|
|
230 |
self.X_input = adata.X[:, adata.var.highly_variable] |
|
|
231 |
self.X_output = adata.layers[adata_layer_counts][ :, adata.var.highly_variable] |
|
|
232 |
self.scale_factor = np.sum(self.X_output, axis=1, keepdims=True)/1e4 |
|
|
233 |
|
|
|
234 |
self.dimensions = hidden_layers |
|
|
235 |
self.dim_latent = latent_space_dim |
|
|
236 |
|
|
|
237 |
self.vae = model.VariationalAutoEncoder( |
|
|
238 |
self.X_output.shape[1], self.dimensions, |
|
|
239 |
self.dim_latent, self.model_type, |
|
|
240 |
False if self.covariates is None else True, |
|
|
241 |
) |
|
|
242 |
|
|
|
243 |
if hasattr(self, 'inferer'): |
|
|
244 |
delattr(self, 'inferer') |
|
|
245 |
|
|
|
246 |
|
|
|
247 |
def pre_train(self, test_size = 0.1, random_state: int = 0, |
|
|
248 |
learning_rate: float = 1e-3, batch_size: int = 256, L: int = 1, alpha: float = 0.10, gamma: float = 0, |
|
|
249 |
phi : float = 1,num_epoch: int = 200, num_step_per_epoch: Optional[int] = None, |
|
|
250 |
early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, |
|
|
251 |
early_stopping_relative: bool = True, verbose: bool = False,path_to_weights: Optional[str] = None): |
|
|
252 |
'''Pretrain the model with specified learning rate. |
|
|
253 |
|
|
|
254 |
Parameters |
|
|
255 |
---------- |
|
|
256 |
test_size : float or int, optional |
|
|
257 |
The proportion or size of the test set. |
|
|
258 |
random_state : int, optional |
|
|
259 |
The random state for data splitting. |
|
|
260 |
learning_rate : float, optional |
|
|
261 |
The initial learning rate for the Adam optimizer. |
|
|
262 |
batch_size : int, optional |
|
|
263 |
The batch size for pre-training. Default is 256. Set to 32 if number of cells is small (less than 1000) |
|
|
264 |
L : int, optional |
|
|
265 |
The number of MC samples. |
|
|
266 |
alpha : float, optional |
|
|
267 |
The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates. |
|
|
268 |
gamma : float, optional |
|
|
269 |
The weight of the mmd loss if used. |
|
|
270 |
phi : float, optional |
|
|
271 |
The weight of Jocob norm of the encoder. |
|
|
272 |
num_epoch : int, optional |
|
|
273 |
The maximum number of epochs. |
|
|
274 |
num_step_per_epoch : int, optional |
|
|
275 |
The number of step per epoch, it will be inferred from number of cells and batch size if it is None. |
|
|
276 |
early_stopping_patience : int, optional |
|
|
277 |
The maximum number of epochs if there is no improvement. |
|
|
278 |
early_stopping_tolerance : float, optional |
|
|
279 |
The minimum change of loss to be considered as an improvement. |
|
|
280 |
early_stopping_relative : bool, optional |
|
|
281 |
Whether monitor the relative change of loss as stopping criteria or not. |
|
|
282 |
path_to_weights : str, optional |
|
|
283 |
The path of weight file to be saved; not saving weight if None. |
|
|
284 |
conditions : str or list, optional |
|
|
285 |
The conditions of different cells |
|
|
286 |
''' |
|
|
287 |
|
|
|
288 |
id_train, id_test = train_test_split( |
|
|
289 |
np.arange(self.X_input.shape[0]), |
|
|
290 |
test_size=test_size, |
|
|
291 |
random_state=random_state) |
|
|
292 |
if num_step_per_epoch is None: |
|
|
293 |
num_step_per_epoch = len(id_train)//batch_size+1 |
|
|
294 |
self.train_dataset = train.warp_dataset(self.X_input[id_train].astype(tf.keras.backend.floatx()), |
|
|
295 |
None if self.covariates is None else self.covariates[id_train].astype(tf.keras.backend.floatx()), |
|
|
296 |
batch_size, |
|
|
297 |
self.X_output[id_train].astype(tf.keras.backend.floatx()), |
|
|
298 |
self.scale_factor[id_train].astype(tf.keras.backend.floatx()), |
|
|
299 |
conditions = None if self.conditions is None else self.conditions[id_train].astype(tf.keras.backend.floatx())) |
|
|
300 |
self.test_dataset = train.warp_dataset(self.X_input[id_test], |
|
|
301 |
None if self.covariates is None else self.covariates[id_test].astype(tf.keras.backend.floatx()), |
|
|
302 |
batch_size, |
|
|
303 |
self.X_output[id_test].astype(tf.keras.backend.floatx()), |
|
|
304 |
self.scale_factor[id_test].astype(tf.keras.backend.floatx()), |
|
|
305 |
conditions = None if self.conditions is None else self.conditions[id_test].astype(tf.keras.backend.floatx())) |
|
|
306 |
|
|
|
307 |
self.vae = train.pre_train( |
|
|
308 |
self.train_dataset, |
|
|
309 |
self.test_dataset, |
|
|
310 |
self.vae, |
|
|
311 |
learning_rate, |
|
|
312 |
L, alpha, gamma, phi, |
|
|
313 |
num_epoch, |
|
|
314 |
num_step_per_epoch, |
|
|
315 |
early_stopping_patience, |
|
|
316 |
early_stopping_tolerance, |
|
|
317 |
early_stopping_relative, |
|
|
318 |
verbose) |
|
|
319 |
|
|
|
320 |
self.update_z() |
|
|
321 |
|
|
|
322 |
if path_to_weights is not None: |
|
|
323 |
self.save_model(path_to_weights) |
|
|
324 |
|
|
|
325 |
|
|
|
326 |
def update_z(self): |
|
|
327 |
self.z = self.get_latent_z() |
|
|
328 |
self._adata_z = sc.AnnData(self.z) |
|
|
329 |
sc.pp.neighbors(self._adata_z) |
|
|
330 |
|
|
|
331 |
|
|
|
332 |
def get_latent_z(self): |
|
|
333 |
''' get the posterier mean of current latent space z (encoder output) |
|
|
334 |
|
|
|
335 |
Returns |
|
|
336 |
---------- |
|
|
337 |
z : np.array |
|
|
338 |
\([N,d]\) The latent means. |
|
|
339 |
''' |
|
|
340 |
c = None if self.covariates is None else self.covariates |
|
|
341 |
return self.vae.get_z(self.X_input, c) |
|
|
342 |
|
|
|
343 |
|
|
|
344 |
def visualize_latent(self, method: str = "UMAP", |
|
|
345 |
color = None, **kwargs): |
|
|
346 |
''' |
|
|
347 |
visualize the current latent space z using the scanpy visualization tools |
|
|
348 |
|
|
|
349 |
Parameters |
|
|
350 |
---------- |
|
|
351 |
method : str, optional |
|
|
352 |
Visualization method to use. The default is "draw_graph" (the FA plot). Possible choices include "PCA", "UMAP", |
|
|
353 |
"diffmap", "TSNE" and "draw_graph" |
|
|
354 |
color : TYPE, optional |
|
|
355 |
Keys for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2']. |
|
|
356 |
The default is None. Same as scanpy. |
|
|
357 |
**kwargs : |
|
|
358 |
Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX). |
|
|
359 |
|
|
|
360 |
Returns |
|
|
361 |
------- |
|
|
362 |
None. |
|
|
363 |
|
|
|
364 |
''' |
|
|
365 |
|
|
|
366 |
if method not in ['PCA', 'UMAP', 'TSNE', 'diffmap', 'draw_graph']: |
|
|
367 |
raise ValueError("visualization method should be one of 'PCA', 'UMAP', 'TSNE', 'diffmap' and 'draw_graph'") |
|
|
368 |
|
|
|
369 |
temp = list(self._adata_z.obsm.keys()) |
|
|
370 |
if method == 'PCA' and not 'X_pca' in temp: |
|
|
371 |
print("Calculate PCs ...") |
|
|
372 |
sc.tl.pca(self._adata_z) |
|
|
373 |
elif method == 'UMAP' and not 'X_umap' in temp: |
|
|
374 |
print("Calculate UMAP ...") |
|
|
375 |
sc.tl.umap(self._adata_z) |
|
|
376 |
elif method == 'TSNE' and not 'X_tsne' in temp: |
|
|
377 |
print("Calculate TSNE ...") |
|
|
378 |
sc.tl.tsne(self._adata_z) |
|
|
379 |
elif method == 'diffmap' and not 'X_diffmap' in temp: |
|
|
380 |
print("Calculate diffusion map ...") |
|
|
381 |
sc.tl.diffmap(self._adata_z) |
|
|
382 |
elif method == 'draw_graph' and not 'X_draw_graph_fa' in temp: |
|
|
383 |
print("Calculate FA ...") |
|
|
384 |
sc.tl.draw_graph(self._adata_z) |
|
|
385 |
|
|
|
386 |
|
|
|
387 |
self._adata.obs = self.adata.obs.copy() |
|
|
388 |
self._adata.obsp = self._adata_z.obsp |
|
|
389 |
# self._adata.uns = self._adata_z.uns |
|
|
390 |
self._adata.obsm = self._adata_z.obsm |
|
|
391 |
|
|
|
392 |
if method == 'PCA': |
|
|
393 |
axes = sc.pl.pca(self._adata, color = color, **kwargs) |
|
|
394 |
elif method == 'UMAP': |
|
|
395 |
axes = sc.pl.umap(self._adata, color = color, **kwargs) |
|
|
396 |
elif method == 'TSNE': |
|
|
397 |
axes = sc.pl.tsne(self._adata, color = color, **kwargs) |
|
|
398 |
elif method == 'diffmap': |
|
|
399 |
axes = sc.pl.diffmap(self._adata, color = color, **kwargs) |
|
|
400 |
elif method == 'draw_graph': |
|
|
401 |
axes = sc.pl.draw_graph(self._adata, color = color, **kwargs) |
|
|
402 |
return axes |
|
|
403 |
|
|
|
404 |
|
|
|
405 |
def init_latent_space(self, cluster_label = None, log_pi = None, res: float = 1.0, |
|
|
406 |
ratio_prune= None, dist = None, dist_thres = 0.5, topk=0, pilayer = False): |
|
|
407 |
'''Initialize the latent space. |
|
|
408 |
|
|
|
409 |
Parameters |
|
|
410 |
---------- |
|
|
411 |
cluster_label : str, optional |
|
|
412 |
The name of vector of labels that can be found in self.adata.obs. |
|
|
413 |
Default is None, which will perform leiden clustering on the pretrained z to get clusters |
|
|
414 |
mu : np.array, optional |
|
|
415 |
\([d,k]\) The value of initial \(\\mu\). |
|
|
416 |
log_pi : np.array, optional |
|
|
417 |
\([1,K]\) The value of initial \(\\log(\\pi)\). |
|
|
418 |
res: |
|
|
419 |
The resolution of leiden clustering, which is a parameter value controlling the coarseness of the clustering. |
|
|
420 |
Higher values lead to more clusters. Deafult is 1. |
|
|
421 |
ratio_prune : float, optional |
|
|
422 |
The ratio of edges to be removed before estimating. |
|
|
423 |
topk : int, optional |
|
|
424 |
The number of top k neighbors to keep for each cluster. |
|
|
425 |
''' |
|
|
426 |
|
|
|
427 |
|
|
|
428 |
if cluster_label is None: |
|
|
429 |
print("Perform leiden clustering on the latent space z ...") |
|
|
430 |
g = get_igraph(self.z) |
|
|
431 |
cluster_labels = leidenalg_igraph(g, res = res) |
|
|
432 |
cluster_labels = cluster_labels.astype(str) |
|
|
433 |
uni_cluster_labels = np.unique(cluster_labels) |
|
|
434 |
else: |
|
|
435 |
if isinstance(cluster_label,str): |
|
|
436 |
cluster_labels = self.adata.obs[cluster_label].to_numpy() |
|
|
437 |
uni_cluster_labels = np.array(self.adata.obs[cluster_label].cat.categories) |
|
|
438 |
else: |
|
|
439 |
## if cluster_label is a list |
|
|
440 |
cluster_labels = cluster_label |
|
|
441 |
uni_cluster_labels = np.unique(cluster_labels) |
|
|
442 |
|
|
|
443 |
n_clusters = len(uni_cluster_labels) |
|
|
444 |
|
|
|
445 |
if not hasattr(self, 'z'): |
|
|
446 |
self.update_z() |
|
|
447 |
z = self.z |
|
|
448 |
mu = np.zeros((z.shape[1], n_clusters)) |
|
|
449 |
for i,l in enumerate(uni_cluster_labels): |
|
|
450 |
mu[:,i] = np.mean(z[cluster_labels==l], axis=0) |
|
|
451 |
|
|
|
452 |
if dist is None: |
|
|
453 |
### update cluster centers if some cluster centers are too close |
|
|
454 |
clustering = AgglomerativeClustering( |
|
|
455 |
n_clusters=None, |
|
|
456 |
distance_threshold=dist_thres, |
|
|
457 |
linkage='complete' |
|
|
458 |
).fit(mu.T/np.sqrt(mu.shape[0])) |
|
|
459 |
n_clusters_new = clustering.n_clusters_ |
|
|
460 |
if n_clusters_new < n_clusters: |
|
|
461 |
print("Merge clusters for cluster centers that are too close ...") |
|
|
462 |
n_clusters = n_clusters_new |
|
|
463 |
for i in range(n_clusters): |
|
|
464 |
temp = uni_cluster_labels[clustering.labels_ == i] |
|
|
465 |
idx = np.isin(cluster_labels, temp) |
|
|
466 |
cluster_labels[idx] = ','.join(temp) |
|
|
467 |
if np.sum(clustering.labels_==i)>1: |
|
|
468 |
print('Merge %s'% ','.join(temp)) |
|
|
469 |
uni_cluster_labels = np.unique(cluster_labels) |
|
|
470 |
mu = np.zeros((z.shape[1], n_clusters)) |
|
|
471 |
for i,l in enumerate(uni_cluster_labels): |
|
|
472 |
mu[:,i] = np.mean(z[cluster_labels==l], axis=0) |
|
|
473 |
|
|
|
474 |
self.adata.obs['vitae_init_clustering'] = cluster_labels |
|
|
475 |
self.adata.obs['vitae_init_clustering'] = self.adata.obs['vitae_init_clustering'].astype('category') |
|
|
476 |
print("Initial clustering labels saved as 'vitae_init_clustering' in self.adata.obs.") |
|
|
477 |
|
|
|
478 |
if (log_pi is None) and (cluster_labels is not None) and (n_clusters>3): |
|
|
479 |
n_states = int((n_clusters+1)*n_clusters/2) |
|
|
480 |
|
|
|
481 |
if dist is None: |
|
|
482 |
dist = _comp_dist(z, cluster_labels, mu.T) |
|
|
483 |
|
|
|
484 |
C = np.triu(np.ones(n_clusters)) |
|
|
485 |
C[C>0] = np.arange(n_states) |
|
|
486 |
C = C + C.T - np.diag(np.diag(C)) |
|
|
487 |
C = C.astype(int) |
|
|
488 |
|
|
|
489 |
log_pi = np.zeros((1,n_states)) |
|
|
490 |
|
|
|
491 |
## pruning to throw away edges for far-away clusters if there are too many clusters |
|
|
492 |
if ratio_prune is not None: |
|
|
493 |
log_pi[0, C[np.triu(dist)>np.quantile(dist[np.triu_indices(n_clusters, 1)], 1-ratio_prune)]] = - np.inf |
|
|
494 |
else: |
|
|
495 |
log_pi[0, C[np.triu(dist)>np.quantile(dist[np.triu_indices(n_clusters, 1)], 5/n_clusters) * 3]] = - np.inf |
|
|
496 |
|
|
|
497 |
## also keep the top k neighbor of clusters |
|
|
498 |
topk = max(0, min(topk, n_clusters-1)) + 1 |
|
|
499 |
topk_indices = np.argsort(dist,axis=1)[:,:topk] |
|
|
500 |
for i in range(n_clusters): |
|
|
501 |
log_pi[0, C[i, topk_indices[i]]] = 0 |
|
|
502 |
|
|
|
503 |
self.n_states = n_clusters |
|
|
504 |
self.labels = cluster_labels |
|
|
505 |
|
|
|
506 |
labels_map = pd.DataFrame.from_dict( |
|
|
507 |
{i:label for i,label in enumerate(uni_cluster_labels)}, |
|
|
508 |
orient='index', columns=['label_names'], dtype=str |
|
|
509 |
) |
|
|
510 |
|
|
|
511 |
self.labels_map = labels_map |
|
|
512 |
self.vae.init_latent_space(self.n_states, mu, log_pi) |
|
|
513 |
self.inferer = Inferer(self.n_states) |
|
|
514 |
self.mu = self.vae.latent_space.mu.numpy() |
|
|
515 |
self.pi = np.triu(np.ones(self.n_states)) |
|
|
516 |
self.pi[self.pi > 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0] |
|
|
517 |
|
|
|
518 |
if pilayer: |
|
|
519 |
self.vae.create_pilayer() |
|
|
520 |
|
|
|
521 |
|
|
|
522 |
def update_latent_space(self, dist_thres: float=0.5): |
|
|
523 |
pi = self.pi[np.triu_indices(self.n_states)] |
|
|
524 |
mu = self.mu |
|
|
525 |
clustering = AgglomerativeClustering( |
|
|
526 |
n_clusters=None, |
|
|
527 |
distance_threshold=dist_thres, |
|
|
528 |
linkage='complete' |
|
|
529 |
).fit(mu.T/np.sqrt(mu.shape[0])) |
|
|
530 |
n_clusters = clustering.n_clusters_ |
|
|
531 |
|
|
|
532 |
if n_clusters<self.n_states: |
|
|
533 |
print("Merge clusters for cluster centers that are too close ...") |
|
|
534 |
mu_new = np.empty((self.dim_latent, n_clusters)) |
|
|
535 |
C = np.zeros((self.n_states, self.n_states)) |
|
|
536 |
C[np.triu_indices(self.n_states, 0)] = pi |
|
|
537 |
C = np.triu(C, 1) + C.T |
|
|
538 |
C_new = np.zeros((n_clusters, n_clusters)) |
|
|
539 |
|
|
|
540 |
uni_cluster_labels = self.labels_map['label_names'].to_numpy() |
|
|
541 |
returned_order = {} |
|
|
542 |
cluster_labels = self.labels |
|
|
543 |
for i in range(n_clusters): |
|
|
544 |
temp = uni_cluster_labels[clustering.labels_ == i] |
|
|
545 |
idx = np.isin(cluster_labels, temp) |
|
|
546 |
cluster_labels[idx] = ','.join(temp) |
|
|
547 |
returned_order[i] = ','.join(temp) |
|
|
548 |
if np.sum(clustering.labels_==i)>1: |
|
|
549 |
print('Merge %s'% ','.join(temp)) |
|
|
550 |
uni_cluster_labels = np.unique(cluster_labels) |
|
|
551 |
for i,l in enumerate(uni_cluster_labels): ## reorder the merged clusters based on the cluster names |
|
|
552 |
k = np.where(returned_order == l) |
|
|
553 |
mu_new[:, i] = np.mean(mu[:,clustering.labels_==k], axis=-1) |
|
|
554 |
# sum of the aggregated pi's |
|
|
555 |
C_new[i, i] = np.sum(np.triu(C[clustering.labels_==k,:][:,clustering.labels_==k])) |
|
|
556 |
for j in range(i+1, n_clusters): |
|
|
557 |
k1 = np.where(returned_order == uni_cluster_labels[j]) |
|
|
558 |
C_new[i, j] = np.sum(C[clustering.labels_== k, :][:, clustering.labels_==k1]) |
|
|
559 |
|
|
|
560 |
# labels_map_new = {} |
|
|
561 |
# for i in range(n_clusters): |
|
|
562 |
# # update label map: int->str |
|
|
563 |
# labels_map_new[i] = self.labels_map.loc[clustering.labels_==i, 'label_names'].str.cat(sep=',') |
|
|
564 |
# if np.sum(clustering.labels_==i)>1: |
|
|
565 |
# print('Merge %s'%labels_map_new[i]) |
|
|
566 |
# # mean of the aggregated cluster means |
|
|
567 |
# mu_new[:, i] = np.mean(mu[:,clustering.labels_==i], axis=-1) |
|
|
568 |
# # sum of the aggregated pi's |
|
|
569 |
# C_new[i, i] = np.sum(np.triu(C[clustering.labels_==i,:][:,clustering.labels_==i])) |
|
|
570 |
# for j in range(i+1, n_clusters): |
|
|
571 |
# C_new[i, j] = np.sum(C[clustering.labels_== i, :][:, clustering.labels_==j]) |
|
|
572 |
C_new = np.triu(C_new,1) + C_new.T |
|
|
573 |
|
|
|
574 |
pi_new = C_new[np.triu_indices(n_clusters)] |
|
|
575 |
log_pi_new = np.log(pi_new, out=np.ones_like(pi_new)*(-np.inf), where=(pi_new!=0)).reshape((1,-1)) |
|
|
576 |
self.n_states = n_clusters |
|
|
577 |
self.labels_map = pd.DataFrame.from_dict( |
|
|
578 |
{i:label for i,label in enumerate(uni_cluster_labels)}, |
|
|
579 |
orient='index', columns=['label_names'], dtype=str |
|
|
580 |
) |
|
|
581 |
self.labels = cluster_labels |
|
|
582 |
# self.labels_map = pd.DataFrame.from_dict( |
|
|
583 |
# labels_map_new, orient='index', columns=['label_names'], dtype=str |
|
|
584 |
# ) |
|
|
585 |
self.vae.init_latent_space(self.n_states, mu_new, log_pi_new) |
|
|
586 |
self.inferer = Inferer(self.n_states) |
|
|
587 |
self.mu = self.vae.latent_space.mu.numpy() |
|
|
588 |
self.pi = np.triu(np.ones(self.n_states)) |
|
|
589 |
self.pi[self.pi > 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0] |
|
|
590 |
|
|
|
591 |
|
|
|
592 |
|
|
|
593 |
def train(self, stratify = False, test_size = 0.1, random_state: int = 0, |
|
|
594 |
learning_rate: float = 1e-3, batch_size: int = 256, |
|
|
595 |
L: int = 1, alpha: float = 0.10, beta: float = 1, gamma: float = 0, phi: float = 1, |
|
|
596 |
num_epoch: int = 200, num_step_per_epoch: Optional[int] = None, |
|
|
597 |
early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, |
|
|
598 |
early_stopping_relative: bool = True, early_stopping_warmup: int = 0, |
|
|
599 |
path_to_weights: Optional[str] = None, |
|
|
600 |
verbose: bool = False, **kwargs): |
|
|
601 |
'''Train the model. |
|
|
602 |
|
|
|
603 |
Parameters |
|
|
604 |
---------- |
|
|
605 |
stratify : np.array, None, or False |
|
|
606 |
If an array is provided, or `stratify=None` and `self.labels` is available, then they will be used to perform stratified shuffle splitting. Otherwise, general shuffle splitting is used. Set to `False` if `self.labels` is not intended for stratified shuffle splitting. |
|
|
607 |
test_size : float or int, optional |
|
|
608 |
The proportion or size of the test set. |
|
|
609 |
random_state : int, optional |
|
|
610 |
The random state for data splitting. |
|
|
611 |
learning_rate : float, optional |
|
|
612 |
The initial learning rate for the Adam optimizer. |
|
|
613 |
batch_size : int, optional |
|
|
614 |
The batch size for training. Default is 256. Set to 32 if number of cells is small (less than 1000) |
|
|
615 |
L : int, optional |
|
|
616 |
The number of MC samples. |
|
|
617 |
alpha : float, optional |
|
|
618 |
The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates. |
|
|
619 |
beta : float, optional |
|
|
620 |
The value of beta in beta-VAE. |
|
|
621 |
gamma : float, optional |
|
|
622 |
The weight of mmd_loss. |
|
|
623 |
phi : float, optional |
|
|
624 |
The weight of Jacob norm of encoder. |
|
|
625 |
num_epoch : int, optional |
|
|
626 |
The number of epoch. |
|
|
627 |
num_step_per_epoch : int, optional |
|
|
628 |
The number of step per epoch, it will be inferred from number of cells and batch size if it is None. |
|
|
629 |
early_stopping_patience : int, optional |
|
|
630 |
The maximum number of epochs if there is no improvement. |
|
|
631 |
early_stopping_tolerance : float, optional |
|
|
632 |
The minimum change of loss to be considered as an improvement. |
|
|
633 |
early_stopping_relative : bool, optional |
|
|
634 |
Whether monitor the relative change of loss or not. |
|
|
635 |
early_stopping_warmup : int, optional |
|
|
636 |
The number of warmup epochs. |
|
|
637 |
path_to_weights : str, optional |
|
|
638 |
The path of weight file to be saved; not saving weight if None. |
|
|
639 |
**kwargs : |
|
|
640 |
Extra key-value arguments for dimension reduction algorithms. |
|
|
641 |
''' |
|
|
642 |
if gamma == 0 or self.conditions is None: |
|
|
643 |
conditions = np.array([np.nan] * self.adata.shape[0]) |
|
|
644 |
else: |
|
|
645 |
conditions = self.conditions |
|
|
646 |
|
|
|
647 |
if stratify is None: |
|
|
648 |
stratify = self.labels |
|
|
649 |
elif stratify is False: |
|
|
650 |
stratify = None |
|
|
651 |
id_train, id_test = train_test_split( |
|
|
652 |
np.arange(self.X_input.shape[0]), |
|
|
653 |
test_size=test_size, |
|
|
654 |
stratify=stratify, |
|
|
655 |
random_state=random_state) |
|
|
656 |
if num_step_per_epoch is None: |
|
|
657 |
num_step_per_epoch = len(id_train)//batch_size+1 |
|
|
658 |
c = None if self.covariates is None else self.covariates.astype(tf.keras.backend.floatx()) |
|
|
659 |
self.train_dataset = train.warp_dataset(self.X_input[id_train].astype(tf.keras.backend.floatx()), |
|
|
660 |
None if c is None else c[id_train], |
|
|
661 |
batch_size, |
|
|
662 |
self.X_output[id_train].astype(tf.keras.backend.floatx()), |
|
|
663 |
self.scale_factor[id_train].astype(tf.keras.backend.floatx()), |
|
|
664 |
conditions = conditions[id_train], |
|
|
665 |
pi_cov = self.pi_cov[id_train]) |
|
|
666 |
self.test_dataset = train.warp_dataset(self.X_input[id_test].astype(tf.keras.backend.floatx()), |
|
|
667 |
None if c is None else c[id_test], |
|
|
668 |
batch_size, |
|
|
669 |
self.X_output[id_test].astype(tf.keras.backend.floatx()), |
|
|
670 |
self.scale_factor[id_test].astype(tf.keras.backend.floatx()), |
|
|
671 |
conditions = conditions[id_test], |
|
|
672 |
pi_cov = self.pi_cov[id_test]) |
|
|
673 |
|
|
|
674 |
self.vae = train.train( |
|
|
675 |
self.train_dataset, |
|
|
676 |
self.test_dataset, |
|
|
677 |
self.vae, |
|
|
678 |
learning_rate, |
|
|
679 |
L, |
|
|
680 |
alpha, |
|
|
681 |
beta, |
|
|
682 |
gamma, |
|
|
683 |
phi, |
|
|
684 |
num_epoch, |
|
|
685 |
num_step_per_epoch, |
|
|
686 |
early_stopping_patience, |
|
|
687 |
early_stopping_tolerance, |
|
|
688 |
early_stopping_relative, |
|
|
689 |
early_stopping_warmup, |
|
|
690 |
verbose, |
|
|
691 |
**kwargs |
|
|
692 |
) |
|
|
693 |
|
|
|
694 |
self.update_z() |
|
|
695 |
self.mu = self.vae.latent_space.mu.numpy() |
|
|
696 |
self.pi = np.triu(np.ones(self.n_states)) |
|
|
697 |
self.pi[self.pi > 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0] |
|
|
698 |
|
|
|
699 |
if path_to_weights is not None: |
|
|
700 |
self.save_model(path_to_weights) |
|
|
701 |
|
|
|
702 |
|
|
|
703 |
def output_pi(self, pi_cov): |
|
|
704 |
"""return a matrix n_states by n_states and a mask for plotting, which can be used to cover the lower triangular(except the diagnoals) of a heatmap""" |
|
|
705 |
p = self.vae.pilayer |
|
|
706 |
pi_cov = tf.expand_dims(tf.constant([pi_cov], dtype=tf.float32), 0) |
|
|
707 |
pi_val = tf.nn.softmax(p(pi_cov)).numpy()[0] |
|
|
708 |
# Create heatmap matrix |
|
|
709 |
n = self.vae.n_states |
|
|
710 |
matrix = np.zeros((n, n)) |
|
|
711 |
matrix[np.triu_indices(n)] = pi_val |
|
|
712 |
mask = np.tril(np.ones_like(matrix), k=-1) |
|
|
713 |
return matrix, mask |
|
|
714 |
|
|
|
715 |
|
|
|
716 |
def return_pilayer_weights(self): |
|
|
717 |
"""return parameters of pilayer, which has dimension dim(pi_cov) + 1 by n_categories, the last row is biases""" |
|
|
718 |
return np.vstack((model.vae.pilayer.weights[0].numpy(), model.vae.pilayer.weights[1].numpy().reshape(1, -1))) |
|
|
719 |
|
|
|
720 |
|
|
|
721 |
def posterior_estimation(self, batch_size: int = 32, L: int = 50, **kwargs): |
|
|
722 |
'''Initialize trajectory inference by computing the posterior estimations. |
|
|
723 |
|
|
|
724 |
Parameters |
|
|
725 |
---------- |
|
|
726 |
batch_size : int, optional |
|
|
727 |
The batch size when doing inference. |
|
|
728 |
L : int, optional |
|
|
729 |
The number of MC samples when doing inference. |
|
|
730 |
**kwargs : |
|
|
731 |
Extra key-value arguments for dimension reduction algorithms. |
|
|
732 |
''' |
|
|
733 |
c = None if self.covariates is None else self.covariates.astype(tf.keras.backend.floatx()) |
|
|
734 |
self.test_dataset = train.warp_dataset(self.X_input.astype(tf.keras.backend.floatx()), |
|
|
735 |
c, |
|
|
736 |
batch_size) |
|
|
737 |
_, _, self.pc_x,\ |
|
|
738 |
self.cell_position_posterior,self.cell_position_variance,_ = self.vae.inference(self.test_dataset, L=L) |
|
|
739 |
|
|
|
740 |
uni_cluster_labels = self.labels_map['label_names'].to_numpy() |
|
|
741 |
self.adata.obs['vitae_new_clustering'] = uni_cluster_labels[np.argmax(self.cell_position_posterior, 1)] |
|
|
742 |
self.adata.obs['vitae_new_clustering'] = self.adata.obs['vitae_new_clustering'].astype('category') |
|
|
743 |
print("New clustering labels saved as 'vitae_new_clustering' in self.adata.obs.") |
|
|
744 |
return None |
|
|
745 |
|
|
|
746 |
|
|
|
747 |
def infer_backbone(self, method: str = 'modified_map', thres = 0.5, |
|
|
748 |
no_loop: bool = True, cutoff: float = 0, |
|
|
749 |
visualize: bool = True, color = 'vitae_new_clustering',path_to_fig = None,**kwargs): |
|
|
750 |
''' Compute edge scores. |
|
|
751 |
|
|
|
752 |
Parameters |
|
|
753 |
---------- |
|
|
754 |
method : string, optional |
|
|
755 |
'mean', 'modified_mean', 'map', or 'modified_map'. |
|
|
756 |
thres : float, optional |
|
|
757 |
The threshold used for filtering edges \(e_{ij}\) that \((n_{i}+n_{j}+e_{ij})/N<thres\), only applied to mean method. |
|
|
758 |
no_loop : boolean, optional |
|
|
759 |
Whether loops are allowed to exist in the graph. If no_loop is true, will prune the graph to contain only the |
|
|
760 |
maximum spanning true |
|
|
761 |
cutoff : string, optional |
|
|
762 |
The score threshold for filtering edges with scores less than cutoff. |
|
|
763 |
visualize: boolean |
|
|
764 |
whether plot the current trajectory backbone (undirected graph) |
|
|
765 |
|
|
|
766 |
Returns |
|
|
767 |
---------- |
|
|
768 |
G : nx.Graph |
|
|
769 |
The weighted graph with weight on each edge indicating its score of existence. |
|
|
770 |
''' |
|
|
771 |
# build_graph, return graph |
|
|
772 |
self.backbone = self.inferer.build_graphs(self.cell_position_posterior, self.pc_x, |
|
|
773 |
method, thres, no_loop, cutoff) |
|
|
774 |
self.cell_position_projected = self.inferer.modify_wtilde(self.cell_position_posterior, |
|
|
775 |
np.array(list(self.backbone.edges))) |
|
|
776 |
|
|
|
777 |
uni_cluster_labels = self.labels_map['label_names'].to_numpy() |
|
|
778 |
temp_dict = {i:label for i,label in enumerate(uni_cluster_labels)} |
|
|
779 |
nx.relabel_nodes(self.backbone, temp_dict) |
|
|
780 |
|
|
|
781 |
self.adata.obs['vitae_new_clustering'] = uni_cluster_labels[np.argmax(self.cell_position_projected, 1)] |
|
|
782 |
self.adata.obs['vitae_new_clustering'] = self.adata.obs['vitae_new_clustering'].astype('category') |
|
|
783 |
print("'vitae_new_clustering' updated based on the projected cell positions.") |
|
|
784 |
|
|
|
785 |
self.uncertainty = np.sum((self.cell_position_projected - self.cell_position_posterior)**2, axis=-1) \ |
|
|
786 |
+ np.sum(self.cell_position_variance, axis=-1) |
|
|
787 |
self.adata.obs['projection_uncertainty'] = self.uncertainty |
|
|
788 |
print("Cell projection uncertainties stored as 'projection_uncertainty' in self.adata.obs") |
|
|
789 |
if visualize: |
|
|
790 |
self._adata.obs = self.adata.obs.copy() |
|
|
791 |
self.ax = self.plot_backbone(directed = False,color = color, **kwargs) |
|
|
792 |
if path_to_fig is not None: |
|
|
793 |
self.ax.figure.savefig(path_to_fig) |
|
|
794 |
self.ax.figure.show() |
|
|
795 |
return None |
|
|
796 |
|
|
|
797 |
|
|
|
798 |
def select_root(self, days, method: str = 'proportion'): |
|
|
799 |
'''Order the vertices/states based on cells' collection time information to select the root state. |
|
|
800 |
|
|
|
801 |
Parameters |
|
|
802 |
---------- |
|
|
803 |
day : np.array |
|
|
804 |
The day information for selected cells used to determine the root vertex. |
|
|
805 |
The dtype should be 'int' or 'float'. |
|
|
806 |
method : str, optional |
|
|
807 |
'sum' or 'mean'. |
|
|
808 |
For 'proportion', the root is the one with maximal proportion of cells from the earliest day. |
|
|
809 |
For 'mean', the root is the one with earliest mean time among cells associated with it. |
|
|
810 |
|
|
|
811 |
Returns |
|
|
812 |
---------- |
|
|
813 |
root : int |
|
|
814 |
The root vertex in the inferred trajectory based on given day information. |
|
|
815 |
''' |
|
|
816 |
## TODO: change return description |
|
|
817 |
if days is not None and len(days)!=self.X_input.shape[0]: |
|
|
818 |
raise ValueError("The length of day information ({}) is not " |
|
|
819 |
"consistent with the number of selected cells ({})!".format( |
|
|
820 |
len(days), self.X_input.shape[0])) |
|
|
821 |
if not hasattr(self, 'cell_position_projected'): |
|
|
822 |
raise ValueError("Need to call 'infer_backbone' first!") |
|
|
823 |
|
|
|
824 |
collection_time = np.dot(days, self.cell_position_projected)/np.sum(self.cell_position_projected, axis = 0) |
|
|
825 |
earliest_prop = np.dot(days==np.min(days), self.cell_position_projected)/np.sum(self.cell_position_projected, axis = 0) |
|
|
826 |
|
|
|
827 |
root_info = self.labels_map.copy() |
|
|
828 |
root_info['mean_collection_time'] = collection_time |
|
|
829 |
root_info['earliest_time_prop'] = earliest_prop |
|
|
830 |
root_info.sort_values('mean_collection_time', inplace=True) |
|
|
831 |
return root_info |
|
|
832 |
|
|
|
833 |
|
|
|
834 |
def plot_backbone(self, directed: bool = False, |
|
|
835 |
method: str = 'UMAP', color = 'vitae_new_clustering', **kwargs): |
|
|
836 |
'''Plot the current trajectory backbone (undirected graph). |
|
|
837 |
|
|
|
838 |
Parameters |
|
|
839 |
---------- |
|
|
840 |
directed : boolean, optional |
|
|
841 |
Whether the backbone is directed or not. |
|
|
842 |
method : str, optional |
|
|
843 |
The dimension reduction method to use. The default is "UMAP". |
|
|
844 |
color : str, optional |
|
|
845 |
The key for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2']. |
|
|
846 |
The default is 'vitae_new_clustering'. |
|
|
847 |
**kwargs : |
|
|
848 |
Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX). |
|
|
849 |
''' |
|
|
850 |
if not isinstance(color,str): |
|
|
851 |
raise ValueError('The color argument should be of type str!') |
|
|
852 |
ax = self.visualize_latent(method = method, color=color, show=False, **kwargs) |
|
|
853 |
dict_label_num = {j:i for i,j in self.labels_map['label_names'].to_dict().items()} |
|
|
854 |
uni_cluster_labels = self.adata.obs['vitae_init_clustering'].cat.categories |
|
|
855 |
cluster_labels = self.adata.obs['vitae_new_clustering'].to_numpy() |
|
|
856 |
embed_z = self._adata.obsm[self.dict_method_scname[method]] |
|
|
857 |
embed_mu = np.zeros((len(uni_cluster_labels), 2)) |
|
|
858 |
for l in uni_cluster_labels: |
|
|
859 |
embed_mu[dict_label_num[l],:] = np.mean(embed_z[cluster_labels==l], axis=0) |
|
|
860 |
|
|
|
861 |
if directed: |
|
|
862 |
graph = self.directed_backbone |
|
|
863 |
else: |
|
|
864 |
graph = self.backbone |
|
|
865 |
edges = list(graph.edges) |
|
|
866 |
edge_scores = np.array([d['weight'] for (u,v,d) in graph.edges(data=True)]) |
|
|
867 |
if max(edge_scores) - min(edge_scores) == 0: |
|
|
868 |
edge_scores = edge_scores/max(edge_scores) |
|
|
869 |
else: |
|
|
870 |
edge_scores = (edge_scores - min(edge_scores))/(max(edge_scores) - min(edge_scores))*3 |
|
|
871 |
|
|
|
872 |
value_range = np.maximum(np.diff(ax.get_xlim())[0], np.diff(ax.get_ylim())[0]) |
|
|
873 |
y_range = np.min(embed_z[:,1]), np.max(embed_z[:,1], axis=0) |
|
|
874 |
for i in range(len(edges)): |
|
|
875 |
points = embed_z[np.sum(self.cell_position_projected[:, edges[i]]>0, axis=-1)==2,:] |
|
|
876 |
points = points[points[:,0].argsort()] |
|
|
877 |
try: |
|
|
878 |
x_smooth, y_smooth = _get_smooth_curve( |
|
|
879 |
points, |
|
|
880 |
embed_mu[edges[i], :], |
|
|
881 |
y_range |
|
|
882 |
) |
|
|
883 |
except: |
|
|
884 |
x_smooth, y_smooth = embed_mu[edges[i], 0], embed_mu[edges[i], 1] |
|
|
885 |
ax.plot(x_smooth, y_smooth, |
|
|
886 |
'-', |
|
|
887 |
linewidth= 1 + edge_scores[i], |
|
|
888 |
color="black", |
|
|
889 |
alpha=0.8, |
|
|
890 |
path_effects=[pe.Stroke(linewidth=1+edge_scores[i]+1.5, |
|
|
891 |
foreground='white'), pe.Normal()], |
|
|
892 |
zorder=1 |
|
|
893 |
) |
|
|
894 |
|
|
|
895 |
if directed: |
|
|
896 |
delta_x = embed_mu[edges[i][1], 0] - x_smooth[-2] |
|
|
897 |
delta_y = embed_mu[edges[i][1], 1] - y_smooth[-2] |
|
|
898 |
length = np.sqrt(delta_x**2 + delta_y**2) / 50 * value_range |
|
|
899 |
ax.arrow( |
|
|
900 |
embed_mu[edges[i][1], 0]-delta_x/length, |
|
|
901 |
embed_mu[edges[i][1], 1]-delta_y/length, |
|
|
902 |
delta_x/length, |
|
|
903 |
delta_y/length, |
|
|
904 |
color='black', alpha=1.0, |
|
|
905 |
shape='full', lw=0, length_includes_head=True, |
|
|
906 |
head_width=np.maximum(0.01*(1 + edge_scores[i]), 0.03) * value_range, |
|
|
907 |
zorder=2) |
|
|
908 |
|
|
|
909 |
colors = self._adata.uns['vitae_new_clustering_colors'] |
|
|
910 |
|
|
|
911 |
for i,l in enumerate(uni_cluster_labels): |
|
|
912 |
ax.scatter(*embed_mu[dict_label_num[l]:dict_label_num[l]+1,:].T, |
|
|
913 |
c=[colors[i]], edgecolors='white', # linewidths=10, norm=norm, |
|
|
914 |
s=250, marker='*', label=l) |
|
|
915 |
|
|
|
916 |
plt.setp(ax, xticks=[], yticks=[]) |
|
|
917 |
box = ax.get_position() |
|
|
918 |
ax.set_position([box.x0, box.y0 + box.height * 0.1, |
|
|
919 |
box.width, box.height * 0.9]) |
|
|
920 |
if directed: |
|
|
921 |
ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), |
|
|
922 |
fancybox=True, shadow=True, ncol=5) |
|
|
923 |
|
|
|
924 |
return ax |
|
|
925 |
|
|
|
926 |
|
|
|
927 |
def plot_center(self, color = "vitae_new_clustering", plot_legend = True, legend_add_index = True, |
|
|
928 |
method: str = 'UMAP',ncol = 2,font_size = "medium", |
|
|
929 |
add_egde = False, add_direct = False,**kwargs): |
|
|
930 |
'''Plot the center of each cluster in the latent space. |
|
|
931 |
|
|
|
932 |
Parameters |
|
|
933 |
---------- |
|
|
934 |
color : str, optional |
|
|
935 |
The color of the center of each cluster. Default is "vitae_new_clustering". |
|
|
936 |
plot_legend : bool, optional |
|
|
937 |
Whether to plot the legend. Default is True. |
|
|
938 |
legend_add_index : bool, optional |
|
|
939 |
Whether to add the index of each cluster in the legend. Default is True. |
|
|
940 |
method : str, optional |
|
|
941 |
The dimension reduction method used for visualization. Default is 'UMAP'. |
|
|
942 |
ncol : int, optional |
|
|
943 |
The number of columns in the legend. Default is 2. |
|
|
944 |
font_size : str, optional |
|
|
945 |
The font size of the legend. Default is "medium". |
|
|
946 |
add_egde : bool, optional |
|
|
947 |
Whether to add the edges between the centers of clusters. Default is False. |
|
|
948 |
add_direct : bool, optional |
|
|
949 |
Whether to add the direction of the edges. Default is False. |
|
|
950 |
''' |
|
|
951 |
if color not in ["vitae_new_clustering","vitae_init_clustering"]: |
|
|
952 |
raise ValueError("Can only plot center of vitae_new_clustering or vitae_init_clustering") |
|
|
953 |
dict_label_num = {j: i for i, j in self.labels_map['label_names'].to_dict().items()} |
|
|
954 |
if legend_add_index: |
|
|
955 |
self._adata.obs["index_"+color] = self._adata.obs[color].map(lambda x: dict_label_num[x]) |
|
|
956 |
ax = self.visualize_latent(method=method, color="index_" + color, show=False, legend_loc="on data", |
|
|
957 |
legend_fontsize=font_size,**kwargs) |
|
|
958 |
colors = self._adata.uns["index_" + color + '_colors'] |
|
|
959 |
else: |
|
|
960 |
ax = self.visualize_latent(method=method, color = color, show=False,**kwargs) |
|
|
961 |
colors = self._adata.uns[color + '_colors'] |
|
|
962 |
uni_cluster_labels = self.adata.obs[color].cat.categories |
|
|
963 |
cluster_labels = self.adata.obs[color].to_numpy() |
|
|
964 |
embed_z = self._adata.obsm[self.dict_method_scname[method]] |
|
|
965 |
embed_mu = np.zeros((len(uni_cluster_labels), 2)) |
|
|
966 |
for l in uni_cluster_labels: |
|
|
967 |
embed_mu[dict_label_num[l], :] = np.mean(embed_z[cluster_labels == l], axis=0) |
|
|
968 |
|
|
|
969 |
leg = (self.labels_map.index.astype(str) + " : " + self.labels_map.label_names).values |
|
|
970 |
for i, l in enumerate(uni_cluster_labels): |
|
|
971 |
ax.scatter(*embed_mu[dict_label_num[l]:dict_label_num[l] + 1, :].T, |
|
|
972 |
c=[colors[i]], edgecolors='white', # linewidths=3, |
|
|
973 |
s=250, marker='*', label=leg[i]) |
|
|
974 |
if plot_legend: |
|
|
975 |
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=ncol, markerscale=0.8, frameon=False) |
|
|
976 |
plt.setp(ax, xticks=[], yticks=[]) |
|
|
977 |
box = ax.get_position() |
|
|
978 |
ax.set_position([box.x0, box.y0 + box.height * 0.1, |
|
|
979 |
box.width, box.height * 0.9]) |
|
|
980 |
if add_egde: |
|
|
981 |
if add_direct: |
|
|
982 |
graph = self.directed_backbone |
|
|
983 |
else: |
|
|
984 |
graph = self.backbone |
|
|
985 |
edges = list(graph.edges) |
|
|
986 |
edge_scores = np.array([d['weight'] for (u, v, d) in graph.edges(data=True)]) |
|
|
987 |
if max(edge_scores) - min(edge_scores) == 0: |
|
|
988 |
edge_scores = edge_scores / max(edge_scores) |
|
|
989 |
else: |
|
|
990 |
edge_scores = (edge_scores - min(edge_scores)) / (max(edge_scores) - min(edge_scores)) * 3 |
|
|
991 |
|
|
|
992 |
value_range = np.maximum(np.diff(ax.get_xlim())[0], np.diff(ax.get_ylim())[0]) |
|
|
993 |
y_range = np.min(embed_z[:, 1]), np.max(embed_z[:, 1], axis=0) |
|
|
994 |
for i in range(len(edges)): |
|
|
995 |
points = embed_z[np.sum(self.cell_position_projected[:, edges[i]] > 0, axis=-1) == 2, :] |
|
|
996 |
points = points[points[:, 0].argsort()] |
|
|
997 |
try: |
|
|
998 |
x_smooth, y_smooth = _get_smooth_curve( |
|
|
999 |
points, |
|
|
1000 |
embed_mu[edges[i], :], |
|
|
1001 |
y_range |
|
|
1002 |
) |
|
|
1003 |
except: |
|
|
1004 |
x_smooth, y_smooth = embed_mu[edges[i], 0], embed_mu[edges[i], 1] |
|
|
1005 |
ax.plot(x_smooth, y_smooth, |
|
|
1006 |
'-', |
|
|
1007 |
linewidth=1 + edge_scores[i], |
|
|
1008 |
color="black", |
|
|
1009 |
alpha=0.8, |
|
|
1010 |
path_effects=[pe.Stroke(linewidth=1 + edge_scores[i] + 1.5, |
|
|
1011 |
foreground='white'), pe.Normal()], |
|
|
1012 |
zorder=1 |
|
|
1013 |
) |
|
|
1014 |
|
|
|
1015 |
if add_direct: |
|
|
1016 |
delta_x = embed_mu[edges[i][1], 0] - x_smooth[-2] |
|
|
1017 |
delta_y = embed_mu[edges[i][1], 1] - y_smooth[-2] |
|
|
1018 |
length = np.sqrt(delta_x ** 2 + delta_y ** 2) / 50 * value_range |
|
|
1019 |
ax.arrow( |
|
|
1020 |
embed_mu[edges[i][1], 0] - delta_x / length, |
|
|
1021 |
embed_mu[edges[i][1], 1] - delta_y / length, |
|
|
1022 |
delta_x / length, |
|
|
1023 |
delta_y / length, |
|
|
1024 |
color='black', alpha=1.0, |
|
|
1025 |
shape='full', lw=0, length_includes_head=True, |
|
|
1026 |
head_width=np.maximum(0.01 * (1 + edge_scores[i]), 0.03) * value_range, |
|
|
1027 |
zorder=2) |
|
|
1028 |
self.ax = ax |
|
|
1029 |
self.ax.figure.show() |
|
|
1030 |
return None |
|
|
1031 |
|
|
|
1032 |
|
|
|
1033 |
def infer_trajectory(self, root: Union[int,str], digraph = None, color = "pseudotime", |
|
|
1034 |
visualize: bool = True, path_to_fig = None, **kwargs): |
|
|
1035 |
'''Infer the trajectory. |
|
|
1036 |
|
|
|
1037 |
Parameters |
|
|
1038 |
---------- |
|
|
1039 |
root : int or string |
|
|
1040 |
The root of the inferred trajectory. Can provide either an int (vertex index) or string (label name) |
|
|
1041 |
digraph : nx.DiGraph, optional |
|
|
1042 |
The directed graph to be used for trajectory inference. If None, the minimum spanning tree of the estimated trajectory backbone will be used. |
|
|
1043 |
cutoff : string, optional |
|
|
1044 |
The threshold for filtering edges with scores less than cutoff. |
|
|
1045 |
visualize: boolean |
|
|
1046 |
Whether plot the current trajectory backbone (directed graph) |
|
|
1047 |
path_to_fig : string, optional |
|
|
1048 |
The path to save figure, or don't save if it is None. |
|
|
1049 |
**kwargs : dict, optional |
|
|
1050 |
Other keywords arguments for plotting. |
|
|
1051 |
''' |
|
|
1052 |
if isinstance(root,str): |
|
|
1053 |
if root not in self.labels_map.values: |
|
|
1054 |
raise ValueError("Root {} is not in the label names!".format(root)) |
|
|
1055 |
root = self.labels_map[self.labels_map['label_names']==root].index[0] |
|
|
1056 |
|
|
|
1057 |
if digraph is None: |
|
|
1058 |
connected_comps = nx.node_connected_component(self.backbone, root) |
|
|
1059 |
subG = self.backbone.subgraph(connected_comps) |
|
|
1060 |
|
|
|
1061 |
## generate directed backbone which contains no loops |
|
|
1062 |
DG = nx.DiGraph(nx.to_directed(self.backbone)) |
|
|
1063 |
temp = DG.subgraph(connected_comps) |
|
|
1064 |
DG.remove_edges_from(temp.edges - nx.dfs_edges(DG, root)) |
|
|
1065 |
self.directed_backbone = DG |
|
|
1066 |
else: |
|
|
1067 |
if not nx.is_directed_acyclic_graph(digraph): |
|
|
1068 |
raise ValueError("The graph 'digraph' should be a directed acyclic graph.") |
|
|
1069 |
if set(digraph.nodes) != set(self.backbone.nodes): |
|
|
1070 |
raise ValueError("The nodes in 'digraph' do not match the nodes in 'self.backbone'.") |
|
|
1071 |
self.directed_backbone = digraph |
|
|
1072 |
|
|
|
1073 |
connected_comps = nx.node_connected_component(digraph, root) |
|
|
1074 |
subG = self.backbone.subgraph(connected_comps) |
|
|
1075 |
|
|
|
1076 |
|
|
|
1077 |
if len(subG.edges)>0: |
|
|
1078 |
milestone_net = self.inferer.build_milestone_net(subG, root) |
|
|
1079 |
if self.inferer.no_loop is False and milestone_net.shape[0]<len(self.backbone.edges): |
|
|
1080 |
warnings.warn("The directed graph shown is a minimum spanning tree of the estimated trajectory backbone to avoid arbitrary assignment of the directions.") |
|
|
1081 |
self.pseudotime = self.inferer.comp_pseudotime(milestone_net, root, self.cell_position_projected) |
|
|
1082 |
else: |
|
|
1083 |
warnings.warn("There are no connected states for starting from the giving root.") |
|
|
1084 |
self.pseudotime = -np.ones(self._adata.shape[0]) |
|
|
1085 |
|
|
|
1086 |
self.adata.obs['pseudotime'] = self.pseudotime |
|
|
1087 |
print("Cell projection uncertainties stored as 'pseudotime' in self.adata.obs") |
|
|
1088 |
|
|
|
1089 |
if visualize: |
|
|
1090 |
self._adata.obs['pseudotime'] = self.pseudotime |
|
|
1091 |
self.ax = self.plot_backbone(directed = True, color = color, **kwargs) |
|
|
1092 |
if path_to_fig is not None: |
|
|
1093 |
self.ax.figure.savefig(path_to_fig) |
|
|
1094 |
self.ax.figure.show() |
|
|
1095 |
|
|
|
1096 |
return None |
|
|
1097 |
|
|
|
1098 |
|
|
|
1099 |
|
|
|
1100 |
def differential_expression_test(self, alpha: float = 0.05, cell_subset = None, order: int = 1): |
|
|
1101 |
'''Differentially gene expression test. All (selected and unselected) genes will be tested |
|
|
1102 |
Only cells in `selected_cell_subset` will be used, which is useful when one need to |
|
|
1103 |
test differentially expressed genes on a branch of the inferred trajectory. |
|
|
1104 |
|
|
|
1105 |
Parameters |
|
|
1106 |
---------- |
|
|
1107 |
alpha : float, optional |
|
|
1108 |
The cutoff of p-values. |
|
|
1109 |
cell_subset : np.array, optional |
|
|
1110 |
The subset of cells to be used for testing. If None, all cells will be used. |
|
|
1111 |
order : int, optional |
|
|
1112 |
The maxium order we used for pseudotime in regression. |
|
|
1113 |
|
|
|
1114 |
Returns |
|
|
1115 |
---------- |
|
|
1116 |
res_df : pandas.DataFrame |
|
|
1117 |
The test results of expressed genes with two columns, |
|
|
1118 |
the estimated coefficients and the adjusted p-values. |
|
|
1119 |
''' |
|
|
1120 |
if not hasattr(self, 'pseudotime'): |
|
|
1121 |
raise ReferenceError("Pseudotime does not exist! Please run 'infer_trajectory' first.") |
|
|
1122 |
if cell_subset is None: |
|
|
1123 |
cell_subset = np.arange(self.X_input.shape[0]) |
|
|
1124 |
print("All cells are selected.") |
|
|
1125 |
if order < 1: |
|
|
1126 |
raise ValueError("Maximal order of pseudotime in regression must be at least 1.") |
|
|
1127 |
|
|
|
1128 |
# Prepare X and Y for regression expression ~ rank(PDT) + covariates |
|
|
1129 |
Y = self.adata.X[cell_subset,:] |
|
|
1130 |
# std_Y = np.std(Y, ddof=1, axis=0, keepdims=True) |
|
|
1131 |
# Y = np.divide(Y-np.mean(Y, axis=0, keepdims=True), std_Y, out=np.empty_like(Y)*np.nan, where=std_Y!=0) |
|
|
1132 |
X = stats.rankdata(self.pseudotime[cell_subset]) |
|
|
1133 |
if order > 1: |
|
|
1134 |
for _order in range(2, order+1): |
|
|
1135 |
X = np.c_[X, X**_order] |
|
|
1136 |
X = ((X-np.mean(X,axis=0, keepdims=True))/np.std(X, ddof=1, axis=0, keepdims=True)) |
|
|
1137 |
X = np.c_[np.ones((X.shape[0],1)), X] |
|
|
1138 |
if self.covariates is not None: |
|
|
1139 |
X = np.c_[X, self.covariates[cell_subset, :]] |
|
|
1140 |
|
|
|
1141 |
res_df = DE_test(Y, X, self.adata.var_names, i_test = np.array(list(range(1,order+1))), alpha = alpha) |
|
|
1142 |
return res_df[res_df.pvalue_adjusted_1 != 0] |
|
|
1143 |
|
|
|
1144 |
|
|
|
1145 |
|
|
|
1146 |
|
|
|
1147 |
def evaluate(self, milestone_net, begin_node_true, grouping = None, |
|
|
1148 |
thres: float = 0.5, no_loop: bool = True, cutoff: Optional[float] = None, |
|
|
1149 |
method: str = 'mean', path: Optional[str] = None): |
|
|
1150 |
''' Evaluate the model. |
|
|
1151 |
|
|
|
1152 |
Parameters |
|
|
1153 |
---------- |
|
|
1154 |
milestone_net : pd.DataFrame |
|
|
1155 |
The true milestone network. For real data, milestone_net will be a DataFrame of the graph of nodes. |
|
|
1156 |
Eg. |
|
|
1157 |
|
|
|
1158 |
from|to |
|
|
1159 |
---|--- |
|
|
1160 |
cluster 1 | cluster 1 |
|
|
1161 |
cluster 1 | cluster 2 |
|
|
1162 |
|
|
|
1163 |
For synthetic data, milestone_net will be a DataFrame of the (projected) |
|
|
1164 |
positions of cells. The indexes are the orders of cells in the dataset. |
|
|
1165 |
Eg. |
|
|
1166 |
|
|
|
1167 |
from|to|w |
|
|
1168 |
---|---|--- |
|
|
1169 |
cluster 1 | cluster 1 | 1 |
|
|
1170 |
cluster 1 | cluster 2 | 0.1 |
|
|
1171 |
begin_node_true : str or int |
|
|
1172 |
The true begin node of the milestone. |
|
|
1173 |
grouping : np.array, optional |
|
|
1174 |
\([N,]\) The labels. For real data, grouping must be provided. |
|
|
1175 |
|
|
|
1176 |
Returns |
|
|
1177 |
---------- |
|
|
1178 |
res : pd.DataFrame |
|
|
1179 |
The evaluation result. |
|
|
1180 |
''' |
|
|
1181 |
if not hasattr(self, 'labels_map'): |
|
|
1182 |
raise ValueError("No given labels for training.") |
|
|
1183 |
|
|
|
1184 |
''' |
|
|
1185 |
# Evaluate for the whole dataset will ignore selected_cell_subset. |
|
|
1186 |
if len(self.selected_cell_subset)!=len(self.cell_names): |
|
|
1187 |
warnings.warn("Evaluate for the whole dataset.") |
|
|
1188 |
''' |
|
|
1189 |
|
|
|
1190 |
# If the begin_node_true, need to encode it by self.le. |
|
|
1191 |
# this dict is for milestone net cause their labels are not merged |
|
|
1192 |
# all keys of label_map_dict are str |
|
|
1193 |
label_map_dict = dict() |
|
|
1194 |
for i in range(self.labels_map.shape[0]): |
|
|
1195 |
label_mapped = self.labels_map.loc[i] |
|
|
1196 |
## merged cluster index is connected by comma |
|
|
1197 |
for each in label_mapped.values[0].split(","): |
|
|
1198 |
label_map_dict[each] = i |
|
|
1199 |
if isinstance(begin_node_true, str): |
|
|
1200 |
begin_node_true = label_map_dict[begin_node_true] |
|
|
1201 |
|
|
|
1202 |
# For generated data, grouping information is already in milestone_net |
|
|
1203 |
if 'w' in milestone_net.columns: |
|
|
1204 |
grouping = None |
|
|
1205 |
|
|
|
1206 |
# If milestone_net is provided, transform them to be numeric. |
|
|
1207 |
if milestone_net is not None: |
|
|
1208 |
milestone_net['from'] = [label_map_dict[x] for x in milestone_net["from"]] |
|
|
1209 |
milestone_net['to'] = [label_map_dict[x] for x in milestone_net["to"]] |
|
|
1210 |
|
|
|
1211 |
# this dict is for potentially merged clusters. |
|
|
1212 |
label_map_dict_for_merged_cluster = dict(zip(self.labels_map["label_names"],self.labels_map.index)) |
|
|
1213 |
mapped_labels = np.array([label_map_dict_for_merged_cluster[x] for x in self.labels]) |
|
|
1214 |
begin_node_pred = int(np.argmin(np.mean(( |
|
|
1215 |
self.z[mapped_labels==begin_node_true,:,np.newaxis] - |
|
|
1216 |
self.mu[np.newaxis,:,:])**2, axis=(0,1)))) |
|
|
1217 |
|
|
|
1218 |
if cutoff is None: |
|
|
1219 |
cutoff = 0.01 |
|
|
1220 |
|
|
|
1221 |
G = self.backbone |
|
|
1222 |
w = self.cell_position_projected |
|
|
1223 |
pseudotime = self.pseudotime |
|
|
1224 |
|
|
|
1225 |
# 1. Topology |
|
|
1226 |
G_pred = nx.Graph() |
|
|
1227 |
G_pred.add_nodes_from(G.nodes) |
|
|
1228 |
G_pred.add_edges_from(G.edges) |
|
|
1229 |
nx.set_node_attributes(G_pred, False, 'is_init') |
|
|
1230 |
G_pred.nodes[begin_node_pred]['is_init'] = True |
|
|
1231 |
|
|
|
1232 |
G_true = nx.Graph() |
|
|
1233 |
G_true.add_nodes_from(G.nodes) |
|
|
1234 |
# if 'grouping' is not provided, assume 'milestone_net' contains proportions |
|
|
1235 |
if grouping is None: |
|
|
1236 |
G_true.add_edges_from(list( |
|
|
1237 |
milestone_net[~pd.isna(milestone_net['w'])].groupby(['from', 'to']).count().index)) |
|
|
1238 |
# otherwise, 'milestone_net' indicates edges |
|
|
1239 |
else: |
|
|
1240 |
if milestone_net is not None: |
|
|
1241 |
G_true.add_edges_from(list( |
|
|
1242 |
milestone_net.groupby(['from', 'to']).count().index)) |
|
|
1243 |
grouping = [label_map_dict[x] for x in grouping] |
|
|
1244 |
grouping = np.array(grouping) |
|
|
1245 |
G_true.remove_edges_from(nx.selfloop_edges(G_true)) |
|
|
1246 |
nx.set_node_attributes(G_true, False, 'is_init') |
|
|
1247 |
G_true.nodes[begin_node_true]['is_init'] = True |
|
|
1248 |
res = topology(G_true, G_pred) |
|
|
1249 |
|
|
|
1250 |
# 2. Milestones assignment |
|
|
1251 |
if grouping is None: |
|
|
1252 |
milestones_true = milestone_net['from'].values.copy() |
|
|
1253 |
milestones_true[(milestone_net['from']!=milestone_net['to']) |
|
|
1254 |
&(milestone_net['w']<0.5)] = milestone_net[(milestone_net['from']!=milestone_net['to']) |
|
|
1255 |
&(milestone_net['w']<0.5)]['to'].values |
|
|
1256 |
else: |
|
|
1257 |
milestones_true = grouping |
|
|
1258 |
milestones_true = milestones_true |
|
|
1259 |
milestones_pred = np.argmax(w, axis=1) |
|
|
1260 |
res['ARI'] = (adjusted_rand_score(milestones_true, milestones_pred) + 1)/2 |
|
|
1261 |
|
|
|
1262 |
if grouping is None: |
|
|
1263 |
n_samples = len(milestone_net) |
|
|
1264 |
prop = np.zeros((n_samples,n_samples)) |
|
|
1265 |
prop[np.arange(n_samples), milestone_net['to']] = 1-milestone_net['w'] |
|
|
1266 |
prop[np.arange(n_samples), milestone_net['from']] = np.where(np.isnan(milestone_net['w']), 1, milestone_net['w']) |
|
|
1267 |
res['GRI'] = get_GRI(prop, w) |
|
|
1268 |
else: |
|
|
1269 |
res['GRI'] = get_GRI(grouping, w) |
|
|
1270 |
|
|
|
1271 |
# 3. Correlation between geodesic distances / Pseudotime |
|
|
1272 |
if no_loop: |
|
|
1273 |
if grouping is None: |
|
|
1274 |
pseudotime_true = milestone_net['from'].values + 1 - milestone_net['w'].values |
|
|
1275 |
pseudotime_true[np.isnan(pseudotime_true)] = milestone_net[pd.isna(milestone_net['w'])]['from'].values |
|
|
1276 |
else: |
|
|
1277 |
pseudotime_true = - np.ones(len(grouping)) |
|
|
1278 |
nx.set_edge_attributes(G_true, values = 1, name = 'weight') |
|
|
1279 |
connected_comps = nx.node_connected_component(G_true, begin_node_true) |
|
|
1280 |
subG = G_true.subgraph(connected_comps) |
|
|
1281 |
milestone_net_true = self.inferer.build_milestone_net(subG, begin_node_true) |
|
|
1282 |
if len(milestone_net_true)>0: |
|
|
1283 |
pseudotime_true[grouping==int(milestone_net_true[0,0])] = 0 |
|
|
1284 |
for i in range(len(milestone_net_true)): |
|
|
1285 |
pseudotime_true[grouping==int(milestone_net_true[i,1])] = milestone_net_true[i,-1] |
|
|
1286 |
pseudotime_true = pseudotime_true[pseudotime>-1] |
|
|
1287 |
pseudotime_pred = pseudotime[pseudotime>-1] |
|
|
1288 |
res['PDT score'] = (np.corrcoef(pseudotime_true,pseudotime_pred)[0,1]+1)/2 |
|
|
1289 |
else: |
|
|
1290 |
res['PDT score'] = np.nan |
|
|
1291 |
|
|
|
1292 |
# 4. Shape |
|
|
1293 |
# score_cos_theta = 0 |
|
|
1294 |
# for (_from,_to) in G.edges: |
|
|
1295 |
# _z = self.z[(w[:,_from]>0) & (w[:,_to]>0),:] |
|
|
1296 |
# v_1 = _z - self.mu[:,_from] |
|
|
1297 |
# v_2 = _z - self.mu[:,_to] |
|
|
1298 |
# cos_theta = np.sum(v_1*v_2, -1)/(np.linalg.norm(v_1,axis=-1)*np.linalg.norm(v_2,axis=-1)+1e-12) |
|
|
1299 |
|
|
|
1300 |
# score_cos_theta += np.sum((1-cos_theta)/2) |
|
|
1301 |
|
|
|
1302 |
# res['score_cos_theta'] = score_cos_theta/(np.sum(np.sum(w>0, axis=-1)==2)+1e-12) |
|
|
1303 |
return res |
|
|
1304 |
|
|
|
1305 |
|
|
|
1306 |
def save_model(self, path_to_file: str = 'model.checkpoint',save_adata: bool = False): |
|
|
1307 |
'''Saving model weights. |
|
|
1308 |
|
|
|
1309 |
Parameters |
|
|
1310 |
---------- |
|
|
1311 |
path_to_file : str, optional |
|
|
1312 |
The path to weight files of pre-trained or trained model |
|
|
1313 |
save_adata : boolean, optional |
|
|
1314 |
Whether to save adata or not. |
|
|
1315 |
''' |
|
|
1316 |
self.vae.save_weights(path_to_file) |
|
|
1317 |
if hasattr(self, 'labels') and self.labels is not None: |
|
|
1318 |
with open(path_to_file + '.label', 'wb') as f: |
|
|
1319 |
np.save(f, self.labels) |
|
|
1320 |
with open(path_to_file + '.config', 'wb') as f: |
|
|
1321 |
self.dim_origin = self.X_input.shape[1] |
|
|
1322 |
np.save(f, np.array([ |
|
|
1323 |
self.dim_origin, self.dimensions, self.dim_latent, |
|
|
1324 |
self.model_type, 0 if self.covariates is None else self.covariates.shape[1]], dtype=object)) |
|
|
1325 |
if hasattr(self, 'inferer') and hasattr(self, 'uncertainty'): |
|
|
1326 |
with open(path_to_file + '.inference', 'wb') as f: |
|
|
1327 |
np.save(f, np.array([ |
|
|
1328 |
self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty, |
|
|
1329 |
self.z,self.cell_position_variance], dtype=object)) |
|
|
1330 |
if save_adata: |
|
|
1331 |
self.adata.write(path_to_file + '.adata.h5ad') |
|
|
1332 |
|
|
|
1333 |
|
|
|
1334 |
def load_model(self, path_to_file: str = 'model.checkpoint', load_labels: bool = False, load_adata: bool = False): |
|
|
1335 |
'''Load model weights. |
|
|
1336 |
|
|
|
1337 |
Parameters |
|
|
1338 |
---------- |
|
|
1339 |
path_to_file : str, optional |
|
|
1340 |
The path to weight files of pre trained or trained model |
|
|
1341 |
load_labels : boolean, optional |
|
|
1342 |
Whether to load clustering labels or not. |
|
|
1343 |
If load_labels is True, then the LatentSpace layer will be initialized basd on the model. |
|
|
1344 |
If load_labels is False, then the LatentSpace layer will not be initialized. |
|
|
1345 |
load_adata : boolean, optional |
|
|
1346 |
Whether to load adata or not. |
|
|
1347 |
''' |
|
|
1348 |
if not os.path.exists(path_to_file + '.config'): |
|
|
1349 |
raise AssertionError('Config file not exist!') |
|
|
1350 |
if load_labels and not os.path.exists(path_to_file + '.label'): |
|
|
1351 |
raise AssertionError('Label file not exist!') |
|
|
1352 |
|
|
|
1353 |
with open(path_to_file + '.config', 'rb') as f: |
|
|
1354 |
[self.dim_origin, self.dimensions, |
|
|
1355 |
self.dim_latent, self.model_type, cov_dim] = np.load(f, allow_pickle=True) |
|
|
1356 |
self.vae = model.VariationalAutoEncoder( |
|
|
1357 |
self.dim_origin, self.dimensions, |
|
|
1358 |
self.dim_latent, self.model_type, False if cov_dim == 0 else True |
|
|
1359 |
) |
|
|
1360 |
|
|
|
1361 |
if load_labels: |
|
|
1362 |
with open(path_to_file + '.label', 'rb') as f: |
|
|
1363 |
cluster_labels = np.load(f, allow_pickle=True) |
|
|
1364 |
self.init_latent_space(cluster_labels, dist_thres=0) |
|
|
1365 |
if os.path.exists(path_to_file + '.inference'): |
|
|
1366 |
with open(path_to_file + '.inference', 'rb') as f: |
|
|
1367 |
arr = np.load(f, allow_pickle=True) |
|
|
1368 |
if len(arr) == 8: |
|
|
1369 |
[self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty, |
|
|
1370 |
self.D_JS, self.z,self.cell_position_variance] = arr |
|
|
1371 |
else: |
|
|
1372 |
[self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty, |
|
|
1373 |
self.z,self.cell_position_variance] = arr |
|
|
1374 |
self._adata_z = sc.AnnData(self.z) |
|
|
1375 |
sc.pp.neighbors(self._adata_z) |
|
|
1376 |
## initialize the weight of encoder and decoder |
|
|
1377 |
self.vae.encoder(np.zeros((1, self.dim_origin + cov_dim))) |
|
|
1378 |
self.vae.decoder(np.expand_dims(np.zeros((1,self.dim_latent + cov_dim)),1)) |
|
|
1379 |
|
|
|
1380 |
self.vae.load_weights(path_to_file) |
|
|
1381 |
self.update_z() |
|
|
1382 |
if load_adata: |
|
|
1383 |
if not os.path.exists(path_to_file + '.adata.h5ad'): |
|
|
1384 |
raise AssertionError('AnnData file not exist!') |
|
|
1385 |
self.adata = sc.read_h5ad(path_to_file + '.adata.h5ad') |
|
|
1386 |
self._adata.obs = self.adata.obs.copy()</code></pre> |
|
|
1387 |
</details> |
|
|
1388 |
<h3>Methods</h3> |
|
|
1389 |
<dl> |
|
|
1390 |
<dt id="VITAE.VITAE.pre_train"><code class="name flex"> |
|
|
1391 |
<span>def <span class="ident">pre_train</span></span>(<span>self, test_size=0.1, random_state: int = 0, learning_rate: float = 0.001, batch_size: int = 256, L: int = 1, alpha: float = 0.1, gamma: float = 0, phi: float = 1, num_epoch: int = 200, num_step_per_epoch: Optional[int] = None, early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, early_stopping_relative: bool = True, verbose: bool = False, path_to_weights: Optional[str] = None)</span> |
|
|
1392 |
</code></dt> |
|
|
1393 |
<dd> |
|
|
1394 |
<div class="desc"><p>Pretrain the model with specified learning rate.</p> |
|
|
1395 |
<h2 id="parameters">Parameters</h2> |
|
|
1396 |
<dl> |
|
|
1397 |
<dt><strong><code>test_size</code></strong> : <code>float</code> or <code>int</code>, optional</dt> |
|
|
1398 |
<dd>The proportion or size of the test set.</dd> |
|
|
1399 |
<dt><strong><code>random_state</code></strong> : <code>int</code>, optional</dt> |
|
|
1400 |
<dd>The random state for data splitting.</dd> |
|
|
1401 |
<dt><strong><code>learning_rate</code></strong> : <code>float</code>, optional</dt> |
|
|
1402 |
<dd>The initial learning rate for the Adam optimizer.</dd> |
|
|
1403 |
<dt><strong><code>batch_size</code></strong> : <code>int</code>, optional</dt> |
|
|
1404 |
<dd>The batch size for pre-training. |
|
|
1405 |
Default is 256. Set to 32 if number of cells is small (less than 1000)</dd> |
|
|
1406 |
<dt><strong><code>L</code></strong> : <code>int</code>, optional</dt> |
|
|
1407 |
<dd>The number of MC samples.</dd> |
|
|
1408 |
<dt><strong><code>alpha</code></strong> : <code>float</code>, optional</dt> |
|
|
1409 |
<dd>The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.</dd> |
|
|
1410 |
<dt><strong><code>gamma</code></strong> : <code>float</code>, optional</dt> |
|
|
1411 |
<dd>The weight of the mmd loss if used.</dd> |
|
|
1412 |
<dt><strong><code>phi</code></strong> : <code>float</code>, optional</dt> |
|
|
1413 |
<dd>The weight of Jocob norm of the encoder.</dd> |
|
|
1414 |
<dt><strong><code>num_epoch</code></strong> : <code>int</code>, optional</dt> |
|
|
1415 |
<dd>The maximum number of epochs.</dd> |
|
|
1416 |
<dt><strong><code>num_step_per_epoch</code></strong> : <code>int</code>, optional</dt> |
|
|
1417 |
<dd>The number of step per epoch, it will be inferred from number of cells and batch size if it is None.</dd> |
|
|
1418 |
<dt><strong><code>early_stopping_patience</code></strong> : <code>int</code>, optional</dt> |
|
|
1419 |
<dd>The maximum number of epochs if there is no improvement.</dd> |
|
|
1420 |
<dt><strong><code>early_stopping_tolerance</code></strong> : <code>float</code>, optional</dt> |
|
|
1421 |
<dd>The minimum change of loss to be considered as an improvement.</dd> |
|
|
1422 |
<dt><strong><code>early_stopping_relative</code></strong> : <code>bool</code>, optional</dt> |
|
|
1423 |
<dd>Whether monitor the relative change of loss as stopping criteria or not.</dd> |
|
|
1424 |
<dt><strong><code>path_to_weights</code></strong> : <code>str</code>, optional</dt> |
|
|
1425 |
<dd>The path of weight file to be saved; not saving weight if None.</dd> |
|
|
1426 |
<dt><strong><code>conditions</code></strong> : <code>str</code> or <code>list</code>, optional</dt> |
|
|
1427 |
<dd>The conditions of different cells</dd> |
|
|
1428 |
</dl></div> |
|
|
1429 |
</dd> |
|
|
1430 |
<dt id="VITAE.VITAE.update_z"><code class="name flex"> |
|
|
1431 |
<span>def <span class="ident">update_z</span></span>(<span>self)</span> |
|
|
1432 |
</code></dt> |
|
|
1433 |
<dd> |
|
|
1434 |
<div class="desc"></div> |
|
|
1435 |
</dd> |
|
|
1436 |
<dt id="VITAE.VITAE.get_latent_z"><code class="name flex"> |
|
|
1437 |
<span>def <span class="ident">get_latent_z</span></span>(<span>self)</span> |
|
|
1438 |
</code></dt> |
|
|
1439 |
<dd> |
|
|
1440 |
<div class="desc"><p>get the posterier mean of current latent space z (encoder output)</p> |
|
|
1441 |
<h2 id="returns">Returns</h2> |
|
|
1442 |
<dl> |
|
|
1443 |
<dt><strong><code>z</code></strong> : <code>np.array</code></dt> |
|
|
1444 |
<dd><span><span class="MathJax_Preview">[N,d]</span><script type="math/tex">[N,d]</script></span> The latent means.</dd> |
|
|
1445 |
</dl></div> |
|
|
1446 |
</dd> |
|
|
1447 |
<dt id="VITAE.VITAE.visualize_latent"><code class="name flex"> |
|
|
1448 |
<span>def <span class="ident">visualize_latent</span></span>(<span>self, method: str = 'UMAP', color=None, **kwargs)</span> |
|
|
1449 |
</code></dt> |
|
|
1450 |
<dd> |
|
|
1451 |
<div class="desc"><p>visualize the current latent space z using the scanpy visualization tools</p> |
|
|
1452 |
<h2 id="parameters">Parameters</h2> |
|
|
1453 |
<dl> |
|
|
1454 |
<dt><strong><code>method</code></strong> : <code>str</code>, optional</dt> |
|
|
1455 |
<dd>Visualization method to use. The default is "draw_graph" (the FA plot). Possible choices include "PCA", "UMAP", |
|
|
1456 |
"diffmap", "TSNE" and "draw_graph"</dd> |
|
|
1457 |
<dt><strong><code>color</code></strong> : <code>TYPE</code>, optional</dt> |
|
|
1458 |
<dd>Keys for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2']. |
|
|
1459 |
The default is None. Same as scanpy.</dd> |
|
|
1460 |
<dt><strong><code>**kwargs</code></strong> : <code> </code></dt> |
|
|
1461 |
<dd>Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX).</dd> |
|
|
1462 |
</dl> |
|
|
1463 |
<h2 id="returns">Returns</h2> |
|
|
1464 |
<p>None.</p></div> |
|
|
1465 |
</dd> |
|
|
1466 |
<dt id="VITAE.VITAE.init_latent_space"><code class="name flex"> |
|
|
1467 |
<span>def <span class="ident">init_latent_space</span></span>(<span>self, cluster_label=None, log_pi=None, res: float = 1.0, ratio_prune=None, dist=None, dist_thres=0.5, topk=0, pilayer=False)</span> |
|
|
1468 |
</code></dt> |
|
|
1469 |
<dd> |
|
|
1470 |
<div class="desc"><p>Initialize the latent space.</p> |
|
|
1471 |
<h2 id="parameters">Parameters</h2> |
|
|
1472 |
<dl> |
|
|
1473 |
<dt><strong><code>cluster_label</code></strong> : <code>str</code>, optional</dt> |
|
|
1474 |
<dd>The name of vector of labels that can be found in self.adata.obs. |
|
|
1475 |
Default is None, which will perform leiden clustering on the pretrained z to get clusters</dd> |
|
|
1476 |
<dt><strong><code>mu</code></strong> : <code>np.array</code>, optional</dt> |
|
|
1477 |
<dd><span><span class="MathJax_Preview">[d,k]</span><script type="math/tex">[d,k]</script></span> The value of initial <span><span class="MathJax_Preview">\mu</span><script type="math/tex">\mu</script></span>.</dd> |
|
|
1478 |
<dt><strong><code>log_pi</code></strong> : <code>np.array</code>, optional</dt> |
|
|
1479 |
<dd><span><span class="MathJax_Preview">[1,K]</span><script type="math/tex">[1,K]</script></span> The value of initial <span><span class="MathJax_Preview">\log(\pi)</span><script type="math/tex">\log(\pi)</script></span>.</dd> |
|
|
1480 |
<dt><strong><code>res</code></strong></dt> |
|
|
1481 |
<dd>The resolution of leiden clustering, which is a parameter value controlling the coarseness of the clustering. |
|
|
1482 |
Higher values lead to more clusters. Deafult is 1.</dd> |
|
|
1483 |
<dt><strong><code>ratio_prune</code></strong> : <code>float</code>, optional</dt> |
|
|
1484 |
<dd>The ratio of edges to be removed before estimating.</dd> |
|
|
1485 |
<dt><strong><code>topk</code></strong> : <code>int</code>, optional</dt> |
|
|
1486 |
<dd>The number of top k neighbors to keep for each cluster.</dd> |
|
|
1487 |
</dl></div> |
|
|
1488 |
</dd> |
|
|
1489 |
<dt id="VITAE.VITAE.update_latent_space"><code class="name flex"> |
|
|
1490 |
<span>def <span class="ident">update_latent_space</span></span>(<span>self, dist_thres: float = 0.5)</span> |
|
|
1491 |
</code></dt> |
|
|
1492 |
<dd> |
|
|
1493 |
<div class="desc"></div> |
|
|
1494 |
</dd> |
|
|
1495 |
<dt id="VITAE.VITAE.train"><code class="name flex"> |
|
|
1496 |
<span>def <span class="ident">train</span></span>(<span>self, stratify=False, test_size=0.1, random_state: int = 0, learning_rate: float = 0.001, batch_size: int = 256, L: int = 1, alpha: float = 0.1, beta: float = 1, gamma: float = 0, phi: float = 1, num_epoch: int = 200, num_step_per_epoch: Optional[int] = None, early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, early_stopping_relative: bool = True, early_stopping_warmup: int = 0, path_to_weights: Optional[str] = None, verbose: bool = False, **kwargs)</span> |
|
|
1497 |
</code></dt> |
|
|
1498 |
<dd> |
|
|
1499 |
<div class="desc"><p>Train the model.</p> |
|
|
1500 |
<h2 id="parameters">Parameters</h2> |
|
|
1501 |
<dl> |
|
|
1502 |
<dt><strong><code>stratify</code></strong> : <code>np.array, None,</code> or <code>False</code></dt> |
|
|
1503 |
<dd>If an array is provided, or <code>stratify=None</code> and <code>self.labels</code> is available, then they will be used to perform stratified shuffle splitting. Otherwise, general shuffle splitting is used. Set to <code>False</code> if <code>self.labels</code> is not intended for stratified shuffle splitting.</dd> |
|
|
1504 |
<dt><strong><code>test_size</code></strong> : <code>float</code> or <code>int</code>, optional</dt> |
|
|
1505 |
<dd>The proportion or size of the test set.</dd> |
|
|
1506 |
<dt><strong><code>random_state</code></strong> : <code>int</code>, optional</dt> |
|
|
1507 |
<dd>The random state for data splitting.</dd> |
|
|
1508 |
<dt><strong><code>learning_rate</code></strong> : <code>float</code>, optional</dt> |
|
|
1509 |
<dd>The initial learning rate for the Adam optimizer.</dd> |
|
|
1510 |
<dt><strong><code>batch_size</code></strong> : <code>int</code>, optional</dt> |
|
|
1511 |
<dd>The batch size for training. Default is 256. Set to 32 if number of cells is small (less than 1000)</dd> |
|
|
1512 |
<dt><strong><code>L</code></strong> : <code>int</code>, optional</dt> |
|
|
1513 |
<dd>The number of MC samples.</dd> |
|
|
1514 |
<dt><strong><code>alpha</code></strong> : <code>float</code>, optional</dt> |
|
|
1515 |
<dd>The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.</dd> |
|
|
1516 |
<dt><strong><code>beta</code></strong> : <code>float</code>, optional</dt> |
|
|
1517 |
<dd>The value of beta in beta-VAE.</dd> |
|
|
1518 |
<dt><strong><code>gamma</code></strong> : <code>float</code>, optional</dt> |
|
|
1519 |
<dd>The weight of mmd_loss.</dd> |
|
|
1520 |
<dt><strong><code>phi</code></strong> : <code>float</code>, optional</dt> |
|
|
1521 |
<dd>The weight of Jacob norm of encoder.</dd> |
|
|
1522 |
<dt><strong><code>num_epoch</code></strong> : <code>int</code>, optional</dt> |
|
|
1523 |
<dd>The number of epoch.</dd> |
|
|
1524 |
<dt><strong><code>num_step_per_epoch</code></strong> : <code>int</code>, optional</dt> |
|
|
1525 |
<dd>The number of step per epoch, it will be inferred from number of cells and batch size if it is None.</dd> |
|
|
1526 |
<dt><strong><code>early_stopping_patience</code></strong> : <code>int</code>, optional</dt> |
|
|
1527 |
<dd>The maximum number of epochs if there is no improvement.</dd> |
|
|
1528 |
<dt><strong><code>early_stopping_tolerance</code></strong> : <code>float</code>, optional</dt> |
|
|
1529 |
<dd>The minimum change of loss to be considered as an improvement.</dd> |
|
|
1530 |
<dt><strong><code>early_stopping_relative</code></strong> : <code>bool</code>, optional</dt> |
|
|
1531 |
<dd>Whether monitor the relative change of loss or not.</dd> |
|
|
1532 |
<dt><strong><code>early_stopping_warmup</code></strong> : <code>int</code>, optional</dt> |
|
|
1533 |
<dd>The number of warmup epochs.</dd> |
|
|
1534 |
<dt><strong><code>path_to_weights</code></strong> : <code>str</code>, optional</dt> |
|
|
1535 |
<dd>The path of weight file to be saved; not saving weight if None.</dd> |
|
|
1536 |
<dt><strong><code>**kwargs</code></strong> : <code> </code></dt> |
|
|
1537 |
<dd>Extra key-value arguments for dimension reduction algorithms.</dd> |
|
|
1538 |
</dl></div> |
|
|
1539 |
</dd> |
|
|
1540 |
<dt id="VITAE.VITAE.output_pi"><code class="name flex"> |
|
|
1541 |
<span>def <span class="ident">output_pi</span></span>(<span>self, pi_cov)</span> |
|
|
1542 |
</code></dt> |
|
|
1543 |
<dd> |
|
|
1544 |
<div class="desc"><p>return a matrix n_states by n_states and a mask for plotting, which can be used to cover the lower triangular(except the diagnoals) of a heatmap</p></div> |
|
|
1545 |
</dd> |
|
|
1546 |
<dt id="VITAE.VITAE.return_pilayer_weights"><code class="name flex"> |
|
|
1547 |
<span>def <span class="ident">return_pilayer_weights</span></span>(<span>self)</span> |
|
|
1548 |
</code></dt> |
|
|
1549 |
<dd> |
|
|
1550 |
<div class="desc"><p>return parameters of pilayer, which has dimension dim(pi_cov) + 1 by n_categories, the last row is biases</p></div> |
|
|
1551 |
</dd> |
|
|
1552 |
<dt id="VITAE.VITAE.posterior_estimation"><code class="name flex"> |
|
|
1553 |
<span>def <span class="ident">posterior_estimation</span></span>(<span>self, batch_size: int = 32, L: int = 50, **kwargs)</span> |
|
|
1554 |
</code></dt> |
|
|
1555 |
<dd> |
|
|
1556 |
<div class="desc"><p>Initialize trajectory inference by computing the posterior estimations. |
|
|
1557 |
</p> |
|
|
1558 |
<h2 id="parameters">Parameters</h2> |
|
|
1559 |
<dl> |
|
|
1560 |
<dt><strong><code>batch_size</code></strong> : <code>int</code>, optional</dt> |
|
|
1561 |
<dd>The batch size when doing inference.</dd> |
|
|
1562 |
<dt><strong><code>L</code></strong> : <code>int</code>, optional</dt> |
|
|
1563 |
<dd>The number of MC samples when doing inference.</dd> |
|
|
1564 |
<dt><strong><code>**kwargs</code></strong> : <code> </code></dt> |
|
|
1565 |
<dd>Extra key-value arguments for dimension reduction algorithms.</dd> |
|
|
1566 |
</dl></div> |
|
|
1567 |
</dd> |
|
|
1568 |
<dt id="VITAE.VITAE.infer_backbone"><code class="name flex"> |
|
|
1569 |
<span>def <span class="ident">infer_backbone</span></span>(<span>self, method: str = 'modified_map', thres=0.5, no_loop: bool = True, cutoff: float = 0, visualize: bool = True, color='vitae_new_clustering', path_to_fig=None, **kwargs)</span> |
|
|
1570 |
</code></dt> |
|
|
1571 |
<dd> |
|
|
1572 |
<div class="desc"><p>Compute edge scores.</p> |
|
|
1573 |
<h2 id="parameters">Parameters</h2> |
|
|
1574 |
<dl> |
|
|
1575 |
<dt><strong><code>method</code></strong> : <code>string</code>, optional</dt> |
|
|
1576 |
<dd>'mean', 'modified_mean', 'map', or 'modified_map'.</dd> |
|
|
1577 |
<dt><strong><code>thres</code></strong> : <code>float</code>, optional</dt> |
|
|
1578 |
<dd>The threshold used for filtering edges <span><span class="MathJax_Preview">e_{ij}</span><script type="math/tex">e_{ij}</script></span> that <span><span class="MathJax_Preview">(n_{i}+n_{j}+e_{ij})/N<thres</span><script type="math/tex">(n_{i}+n_{j}+e_{ij})/N<thres</script></span>, only applied to mean method.</dd> |
|
|
1579 |
<dt><strong><code>no_loop</code></strong> : <code>boolean</code>, optional</dt> |
|
|
1580 |
<dd>Whether loops are allowed to exist in the graph. If no_loop is true, will prune the graph to contain only the |
|
|
1581 |
maximum spanning true</dd> |
|
|
1582 |
<dt><strong><code>cutoff</code></strong> : <code>string</code>, optional</dt> |
|
|
1583 |
<dd>The score threshold for filtering edges with scores less than cutoff.</dd> |
|
|
1584 |
<dt><strong><code>visualize</code></strong> : <code>boolean</code></dt> |
|
|
1585 |
<dd>whether plot the current trajectory backbone (undirected graph)</dd> |
|
|
1586 |
</dl> |
|
|
1587 |
<h2 id="returns">Returns</h2> |
|
|
1588 |
<dl> |
|
|
1589 |
<dt><strong><code>G</code></strong> : <code>nx.Graph</code></dt> |
|
|
1590 |
<dd>The weighted graph with weight on each edge indicating its score of existence.</dd> |
|
|
1591 |
</dl></div> |
|
|
1592 |
</dd> |
|
|
1593 |
<dt id="VITAE.VITAE.select_root"><code class="name flex"> |
|
|
1594 |
<span>def <span class="ident">select_root</span></span>(<span>self, days, method: str = 'proportion')</span> |
|
|
1595 |
</code></dt> |
|
|
1596 |
<dd> |
|
|
1597 |
<div class="desc"><p>Order the vertices/states based on cells' collection time information to select the root state. |
|
|
1598 |
</p> |
|
|
1599 |
<h2 id="parameters">Parameters</h2> |
|
|
1600 |
<dl> |
|
|
1601 |
<dt><strong><code>day</code></strong> : <code>np.array </code></dt> |
|
|
1602 |
<dd>The day information for selected cells used to determine the root vertex. |
|
|
1603 |
The dtype should be 'int' or 'float'.</dd> |
|
|
1604 |
<dt><strong><code>method</code></strong> : <code>str</code>, optional</dt> |
|
|
1605 |
<dd>'sum' or 'mean'. |
|
|
1606 |
For 'proportion', the root is the one with maximal proportion of cells from the earliest day. |
|
|
1607 |
For 'mean', the root is the one with earliest mean time among cells associated with it.</dd> |
|
|
1608 |
</dl> |
|
|
1609 |
<h2 id="returns">Returns</h2> |
|
|
1610 |
<dl> |
|
|
1611 |
<dt><strong><code>root</code></strong> : <code>int </code></dt> |
|
|
1612 |
<dd>The root vertex in the inferred trajectory based on given day information.</dd> |
|
|
1613 |
</dl></div> |
|
|
1614 |
</dd> |
|
|
1615 |
<dt id="VITAE.VITAE.plot_backbone"><code class="name flex"> |
|
|
1616 |
<span>def <span class="ident">plot_backbone</span></span>(<span>self, directed: bool = False, method: str = 'UMAP', color='vitae_new_clustering', **kwargs)</span> |
|
|
1617 |
</code></dt> |
|
|
1618 |
<dd> |
|
|
1619 |
<div class="desc"><p>Plot the current trajectory backbone (undirected graph).</p> |
|
|
1620 |
<h2 id="parameters">Parameters</h2> |
|
|
1621 |
<dl> |
|
|
1622 |
<dt><strong><code>directed</code></strong> : <code>boolean</code>, optional</dt> |
|
|
1623 |
<dd>Whether the backbone is directed or not.</dd> |
|
|
1624 |
<dt><strong><code>method</code></strong> : <code>str</code>, optional</dt> |
|
|
1625 |
<dd>The dimension reduction method to use. The default is "UMAP".</dd> |
|
|
1626 |
<dt><strong><code>color</code></strong> : <code>str</code>, optional</dt> |
|
|
1627 |
<dd>The key for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2']. |
|
|
1628 |
The default is 'vitae_new_clustering'.</dd> |
|
|
1629 |
</dl> |
|
|
1630 |
<p>**kwargs : |
|
|
1631 |
Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX).</p></div> |
|
|
1632 |
</dd> |
|
|
1633 |
<dt id="VITAE.VITAE.plot_center"><code class="name flex"> |
|
|
1634 |
<span>def <span class="ident">plot_center</span></span>(<span>self, color='vitae_new_clustering', plot_legend=True, legend_add_index=True, method: str = 'UMAP', ncol=2, font_size='medium', add_egde=False, add_direct=False, **kwargs)</span> |
|
|
1635 |
</code></dt> |
|
|
1636 |
<dd> |
|
|
1637 |
<div class="desc"><p>Plot the center of each cluster in the latent space.</p> |
|
|
1638 |
<h2 id="parameters">Parameters</h2> |
|
|
1639 |
<dl> |
|
|
1640 |
<dt><strong><code>color</code></strong> : <code>str</code>, optional</dt> |
|
|
1641 |
<dd>The color of the center of each cluster. Default is "vitae_new_clustering".</dd> |
|
|
1642 |
<dt><strong><code>plot_legend</code></strong> : <code>bool</code>, optional</dt> |
|
|
1643 |
<dd>Whether to plot the legend. Default is True.</dd> |
|
|
1644 |
<dt><strong><code>legend_add_index</code></strong> : <code>bool</code>, optional</dt> |
|
|
1645 |
<dd>Whether to add the index of each cluster in the legend. Default is True.</dd> |
|
|
1646 |
<dt><strong><code>method</code></strong> : <code>str</code>, optional</dt> |
|
|
1647 |
<dd>The dimension reduction method used for visualization. Default is 'UMAP'.</dd> |
|
|
1648 |
<dt><strong><code>ncol</code></strong> : <code>int</code>, optional</dt> |
|
|
1649 |
<dd>The number of columns in the legend. Default is 2.</dd> |
|
|
1650 |
<dt><strong><code>font_size</code></strong> : <code>str</code>, optional</dt> |
|
|
1651 |
<dd>The font size of the legend. Default is "medium".</dd> |
|
|
1652 |
<dt><strong><code>add_egde</code></strong> : <code>bool</code>, optional</dt> |
|
|
1653 |
<dd>Whether to add the edges between the centers of clusters. Default is False.</dd> |
|
|
1654 |
<dt><strong><code>add_direct</code></strong> : <code>bool</code>, optional</dt> |
|
|
1655 |
<dd>Whether to add the direction of the edges. Default is False.</dd> |
|
|
1656 |
</dl></div> |
|
|
1657 |
</dd> |
|
|
1658 |
<dt id="VITAE.VITAE.infer_trajectory"><code class="name flex"> |
|
|
1659 |
<span>def <span class="ident">infer_trajectory</span></span>(<span>self, root: Union[int, str], digraph=None, color='pseudotime', visualize: bool = True, path_to_fig=None, **kwargs)</span> |
|
|
1660 |
</code></dt> |
|
|
1661 |
<dd> |
|
|
1662 |
<div class="desc"><p>Infer the trajectory.</p> |
|
|
1663 |
<h2 id="parameters">Parameters</h2> |
|
|
1664 |
<dl> |
|
|
1665 |
<dt><strong><code>root</code></strong> : <code>int</code> or <code>string</code></dt> |
|
|
1666 |
<dd>The root of the inferred trajectory. Can provide either an int (vertex index) or string (label name)</dd> |
|
|
1667 |
<dt><strong><code>digraph</code></strong> : <code>nx.DiGraph</code>, optional</dt> |
|
|
1668 |
<dd>The directed graph to be used for trajectory inference. If None, the minimum spanning tree of the estimated trajectory backbone will be used.</dd> |
|
|
1669 |
<dt><strong><code>cutoff</code></strong> : <code>string</code>, optional</dt> |
|
|
1670 |
<dd>The threshold for filtering edges with scores less than cutoff.</dd> |
|
|
1671 |
<dt><strong><code>visualize</code></strong> : <code>boolean</code></dt> |
|
|
1672 |
<dd>Whether plot the current trajectory backbone (directed graph)</dd> |
|
|
1673 |
<dt><strong><code>path_to_fig</code></strong> : <code>string</code>, optional</dt> |
|
|
1674 |
<dd>The path to save figure, or don't save if it is None.</dd> |
|
|
1675 |
<dt><strong><code>**kwargs</code></strong> : <code>dict</code>, optional</dt> |
|
|
1676 |
<dd>Other keywords arguments for plotting.</dd> |
|
|
1677 |
</dl></div> |
|
|
1678 |
</dd> |
|
|
1679 |
<dt id="VITAE.VITAE.differential_expression_test"><code class="name flex"> |
|
|
1680 |
<span>def <span class="ident">differential_expression_test</span></span>(<span>self, alpha: float = 0.05, cell_subset=None, order: int = 1)</span> |
|
|
1681 |
</code></dt> |
|
|
1682 |
<dd> |
|
|
1683 |
<div class="desc"><p>Differentially gene expression test. All (selected and unselected) genes will be tested |
|
|
1684 |
Only cells in <code>selected_cell_subset</code> will be used, which is useful when one need to |
|
|
1685 |
test differentially expressed genes on a branch of the inferred trajectory.</p> |
|
|
1686 |
<h2 id="parameters">Parameters</h2> |
|
|
1687 |
<dl> |
|
|
1688 |
<dt><strong><code>alpha</code></strong> : <code>float</code>, optional</dt> |
|
|
1689 |
<dd>The cutoff of p-values.</dd> |
|
|
1690 |
<dt><strong><code>cell_subset</code></strong> : <code>np.array</code>, optional</dt> |
|
|
1691 |
<dd>The subset of cells to be used for testing. If None, all cells will be used.</dd> |
|
|
1692 |
<dt><strong><code>order</code></strong> : <code>int</code>, optional</dt> |
|
|
1693 |
<dd>The maxium order we used for pseudotime in regression.</dd> |
|
|
1694 |
</dl> |
|
|
1695 |
<h2 id="returns">Returns</h2> |
|
|
1696 |
<dl> |
|
|
1697 |
<dt><strong><code>res_df</code></strong> : <code>pandas.DataFrame</code></dt> |
|
|
1698 |
<dd>The test results of expressed genes with two columns, |
|
|
1699 |
the estimated coefficients and the adjusted p-values.</dd> |
|
|
1700 |
</dl></div> |
|
|
1701 |
</dd> |
|
|
1702 |
<dt id="VITAE.VITAE.evaluate"><code class="name flex"> |
|
|
1703 |
<span>def <span class="ident">evaluate</span></span>(<span>self, milestone_net, begin_node_true, grouping=None, thres: float = 0.5, no_loop: bool = True, cutoff: Optional[float] = None, method: str = 'mean', path: Optional[str] = None)</span> |
|
|
1704 |
</code></dt> |
|
|
1705 |
<dd> |
|
|
1706 |
<div class="desc"><p>Evaluate the model.</p> |
|
|
1707 |
<h2 id="parameters">Parameters</h2> |
|
|
1708 |
<dl> |
|
|
1709 |
<dt><strong><code>milestone_net</code></strong> : <code>pd.DataFrame</code></dt> |
|
|
1710 |
<dd> |
|
|
1711 |
<p>The true milestone network. For real data, milestone_net will be a DataFrame of the graph of nodes. |
|
|
1712 |
Eg.</p> |
|
|
1713 |
<table> |
|
|
1714 |
<thead> |
|
|
1715 |
<tr> |
|
|
1716 |
<th>from</th> |
|
|
1717 |
<th>to</th> |
|
|
1718 |
</tr> |
|
|
1719 |
</thead> |
|
|
1720 |
<tbody> |
|
|
1721 |
<tr> |
|
|
1722 |
<td>cluster 1</td> |
|
|
1723 |
<td>cluster 1</td> |
|
|
1724 |
</tr> |
|
|
1725 |
<tr> |
|
|
1726 |
<td>cluster 1</td> |
|
|
1727 |
<td>cluster 2</td> |
|
|
1728 |
</tr> |
|
|
1729 |
</tbody> |
|
|
1730 |
</table> |
|
|
1731 |
<p>For synthetic data, milestone_net will be a DataFrame of the (projected) |
|
|
1732 |
positions of cells. The indexes are the orders of cells in the dataset. |
|
|
1733 |
Eg.</p> |
|
|
1734 |
<table> |
|
|
1735 |
<thead> |
|
|
1736 |
<tr> |
|
|
1737 |
<th>from</th> |
|
|
1738 |
<th>to</th> |
|
|
1739 |
<th>w</th> |
|
|
1740 |
</tr> |
|
|
1741 |
</thead> |
|
|
1742 |
<tbody> |
|
|
1743 |
<tr> |
|
|
1744 |
<td>cluster 1</td> |
|
|
1745 |
<td>cluster 1</td> |
|
|
1746 |
<td>1</td> |
|
|
1747 |
</tr> |
|
|
1748 |
<tr> |
|
|
1749 |
<td>cluster 1</td> |
|
|
1750 |
<td>cluster 2</td> |
|
|
1751 |
<td>0.1</td> |
|
|
1752 |
</tr> |
|
|
1753 |
</tbody> |
|
|
1754 |
</table> |
|
|
1755 |
</dd> |
|
|
1756 |
<dt><strong><code>begin_node_true</code></strong> : <code>str</code> or <code>int</code></dt> |
|
|
1757 |
<dd>The true begin node of the milestone.</dd> |
|
|
1758 |
<dt><strong><code>grouping</code></strong> : <code>np.array</code>, optional</dt> |
|
|
1759 |
<dd><span><span class="MathJax_Preview">[N,]</span><script type="math/tex">[N,]</script></span> The labels. For real data, grouping must be provided.</dd> |
|
|
1760 |
</dl> |
|
|
1761 |
<h2 id="returns">Returns</h2> |
|
|
1762 |
<dl> |
|
|
1763 |
<dt><strong><code>res</code></strong> : <code>pd.DataFrame</code></dt> |
|
|
1764 |
<dd>The evaluation result.</dd> |
|
|
1765 |
</dl></div> |
|
|
1766 |
</dd> |
|
|
1767 |
<dt id="VITAE.VITAE.save_model"><code class="name flex"> |
|
|
1768 |
<span>def <span class="ident">save_model</span></span>(<span>self, path_to_file: str = 'model.checkpoint', save_adata: bool = False)</span> |
|
|
1769 |
</code></dt> |
|
|
1770 |
<dd> |
|
|
1771 |
<div class="desc"><p>Saving model weights.</p> |
|
|
1772 |
<h2 id="parameters">Parameters</h2> |
|
|
1773 |
<dl> |
|
|
1774 |
<dt><strong><code>path_to_file</code></strong> : <code>str</code>, optional</dt> |
|
|
1775 |
<dd>The path to weight files of pre-trained or trained model</dd> |
|
|
1776 |
<dt><strong><code>save_adata</code></strong> : <code>boolean</code>, optional</dt> |
|
|
1777 |
<dd>Whether to save adata or not.</dd> |
|
|
1778 |
</dl></div> |
|
|
1779 |
</dd> |
|
|
1780 |
<dt id="VITAE.VITAE.load_model"><code class="name flex"> |
|
|
1781 |
<span>def <span class="ident">load_model</span></span>(<span>self, path_to_file: str = 'model.checkpoint', load_labels: bool = False, load_adata: bool = False)</span> |
|
|
1782 |
</code></dt> |
|
|
1783 |
<dd> |
|
|
1784 |
<div class="desc"><p>Load model weights.</p> |
|
|
1785 |
<h2 id="parameters">Parameters</h2> |
|
|
1786 |
<dl> |
|
|
1787 |
<dt><strong><code>path_to_file</code></strong> : <code>str</code>, optional</dt> |
|
|
1788 |
<dd>The path to weight files of pre trained or trained model</dd> |
|
|
1789 |
<dt><strong><code>load_labels</code></strong> : <code>boolean</code>, optional</dt> |
|
|
1790 |
<dd>Whether to load clustering labels or not. |
|
|
1791 |
If load_labels is True, then the LatentSpace layer will be initialized basd on the model. |
|
|
1792 |
If load_labels is False, then the LatentSpace layer will not be initialized.</dd> |
|
|
1793 |
<dt><strong><code>load_adata</code></strong> : <code>boolean</code>, optional</dt> |
|
|
1794 |
<dd>Whether to load adata or not.</dd> |
|
|
1795 |
</dl></div> |
|
|
1796 |
</dd> |
|
|
1797 |
</dl> |
|
|
1798 |
</dd> |
|
|
1799 |
</dl> |
|
|
1800 |
</section> |
|
|
1801 |
</article> |
|
|
1802 |
<nav id="sidebar"> |
|
|
1803 |
<div class="toc"> |
|
|
1804 |
<ul></ul> |
|
|
1805 |
</div> |
|
|
1806 |
<ul id="index"> |
|
|
1807 |
<li><h3><a href="#header-submodules">Sub-modules</a></h3> |
|
|
1808 |
<ul> |
|
|
1809 |
<li><code><a title="VITAE.inference" href="inference.html">VITAE.inference</a></code></li> |
|
|
1810 |
<li><code><a title="VITAE.metric" href="metric.html">VITAE.metric</a></code></li> |
|
|
1811 |
<li><code><a title="VITAE.model" href="model.html">VITAE.model</a></code></li> |
|
|
1812 |
<li><code><a title="VITAE.train" href="train.html">VITAE.train</a></code></li> |
|
|
1813 |
<li><code><a title="VITAE.utils" href="utils.html">VITAE.utils</a></code></li> |
|
|
1814 |
</ul> |
|
|
1815 |
</li> |
|
|
1816 |
<li><h3><a href="#header-classes">Classes</a></h3> |
|
|
1817 |
<ul> |
|
|
1818 |
<li> |
|
|
1819 |
<h4><code><a title="VITAE.VITAE" href="#VITAE.VITAE">VITAE</a></code></h4> |
|
|
1820 |
<ul class=""> |
|
|
1821 |
<li><code><a title="VITAE.VITAE.pre_train" href="#VITAE.VITAE.pre_train">pre_train</a></code></li> |
|
|
1822 |
<li><code><a title="VITAE.VITAE.update_z" href="#VITAE.VITAE.update_z">update_z</a></code></li> |
|
|
1823 |
<li><code><a title="VITAE.VITAE.get_latent_z" href="#VITAE.VITAE.get_latent_z">get_latent_z</a></code></li> |
|
|
1824 |
<li><code><a title="VITAE.VITAE.visualize_latent" href="#VITAE.VITAE.visualize_latent">visualize_latent</a></code></li> |
|
|
1825 |
<li><code><a title="VITAE.VITAE.init_latent_space" href="#VITAE.VITAE.init_latent_space">init_latent_space</a></code></li> |
|
|
1826 |
<li><code><a title="VITAE.VITAE.update_latent_space" href="#VITAE.VITAE.update_latent_space">update_latent_space</a></code></li> |
|
|
1827 |
<li><code><a title="VITAE.VITAE.train" href="#VITAE.VITAE.train">train</a></code></li> |
|
|
1828 |
<li><code><a title="VITAE.VITAE.output_pi" href="#VITAE.VITAE.output_pi">output_pi</a></code></li> |
|
|
1829 |
<li><code><a title="VITAE.VITAE.return_pilayer_weights" href="#VITAE.VITAE.return_pilayer_weights">return_pilayer_weights</a></code></li> |
|
|
1830 |
<li><code><a title="VITAE.VITAE.posterior_estimation" href="#VITAE.VITAE.posterior_estimation">posterior_estimation</a></code></li> |
|
|
1831 |
<li><code><a title="VITAE.VITAE.infer_backbone" href="#VITAE.VITAE.infer_backbone">infer_backbone</a></code></li> |
|
|
1832 |
<li><code><a title="VITAE.VITAE.select_root" href="#VITAE.VITAE.select_root">select_root</a></code></li> |
|
|
1833 |
<li><code><a title="VITAE.VITAE.plot_backbone" href="#VITAE.VITAE.plot_backbone">plot_backbone</a></code></li> |
|
|
1834 |
<li><code><a title="VITAE.VITAE.plot_center" href="#VITAE.VITAE.plot_center">plot_center</a></code></li> |
|
|
1835 |
<li><code><a title="VITAE.VITAE.infer_trajectory" href="#VITAE.VITAE.infer_trajectory">infer_trajectory</a></code></li> |
|
|
1836 |
<li><code><a title="VITAE.VITAE.differential_expression_test" href="#VITAE.VITAE.differential_expression_test">differential_expression_test</a></code></li> |
|
|
1837 |
<li><code><a title="VITAE.VITAE.evaluate" href="#VITAE.VITAE.evaluate">evaluate</a></code></li> |
|
|
1838 |
<li><code><a title="VITAE.VITAE.save_model" href="#VITAE.VITAE.save_model">save_model</a></code></li> |
|
|
1839 |
<li><code><a title="VITAE.VITAE.load_model" href="#VITAE.VITAE.load_model">load_model</a></code></li> |
|
|
1840 |
</ul> |
|
|
1841 |
</li> |
|
|
1842 |
</ul> |
|
|
1843 |
</li> |
|
|
1844 |
</ul> |
|
|
1845 |
</nav> |
|
|
1846 |
</main> |
|
|
1847 |
<footer id="footer"> |
|
|
1848 |
<p>Generated by <a href="https://pdoc3.github.io/pdoc" title="pdoc: Python API documentation generator"><cite>pdoc</cite> 0.11.1</a>.</p> |
|
|
1849 |
</footer> |
|
|
1850 |
</body> |
|
|
1851 |
</html> |