|
a |
|
b/VITAE/VITAE.py |
|
|
1 |
from typing import Optional, Union |
|
|
2 |
import warnings |
|
|
3 |
import os |
|
|
4 |
|
|
|
5 |
import numpy as np |
|
|
6 |
import pandas as pd |
|
|
7 |
from scipy import stats |
|
|
8 |
|
|
|
9 |
import VITAE.model as model |
|
|
10 |
import VITAE.train as train |
|
|
11 |
from VITAE.inference import Inferer |
|
|
12 |
from VITAE.utils import get_igraph, leidenalg_igraph, \ |
|
|
13 |
DE_test, _comp_dist, _get_smooth_curve |
|
|
14 |
from VITAE.metric import topology, get_GRI |
|
|
15 |
import tensorflow as tf |
|
|
16 |
|
|
|
17 |
from sklearn.metrics.cluster import adjusted_rand_score |
|
|
18 |
from sklearn.model_selection import train_test_split |
|
|
19 |
from sklearn.cluster import AgglomerativeClustering |
|
|
20 |
from sklearn.preprocessing import OneHotEncoder, StandardScaler, OrdinalEncoder |
|
|
21 |
|
|
|
22 |
import scanpy as sc |
|
|
23 |
import networkx as nx |
|
|
24 |
import matplotlib.pyplot as plt |
|
|
25 |
import matplotlib.patheffects as pe |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
class VITAE(): |
|
|
29 |
""" |
|
|
30 |
Variational Inference for Trajectory by AutoEncoder. |
|
|
31 |
""" |
|
|
32 |
def __init__(self, adata: sc.AnnData, |
|
|
33 |
covariates = None, pi_covariates = None, |
|
|
34 |
model_type: str = 'Gaussian', |
|
|
35 |
npc: int = 64, |
|
|
36 |
adata_layer_counts = None, |
|
|
37 |
copy_adata: bool = False, |
|
|
38 |
hidden_layers = [32], |
|
|
39 |
latent_space_dim: int = 16, |
|
|
40 |
conditions = None): |
|
|
41 |
''' |
|
|
42 |
Get input data for model. Data need to be first processed using scancy and stored as an AnnData object |
|
|
43 |
The 'UMI' or 'non-UMI' model need the original count matrix, so the count matrix need to be saved in |
|
|
44 |
adata.layers in order to use these models. |
|
|
45 |
|
|
|
46 |
|
|
|
47 |
Parameters |
|
|
48 |
---------- |
|
|
49 |
adata : sc.AnnData |
|
|
50 |
The scanpy AnnData object. adata should already contain adata.var.highly_variable |
|
|
51 |
covariates : list, optional |
|
|
52 |
A list of names of covariate vectors that are stored in adata.obs |
|
|
53 |
pi_covariates: list, optional |
|
|
54 |
A list of names of covariate vectors used as input for pilayer |
|
|
55 |
model_type : str, optional |
|
|
56 |
'UMI', 'non-UMI' and 'Gaussian', default is 'Gaussian'. |
|
|
57 |
npc : int, optional |
|
|
58 |
The number of PCs to use when model_type is 'Gaussian'. The default is 64. |
|
|
59 |
adata_layer_counts: str, optional |
|
|
60 |
the key name of adata.layers that stores the count data if model_type is |
|
|
61 |
'UMI' or 'non-UMI' |
|
|
62 |
copy_adata: bool, optional. Set to True if we don't want VITAE to modify the original adata. If set to True, self.adata will be an independent copy of the original adata. |
|
|
63 |
hidden_layers : list, optional |
|
|
64 |
The list of dimensions of layers of autoencoder between latent space and original space. Default is to have only one hidden layer with 32 nodes |
|
|
65 |
latent_space_dim : int, optional |
|
|
66 |
The dimension of latent space. |
|
|
67 |
gamme : float, optional |
|
|
68 |
The weight of the MMD loss |
|
|
69 |
conditions : str or list, optional |
|
|
70 |
The conditions of different cells |
|
|
71 |
|
|
|
72 |
|
|
|
73 |
Returns |
|
|
74 |
------- |
|
|
75 |
None. |
|
|
76 |
|
|
|
77 |
''' |
|
|
78 |
self.dict_method_scname = { |
|
|
79 |
'PCA' : 'X_pca', |
|
|
80 |
'UMAP' : 'X_umap', |
|
|
81 |
'TSNE' : 'X_tsne', |
|
|
82 |
'diffmap' : 'X_diffmap', |
|
|
83 |
'draw_graph' : 'X_draw_graph_fa' |
|
|
84 |
} |
|
|
85 |
|
|
|
86 |
if model_type != 'Gaussian': |
|
|
87 |
if adata_layer_counts is None: |
|
|
88 |
raise ValueError("need to provide the name in adata.layers that stores the raw count data") |
|
|
89 |
if 'highly_variable' not in adata.var: |
|
|
90 |
raise ValueError("need to first select highly variable genes using scanpy") |
|
|
91 |
|
|
|
92 |
self.model_type = model_type |
|
|
93 |
|
|
|
94 |
if copy_adata: |
|
|
95 |
self.adata = adata.copy() |
|
|
96 |
else: |
|
|
97 |
self.adata = adata |
|
|
98 |
|
|
|
99 |
if covariates is not None: |
|
|
100 |
if isinstance(covariates, str): |
|
|
101 |
covariates = [covariates] |
|
|
102 |
covariates = np.array(covariates) |
|
|
103 |
id_cat = (adata.obs[covariates].dtypes == 'category') |
|
|
104 |
# add OneHotEncoder & StandardScaler as class variable if needed |
|
|
105 |
if np.sum(id_cat)>0: |
|
|
106 |
covariates_cat = OneHotEncoder(drop='if_binary', handle_unknown='ignore' |
|
|
107 |
).fit_transform(adata.obs[covariates[id_cat]]).toarray() |
|
|
108 |
else: |
|
|
109 |
covariates_cat = np.array([]).reshape(adata.shape[0],0) |
|
|
110 |
|
|
|
111 |
# temporarily disable StandardScaler |
|
|
112 |
if np.sum(~id_cat)>0: |
|
|
113 |
#covariates_con = StandardScaler().fit_transform(adata.obs[covariates[~id_cat]]) |
|
|
114 |
covariates_con = adata.obs[covariates[~id_cat]] |
|
|
115 |
else: |
|
|
116 |
covariates_con = np.array([]).reshape(adata.shape[0],0) |
|
|
117 |
|
|
|
118 |
self.covariates = np.c_[covariates_cat, covariates_con].astype(tf.keras.backend.floatx()) |
|
|
119 |
else: |
|
|
120 |
self.covariates = None |
|
|
121 |
|
|
|
122 |
if conditions is not None: |
|
|
123 |
## observations with np.nan will not participant in calculating mmd_loss |
|
|
124 |
if isinstance(conditions, str): |
|
|
125 |
conditions = [conditions] |
|
|
126 |
conditions = np.array(conditions) |
|
|
127 |
if np.any(adata.obs[conditions].dtypes != 'category'): |
|
|
128 |
raise ValueError("Conditions should all be categorical.") |
|
|
129 |
|
|
|
130 |
self.conditions = OrdinalEncoder(dtype=int, encoded_missing_value=-1).fit_transform(adata.obs[conditions]) + int(1) |
|
|
131 |
else: |
|
|
132 |
self.conditions = None |
|
|
133 |
|
|
|
134 |
if pi_covariates is not None: |
|
|
135 |
self.pi_cov = adata.obs[pi_covariates].to_numpy() |
|
|
136 |
if self.pi_cov.ndim == 1: |
|
|
137 |
self.pi_cov = self.pi_cov.reshape(-1, 1) |
|
|
138 |
self.pi_cov = self.pi_cov.astype(tf.keras.backend.floatx()) |
|
|
139 |
else: |
|
|
140 |
self.pi_cov = np.zeros((adata.shape[0],1), dtype=tf.keras.backend.floatx()) |
|
|
141 |
|
|
|
142 |
self.model_type = model_type |
|
|
143 |
self._adata = sc.AnnData(X = self.adata.X, var = self.adata.var) |
|
|
144 |
self._adata.obs = self.adata.obs |
|
|
145 |
self._adata.uns = self.adata.uns |
|
|
146 |
|
|
|
147 |
|
|
|
148 |
if model_type == 'Gaussian': |
|
|
149 |
sc.tl.pca(adata, n_comps = npc) |
|
|
150 |
self.X_input = self.X_output = adata.obsm['X_pca'] |
|
|
151 |
self.scale_factor = np.ones(self.X_output.shape[0]) |
|
|
152 |
else: |
|
|
153 |
print(f"{adata.var.highly_variable.sum()} highly variable genes selected as input") |
|
|
154 |
self.X_input = adata.X[:, adata.var.highly_variable] |
|
|
155 |
self.X_output = adata.layers[adata_layer_counts][ :, adata.var.highly_variable] |
|
|
156 |
self.scale_factor = np.sum(self.X_output, axis=1, keepdims=True)/1e4 |
|
|
157 |
|
|
|
158 |
self.dimensions = hidden_layers |
|
|
159 |
self.dim_latent = latent_space_dim |
|
|
160 |
|
|
|
161 |
self.vae = model.VariationalAutoEncoder( |
|
|
162 |
self.X_output.shape[1], self.dimensions, |
|
|
163 |
self.dim_latent, self.model_type, |
|
|
164 |
False if self.covariates is None else True, |
|
|
165 |
) |
|
|
166 |
|
|
|
167 |
if hasattr(self, 'inferer'): |
|
|
168 |
delattr(self, 'inferer') |
|
|
169 |
|
|
|
170 |
|
|
|
171 |
def pre_train(self, test_size = 0.1, random_state: int = 0, |
|
|
172 |
learning_rate: float = 1e-3, batch_size: int = 256, L: int = 1, alpha: float = 0.10, gamma: float = 0, |
|
|
173 |
phi : float = 1,num_epoch: int = 200, num_step_per_epoch: Optional[int] = None, |
|
|
174 |
early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, |
|
|
175 |
early_stopping_relative: bool = True, verbose: bool = False,path_to_weights: Optional[str] = None): |
|
|
176 |
'''Pretrain the model with specified learning rate. |
|
|
177 |
|
|
|
178 |
Parameters |
|
|
179 |
---------- |
|
|
180 |
test_size : float or int, optional |
|
|
181 |
The proportion or size of the test set. |
|
|
182 |
random_state : int, optional |
|
|
183 |
The random state for data splitting. |
|
|
184 |
learning_rate : float, optional |
|
|
185 |
The initial learning rate for the Adam optimizer. |
|
|
186 |
batch_size : int, optional |
|
|
187 |
The batch size for pre-training. Default is 256. Set to 32 if number of cells is small (less than 1000) |
|
|
188 |
L : int, optional |
|
|
189 |
The number of MC samples. |
|
|
190 |
alpha : float, optional |
|
|
191 |
The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates. |
|
|
192 |
gamma : float, optional |
|
|
193 |
The weight of the mmd loss if used. |
|
|
194 |
phi : float, optional |
|
|
195 |
The weight of Jocob norm of the encoder. |
|
|
196 |
num_epoch : int, optional |
|
|
197 |
The maximum number of epochs. |
|
|
198 |
num_step_per_epoch : int, optional |
|
|
199 |
The number of step per epoch, it will be inferred from number of cells and batch size if it is None. |
|
|
200 |
early_stopping_patience : int, optional |
|
|
201 |
The maximum number of epochs if there is no improvement. |
|
|
202 |
early_stopping_tolerance : float, optional |
|
|
203 |
The minimum change of loss to be considered as an improvement. |
|
|
204 |
early_stopping_relative : bool, optional |
|
|
205 |
Whether monitor the relative change of loss as stopping criteria or not. |
|
|
206 |
path_to_weights : str, optional |
|
|
207 |
The path of weight file to be saved; not saving weight if None. |
|
|
208 |
conditions : str or list, optional |
|
|
209 |
The conditions of different cells |
|
|
210 |
''' |
|
|
211 |
|
|
|
212 |
id_train, id_test = train_test_split( |
|
|
213 |
np.arange(self.X_input.shape[0]), |
|
|
214 |
test_size=test_size, |
|
|
215 |
random_state=random_state) |
|
|
216 |
if num_step_per_epoch is None: |
|
|
217 |
num_step_per_epoch = len(id_train)//batch_size+1 |
|
|
218 |
self.train_dataset = train.warp_dataset(self.X_input[id_train].astype(tf.keras.backend.floatx()), |
|
|
219 |
None if self.covariates is None else self.covariates[id_train].astype(tf.keras.backend.floatx()), |
|
|
220 |
batch_size, |
|
|
221 |
self.X_output[id_train].astype(tf.keras.backend.floatx()), |
|
|
222 |
self.scale_factor[id_train].astype(tf.keras.backend.floatx()), |
|
|
223 |
conditions = None if self.conditions is None else self.conditions[id_train].astype(tf.keras.backend.floatx())) |
|
|
224 |
self.test_dataset = train.warp_dataset(self.X_input[id_test], |
|
|
225 |
None if self.covariates is None else self.covariates[id_test].astype(tf.keras.backend.floatx()), |
|
|
226 |
batch_size, |
|
|
227 |
self.X_output[id_test].astype(tf.keras.backend.floatx()), |
|
|
228 |
self.scale_factor[id_test].astype(tf.keras.backend.floatx()), |
|
|
229 |
conditions = None if self.conditions is None else self.conditions[id_test].astype(tf.keras.backend.floatx())) |
|
|
230 |
|
|
|
231 |
self.vae = train.pre_train( |
|
|
232 |
self.train_dataset, |
|
|
233 |
self.test_dataset, |
|
|
234 |
self.vae, |
|
|
235 |
learning_rate, |
|
|
236 |
L, alpha, gamma, phi, |
|
|
237 |
num_epoch, |
|
|
238 |
num_step_per_epoch, |
|
|
239 |
early_stopping_patience, |
|
|
240 |
early_stopping_tolerance, |
|
|
241 |
early_stopping_relative, |
|
|
242 |
verbose) |
|
|
243 |
|
|
|
244 |
self.update_z() |
|
|
245 |
|
|
|
246 |
if path_to_weights is not None: |
|
|
247 |
self.save_model(path_to_weights) |
|
|
248 |
|
|
|
249 |
|
|
|
250 |
def update_z(self): |
|
|
251 |
self.z = self.get_latent_z() |
|
|
252 |
self._adata_z = sc.AnnData(self.z) |
|
|
253 |
sc.pp.neighbors(self._adata_z) |
|
|
254 |
|
|
|
255 |
|
|
|
256 |
def get_latent_z(self): |
|
|
257 |
''' get the posterier mean of current latent space z (encoder output) |
|
|
258 |
|
|
|
259 |
Returns |
|
|
260 |
---------- |
|
|
261 |
z : np.array |
|
|
262 |
\([N,d]\) The latent means. |
|
|
263 |
''' |
|
|
264 |
c = None if self.covariates is None else self.covariates |
|
|
265 |
return self.vae.get_z(self.X_input, c) |
|
|
266 |
|
|
|
267 |
|
|
|
268 |
def visualize_latent(self, method: str = "UMAP", |
|
|
269 |
color = None, **kwargs): |
|
|
270 |
''' |
|
|
271 |
visualize the current latent space z using the scanpy visualization tools |
|
|
272 |
|
|
|
273 |
Parameters |
|
|
274 |
---------- |
|
|
275 |
method : str, optional |
|
|
276 |
Visualization method to use. The default is "draw_graph" (the FA plot). Possible choices include "PCA", "UMAP", |
|
|
277 |
"diffmap", "TSNE" and "draw_graph" |
|
|
278 |
color : TYPE, optional |
|
|
279 |
Keys for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2']. |
|
|
280 |
The default is None. Same as scanpy. |
|
|
281 |
**kwargs : |
|
|
282 |
Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX). |
|
|
283 |
|
|
|
284 |
Returns |
|
|
285 |
------- |
|
|
286 |
None. |
|
|
287 |
|
|
|
288 |
''' |
|
|
289 |
|
|
|
290 |
if method not in ['PCA', 'UMAP', 'TSNE', 'diffmap', 'draw_graph']: |
|
|
291 |
raise ValueError("visualization method should be one of 'PCA', 'UMAP', 'TSNE', 'diffmap' and 'draw_graph'") |
|
|
292 |
|
|
|
293 |
temp = list(self._adata_z.obsm.keys()) |
|
|
294 |
if method == 'PCA' and not 'X_pca' in temp: |
|
|
295 |
print("Calculate PCs ...") |
|
|
296 |
sc.tl.pca(self._adata_z) |
|
|
297 |
elif method == 'UMAP' and not 'X_umap' in temp: |
|
|
298 |
print("Calculate UMAP ...") |
|
|
299 |
sc.tl.umap(self._adata_z) |
|
|
300 |
elif method == 'TSNE' and not 'X_tsne' in temp: |
|
|
301 |
print("Calculate TSNE ...") |
|
|
302 |
sc.tl.tsne(self._adata_z) |
|
|
303 |
elif method == 'diffmap' and not 'X_diffmap' in temp: |
|
|
304 |
print("Calculate diffusion map ...") |
|
|
305 |
sc.tl.diffmap(self._adata_z) |
|
|
306 |
elif method == 'draw_graph' and not 'X_draw_graph_fa' in temp: |
|
|
307 |
print("Calculate FA ...") |
|
|
308 |
sc.tl.draw_graph(self._adata_z) |
|
|
309 |
|
|
|
310 |
|
|
|
311 |
self._adata.obs = self.adata.obs.copy() |
|
|
312 |
self._adata.obsp = self._adata_z.obsp |
|
|
313 |
# self._adata.uns = self._adata_z.uns |
|
|
314 |
self._adata.obsm = self._adata_z.obsm |
|
|
315 |
|
|
|
316 |
if method == 'PCA': |
|
|
317 |
axes = sc.pl.pca(self._adata, color = color, **kwargs) |
|
|
318 |
elif method == 'UMAP': |
|
|
319 |
axes = sc.pl.umap(self._adata, color = color, **kwargs) |
|
|
320 |
elif method == 'TSNE': |
|
|
321 |
axes = sc.pl.tsne(self._adata, color = color, **kwargs) |
|
|
322 |
elif method == 'diffmap': |
|
|
323 |
axes = sc.pl.diffmap(self._adata, color = color, **kwargs) |
|
|
324 |
elif method == 'draw_graph': |
|
|
325 |
axes = sc.pl.draw_graph(self._adata, color = color, **kwargs) |
|
|
326 |
return axes |
|
|
327 |
|
|
|
328 |
|
|
|
329 |
def init_latent_space(self, cluster_label = None, log_pi = None, res: float = 1.0, |
|
|
330 |
ratio_prune= None, dist = None, dist_thres = 0.5, topk=0, pilayer = False): |
|
|
331 |
'''Initialize the latent space. |
|
|
332 |
|
|
|
333 |
Parameters |
|
|
334 |
---------- |
|
|
335 |
cluster_label : str, optional |
|
|
336 |
The name of vector of labels that can be found in self.adata.obs. |
|
|
337 |
Default is None, which will perform leiden clustering on the pretrained z to get clusters |
|
|
338 |
mu : np.array, optional |
|
|
339 |
\([d,k]\) The value of initial \(\\mu\). |
|
|
340 |
log_pi : np.array, optional |
|
|
341 |
\([1,K]\) The value of initial \(\\log(\\pi)\). |
|
|
342 |
res: |
|
|
343 |
The resolution of leiden clustering, which is a parameter value controlling the coarseness of the clustering. |
|
|
344 |
Higher values lead to more clusters. Deafult is 1. |
|
|
345 |
ratio_prune : float, optional |
|
|
346 |
The ratio of edges to be removed before estimating. |
|
|
347 |
topk : int, optional |
|
|
348 |
The number of top k neighbors to keep for each cluster. |
|
|
349 |
''' |
|
|
350 |
|
|
|
351 |
|
|
|
352 |
if cluster_label is None: |
|
|
353 |
print("Perform leiden clustering on the latent space z ...") |
|
|
354 |
g = get_igraph(self.z) |
|
|
355 |
cluster_labels = leidenalg_igraph(g, res = res) |
|
|
356 |
cluster_labels = cluster_labels.astype(str) |
|
|
357 |
uni_cluster_labels = np.unique(cluster_labels) |
|
|
358 |
else: |
|
|
359 |
if isinstance(cluster_label,str): |
|
|
360 |
cluster_labels = self.adata.obs[cluster_label].to_numpy() |
|
|
361 |
uni_cluster_labels = np.array(self.adata.obs[cluster_label].cat.categories) |
|
|
362 |
else: |
|
|
363 |
## if cluster_label is a list |
|
|
364 |
cluster_labels = cluster_label |
|
|
365 |
uni_cluster_labels = np.unique(cluster_labels) |
|
|
366 |
|
|
|
367 |
n_clusters = len(uni_cluster_labels) |
|
|
368 |
|
|
|
369 |
if not hasattr(self, 'z'): |
|
|
370 |
self.update_z() |
|
|
371 |
z = self.z |
|
|
372 |
mu = np.zeros((z.shape[1], n_clusters)) |
|
|
373 |
for i,l in enumerate(uni_cluster_labels): |
|
|
374 |
mu[:,i] = np.mean(z[cluster_labels==l], axis=0) |
|
|
375 |
|
|
|
376 |
if dist is None: |
|
|
377 |
### update cluster centers if some cluster centers are too close |
|
|
378 |
clustering = AgglomerativeClustering( |
|
|
379 |
n_clusters=None, |
|
|
380 |
distance_threshold=dist_thres, |
|
|
381 |
linkage='complete' |
|
|
382 |
).fit(mu.T/np.sqrt(mu.shape[0])) |
|
|
383 |
n_clusters_new = clustering.n_clusters_ |
|
|
384 |
if n_clusters_new < n_clusters: |
|
|
385 |
print("Merge clusters for cluster centers that are too close ...") |
|
|
386 |
n_clusters = n_clusters_new |
|
|
387 |
for i in range(n_clusters): |
|
|
388 |
temp = uni_cluster_labels[clustering.labels_ == i] |
|
|
389 |
idx = np.isin(cluster_labels, temp) |
|
|
390 |
cluster_labels[idx] = ','.join(temp) |
|
|
391 |
if np.sum(clustering.labels_==i)>1: |
|
|
392 |
print('Merge %s'% ','.join(temp)) |
|
|
393 |
uni_cluster_labels = np.unique(cluster_labels) |
|
|
394 |
mu = np.zeros((z.shape[1], n_clusters)) |
|
|
395 |
for i,l in enumerate(uni_cluster_labels): |
|
|
396 |
mu[:,i] = np.mean(z[cluster_labels==l], axis=0) |
|
|
397 |
|
|
|
398 |
self.adata.obs['vitae_init_clustering'] = cluster_labels |
|
|
399 |
self.adata.obs['vitae_init_clustering'] = self.adata.obs['vitae_init_clustering'].astype('category') |
|
|
400 |
print("Initial clustering labels saved as 'vitae_init_clustering' in self.adata.obs.") |
|
|
401 |
|
|
|
402 |
if (log_pi is None) and (cluster_labels is not None) and (n_clusters>3): |
|
|
403 |
n_states = int((n_clusters+1)*n_clusters/2) |
|
|
404 |
|
|
|
405 |
if dist is None: |
|
|
406 |
dist = _comp_dist(z, cluster_labels, mu.T) |
|
|
407 |
|
|
|
408 |
C = np.triu(np.ones(n_clusters)) |
|
|
409 |
C[C>0] = np.arange(n_states) |
|
|
410 |
C = C + C.T - np.diag(np.diag(C)) |
|
|
411 |
C = C.astype(int) |
|
|
412 |
|
|
|
413 |
log_pi = np.zeros((1,n_states)) |
|
|
414 |
|
|
|
415 |
## pruning to throw away edges for far-away clusters if there are too many clusters |
|
|
416 |
if ratio_prune is not None: |
|
|
417 |
log_pi[0, C[np.triu(dist)>np.quantile(dist[np.triu_indices(n_clusters, 1)], 1-ratio_prune)]] = - np.inf |
|
|
418 |
else: |
|
|
419 |
log_pi[0, C[np.triu(dist)>np.quantile(dist[np.triu_indices(n_clusters, 1)], 5/n_clusters) * 3]] = - np.inf |
|
|
420 |
|
|
|
421 |
## also keep the top k neighbor of clusters |
|
|
422 |
topk = max(0, min(topk, n_clusters-1)) + 1 |
|
|
423 |
topk_indices = np.argsort(dist,axis=1)[:,:topk] |
|
|
424 |
for i in range(n_clusters): |
|
|
425 |
log_pi[0, C[i, topk_indices[i]]] = 0 |
|
|
426 |
|
|
|
427 |
self.n_states = n_clusters |
|
|
428 |
self.labels = cluster_labels |
|
|
429 |
|
|
|
430 |
labels_map = pd.DataFrame.from_dict( |
|
|
431 |
{i:label for i,label in enumerate(uni_cluster_labels)}, |
|
|
432 |
orient='index', columns=['label_names'], dtype=str |
|
|
433 |
) |
|
|
434 |
|
|
|
435 |
self.labels_map = labels_map |
|
|
436 |
self.vae.init_latent_space(self.n_states, mu, log_pi) |
|
|
437 |
self.inferer = Inferer(self.n_states) |
|
|
438 |
self.mu = self.vae.latent_space.mu.numpy() |
|
|
439 |
self.pi = np.triu(np.ones(self.n_states)) |
|
|
440 |
self.pi[self.pi > 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0] |
|
|
441 |
|
|
|
442 |
if pilayer: |
|
|
443 |
self.vae.create_pilayer() |
|
|
444 |
|
|
|
445 |
|
|
|
446 |
def update_latent_space(self, dist_thres: float=0.5): |
|
|
447 |
pi = self.pi[np.triu_indices(self.n_states)] |
|
|
448 |
mu = self.mu |
|
|
449 |
clustering = AgglomerativeClustering( |
|
|
450 |
n_clusters=None, |
|
|
451 |
distance_threshold=dist_thres, |
|
|
452 |
linkage='complete' |
|
|
453 |
).fit(mu.T/np.sqrt(mu.shape[0])) |
|
|
454 |
n_clusters = clustering.n_clusters_ |
|
|
455 |
|
|
|
456 |
if n_clusters<self.n_states: |
|
|
457 |
print("Merge clusters for cluster centers that are too close ...") |
|
|
458 |
mu_new = np.empty((self.dim_latent, n_clusters)) |
|
|
459 |
C = np.zeros((self.n_states, self.n_states)) |
|
|
460 |
C[np.triu_indices(self.n_states, 0)] = pi |
|
|
461 |
C = np.triu(C, 1) + C.T |
|
|
462 |
C_new = np.zeros((n_clusters, n_clusters)) |
|
|
463 |
|
|
|
464 |
uni_cluster_labels = self.labels_map['label_names'].to_numpy() |
|
|
465 |
returned_order = {} |
|
|
466 |
cluster_labels = self.labels |
|
|
467 |
for i in range(n_clusters): |
|
|
468 |
temp = uni_cluster_labels[clustering.labels_ == i] |
|
|
469 |
idx = np.isin(cluster_labels, temp) |
|
|
470 |
cluster_labels[idx] = ','.join(temp) |
|
|
471 |
returned_order[i] = ','.join(temp) |
|
|
472 |
if np.sum(clustering.labels_==i)>1: |
|
|
473 |
print('Merge %s'% ','.join(temp)) |
|
|
474 |
uni_cluster_labels = np.unique(cluster_labels) |
|
|
475 |
for i,l in enumerate(uni_cluster_labels): ## reorder the merged clusters based on the cluster names |
|
|
476 |
k = np.where(returned_order == l) |
|
|
477 |
mu_new[:, i] = np.mean(mu[:,clustering.labels_==k], axis=-1) |
|
|
478 |
# sum of the aggregated pi's |
|
|
479 |
C_new[i, i] = np.sum(np.triu(C[clustering.labels_==k,:][:,clustering.labels_==k])) |
|
|
480 |
for j in range(i+1, n_clusters): |
|
|
481 |
k1 = np.where(returned_order == uni_cluster_labels[j]) |
|
|
482 |
C_new[i, j] = np.sum(C[clustering.labels_== k, :][:, clustering.labels_==k1]) |
|
|
483 |
|
|
|
484 |
# labels_map_new = {} |
|
|
485 |
# for i in range(n_clusters): |
|
|
486 |
# # update label map: int->str |
|
|
487 |
# labels_map_new[i] = self.labels_map.loc[clustering.labels_==i, 'label_names'].str.cat(sep=',') |
|
|
488 |
# if np.sum(clustering.labels_==i)>1: |
|
|
489 |
# print('Merge %s'%labels_map_new[i]) |
|
|
490 |
# # mean of the aggregated cluster means |
|
|
491 |
# mu_new[:, i] = np.mean(mu[:,clustering.labels_==i], axis=-1) |
|
|
492 |
# # sum of the aggregated pi's |
|
|
493 |
# C_new[i, i] = np.sum(np.triu(C[clustering.labels_==i,:][:,clustering.labels_==i])) |
|
|
494 |
# for j in range(i+1, n_clusters): |
|
|
495 |
# C_new[i, j] = np.sum(C[clustering.labels_== i, :][:, clustering.labels_==j]) |
|
|
496 |
C_new = np.triu(C_new,1) + C_new.T |
|
|
497 |
|
|
|
498 |
pi_new = C_new[np.triu_indices(n_clusters)] |
|
|
499 |
log_pi_new = np.log(pi_new, out=np.ones_like(pi_new)*(-np.inf), where=(pi_new!=0)).reshape((1,-1)) |
|
|
500 |
self.n_states = n_clusters |
|
|
501 |
self.labels_map = pd.DataFrame.from_dict( |
|
|
502 |
{i:label for i,label in enumerate(uni_cluster_labels)}, |
|
|
503 |
orient='index', columns=['label_names'], dtype=str |
|
|
504 |
) |
|
|
505 |
self.labels = cluster_labels |
|
|
506 |
# self.labels_map = pd.DataFrame.from_dict( |
|
|
507 |
# labels_map_new, orient='index', columns=['label_names'], dtype=str |
|
|
508 |
# ) |
|
|
509 |
self.vae.init_latent_space(self.n_states, mu_new, log_pi_new) |
|
|
510 |
self.inferer = Inferer(self.n_states) |
|
|
511 |
self.mu = self.vae.latent_space.mu.numpy() |
|
|
512 |
self.pi = np.triu(np.ones(self.n_states)) |
|
|
513 |
self.pi[self.pi > 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0] |
|
|
514 |
|
|
|
515 |
|
|
|
516 |
|
|
|
517 |
def train(self, stratify = False, test_size = 0.1, random_state: int = 0, |
|
|
518 |
learning_rate: float = 1e-3, batch_size: int = 256, |
|
|
519 |
L: int = 1, alpha: float = 0.10, beta: float = 1, gamma: float = 0, phi: float = 1, |
|
|
520 |
num_epoch: int = 200, num_step_per_epoch: Optional[int] = None, |
|
|
521 |
early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, |
|
|
522 |
early_stopping_relative: bool = True, early_stopping_warmup: int = 0, |
|
|
523 |
path_to_weights: Optional[str] = None, |
|
|
524 |
verbose: bool = False, **kwargs): |
|
|
525 |
'''Train the model. |
|
|
526 |
|
|
|
527 |
Parameters |
|
|
528 |
---------- |
|
|
529 |
stratify : np.array, None, or False |
|
|
530 |
If an array is provided, or `stratify=None` and `self.labels` is available, then they will be used to perform stratified shuffle splitting. Otherwise, general shuffle splitting is used. Set to `False` if `self.labels` is not intended for stratified shuffle splitting. |
|
|
531 |
test_size : float or int, optional |
|
|
532 |
The proportion or size of the test set. |
|
|
533 |
random_state : int, optional |
|
|
534 |
The random state for data splitting. |
|
|
535 |
learning_rate : float, optional |
|
|
536 |
The initial learning rate for the Adam optimizer. |
|
|
537 |
batch_size : int, optional |
|
|
538 |
The batch size for training. Default is 256. Set to 32 if number of cells is small (less than 1000) |
|
|
539 |
L : int, optional |
|
|
540 |
The number of MC samples. |
|
|
541 |
alpha : float, optional |
|
|
542 |
The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates. |
|
|
543 |
beta : float, optional |
|
|
544 |
The value of beta in beta-VAE. |
|
|
545 |
gamma : float, optional |
|
|
546 |
The weight of mmd_loss. |
|
|
547 |
phi : float, optional |
|
|
548 |
The weight of Jacob norm of encoder. |
|
|
549 |
num_epoch : int, optional |
|
|
550 |
The number of epoch. |
|
|
551 |
num_step_per_epoch : int, optional |
|
|
552 |
The number of step per epoch, it will be inferred from number of cells and batch size if it is None. |
|
|
553 |
early_stopping_patience : int, optional |
|
|
554 |
The maximum number of epochs if there is no improvement. |
|
|
555 |
early_stopping_tolerance : float, optional |
|
|
556 |
The minimum change of loss to be considered as an improvement. |
|
|
557 |
early_stopping_relative : bool, optional |
|
|
558 |
Whether monitor the relative change of loss or not. |
|
|
559 |
early_stopping_warmup : int, optional |
|
|
560 |
The number of warmup epochs. |
|
|
561 |
path_to_weights : str, optional |
|
|
562 |
The path of weight file to be saved; not saving weight if None. |
|
|
563 |
**kwargs : |
|
|
564 |
Extra key-value arguments for dimension reduction algorithms. |
|
|
565 |
''' |
|
|
566 |
if gamma == 0 or self.conditions is None: |
|
|
567 |
conditions = np.array([np.nan] * self.adata.shape[0]) |
|
|
568 |
else: |
|
|
569 |
conditions = self.conditions |
|
|
570 |
|
|
|
571 |
if stratify is None: |
|
|
572 |
stratify = self.labels |
|
|
573 |
elif stratify is False: |
|
|
574 |
stratify = None |
|
|
575 |
id_train, id_test = train_test_split( |
|
|
576 |
np.arange(self.X_input.shape[0]), |
|
|
577 |
test_size=test_size, |
|
|
578 |
stratify=stratify, |
|
|
579 |
random_state=random_state) |
|
|
580 |
if num_step_per_epoch is None: |
|
|
581 |
num_step_per_epoch = len(id_train)//batch_size+1 |
|
|
582 |
c = None if self.covariates is None else self.covariates.astype(tf.keras.backend.floatx()) |
|
|
583 |
self.train_dataset = train.warp_dataset(self.X_input[id_train].astype(tf.keras.backend.floatx()), |
|
|
584 |
None if c is None else c[id_train], |
|
|
585 |
batch_size, |
|
|
586 |
self.X_output[id_train].astype(tf.keras.backend.floatx()), |
|
|
587 |
self.scale_factor[id_train].astype(tf.keras.backend.floatx()), |
|
|
588 |
conditions = conditions[id_train], |
|
|
589 |
pi_cov = self.pi_cov[id_train]) |
|
|
590 |
self.test_dataset = train.warp_dataset(self.X_input[id_test].astype(tf.keras.backend.floatx()), |
|
|
591 |
None if c is None else c[id_test], |
|
|
592 |
batch_size, |
|
|
593 |
self.X_output[id_test].astype(tf.keras.backend.floatx()), |
|
|
594 |
self.scale_factor[id_test].astype(tf.keras.backend.floatx()), |
|
|
595 |
conditions = conditions[id_test], |
|
|
596 |
pi_cov = self.pi_cov[id_test]) |
|
|
597 |
|
|
|
598 |
self.vae = train.train( |
|
|
599 |
self.train_dataset, |
|
|
600 |
self.test_dataset, |
|
|
601 |
self.vae, |
|
|
602 |
learning_rate, |
|
|
603 |
L, |
|
|
604 |
alpha, |
|
|
605 |
beta, |
|
|
606 |
gamma, |
|
|
607 |
phi, |
|
|
608 |
num_epoch, |
|
|
609 |
num_step_per_epoch, |
|
|
610 |
early_stopping_patience, |
|
|
611 |
early_stopping_tolerance, |
|
|
612 |
early_stopping_relative, |
|
|
613 |
early_stopping_warmup, |
|
|
614 |
verbose, |
|
|
615 |
**kwargs |
|
|
616 |
) |
|
|
617 |
|
|
|
618 |
self.update_z() |
|
|
619 |
self.mu = self.vae.latent_space.mu.numpy() |
|
|
620 |
self.pi = np.triu(np.ones(self.n_states)) |
|
|
621 |
self.pi[self.pi > 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0] |
|
|
622 |
|
|
|
623 |
if path_to_weights is not None: |
|
|
624 |
self.save_model(path_to_weights) |
|
|
625 |
|
|
|
626 |
|
|
|
627 |
def output_pi(self, pi_cov): |
|
|
628 |
"""return a matrix n_states by n_states and a mask for plotting, which can be used to cover the lower triangular(except the diagnoals) of a heatmap""" |
|
|
629 |
p = self.vae.pilayer |
|
|
630 |
pi_cov = tf.expand_dims(tf.constant([pi_cov], dtype=tf.float32), 0) |
|
|
631 |
pi_val = tf.nn.softmax(p(pi_cov)).numpy()[0] |
|
|
632 |
# Create heatmap matrix |
|
|
633 |
n = self.vae.n_states |
|
|
634 |
matrix = np.zeros((n, n)) |
|
|
635 |
matrix[np.triu_indices(n)] = pi_val |
|
|
636 |
mask = np.tril(np.ones_like(matrix), k=-1) |
|
|
637 |
return matrix, mask |
|
|
638 |
|
|
|
639 |
|
|
|
640 |
def return_pilayer_weights(self): |
|
|
641 |
"""return parameters of pilayer, which has dimension dim(pi_cov) + 1 by n_categories, the last row is biases""" |
|
|
642 |
return np.vstack((model.vae.pilayer.weights[0].numpy(), model.vae.pilayer.weights[1].numpy().reshape(1, -1))) |
|
|
643 |
|
|
|
644 |
|
|
|
645 |
def posterior_estimation(self, batch_size: int = 32, L: int = 50, **kwargs): |
|
|
646 |
'''Initialize trajectory inference by computing the posterior estimations. |
|
|
647 |
|
|
|
648 |
Parameters |
|
|
649 |
---------- |
|
|
650 |
batch_size : int, optional |
|
|
651 |
The batch size when doing inference. |
|
|
652 |
L : int, optional |
|
|
653 |
The number of MC samples when doing inference. |
|
|
654 |
**kwargs : |
|
|
655 |
Extra key-value arguments for dimension reduction algorithms. |
|
|
656 |
''' |
|
|
657 |
c = None if self.covariates is None else self.covariates.astype(tf.keras.backend.floatx()) |
|
|
658 |
self.test_dataset = train.warp_dataset(self.X_input.astype(tf.keras.backend.floatx()), |
|
|
659 |
c, |
|
|
660 |
batch_size) |
|
|
661 |
_, _, self.pc_x,\ |
|
|
662 |
self.cell_position_posterior,self.cell_position_variance,_ = self.vae.inference(self.test_dataset, L=L) |
|
|
663 |
|
|
|
664 |
uni_cluster_labels = self.labels_map['label_names'].to_numpy() |
|
|
665 |
self.adata.obs['vitae_new_clustering'] = uni_cluster_labels[np.argmax(self.cell_position_posterior, 1)] |
|
|
666 |
self.adata.obs['vitae_new_clustering'] = self.adata.obs['vitae_new_clustering'].astype('category') |
|
|
667 |
print("New clustering labels saved as 'vitae_new_clustering' in self.adata.obs.") |
|
|
668 |
return None |
|
|
669 |
|
|
|
670 |
|
|
|
671 |
def infer_backbone(self, method: str = 'modified_map', thres = 0.5, |
|
|
672 |
no_loop: bool = True, cutoff: float = 0, |
|
|
673 |
visualize: bool = True, color = 'vitae_new_clustering',path_to_fig = None,**kwargs): |
|
|
674 |
''' Compute edge scores. |
|
|
675 |
|
|
|
676 |
Parameters |
|
|
677 |
---------- |
|
|
678 |
method : string, optional |
|
|
679 |
'mean', 'modified_mean', 'map', or 'modified_map'. |
|
|
680 |
thres : float, optional |
|
|
681 |
The threshold used for filtering edges \(e_{ij}\) that \((n_{i}+n_{j}+e_{ij})/N<thres\), only applied to mean method. |
|
|
682 |
no_loop : boolean, optional |
|
|
683 |
Whether loops are allowed to exist in the graph. If no_loop is true, will prune the graph to contain only the |
|
|
684 |
maximum spanning true |
|
|
685 |
cutoff : string, optional |
|
|
686 |
The score threshold for filtering edges with scores less than cutoff. |
|
|
687 |
visualize: boolean |
|
|
688 |
whether plot the current trajectory backbone (undirected graph) |
|
|
689 |
|
|
|
690 |
Returns |
|
|
691 |
---------- |
|
|
692 |
G : nx.Graph |
|
|
693 |
The weighted graph with weight on each edge indicating its score of existence. |
|
|
694 |
''' |
|
|
695 |
# build_graph, return graph |
|
|
696 |
self.backbone = self.inferer.build_graphs(self.cell_position_posterior, self.pc_x, |
|
|
697 |
method, thres, no_loop, cutoff) |
|
|
698 |
self.cell_position_projected = self.inferer.modify_wtilde(self.cell_position_posterior, |
|
|
699 |
np.array(list(self.backbone.edges))) |
|
|
700 |
|
|
|
701 |
uni_cluster_labels = self.labels_map['label_names'].to_numpy() |
|
|
702 |
temp_dict = {i:label for i,label in enumerate(uni_cluster_labels)} |
|
|
703 |
nx.relabel_nodes(self.backbone, temp_dict) |
|
|
704 |
|
|
|
705 |
self.adata.obs['vitae_new_clustering'] = uni_cluster_labels[np.argmax(self.cell_position_projected, 1)] |
|
|
706 |
self.adata.obs['vitae_new_clustering'] = self.adata.obs['vitae_new_clustering'].astype('category') |
|
|
707 |
print("'vitae_new_clustering' updated based on the projected cell positions.") |
|
|
708 |
|
|
|
709 |
self.uncertainty = np.sum((self.cell_position_projected - self.cell_position_posterior)**2, axis=-1) \ |
|
|
710 |
+ np.sum(self.cell_position_variance, axis=-1) |
|
|
711 |
self.adata.obs['projection_uncertainty'] = self.uncertainty |
|
|
712 |
print("Cell projection uncertainties stored as 'projection_uncertainty' in self.adata.obs") |
|
|
713 |
if visualize: |
|
|
714 |
self._adata.obs = self.adata.obs.copy() |
|
|
715 |
self.ax = self.plot_backbone(directed = False,color = color, **kwargs) |
|
|
716 |
if path_to_fig is not None: |
|
|
717 |
self.ax.figure.savefig(path_to_fig) |
|
|
718 |
self.ax.figure.show() |
|
|
719 |
return None |
|
|
720 |
|
|
|
721 |
|
|
|
722 |
def select_root(self, days, method: str = 'proportion'): |
|
|
723 |
'''Order the vertices/states based on cells' collection time information to select the root state. |
|
|
724 |
|
|
|
725 |
Parameters |
|
|
726 |
---------- |
|
|
727 |
day : np.array |
|
|
728 |
The day information for selected cells used to determine the root vertex. |
|
|
729 |
The dtype should be 'int' or 'float'. |
|
|
730 |
method : str, optional |
|
|
731 |
'sum' or 'mean'. |
|
|
732 |
For 'proportion', the root is the one with maximal proportion of cells from the earliest day. |
|
|
733 |
For 'mean', the root is the one with earliest mean time among cells associated with it. |
|
|
734 |
|
|
|
735 |
Returns |
|
|
736 |
---------- |
|
|
737 |
root : int |
|
|
738 |
The root vertex in the inferred trajectory based on given day information. |
|
|
739 |
''' |
|
|
740 |
## TODO: change return description |
|
|
741 |
if days is not None and len(days)!=self.X_input.shape[0]: |
|
|
742 |
raise ValueError("The length of day information ({}) is not " |
|
|
743 |
"consistent with the number of selected cells ({})!".format( |
|
|
744 |
len(days), self.X_input.shape[0])) |
|
|
745 |
if not hasattr(self, 'cell_position_projected'): |
|
|
746 |
raise ValueError("Need to call 'infer_backbone' first!") |
|
|
747 |
|
|
|
748 |
collection_time = np.dot(days, self.cell_position_projected)/np.sum(self.cell_position_projected, axis = 0) |
|
|
749 |
earliest_prop = np.dot(days==np.min(days), self.cell_position_projected)/np.sum(self.cell_position_projected, axis = 0) |
|
|
750 |
|
|
|
751 |
root_info = self.labels_map.copy() |
|
|
752 |
root_info['mean_collection_time'] = collection_time |
|
|
753 |
root_info['earliest_time_prop'] = earliest_prop |
|
|
754 |
root_info.sort_values('mean_collection_time', inplace=True) |
|
|
755 |
return root_info |
|
|
756 |
|
|
|
757 |
|
|
|
758 |
def plot_backbone(self, directed: bool = False, |
|
|
759 |
method: str = 'UMAP', color = 'vitae_new_clustering', **kwargs): |
|
|
760 |
'''Plot the current trajectory backbone (undirected graph). |
|
|
761 |
|
|
|
762 |
Parameters |
|
|
763 |
---------- |
|
|
764 |
directed : boolean, optional |
|
|
765 |
Whether the backbone is directed or not. |
|
|
766 |
method : str, optional |
|
|
767 |
The dimension reduction method to use. The default is "UMAP". |
|
|
768 |
color : str, optional |
|
|
769 |
The key for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2']. |
|
|
770 |
The default is 'vitae_new_clustering'. |
|
|
771 |
**kwargs : |
|
|
772 |
Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX). |
|
|
773 |
''' |
|
|
774 |
if not isinstance(color,str): |
|
|
775 |
raise ValueError('The color argument should be of type str!') |
|
|
776 |
ax = self.visualize_latent(method = method, color=color, show=False, **kwargs) |
|
|
777 |
dict_label_num = {j:i for i,j in self.labels_map['label_names'].to_dict().items()} |
|
|
778 |
uni_cluster_labels = self.adata.obs['vitae_init_clustering'].cat.categories |
|
|
779 |
cluster_labels = self.adata.obs['vitae_new_clustering'].to_numpy() |
|
|
780 |
embed_z = self._adata.obsm[self.dict_method_scname[method]] |
|
|
781 |
embed_mu = np.zeros((len(uni_cluster_labels), 2)) |
|
|
782 |
for l in uni_cluster_labels: |
|
|
783 |
embed_mu[dict_label_num[l],:] = np.mean(embed_z[cluster_labels==l], axis=0) |
|
|
784 |
|
|
|
785 |
if directed: |
|
|
786 |
graph = self.directed_backbone |
|
|
787 |
else: |
|
|
788 |
graph = self.backbone |
|
|
789 |
edges = list(graph.edges) |
|
|
790 |
edge_scores = np.array([d['weight'] for (u,v,d) in graph.edges(data=True)]) |
|
|
791 |
if max(edge_scores) - min(edge_scores) == 0: |
|
|
792 |
edge_scores = edge_scores/max(edge_scores) |
|
|
793 |
else: |
|
|
794 |
edge_scores = (edge_scores - min(edge_scores))/(max(edge_scores) - min(edge_scores))*3 |
|
|
795 |
|
|
|
796 |
value_range = np.maximum(np.diff(ax.get_xlim())[0], np.diff(ax.get_ylim())[0]) |
|
|
797 |
y_range = np.min(embed_z[:,1]), np.max(embed_z[:,1], axis=0) |
|
|
798 |
for i in range(len(edges)): |
|
|
799 |
points = embed_z[np.sum(self.cell_position_projected[:, edges[i]]>0, axis=-1)==2,:] |
|
|
800 |
points = points[points[:,0].argsort()] |
|
|
801 |
try: |
|
|
802 |
x_smooth, y_smooth = _get_smooth_curve( |
|
|
803 |
points, |
|
|
804 |
embed_mu[edges[i], :], |
|
|
805 |
y_range |
|
|
806 |
) |
|
|
807 |
except: |
|
|
808 |
x_smooth, y_smooth = embed_mu[edges[i], 0], embed_mu[edges[i], 1] |
|
|
809 |
ax.plot(x_smooth, y_smooth, |
|
|
810 |
'-', |
|
|
811 |
linewidth= 1 + edge_scores[i], |
|
|
812 |
color="black", |
|
|
813 |
alpha=0.8, |
|
|
814 |
path_effects=[pe.Stroke(linewidth=1+edge_scores[i]+1.5, |
|
|
815 |
foreground='white'), pe.Normal()], |
|
|
816 |
zorder=1 |
|
|
817 |
) |
|
|
818 |
|
|
|
819 |
if directed: |
|
|
820 |
delta_x = embed_mu[edges[i][1], 0] - x_smooth[-2] |
|
|
821 |
delta_y = embed_mu[edges[i][1], 1] - y_smooth[-2] |
|
|
822 |
length = np.sqrt(delta_x**2 + delta_y**2) / 50 * value_range |
|
|
823 |
ax.arrow( |
|
|
824 |
embed_mu[edges[i][1], 0]-delta_x/length, |
|
|
825 |
embed_mu[edges[i][1], 1]-delta_y/length, |
|
|
826 |
delta_x/length, |
|
|
827 |
delta_y/length, |
|
|
828 |
color='black', alpha=1.0, |
|
|
829 |
shape='full', lw=0, length_includes_head=True, |
|
|
830 |
head_width=np.maximum(0.01*(1 + edge_scores[i]), 0.03) * value_range, |
|
|
831 |
zorder=2) |
|
|
832 |
|
|
|
833 |
colors = self._adata.uns['vitae_new_clustering_colors'] |
|
|
834 |
|
|
|
835 |
for i,l in enumerate(uni_cluster_labels): |
|
|
836 |
ax.scatter(*embed_mu[dict_label_num[l]:dict_label_num[l]+1,:].T, |
|
|
837 |
c=[colors[i]], edgecolors='white', # linewidths=10, norm=norm, |
|
|
838 |
s=250, marker='*', label=l) |
|
|
839 |
|
|
|
840 |
plt.setp(ax, xticks=[], yticks=[]) |
|
|
841 |
box = ax.get_position() |
|
|
842 |
ax.set_position([box.x0, box.y0 + box.height * 0.1, |
|
|
843 |
box.width, box.height * 0.9]) |
|
|
844 |
if directed: |
|
|
845 |
ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), |
|
|
846 |
fancybox=True, shadow=True, ncol=5) |
|
|
847 |
|
|
|
848 |
return ax |
|
|
849 |
|
|
|
850 |
|
|
|
851 |
def plot_center(self, color = "vitae_new_clustering", plot_legend = True, legend_add_index = True, |
|
|
852 |
method: str = 'UMAP',ncol = 2,font_size = "medium", |
|
|
853 |
add_egde = False, add_direct = False,**kwargs): |
|
|
854 |
'''Plot the center of each cluster in the latent space. |
|
|
855 |
|
|
|
856 |
Parameters |
|
|
857 |
---------- |
|
|
858 |
color : str, optional |
|
|
859 |
The color of the center of each cluster. Default is "vitae_new_clustering". |
|
|
860 |
plot_legend : bool, optional |
|
|
861 |
Whether to plot the legend. Default is True. |
|
|
862 |
legend_add_index : bool, optional |
|
|
863 |
Whether to add the index of each cluster in the legend. Default is True. |
|
|
864 |
method : str, optional |
|
|
865 |
The dimension reduction method used for visualization. Default is 'UMAP'. |
|
|
866 |
ncol : int, optional |
|
|
867 |
The number of columns in the legend. Default is 2. |
|
|
868 |
font_size : str, optional |
|
|
869 |
The font size of the legend. Default is "medium". |
|
|
870 |
add_egde : bool, optional |
|
|
871 |
Whether to add the edges between the centers of clusters. Default is False. |
|
|
872 |
add_direct : bool, optional |
|
|
873 |
Whether to add the direction of the edges. Default is False. |
|
|
874 |
''' |
|
|
875 |
if color not in ["vitae_new_clustering","vitae_init_clustering"]: |
|
|
876 |
raise ValueError("Can only plot center of vitae_new_clustering or vitae_init_clustering") |
|
|
877 |
dict_label_num = {j: i for i, j in self.labels_map['label_names'].to_dict().items()} |
|
|
878 |
if legend_add_index: |
|
|
879 |
self._adata.obs["index_"+color] = self._adata.obs[color].map(lambda x: dict_label_num[x]) |
|
|
880 |
ax = self.visualize_latent(method=method, color="index_" + color, show=False, legend_loc="on data", |
|
|
881 |
legend_fontsize=font_size,**kwargs) |
|
|
882 |
colors = self._adata.uns["index_" + color + '_colors'] |
|
|
883 |
else: |
|
|
884 |
ax = self.visualize_latent(method=method, color = color, show=False,**kwargs) |
|
|
885 |
colors = self._adata.uns[color + '_colors'] |
|
|
886 |
uni_cluster_labels = self.adata.obs[color].cat.categories |
|
|
887 |
cluster_labels = self.adata.obs[color].to_numpy() |
|
|
888 |
embed_z = self._adata.obsm[self.dict_method_scname[method]] |
|
|
889 |
embed_mu = np.zeros((len(uni_cluster_labels), 2)) |
|
|
890 |
for l in uni_cluster_labels: |
|
|
891 |
embed_mu[dict_label_num[l], :] = np.mean(embed_z[cluster_labels == l], axis=0) |
|
|
892 |
|
|
|
893 |
leg = (self.labels_map.index.astype(str) + " : " + self.labels_map.label_names).values |
|
|
894 |
for i, l in enumerate(uni_cluster_labels): |
|
|
895 |
ax.scatter(*embed_mu[dict_label_num[l]:dict_label_num[l] + 1, :].T, |
|
|
896 |
c=[colors[i]], edgecolors='white', # linewidths=3, |
|
|
897 |
s=250, marker='*', label=leg[i]) |
|
|
898 |
if plot_legend: |
|
|
899 |
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=ncol, markerscale=0.8, frameon=False) |
|
|
900 |
plt.setp(ax, xticks=[], yticks=[]) |
|
|
901 |
box = ax.get_position() |
|
|
902 |
ax.set_position([box.x0, box.y0 + box.height * 0.1, |
|
|
903 |
box.width, box.height * 0.9]) |
|
|
904 |
if add_egde: |
|
|
905 |
if add_direct: |
|
|
906 |
graph = self.directed_backbone |
|
|
907 |
else: |
|
|
908 |
graph = self.backbone |
|
|
909 |
edges = list(graph.edges) |
|
|
910 |
edge_scores = np.array([d['weight'] for (u, v, d) in graph.edges(data=True)]) |
|
|
911 |
if max(edge_scores) - min(edge_scores) == 0: |
|
|
912 |
edge_scores = edge_scores / max(edge_scores) |
|
|
913 |
else: |
|
|
914 |
edge_scores = (edge_scores - min(edge_scores)) / (max(edge_scores) - min(edge_scores)) * 3 |
|
|
915 |
|
|
|
916 |
value_range = np.maximum(np.diff(ax.get_xlim())[0], np.diff(ax.get_ylim())[0]) |
|
|
917 |
y_range = np.min(embed_z[:, 1]), np.max(embed_z[:, 1], axis=0) |
|
|
918 |
for i in range(len(edges)): |
|
|
919 |
points = embed_z[np.sum(self.cell_position_projected[:, edges[i]] > 0, axis=-1) == 2, :] |
|
|
920 |
points = points[points[:, 0].argsort()] |
|
|
921 |
try: |
|
|
922 |
x_smooth, y_smooth = _get_smooth_curve( |
|
|
923 |
points, |
|
|
924 |
embed_mu[edges[i], :], |
|
|
925 |
y_range |
|
|
926 |
) |
|
|
927 |
except: |
|
|
928 |
x_smooth, y_smooth = embed_mu[edges[i], 0], embed_mu[edges[i], 1] |
|
|
929 |
ax.plot(x_smooth, y_smooth, |
|
|
930 |
'-', |
|
|
931 |
linewidth=1 + edge_scores[i], |
|
|
932 |
color="black", |
|
|
933 |
alpha=0.8, |
|
|
934 |
path_effects=[pe.Stroke(linewidth=1 + edge_scores[i] + 1.5, |
|
|
935 |
foreground='white'), pe.Normal()], |
|
|
936 |
zorder=1 |
|
|
937 |
) |
|
|
938 |
|
|
|
939 |
if add_direct: |
|
|
940 |
delta_x = embed_mu[edges[i][1], 0] - x_smooth[-2] |
|
|
941 |
delta_y = embed_mu[edges[i][1], 1] - y_smooth[-2] |
|
|
942 |
length = np.sqrt(delta_x ** 2 + delta_y ** 2) / 50 * value_range |
|
|
943 |
ax.arrow( |
|
|
944 |
embed_mu[edges[i][1], 0] - delta_x / length, |
|
|
945 |
embed_mu[edges[i][1], 1] - delta_y / length, |
|
|
946 |
delta_x / length, |
|
|
947 |
delta_y / length, |
|
|
948 |
color='black', alpha=1.0, |
|
|
949 |
shape='full', lw=0, length_includes_head=True, |
|
|
950 |
head_width=np.maximum(0.01 * (1 + edge_scores[i]), 0.03) * value_range, |
|
|
951 |
zorder=2) |
|
|
952 |
self.ax = ax |
|
|
953 |
self.ax.figure.show() |
|
|
954 |
return None |
|
|
955 |
|
|
|
956 |
|
|
|
957 |
def infer_trajectory(self, root: Union[int,str], digraph = None, color = "pseudotime", |
|
|
958 |
visualize: bool = True, path_to_fig = None, **kwargs): |
|
|
959 |
'''Infer the trajectory. |
|
|
960 |
|
|
|
961 |
Parameters |
|
|
962 |
---------- |
|
|
963 |
root : int or string |
|
|
964 |
The root of the inferred trajectory. Can provide either an int (vertex index) or string (label name) |
|
|
965 |
digraph : nx.DiGraph, optional |
|
|
966 |
The directed graph to be used for trajectory inference. If None, the minimum spanning tree of the estimated trajectory backbone will be used. |
|
|
967 |
cutoff : string, optional |
|
|
968 |
The threshold for filtering edges with scores less than cutoff. |
|
|
969 |
visualize: boolean |
|
|
970 |
Whether plot the current trajectory backbone (directed graph) |
|
|
971 |
path_to_fig : string, optional |
|
|
972 |
The path to save figure, or don't save if it is None. |
|
|
973 |
**kwargs : dict, optional |
|
|
974 |
Other keywords arguments for plotting. |
|
|
975 |
''' |
|
|
976 |
if isinstance(root,str): |
|
|
977 |
if root not in self.labels_map.values: |
|
|
978 |
raise ValueError("Root {} is not in the label names!".format(root)) |
|
|
979 |
root = self.labels_map[self.labels_map['label_names']==root].index[0] |
|
|
980 |
|
|
|
981 |
if digraph is None: |
|
|
982 |
connected_comps = nx.node_connected_component(self.backbone, root) |
|
|
983 |
subG = self.backbone.subgraph(connected_comps) |
|
|
984 |
|
|
|
985 |
## generate directed backbone which contains no loops |
|
|
986 |
DG = nx.DiGraph(nx.to_directed(self.backbone)) |
|
|
987 |
temp = DG.subgraph(connected_comps) |
|
|
988 |
DG.remove_edges_from(temp.edges - nx.dfs_edges(DG, root)) |
|
|
989 |
self.directed_backbone = DG |
|
|
990 |
else: |
|
|
991 |
if not nx.is_directed_acyclic_graph(digraph): |
|
|
992 |
raise ValueError("The graph 'digraph' should be a directed acyclic graph.") |
|
|
993 |
if set(digraph.nodes) != set(self.backbone.nodes): |
|
|
994 |
raise ValueError("The nodes in 'digraph' do not match the nodes in 'self.backbone'.") |
|
|
995 |
self.directed_backbone = digraph |
|
|
996 |
|
|
|
997 |
connected_comps = nx.node_connected_component(digraph, root) |
|
|
998 |
subG = self.backbone.subgraph(connected_comps) |
|
|
999 |
|
|
|
1000 |
|
|
|
1001 |
if len(subG.edges)>0: |
|
|
1002 |
milestone_net = self.inferer.build_milestone_net(subG, root) |
|
|
1003 |
if self.inferer.no_loop is False and milestone_net.shape[0]<len(self.backbone.edges): |
|
|
1004 |
warnings.warn("The directed graph shown is a minimum spanning tree of the estimated trajectory backbone to avoid arbitrary assignment of the directions.") |
|
|
1005 |
self.pseudotime = self.inferer.comp_pseudotime(milestone_net, root, self.cell_position_projected) |
|
|
1006 |
else: |
|
|
1007 |
warnings.warn("There are no connected states for starting from the giving root.") |
|
|
1008 |
self.pseudotime = -np.ones(self._adata.shape[0]) |
|
|
1009 |
|
|
|
1010 |
self.adata.obs['pseudotime'] = self.pseudotime |
|
|
1011 |
print("Cell projection uncertainties stored as 'pseudotime' in self.adata.obs") |
|
|
1012 |
|
|
|
1013 |
if visualize: |
|
|
1014 |
self._adata.obs['pseudotime'] = self.pseudotime |
|
|
1015 |
self.ax = self.plot_backbone(directed = True, color = color, **kwargs) |
|
|
1016 |
if path_to_fig is not None: |
|
|
1017 |
self.ax.figure.savefig(path_to_fig) |
|
|
1018 |
self.ax.figure.show() |
|
|
1019 |
|
|
|
1020 |
return None |
|
|
1021 |
|
|
|
1022 |
|
|
|
1023 |
|
|
|
1024 |
def differential_expression_test(self, alpha: float = 0.05, cell_subset = None, order: int = 1): |
|
|
1025 |
'''Differentially gene expression test. All (selected and unselected) genes will be tested |
|
|
1026 |
Only cells in `selected_cell_subset` will be used, which is useful when one need to |
|
|
1027 |
test differentially expressed genes on a branch of the inferred trajectory. |
|
|
1028 |
|
|
|
1029 |
Parameters |
|
|
1030 |
---------- |
|
|
1031 |
alpha : float, optional |
|
|
1032 |
The cutoff of p-values. |
|
|
1033 |
cell_subset : np.array, optional |
|
|
1034 |
The subset of cells to be used for testing. If None, all cells will be used. |
|
|
1035 |
order : int, optional |
|
|
1036 |
The maxium order we used for pseudotime in regression. |
|
|
1037 |
|
|
|
1038 |
Returns |
|
|
1039 |
---------- |
|
|
1040 |
res_df : pandas.DataFrame |
|
|
1041 |
The test results of expressed genes with two columns, |
|
|
1042 |
the estimated coefficients and the adjusted p-values. |
|
|
1043 |
''' |
|
|
1044 |
if not hasattr(self, 'pseudotime'): |
|
|
1045 |
raise ReferenceError("Pseudotime does not exist! Please run 'infer_trajectory' first.") |
|
|
1046 |
if cell_subset is None: |
|
|
1047 |
cell_subset = np.arange(self.X_input.shape[0]) |
|
|
1048 |
print("All cells are selected.") |
|
|
1049 |
if order < 1: |
|
|
1050 |
raise ValueError("Maximal order of pseudotime in regression must be at least 1.") |
|
|
1051 |
|
|
|
1052 |
# Prepare X and Y for regression expression ~ rank(PDT) + covariates |
|
|
1053 |
Y = self.adata.X[cell_subset,:] |
|
|
1054 |
# std_Y = np.std(Y, ddof=1, axis=0, keepdims=True) |
|
|
1055 |
# Y = np.divide(Y-np.mean(Y, axis=0, keepdims=True), std_Y, out=np.empty_like(Y)*np.nan, where=std_Y!=0) |
|
|
1056 |
X = stats.rankdata(self.pseudotime[cell_subset]) |
|
|
1057 |
if order > 1: |
|
|
1058 |
for _order in range(2, order+1): |
|
|
1059 |
X = np.c_[X, X**_order] |
|
|
1060 |
X = ((X-np.mean(X,axis=0, keepdims=True))/np.std(X, ddof=1, axis=0, keepdims=True)) |
|
|
1061 |
X = np.c_[np.ones((X.shape[0],1)), X] |
|
|
1062 |
if self.covariates is not None: |
|
|
1063 |
X = np.c_[X, self.covariates[cell_subset, :]] |
|
|
1064 |
|
|
|
1065 |
res_df = DE_test(Y, X, self.adata.var_names, i_test = np.array(list(range(1,order+1))), alpha = alpha) |
|
|
1066 |
return res_df[res_df.pvalue_adjusted_1 != 0] |
|
|
1067 |
|
|
|
1068 |
|
|
|
1069 |
|
|
|
1070 |
|
|
|
1071 |
def evaluate(self, milestone_net, begin_node_true, grouping = None, |
|
|
1072 |
thres: float = 0.5, no_loop: bool = True, cutoff: Optional[float] = None, |
|
|
1073 |
method: str = 'mean', path: Optional[str] = None): |
|
|
1074 |
''' Evaluate the model. |
|
|
1075 |
|
|
|
1076 |
Parameters |
|
|
1077 |
---------- |
|
|
1078 |
milestone_net : pd.DataFrame |
|
|
1079 |
The true milestone network. For real data, milestone_net will be a DataFrame of the graph of nodes. |
|
|
1080 |
Eg. |
|
|
1081 |
|
|
|
1082 |
from|to |
|
|
1083 |
---|--- |
|
|
1084 |
cluster 1 | cluster 1 |
|
|
1085 |
cluster 1 | cluster 2 |
|
|
1086 |
|
|
|
1087 |
For synthetic data, milestone_net will be a DataFrame of the (projected) |
|
|
1088 |
positions of cells. The indexes are the orders of cells in the dataset. |
|
|
1089 |
Eg. |
|
|
1090 |
|
|
|
1091 |
from|to|w |
|
|
1092 |
---|---|--- |
|
|
1093 |
cluster 1 | cluster 1 | 1 |
|
|
1094 |
cluster 1 | cluster 2 | 0.1 |
|
|
1095 |
begin_node_true : str or int |
|
|
1096 |
The true begin node of the milestone. |
|
|
1097 |
grouping : np.array, optional |
|
|
1098 |
\([N,]\) The labels. For real data, grouping must be provided. |
|
|
1099 |
|
|
|
1100 |
Returns |
|
|
1101 |
---------- |
|
|
1102 |
res : pd.DataFrame |
|
|
1103 |
The evaluation result. |
|
|
1104 |
''' |
|
|
1105 |
if not hasattr(self, 'labels_map'): |
|
|
1106 |
raise ValueError("No given labels for training.") |
|
|
1107 |
|
|
|
1108 |
''' |
|
|
1109 |
# Evaluate for the whole dataset will ignore selected_cell_subset. |
|
|
1110 |
if len(self.selected_cell_subset)!=len(self.cell_names): |
|
|
1111 |
warnings.warn("Evaluate for the whole dataset.") |
|
|
1112 |
''' |
|
|
1113 |
|
|
|
1114 |
# If the begin_node_true, need to encode it by self.le. |
|
|
1115 |
# this dict is for milestone net cause their labels are not merged |
|
|
1116 |
# all keys of label_map_dict are str |
|
|
1117 |
label_map_dict = dict() |
|
|
1118 |
for i in range(self.labels_map.shape[0]): |
|
|
1119 |
label_mapped = self.labels_map.loc[i] |
|
|
1120 |
## merged cluster index is connected by comma |
|
|
1121 |
for each in label_mapped.values[0].split(","): |
|
|
1122 |
label_map_dict[each] = i |
|
|
1123 |
if isinstance(begin_node_true, str): |
|
|
1124 |
begin_node_true = label_map_dict[begin_node_true] |
|
|
1125 |
|
|
|
1126 |
# For generated data, grouping information is already in milestone_net |
|
|
1127 |
if 'w' in milestone_net.columns: |
|
|
1128 |
grouping = None |
|
|
1129 |
|
|
|
1130 |
# If milestone_net is provided, transform them to be numeric. |
|
|
1131 |
if milestone_net is not None: |
|
|
1132 |
milestone_net['from'] = [label_map_dict[x] for x in milestone_net["from"]] |
|
|
1133 |
milestone_net['to'] = [label_map_dict[x] for x in milestone_net["to"]] |
|
|
1134 |
|
|
|
1135 |
# this dict is for potentially merged clusters. |
|
|
1136 |
label_map_dict_for_merged_cluster = dict(zip(self.labels_map["label_names"],self.labels_map.index)) |
|
|
1137 |
mapped_labels = np.array([label_map_dict_for_merged_cluster[x] for x in self.labels]) |
|
|
1138 |
begin_node_pred = int(np.argmin(np.mean(( |
|
|
1139 |
self.z[mapped_labels==begin_node_true,:,np.newaxis] - |
|
|
1140 |
self.mu[np.newaxis,:,:])**2, axis=(0,1)))) |
|
|
1141 |
|
|
|
1142 |
if cutoff is None: |
|
|
1143 |
cutoff = 0.01 |
|
|
1144 |
|
|
|
1145 |
G = self.backbone |
|
|
1146 |
w = self.cell_position_projected |
|
|
1147 |
pseudotime = self.pseudotime |
|
|
1148 |
|
|
|
1149 |
# 1. Topology |
|
|
1150 |
G_pred = nx.Graph() |
|
|
1151 |
G_pred.add_nodes_from(G.nodes) |
|
|
1152 |
G_pred.add_edges_from(G.edges) |
|
|
1153 |
nx.set_node_attributes(G_pred, False, 'is_init') |
|
|
1154 |
G_pred.nodes[begin_node_pred]['is_init'] = True |
|
|
1155 |
|
|
|
1156 |
G_true = nx.Graph() |
|
|
1157 |
G_true.add_nodes_from(G.nodes) |
|
|
1158 |
# if 'grouping' is not provided, assume 'milestone_net' contains proportions |
|
|
1159 |
if grouping is None: |
|
|
1160 |
G_true.add_edges_from(list( |
|
|
1161 |
milestone_net[~pd.isna(milestone_net['w'])].groupby(['from', 'to']).count().index)) |
|
|
1162 |
# otherwise, 'milestone_net' indicates edges |
|
|
1163 |
else: |
|
|
1164 |
if milestone_net is not None: |
|
|
1165 |
G_true.add_edges_from(list( |
|
|
1166 |
milestone_net.groupby(['from', 'to']).count().index)) |
|
|
1167 |
grouping = [label_map_dict[x] for x in grouping] |
|
|
1168 |
grouping = np.array(grouping) |
|
|
1169 |
G_true.remove_edges_from(nx.selfloop_edges(G_true)) |
|
|
1170 |
nx.set_node_attributes(G_true, False, 'is_init') |
|
|
1171 |
G_true.nodes[begin_node_true]['is_init'] = True |
|
|
1172 |
res = topology(G_true, G_pred) |
|
|
1173 |
|
|
|
1174 |
# 2. Milestones assignment |
|
|
1175 |
if grouping is None: |
|
|
1176 |
milestones_true = milestone_net['from'].values.copy() |
|
|
1177 |
milestones_true[(milestone_net['from']!=milestone_net['to']) |
|
|
1178 |
&(milestone_net['w']<0.5)] = milestone_net[(milestone_net['from']!=milestone_net['to']) |
|
|
1179 |
&(milestone_net['w']<0.5)]['to'].values |
|
|
1180 |
else: |
|
|
1181 |
milestones_true = grouping |
|
|
1182 |
milestones_true = milestones_true |
|
|
1183 |
milestones_pred = np.argmax(w, axis=1) |
|
|
1184 |
res['ARI'] = (adjusted_rand_score(milestones_true, milestones_pred) + 1)/2 |
|
|
1185 |
|
|
|
1186 |
if grouping is None: |
|
|
1187 |
n_samples = len(milestone_net) |
|
|
1188 |
prop = np.zeros((n_samples,n_samples)) |
|
|
1189 |
prop[np.arange(n_samples), milestone_net['to']] = 1-milestone_net['w'] |
|
|
1190 |
prop[np.arange(n_samples), milestone_net['from']] = np.where(np.isnan(milestone_net['w']), 1, milestone_net['w']) |
|
|
1191 |
res['GRI'] = get_GRI(prop, w) |
|
|
1192 |
else: |
|
|
1193 |
res['GRI'] = get_GRI(grouping, w) |
|
|
1194 |
|
|
|
1195 |
# 3. Correlation between geodesic distances / Pseudotime |
|
|
1196 |
if no_loop: |
|
|
1197 |
if grouping is None: |
|
|
1198 |
pseudotime_true = milestone_net['from'].values + 1 - milestone_net['w'].values |
|
|
1199 |
pseudotime_true[np.isnan(pseudotime_true)] = milestone_net[pd.isna(milestone_net['w'])]['from'].values |
|
|
1200 |
else: |
|
|
1201 |
pseudotime_true = - np.ones(len(grouping)) |
|
|
1202 |
nx.set_edge_attributes(G_true, values = 1, name = 'weight') |
|
|
1203 |
connected_comps = nx.node_connected_component(G_true, begin_node_true) |
|
|
1204 |
subG = G_true.subgraph(connected_comps) |
|
|
1205 |
milestone_net_true = self.inferer.build_milestone_net(subG, begin_node_true) |
|
|
1206 |
if len(milestone_net_true)>0: |
|
|
1207 |
pseudotime_true[grouping==int(milestone_net_true[0,0])] = 0 |
|
|
1208 |
for i in range(len(milestone_net_true)): |
|
|
1209 |
pseudotime_true[grouping==int(milestone_net_true[i,1])] = milestone_net_true[i,-1] |
|
|
1210 |
pseudotime_true = pseudotime_true[pseudotime>-1] |
|
|
1211 |
pseudotime_pred = pseudotime[pseudotime>-1] |
|
|
1212 |
res['PDT score'] = (np.corrcoef(pseudotime_true,pseudotime_pred)[0,1]+1)/2 |
|
|
1213 |
else: |
|
|
1214 |
res['PDT score'] = np.nan |
|
|
1215 |
|
|
|
1216 |
# 4. Shape |
|
|
1217 |
# score_cos_theta = 0 |
|
|
1218 |
# for (_from,_to) in G.edges: |
|
|
1219 |
# _z = self.z[(w[:,_from]>0) & (w[:,_to]>0),:] |
|
|
1220 |
# v_1 = _z - self.mu[:,_from] |
|
|
1221 |
# v_2 = _z - self.mu[:,_to] |
|
|
1222 |
# cos_theta = np.sum(v_1*v_2, -1)/(np.linalg.norm(v_1,axis=-1)*np.linalg.norm(v_2,axis=-1)+1e-12) |
|
|
1223 |
|
|
|
1224 |
# score_cos_theta += np.sum((1-cos_theta)/2) |
|
|
1225 |
|
|
|
1226 |
# res['score_cos_theta'] = score_cos_theta/(np.sum(np.sum(w>0, axis=-1)==2)+1e-12) |
|
|
1227 |
return res |
|
|
1228 |
|
|
|
1229 |
|
|
|
1230 |
def save_model(self, path_to_file: str = 'model.checkpoint',save_adata: bool = False): |
|
|
1231 |
'''Saving model weights. |
|
|
1232 |
|
|
|
1233 |
Parameters |
|
|
1234 |
---------- |
|
|
1235 |
path_to_file : str, optional |
|
|
1236 |
The path to weight files of pre-trained or trained model |
|
|
1237 |
save_adata : boolean, optional |
|
|
1238 |
Whether to save adata or not. |
|
|
1239 |
''' |
|
|
1240 |
self.vae.save_weights(path_to_file) |
|
|
1241 |
if hasattr(self, 'labels') and self.labels is not None: |
|
|
1242 |
with open(path_to_file + '.label', 'wb') as f: |
|
|
1243 |
np.save(f, self.labels) |
|
|
1244 |
with open(path_to_file + '.config', 'wb') as f: |
|
|
1245 |
self.dim_origin = self.X_input.shape[1] |
|
|
1246 |
np.save(f, np.array([ |
|
|
1247 |
self.dim_origin, self.dimensions, self.dim_latent, |
|
|
1248 |
self.model_type, 0 if self.covariates is None else self.covariates.shape[1]], dtype=object)) |
|
|
1249 |
if hasattr(self, 'inferer') and hasattr(self, 'uncertainty'): |
|
|
1250 |
with open(path_to_file + '.inference', 'wb') as f: |
|
|
1251 |
np.save(f, np.array([ |
|
|
1252 |
self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty, |
|
|
1253 |
self.z,self.cell_position_variance], dtype=object)) |
|
|
1254 |
if save_adata: |
|
|
1255 |
self.adata.write(path_to_file + '.adata.h5ad') |
|
|
1256 |
|
|
|
1257 |
|
|
|
1258 |
def load_model(self, path_to_file: str = 'model.checkpoint', load_labels: bool = False, load_adata: bool = False): |
|
|
1259 |
'''Load model weights. |
|
|
1260 |
|
|
|
1261 |
Parameters |
|
|
1262 |
---------- |
|
|
1263 |
path_to_file : str, optional |
|
|
1264 |
The path to weight files of pre trained or trained model |
|
|
1265 |
load_labels : boolean, optional |
|
|
1266 |
Whether to load clustering labels or not. |
|
|
1267 |
If load_labels is True, then the LatentSpace layer will be initialized basd on the model. |
|
|
1268 |
If load_labels is False, then the LatentSpace layer will not be initialized. |
|
|
1269 |
load_adata : boolean, optional |
|
|
1270 |
Whether to load adata or not. |
|
|
1271 |
''' |
|
|
1272 |
if not os.path.exists(path_to_file + '.config'): |
|
|
1273 |
raise AssertionError('Config file not exist!') |
|
|
1274 |
if load_labels and not os.path.exists(path_to_file + '.label'): |
|
|
1275 |
raise AssertionError('Label file not exist!') |
|
|
1276 |
|
|
|
1277 |
with open(path_to_file + '.config', 'rb') as f: |
|
|
1278 |
[self.dim_origin, self.dimensions, |
|
|
1279 |
self.dim_latent, self.model_type, cov_dim] = np.load(f, allow_pickle=True) |
|
|
1280 |
self.vae = model.VariationalAutoEncoder( |
|
|
1281 |
self.dim_origin, self.dimensions, |
|
|
1282 |
self.dim_latent, self.model_type, False if cov_dim == 0 else True |
|
|
1283 |
) |
|
|
1284 |
|
|
|
1285 |
if load_labels: |
|
|
1286 |
with open(path_to_file + '.label', 'rb') as f: |
|
|
1287 |
cluster_labels = np.load(f, allow_pickle=True) |
|
|
1288 |
self.init_latent_space(cluster_labels, dist_thres=0) |
|
|
1289 |
if os.path.exists(path_to_file + '.inference'): |
|
|
1290 |
with open(path_to_file + '.inference', 'rb') as f: |
|
|
1291 |
arr = np.load(f, allow_pickle=True) |
|
|
1292 |
if len(arr) == 8: |
|
|
1293 |
[self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty, |
|
|
1294 |
self.D_JS, self.z,self.cell_position_variance] = arr |
|
|
1295 |
else: |
|
|
1296 |
[self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty, |
|
|
1297 |
self.z,self.cell_position_variance] = arr |
|
|
1298 |
self._adata_z = sc.AnnData(self.z) |
|
|
1299 |
sc.pp.neighbors(self._adata_z) |
|
|
1300 |
## initialize the weight of encoder and decoder |
|
|
1301 |
self.vae.encoder(np.zeros((1, self.dim_origin + cov_dim))) |
|
|
1302 |
self.vae.decoder(np.expand_dims(np.zeros((1,self.dim_latent + cov_dim)),1)) |
|
|
1303 |
|
|
|
1304 |
self.vae.load_weights(path_to_file) |
|
|
1305 |
self.update_z() |
|
|
1306 |
if load_adata: |
|
|
1307 |
if not os.path.exists(path_to_file + '.adata.h5ad'): |
|
|
1308 |
raise AssertionError('AnnData file not exist!') |
|
|
1309 |
self.adata = sc.read_h5ad(path_to_file + '.adata.h5ad') |
|
|
1310 |
self._adata.obs = self.adata.obs.copy() |