Diff of /VITAE/VITAE.py [000000] .. [2c6b19]

Switch to unified view

a b/VITAE/VITAE.py
1
from typing import Optional, Union
2
import warnings
3
import os
4
5
import numpy as np
6
import pandas as pd
7
from scipy import stats
8
9
import VITAE.model as model 
10
import VITAE.train as train 
11
from VITAE.inference import Inferer
12
from VITAE.utils import get_igraph, leidenalg_igraph, \
13
   DE_test, _comp_dist, _get_smooth_curve
14
from VITAE.metric import topology, get_GRI
15
import tensorflow as tf
16
17
from sklearn.metrics.cluster import adjusted_rand_score
18
from sklearn.model_selection import train_test_split
19
from sklearn.cluster import AgglomerativeClustering
20
from sklearn.preprocessing import OneHotEncoder, StandardScaler, OrdinalEncoder
21
22
import scanpy as sc
23
import networkx as nx
24
import matplotlib.pyplot as plt
25
import matplotlib.patheffects as pe
26
27
28
class VITAE():
29
    """
30
    Variational Inference for Trajectory by AutoEncoder.
31
    """
32
    def __init__(self, adata: sc.AnnData,
33
               covariates = None, pi_covariates = None,
34
               model_type: str = 'Gaussian',
35
               npc: int = 64,
36
               adata_layer_counts = None,
37
               copy_adata: bool = False,
38
               hidden_layers = [32],
39
               latent_space_dim: int = 16,
40
               conditions = None):
41
        '''
42
        Get input data for model. Data need to be first processed using scancy and stored as an AnnData object
43
         The 'UMI' or 'non-UMI' model need the original count matrix, so the count matrix need to be saved in
44
         adata.layers in order to use these models.
45
46
47
        Parameters
48
        ----------
49
        adata : sc.AnnData
50
            The scanpy AnnData object. adata should already contain adata.var.highly_variable
51
        covariates : list, optional
52
            A list of names of covariate vectors that are stored in adata.obs
53
        pi_covariates: list, optional
54
            A list of names of covariate vectors used as input for pilayer
55
        model_type : str, optional
56
            'UMI', 'non-UMI' and 'Gaussian', default is 'Gaussian'.
57
        npc : int, optional
58
            The number of PCs to use when model_type is 'Gaussian'. The default is 64.
59
        adata_layer_counts: str, optional
60
            the key name of adata.layers that stores the count data if model_type is
61
            'UMI' or 'non-UMI'
62
        copy_adata: bool, optional. Set to True if we don't want VITAE to modify the original adata. If set to True, self.adata will be an independent copy of the original adata. 
63
        hidden_layers : list, optional
64
            The list of dimensions of layers of autoencoder between latent space and original space. Default is to have only one hidden layer with 32 nodes
65
        latent_space_dim : int, optional
66
            The dimension of latent space.
67
        gamme : float, optional
68
            The weight of the MMD loss
69
        conditions : str or list, optional
70
            The conditions of different cells
71
72
73
        Returns
74
        -------
75
        None.
76
77
        '''
78
        self.dict_method_scname = {
79
            'PCA' : 'X_pca',
80
            'UMAP' : 'X_umap',
81
            'TSNE' : 'X_tsne',
82
            'diffmap' : 'X_diffmap',
83
            'draw_graph' : 'X_draw_graph_fa'
84
        }
85
86
        if model_type != 'Gaussian':
87
            if adata_layer_counts is None:
88
                raise ValueError("need to provide the name in adata.layers that stores the raw count data")
89
            if 'highly_variable' not in adata.var:
90
                raise ValueError("need to first select highly variable genes using scanpy")
91
92
        self.model_type = model_type
93
94
        if copy_adata:
95
            self.adata = adata.copy()
96
        else:
97
            self.adata = adata
98
99
        if covariates is not None:
100
            if isinstance(covariates, str):
101
                covariates = [covariates]
102
            covariates = np.array(covariates)
103
            id_cat = (adata.obs[covariates].dtypes == 'category')
104
            # add OneHotEncoder & StandardScaler as class variable if needed
105
            if np.sum(id_cat)>0:
106
                covariates_cat = OneHotEncoder(drop='if_binary', handle_unknown='ignore'
107
                    ).fit_transform(adata.obs[covariates[id_cat]]).toarray()
108
            else:
109
                covariates_cat = np.array([]).reshape(adata.shape[0],0)
110
111
            # temporarily disable StandardScaler
112
            if np.sum(~id_cat)>0:
113
                #covariates_con = StandardScaler().fit_transform(adata.obs[covariates[~id_cat]])
114
                covariates_con = adata.obs[covariates[~id_cat]]
115
            else:
116
                covariates_con = np.array([]).reshape(adata.shape[0],0)
117
118
            self.covariates = np.c_[covariates_cat, covariates_con].astype(tf.keras.backend.floatx())
119
        else:
120
            self.covariates = None
121
122
        if conditions is not None:
123
            ## observations with np.nan will not participant in calculating mmd_loss
124
            if isinstance(conditions, str):
125
                conditions = [conditions]
126
            conditions = np.array(conditions)
127
            if np.any(adata.obs[conditions].dtypes != 'category'):
128
                raise ValueError("Conditions should all be categorical.")
129
130
            self.conditions = OrdinalEncoder(dtype=int, encoded_missing_value=-1).fit_transform(adata.obs[conditions]) + int(1)
131
        else:
132
            self.conditions = None
133
134
        if pi_covariates is not None:
135
            self.pi_cov = adata.obs[pi_covariates].to_numpy()
136
            if self.pi_cov.ndim == 1:
137
                self.pi_cov = self.pi_cov.reshape(-1, 1)
138
                self.pi_cov = self.pi_cov.astype(tf.keras.backend.floatx())
139
        else:
140
            self.pi_cov = np.zeros((adata.shape[0],1), dtype=tf.keras.backend.floatx())
141
            
142
        self.model_type = model_type
143
        self._adata = sc.AnnData(X = self.adata.X, var = self.adata.var)
144
        self._adata.obs = self.adata.obs
145
        self._adata.uns = self.adata.uns
146
147
148
        if model_type == 'Gaussian':
149
            sc.tl.pca(adata, n_comps = npc)
150
            self.X_input = self.X_output = adata.obsm['X_pca']
151
            self.scale_factor = np.ones(self.X_output.shape[0])
152
        else:
153
            print(f"{adata.var.highly_variable.sum()} highly variable genes selected as input") 
154
            self.X_input = adata.X[:, adata.var.highly_variable]
155
            self.X_output = adata.layers[adata_layer_counts][ :, adata.var.highly_variable]
156
            self.scale_factor = np.sum(self.X_output, axis=1, keepdims=True)/1e4
157
158
        self.dimensions = hidden_layers
159
        self.dim_latent = latent_space_dim
160
161
        self.vae = model.VariationalAutoEncoder(
162
            self.X_output.shape[1], self.dimensions,
163
            self.dim_latent, self.model_type,
164
            False if self.covariates is None else True,
165
            )
166
167
        if hasattr(self, 'inferer'):
168
            delattr(self, 'inferer')
169
        
170
171
    def pre_train(self, test_size = 0.1, random_state: int = 0,
172
            learning_rate: float = 1e-3, batch_size: int = 256, L: int = 1, alpha: float = 0.10, gamma: float = 0,
173
            phi : float = 1,num_epoch: int = 200, num_step_per_epoch: Optional[int] = None,
174
            early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, 
175
            early_stopping_relative: bool = True, verbose: bool = False,path_to_weights: Optional[str] = None):
176
        '''Pretrain the model with specified learning rate.
177
178
        Parameters
179
        ----------
180
        test_size : float or int, optional
181
            The proportion or size of the test set.
182
        random_state : int, optional
183
            The random state for data splitting.
184
        learning_rate : float, optional
185
            The initial learning rate for the Adam optimizer.
186
        batch_size : int, optional 
187
            The batch size for pre-training.  Default is 256. Set to 32 if number of cells is small (less than 1000)
188
        L : int, optional 
189
            The number of MC samples.
190
        alpha : float, optional
191
            The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.
192
        gamma : float, optional
193
            The weight of the mmd loss if used.
194
        phi : float, optional
195
            The weight of Jocob norm of the encoder.
196
        num_epoch : int, optional 
197
            The maximum number of epochs.
198
        num_step_per_epoch : int, optional 
199
            The number of step per epoch, it will be inferred from number of cells and batch size if it is None.            
200
        early_stopping_patience : int, optional 
201
            The maximum number of epochs if there is no improvement.
202
        early_stopping_tolerance : float, optional 
203
            The minimum change of loss to be considered as an improvement.
204
        early_stopping_relative : bool, optional
205
            Whether monitor the relative change of loss as stopping criteria or not.
206
        path_to_weights : str, optional 
207
            The path of weight file to be saved; not saving weight if None.
208
        conditions : str or list, optional
209
            The conditions of different cells
210
        '''
211
212
        id_train, id_test = train_test_split(
213
                                np.arange(self.X_input.shape[0]), 
214
                                test_size=test_size, 
215
                                random_state=random_state)
216
        if num_step_per_epoch is None:
217
            num_step_per_epoch = len(id_train)//batch_size+1
218
        self.train_dataset = train.warp_dataset(self.X_input[id_train].astype(tf.keras.backend.floatx()), 
219
                                                None if self.covariates is None else self.covariates[id_train].astype(tf.keras.backend.floatx()),
220
                                                batch_size, 
221
                                                self.X_output[id_train].astype(tf.keras.backend.floatx()), 
222
                                                self.scale_factor[id_train].astype(tf.keras.backend.floatx()),
223
                                                conditions = None if self.conditions is None else self.conditions[id_train].astype(tf.keras.backend.floatx()))
224
        self.test_dataset = train.warp_dataset(self.X_input[id_test], 
225
                                                None if self.covariates is None else self.covariates[id_test].astype(tf.keras.backend.floatx()),
226
                                                batch_size, 
227
                                                self.X_output[id_test].astype(tf.keras.backend.floatx()), 
228
                                                self.scale_factor[id_test].astype(tf.keras.backend.floatx()),
229
                                                conditions = None if self.conditions is None else self.conditions[id_test].astype(tf.keras.backend.floatx()))
230
231
        self.vae = train.pre_train(
232
            self.train_dataset,
233
            self.test_dataset,
234
            self.vae,
235
            learning_rate,                        
236
            L, alpha, gamma, phi,
237
            num_epoch,
238
            num_step_per_epoch,
239
            early_stopping_patience,
240
            early_stopping_tolerance,
241
            early_stopping_relative,
242
            verbose)
243
        
244
        self.update_z()
245
246
        if path_to_weights is not None:
247
            self.save_model(path_to_weights)
248
            
249
250
    def update_z(self):
251
        self.z = self.get_latent_z()        
252
        self._adata_z = sc.AnnData(self.z)
253
        sc.pp.neighbors(self._adata_z)
254
255
            
256
    def get_latent_z(self):
257
        ''' get the posterier mean of current latent space z (encoder output)
258
259
        Returns
260
        ----------
261
        z : np.array
262
            \([N,d]\) The latent means.
263
        ''' 
264
        c = None if self.covariates is None else self.covariates
265
        return self.vae.get_z(self.X_input, c)
266
            
267
    
268
    def visualize_latent(self, method: str = "UMAP", 
269
                         color = None, **kwargs):
270
        '''
271
        visualize the current latent space z using the scanpy visualization tools
272
273
        Parameters
274
        ----------
275
        method : str, optional
276
            Visualization method to use. The default is "draw_graph" (the FA plot). Possible choices include "PCA", "UMAP", 
277
            "diffmap", "TSNE" and "draw_graph"
278
        color : TYPE, optional
279
            Keys for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2'].
280
            The default is None. Same as scanpy.
281
        **kwargs :  
282
            Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX).   
283
284
        Returns
285
        -------
286
        None.
287
288
        '''
289
          
290
        if method not in ['PCA', 'UMAP', 'TSNE', 'diffmap', 'draw_graph']:
291
            raise ValueError("visualization method should be one of 'PCA', 'UMAP', 'TSNE', 'diffmap' and 'draw_graph'")
292
        
293
        temp = list(self._adata_z.obsm.keys())
294
        if method == 'PCA' and not 'X_pca' in temp:
295
            print("Calculate PCs ...")
296
            sc.tl.pca(self._adata_z)
297
        elif method == 'UMAP' and not 'X_umap' in temp:  
298
            print("Calculate UMAP ...")
299
            sc.tl.umap(self._adata_z)
300
        elif method == 'TSNE' and not 'X_tsne' in temp:
301
            print("Calculate TSNE ...")
302
            sc.tl.tsne(self._adata_z)
303
        elif method == 'diffmap' and not 'X_diffmap' in temp:
304
            print("Calculate diffusion map ...")
305
            sc.tl.diffmap(self._adata_z)
306
        elif method == 'draw_graph' and not 'X_draw_graph_fa' in temp:
307
            print("Calculate FA ...")
308
            sc.tl.draw_graph(self._adata_z)
309
            
310
311
        self._adata.obs = self.adata.obs.copy()
312
        self._adata.obsp = self._adata_z.obsp
313
#        self._adata.uns = self._adata_z.uns
314
        self._adata.obsm = self._adata_z.obsm
315
    
316
        if method == 'PCA':
317
            axes = sc.pl.pca(self._adata, color = color, **kwargs)
318
        elif method == 'UMAP':            
319
            axes = sc.pl.umap(self._adata, color = color, **kwargs)
320
        elif method == 'TSNE':
321
            axes = sc.pl.tsne(self._adata, color = color, **kwargs)
322
        elif method == 'diffmap':
323
            axes = sc.pl.diffmap(self._adata, color = color, **kwargs)
324
        elif method == 'draw_graph':
325
            axes = sc.pl.draw_graph(self._adata, color = color, **kwargs)
326
        return axes
327
328
329
    def init_latent_space(self, cluster_label = None, log_pi = None, res: float = 1.0, 
330
                          ratio_prune= None, dist = None, dist_thres = 0.5, topk=0, pilayer = False):
331
        '''Initialize the latent space.
332
333
        Parameters
334
        ----------
335
        cluster_label : str, optional
336
            The name of vector of labels that can be found in self.adata.obs. 
337
            Default is None, which will perform leiden clustering on the pretrained z to get clusters
338
        mu : np.array, optional
339
            \([d,k]\) The value of initial \(\\mu\).
340
        log_pi : np.array, optional
341
            \([1,K]\) The value of initial \(\\log(\\pi)\).
342
        res: 
343
            The resolution of leiden clustering, which is a parameter value controlling the coarseness of the clustering. 
344
            Higher values lead to more clusters. Deafult is 1.
345
        ratio_prune : float, optional
346
            The ratio of edges to be removed before estimating.
347
        topk : int, optional
348
            The number of top k neighbors to keep for each cluster.
349
        '''   
350
    
351
        
352
        if cluster_label is None:
353
            print("Perform leiden clustering on the latent space z ...")
354
            g = get_igraph(self.z)
355
            cluster_labels = leidenalg_igraph(g, res = res)
356
            cluster_labels = cluster_labels.astype(str) 
357
            uni_cluster_labels = np.unique(cluster_labels)
358
        else:
359
            if isinstance(cluster_label,str):
360
                cluster_labels = self.adata.obs[cluster_label].to_numpy()
361
                uni_cluster_labels = np.array(self.adata.obs[cluster_label].cat.categories)
362
            else:
363
                ## if cluster_label is a list
364
                cluster_labels = cluster_label
365
                uni_cluster_labels = np.unique(cluster_labels)
366
367
        n_clusters = len(uni_cluster_labels)
368
369
        if not hasattr(self, 'z'):
370
            self.update_z()        
371
        z = self.z
372
        mu = np.zeros((z.shape[1], n_clusters))
373
        for i,l in enumerate(uni_cluster_labels):
374
            mu[:,i] = np.mean(z[cluster_labels==l], axis=0)
375
       
376
        if dist is None:
377
            ### update cluster centers if some cluster centers are too close
378
            clustering = AgglomerativeClustering(
379
                n_clusters=None,
380
                distance_threshold=dist_thres,
381
                linkage='complete'
382
                ).fit(mu.T/np.sqrt(mu.shape[0]))
383
            n_clusters_new = clustering.n_clusters_
384
            if n_clusters_new < n_clusters:
385
                print("Merge clusters for cluster centers that are too close ...")
386
                n_clusters = n_clusters_new
387
                for i in range(n_clusters):    
388
                    temp = uni_cluster_labels[clustering.labels_ == i]
389
                    idx = np.isin(cluster_labels, temp)
390
                    cluster_labels[idx] = ','.join(temp)
391
                    if np.sum(clustering.labels_==i)>1:
392
                        print('Merge %s'% ','.join(temp))
393
                uni_cluster_labels = np.unique(cluster_labels)
394
                mu = np.zeros((z.shape[1], n_clusters))
395
                for i,l in enumerate(uni_cluster_labels):
396
                    mu[:,i] = np.mean(z[cluster_labels==l], axis=0)
397
            
398
        self.adata.obs['vitae_init_clustering'] = cluster_labels
399
        self.adata.obs['vitae_init_clustering'] = self.adata.obs['vitae_init_clustering'].astype('category')
400
        print("Initial clustering labels saved as 'vitae_init_clustering' in self.adata.obs.")
401
   
402
        if (log_pi is None) and (cluster_labels is not None) and (n_clusters>3):                         
403
            n_states = int((n_clusters+1)*n_clusters/2)
404
            
405
            if dist is None:
406
                dist = _comp_dist(z, cluster_labels, mu.T)
407
408
            C = np.triu(np.ones(n_clusters))
409
            C[C>0] = np.arange(n_states)
410
            C = C + C.T - np.diag(np.diag(C))
411
            C = C.astype(int)
412
413
            log_pi = np.zeros((1,n_states))            
414
415
            ## pruning to throw away edges for far-away clusters if there are too many clusters
416
            if ratio_prune is not None:
417
                log_pi[0, C[np.triu(dist)>np.quantile(dist[np.triu_indices(n_clusters, 1)], 1-ratio_prune)]] = - np.inf
418
            else:
419
                log_pi[0, C[np.triu(dist)>np.quantile(dist[np.triu_indices(n_clusters, 1)], 5/n_clusters) * 3]] = - np.inf
420
421
            ## also keep the top k neighbor of clusters
422
            topk = max(0, min(topk, n_clusters-1)) + 1
423
            topk_indices = np.argsort(dist,axis=1)[:,:topk]
424
            for i in range(n_clusters):
425
                log_pi[0, C[i, topk_indices[i]]] = 0
426
427
        self.n_states = n_clusters
428
        self.labels = cluster_labels
429
        
430
        labels_map = pd.DataFrame.from_dict(
431
            {i:label for i,label in enumerate(uni_cluster_labels)}, 
432
            orient='index', columns=['label_names'], dtype=str
433
            )
434
        
435
        self.labels_map = labels_map
436
        self.vae.init_latent_space(self.n_states, mu, log_pi)
437
        self.inferer = Inferer(self.n_states)
438
        self.mu = self.vae.latent_space.mu.numpy()
439
        self.pi = np.triu(np.ones(self.n_states))
440
        self.pi[self.pi > 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0]
441
442
        if pilayer:
443
            self.vae.create_pilayer()
444
445
446
    def update_latent_space(self, dist_thres: float=0.5):
447
        pi = self.pi[np.triu_indices(self.n_states)]
448
        mu = self.mu    
449
        clustering = AgglomerativeClustering(
450
            n_clusters=None,
451
            distance_threshold=dist_thres,
452
            linkage='complete'
453
            ).fit(mu.T/np.sqrt(mu.shape[0]))
454
        n_clusters = clustering.n_clusters_   
455
456
        if n_clusters<self.n_states:      
457
            print("Merge clusters for cluster centers that are too close ...")
458
            mu_new = np.empty((self.dim_latent, n_clusters))
459
            C = np.zeros((self.n_states, self.n_states))
460
            C[np.triu_indices(self.n_states, 0)] = pi
461
            C = np.triu(C, 1) + C.T
462
            C_new = np.zeros((n_clusters, n_clusters))
463
            
464
            uni_cluster_labels = self.labels_map['label_names'].to_numpy()
465
            returned_order = {}
466
            cluster_labels = self.labels
467
            for i in range(n_clusters):
468
                temp = uni_cluster_labels[clustering.labels_ == i]
469
                idx = np.isin(cluster_labels, temp)
470
                cluster_labels[idx] = ','.join(temp)
471
                returned_order[i] = ','.join(temp)
472
                if np.sum(clustering.labels_==i)>1:
473
                    print('Merge %s'% ','.join(temp))
474
            uni_cluster_labels = np.unique(cluster_labels) 
475
            for i,l in enumerate(uni_cluster_labels):  ## reorder the merged clusters based on the cluster names
476
                k = np.where(returned_order == l)
477
                mu_new[:, i] = np.mean(mu[:,clustering.labels_==k], axis=-1)
478
                # sum of the aggregated pi's
479
                C_new[i, i] = np.sum(np.triu(C[clustering.labels_==k,:][:,clustering.labels_==k]))
480
                for j in range(i+1, n_clusters):
481
                    k1 = np.where(returned_order == uni_cluster_labels[j])
482
                    C_new[i, j] = np.sum(C[clustering.labels_== k, :][:, clustering.labels_==k1])
483
484
#            labels_map_new = {}
485
#            for i in range(n_clusters):                       
486
#                # update label map: int->str
487
#                labels_map_new[i] = self.labels_map.loc[clustering.labels_==i, 'label_names'].str.cat(sep=',')
488
#                if np.sum(clustering.labels_==i)>1:
489
#                    print('Merge %s'%labels_map_new[i])
490
#                # mean of the aggregated cluster means
491
#                mu_new[:, i] = np.mean(mu[:,clustering.labels_==i], axis=-1)
492
#                # sum of the aggregated pi's
493
#                C_new[i, i] = np.sum(np.triu(C[clustering.labels_==i,:][:,clustering.labels_==i]))
494
#                for j in range(i+1, n_clusters):
495
#                    C_new[i, j] = np.sum(C[clustering.labels_== i, :][:, clustering.labels_==j])
496
            C_new = np.triu(C_new,1) + C_new.T
497
498
            pi_new = C_new[np.triu_indices(n_clusters)]
499
            log_pi_new = np.log(pi_new, out=np.ones_like(pi_new)*(-np.inf), where=(pi_new!=0)).reshape((1,-1))
500
            self.n_states = n_clusters
501
            self.labels_map = pd.DataFrame.from_dict(
502
                {i:label for i,label in enumerate(uni_cluster_labels)},
503
                orient='index', columns=['label_names'], dtype=str
504
                )
505
            self.labels = cluster_labels
506
#            self.labels_map = pd.DataFrame.from_dict(
507
#                labels_map_new, orient='index', columns=['label_names'], dtype=str
508
#            )
509
            self.vae.init_latent_space(self.n_states, mu_new, log_pi_new)
510
            self.inferer = Inferer(self.n_states)
511
            self.mu = self.vae.latent_space.mu.numpy()
512
            self.pi = np.triu(np.ones(self.n_states))
513
            self.pi[self.pi > 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0]
514
515
516
517
    def train(self, stratify = False, test_size = 0.1, random_state: int = 0,
518
            learning_rate: float = 1e-3, batch_size: int = 256,
519
            L: int = 1, alpha: float = 0.10, beta: float = 1, gamma: float = 0, phi: float = 1,
520
            num_epoch: int = 200, num_step_per_epoch: Optional[int] =  None,
521
            early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, 
522
            early_stopping_relative: bool = True, early_stopping_warmup: int = 0,
523
            path_to_weights: Optional[str] = None,
524
            verbose: bool = False, **kwargs):
525
        '''Train the model.
526
527
        Parameters
528
        ----------
529
        stratify : np.array, None, or False
530
            If an array is provided, or `stratify=None` and `self.labels` is available, then they will be used to perform stratified shuffle splitting. Otherwise, general shuffle splitting is used. Set to `False` if `self.labels` is not intended for stratified shuffle splitting.
531
        test_size : float or int, optional
532
            The proportion or size of the test set.
533
        random_state : int, optional
534
            The random state for data splitting.
535
        learning_rate : float, optional  
536
            The initial learning rate for the Adam optimizer.
537
        batch_size : int, optional  
538
            The batch size for training. Default is 256. Set to 32 if number of cells is small (less than 1000)
539
        L : int, optional  
540
            The number of MC samples.
541
        alpha : float, optional  
542
            The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.
543
        beta : float, optional  
544
            The value of beta in beta-VAE.
545
        gamma : float, optional
546
            The weight of mmd_loss.
547
        phi : float, optional
548
            The weight of Jacob norm of encoder.
549
        num_epoch : int, optional  
550
            The number of epoch.
551
        num_step_per_epoch : int, optional 
552
            The number of step per epoch, it will be inferred from number of cells and batch size if it is None.
553
        early_stopping_patience : int, optional 
554
            The maximum number of epochs if there is no improvement.
555
        early_stopping_tolerance : float, optional 
556
            The minimum change of loss to be considered as an improvement.
557
        early_stopping_relative : bool, optional
558
            Whether monitor the relative change of loss or not.            
559
        early_stopping_warmup : int, optional 
560
            The number of warmup epochs.            
561
        path_to_weights : str, optional 
562
            The path of weight file to be saved; not saving weight if None.
563
        **kwargs :  
564
            Extra key-value arguments for dimension reduction algorithms.        
565
        '''
566
        if gamma == 0 or self.conditions is None:
567
            conditions = np.array([np.nan] * self.adata.shape[0])
568
        else:
569
            conditions = self.conditions
570
571
        if stratify is None:
572
            stratify = self.labels
573
        elif stratify is False:
574
            stratify = None    
575
        id_train, id_test = train_test_split(
576
                                np.arange(self.X_input.shape[0]), 
577
                                test_size=test_size, 
578
                                stratify=stratify, 
579
                                random_state=random_state)
580
        if num_step_per_epoch is None:
581
            num_step_per_epoch = len(id_train)//batch_size+1
582
        c = None if self.covariates is None else self.covariates.astype(tf.keras.backend.floatx())
583
        self.train_dataset = train.warp_dataset(self.X_input[id_train].astype(tf.keras.backend.floatx()),
584
                                                None if c is None else c[id_train],
585
                                                batch_size, 
586
                                                self.X_output[id_train].astype(tf.keras.backend.floatx()), 
587
                                                self.scale_factor[id_train].astype(tf.keras.backend.floatx()),
588
                                                conditions = conditions[id_train],
589
                                                pi_cov = self.pi_cov[id_train])
590
        self.test_dataset = train.warp_dataset(self.X_input[id_test].astype(tf.keras.backend.floatx()),
591
                                                None if c is None else c[id_test],
592
                                                batch_size, 
593
                                                self.X_output[id_test].astype(tf.keras.backend.floatx()), 
594
                                                self.scale_factor[id_test].astype(tf.keras.backend.floatx()),
595
                                                conditions = conditions[id_test],
596
                                                pi_cov = self.pi_cov[id_test])
597
                                   
598
        self.vae = train.train(
599
            self.train_dataset,
600
            self.test_dataset,
601
            self.vae,
602
            learning_rate,
603
            L,
604
            alpha,
605
            beta,
606
            gamma,
607
            phi,
608
            num_epoch,
609
            num_step_per_epoch,
610
            early_stopping_patience,
611
            early_stopping_tolerance,
612
            early_stopping_relative,
613
            early_stopping_warmup,  
614
            verbose,
615
            **kwargs            
616
            )
617
        
618
        self.update_z()
619
        self.mu = self.vae.latent_space.mu.numpy()
620
        self.pi = np.triu(np.ones(self.n_states))
621
        self.pi[self.pi > 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0]
622
            
623
        if path_to_weights is not None:
624
            self.save_model(path_to_weights)
625
    
626
627
    def output_pi(self, pi_cov):
628
        """return a matrix n_states by n_states and a mask for plotting, which can be used to cover the lower triangular(except the diagnoals) of a heatmap"""
629
        p = self.vae.pilayer
630
        pi_cov = tf.expand_dims(tf.constant([pi_cov], dtype=tf.float32), 0)
631
        pi_val = tf.nn.softmax(p(pi_cov)).numpy()[0]
632
        # Create heatmap matrix
633
        n = self.vae.n_states
634
        matrix = np.zeros((n, n))
635
        matrix[np.triu_indices(n)] = pi_val
636
        mask = np.tril(np.ones_like(matrix), k=-1)
637
        return matrix, mask
638
639
640
    def return_pilayer_weights(self):
641
        """return parameters of pilayer, which has dimension dim(pi_cov) + 1 by n_categories, the last row is biases"""
642
        return np.vstack((model.vae.pilayer.weights[0].numpy(), model.vae.pilayer.weights[1].numpy().reshape(1, -1)))
643
644
645
    def posterior_estimation(self, batch_size: int = 32, L: int = 50, **kwargs):
646
        '''Initialize trajectory inference by computing the posterior estimations.        
647
648
        Parameters
649
        ----------
650
        batch_size : int, optional
651
            The batch size when doing inference.
652
        L : int, optional
653
            The number of MC samples when doing inference.
654
        **kwargs :  
655
            Extra key-value arguments for dimension reduction algorithms.              
656
        '''
657
        c = None if self.covariates is None else self.covariates.astype(tf.keras.backend.floatx())
658
        self.test_dataset = train.warp_dataset(self.X_input.astype(tf.keras.backend.floatx()), 
659
                                               c,
660
                                               batch_size)
661
        _, _, self.pc_x,\
662
            self.cell_position_posterior,self.cell_position_variance,_ = self.vae.inference(self.test_dataset, L=L)
663
            
664
        uni_cluster_labels = self.labels_map['label_names'].to_numpy()
665
        self.adata.obs['vitae_new_clustering'] = uni_cluster_labels[np.argmax(self.cell_position_posterior, 1)]
666
        self.adata.obs['vitae_new_clustering'] = self.adata.obs['vitae_new_clustering'].astype('category')
667
        print("New clustering labels saved as 'vitae_new_clustering' in self.adata.obs.")
668
        return None
669
670
671
    def infer_backbone(self, method: str = 'modified_map', thres = 0.5,
672
            no_loop: bool = True, cutoff: float = 0,
673
            visualize: bool = True, color = 'vitae_new_clustering',path_to_fig = None,**kwargs):
674
        ''' Compute edge scores.
675
676
        Parameters
677
        ----------
678
        method : string, optional
679
            'mean', 'modified_mean', 'map', or 'modified_map'.
680
        thres : float, optional
681
            The threshold used for filtering edges \(e_{ij}\) that \((n_{i}+n_{j}+e_{ij})/N<thres\), only applied to mean method.
682
        no_loop : boolean, optional
683
            Whether loops are allowed to exist in the graph. If no_loop is true, will prune the graph to contain only the
684
            maximum spanning true
685
        cutoff : string, optional
686
            The score threshold for filtering edges with scores less than cutoff.
687
        visualize: boolean
688
            whether plot the current trajectory backbone (undirected graph)
689
690
        Returns
691
        ----------
692
        G : nx.Graph
693
            The weighted graph with weight on each edge indicating its score of existence.
694
        '''
695
        # build_graph, return graph
696
        self.backbone = self.inferer.build_graphs(self.cell_position_posterior, self.pc_x,
697
                method, thres, no_loop, cutoff)
698
        self.cell_position_projected = self.inferer.modify_wtilde(self.cell_position_posterior, 
699
                np.array(list(self.backbone.edges)))
700
        
701
        uni_cluster_labels = self.labels_map['label_names'].to_numpy()
702
        temp_dict = {i:label for i,label in enumerate(uni_cluster_labels)}
703
        nx.relabel_nodes(self.backbone, temp_dict)
704
       
705
        self.adata.obs['vitae_new_clustering'] = uni_cluster_labels[np.argmax(self.cell_position_projected, 1)]
706
        self.adata.obs['vitae_new_clustering'] = self.adata.obs['vitae_new_clustering'].astype('category')
707
        print("'vitae_new_clustering' updated based on the projected cell positions.")
708
709
        self.uncertainty = np.sum((self.cell_position_projected - self.cell_position_posterior)**2, axis=-1) \
710
            + np.sum(self.cell_position_variance, axis=-1)
711
        self.adata.obs['projection_uncertainty'] = self.uncertainty
712
        print("Cell projection uncertainties stored as 'projection_uncertainty' in self.adata.obs")
713
        if visualize:
714
            self._adata.obs = self.adata.obs.copy()
715
            self.ax = self.plot_backbone(directed = False,color = color, **kwargs)
716
            if path_to_fig is not None:
717
                self.ax.figure.savefig(path_to_fig)
718
            self.ax.figure.show()
719
        return None
720
721
722
    def select_root(self, days, method: str = 'proportion'):
723
        '''Order the vertices/states based on cells' collection time information to select the root state.      
724
725
        Parameters
726
        ----------
727
        day : np.array 
728
            The day information for selected cells used to determine the root vertex.
729
            The dtype should be 'int' or 'float'.
730
        method : str, optional
731
            'sum' or 'mean'. 
732
            For 'proportion', the root is the one with maximal proportion of cells from the earliest day.
733
            For 'mean', the root is the one with earliest mean time among cells associated with it.
734
735
        Returns
736
        ----------
737
        root : int 
738
            The root vertex in the inferred trajectory based on given day information.
739
        '''
740
        ## TODO: change return description
741
        if days is not None and len(days)!=self.X_input.shape[0]:
742
            raise ValueError("The length of day information ({}) is not "
743
                "consistent with the number of selected cells ({})!".format(
744
                    len(days), self.X_input.shape[0]))
745
        if not hasattr(self, 'cell_position_projected'):
746
            raise ValueError("Need to call 'infer_backbone' first!")
747
748
        collection_time = np.dot(days, self.cell_position_projected)/np.sum(self.cell_position_projected, axis = 0)
749
        earliest_prop = np.dot(days==np.min(days), self.cell_position_projected)/np.sum(self.cell_position_projected, axis = 0)
750
        
751
        root_info = self.labels_map.copy()
752
        root_info['mean_collection_time'] = collection_time
753
        root_info['earliest_time_prop'] = earliest_prop
754
        root_info.sort_values('mean_collection_time', inplace=True)
755
        return root_info
756
757
758
    def plot_backbone(self, directed: bool = False, 
759
                      method: str = 'UMAP', color = 'vitae_new_clustering', **kwargs):
760
        '''Plot the current trajectory backbone (undirected graph).
761
762
        Parameters
763
        ----------
764
        directed : boolean, optional
765
            Whether the backbone is directed or not.
766
        method : str, optional
767
            The dimension reduction method to use. The default is "UMAP".
768
        color : str, optional
769
            The key for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2'].
770
            The default is 'vitae_new_clustering'.
771
        **kwargs :
772
            Extra key-value arguments that can be passed to scanpy plotting functions (scanpy.pl.XX).
773
        '''
774
        if not isinstance(color,str):
775
            raise ValueError('The color argument should be of type str!')
776
        ax = self.visualize_latent(method = method, color=color, show=False, **kwargs)
777
        dict_label_num = {j:i for i,j in self.labels_map['label_names'].to_dict().items()}
778
        uni_cluster_labels = self.adata.obs['vitae_init_clustering'].cat.categories
779
        cluster_labels = self.adata.obs['vitae_new_clustering'].to_numpy()
780
        embed_z = self._adata.obsm[self.dict_method_scname[method]]
781
        embed_mu = np.zeros((len(uni_cluster_labels), 2))
782
        for l in uni_cluster_labels:
783
            embed_mu[dict_label_num[l],:] = np.mean(embed_z[cluster_labels==l], axis=0)
784
785
        if directed:
786
            graph = self.directed_backbone
787
        else:
788
            graph = self.backbone
789
        edges = list(graph.edges)
790
        edge_scores = np.array([d['weight'] for (u,v,d) in graph.edges(data=True)])
791
        if max(edge_scores) - min(edge_scores) == 0:
792
            edge_scores = edge_scores/max(edge_scores)
793
        else:
794
            edge_scores = (edge_scores - min(edge_scores))/(max(edge_scores) - min(edge_scores))*3
795
796
        value_range = np.maximum(np.diff(ax.get_xlim())[0], np.diff(ax.get_ylim())[0])
797
        y_range = np.min(embed_z[:,1]), np.max(embed_z[:,1], axis=0)
798
        for i in range(len(edges)):
799
            points = embed_z[np.sum(self.cell_position_projected[:, edges[i]]>0, axis=-1)==2,:]
800
            points = points[points[:,0].argsort()]
801
            try:
802
                x_smooth, y_smooth = _get_smooth_curve(
803
                    points,
804
                    embed_mu[edges[i], :],
805
                    y_range
806
                    )
807
            except:
808
                x_smooth, y_smooth = embed_mu[edges[i], 0], embed_mu[edges[i], 1]
809
            ax.plot(x_smooth, y_smooth,
810
                '-',
811
                linewidth= 1 + edge_scores[i],
812
                color="black",
813
                alpha=0.8,
814
                path_effects=[pe.Stroke(linewidth=1+edge_scores[i]+1.5,
815
                                        foreground='white'), pe.Normal()],
816
                zorder=1
817
                )
818
819
            if directed:
820
                delta_x = embed_mu[edges[i][1], 0] - x_smooth[-2]
821
                delta_y = embed_mu[edges[i][1], 1] - y_smooth[-2]
822
                length = np.sqrt(delta_x**2 + delta_y**2) / 50 * value_range
823
                ax.arrow(
824
                        embed_mu[edges[i][1], 0]-delta_x/length,
825
                        embed_mu[edges[i][1], 1]-delta_y/length,
826
                        delta_x/length,
827
                        delta_y/length,
828
                        color='black', alpha=1.0,
829
                        shape='full', lw=0, length_includes_head=True,
830
                        head_width=np.maximum(0.01*(1 + edge_scores[i]), 0.03) * value_range,
831
                        zorder=2) 
832
        
833
        colors = self._adata.uns['vitae_new_clustering_colors']
834
            
835
        for i,l in enumerate(uni_cluster_labels):
836
            ax.scatter(*embed_mu[dict_label_num[l]:dict_label_num[l]+1,:].T, 
837
                       c=[colors[i]], edgecolors='white', # linewidths=10,  norm=norm,
838
                       s=250, marker='*', label=l)
839
840
        plt.setp(ax, xticks=[], yticks=[])
841
        box = ax.get_position()
842
        ax.set_position([box.x0, box.y0 + box.height * 0.1,
843
                            box.width, box.height * 0.9])
844
        if directed:
845
            ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
846
                fancybox=True, shadow=True, ncol=5)
847
848
        return ax
849
850
851
    def plot_center(self, color = "vitae_new_clustering", plot_legend = True, legend_add_index = True,
852
                    method: str = 'UMAP',ncol = 2,font_size = "medium",
853
                    add_egde = False, add_direct = False,**kwargs):
854
        '''Plot the center of each cluster in the latent space.
855
856
        Parameters
857
        ----------
858
        color : str, optional
859
            The color of the center of each cluster. Default is "vitae_new_clustering".
860
        plot_legend : bool, optional
861
            Whether to plot the legend. Default is True.
862
        legend_add_index : bool, optional
863
            Whether to add the index of each cluster in the legend. Default is True.
864
        method : str, optional
865
            The dimension reduction method used for visualization. Default is 'UMAP'.
866
        ncol : int, optional
867
            The number of columns in the legend. Default is 2.
868
        font_size : str, optional
869
            The font size of the legend. Default is "medium".
870
        add_egde : bool, optional
871
            Whether to add the edges between the centers of clusters. Default is False.
872
        add_direct : bool, optional
873
            Whether to add the direction of the edges. Default is False.
874
        '''
875
        if color not in ["vitae_new_clustering","vitae_init_clustering"]:
876
            raise ValueError("Can only plot center of vitae_new_clustering or vitae_init_clustering")
877
        dict_label_num = {j: i for i, j in self.labels_map['label_names'].to_dict().items()}
878
        if legend_add_index:
879
            self._adata.obs["index_"+color] = self._adata.obs[color].map(lambda x: dict_label_num[x])
880
            ax = self.visualize_latent(method=method, color="index_" + color, show=False, legend_loc="on data",
881
                                        legend_fontsize=font_size,**kwargs)
882
            colors = self._adata.uns["index_" + color + '_colors']
883
        else:
884
            ax = self.visualize_latent(method=method, color = color, show=False,**kwargs)
885
            colors = self._adata.uns[color + '_colors']
886
        uni_cluster_labels = self.adata.obs[color].cat.categories
887
        cluster_labels = self.adata.obs[color].to_numpy()
888
        embed_z = self._adata.obsm[self.dict_method_scname[method]]
889
        embed_mu = np.zeros((len(uni_cluster_labels), 2))
890
        for l in uni_cluster_labels:
891
            embed_mu[dict_label_num[l], :] = np.mean(embed_z[cluster_labels == l], axis=0)
892
893
        leg = (self.labels_map.index.astype(str) + " : " + self.labels_map.label_names).values
894
        for i, l in enumerate(uni_cluster_labels):
895
            ax.scatter(*embed_mu[dict_label_num[l]:dict_label_num[l] + 1, :].T,
896
                       c=[colors[i]], edgecolors='white', # linewidths=3,
897
                       s=250, marker='*', label=leg[i])
898
        if plot_legend:
899
            ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=ncol, markerscale=0.8, frameon=False)
900
        plt.setp(ax, xticks=[], yticks=[])
901
        box = ax.get_position()
902
        ax.set_position([box.x0, box.y0 + box.height * 0.1,
903
                         box.width, box.height * 0.9])
904
        if add_egde:
905
            if add_direct:
906
                graph = self.directed_backbone
907
            else:
908
                graph = self.backbone
909
            edges = list(graph.edges)
910
            edge_scores = np.array([d['weight'] for (u, v, d) in graph.edges(data=True)])
911
            if max(edge_scores) - min(edge_scores) == 0:
912
                edge_scores = edge_scores / max(edge_scores)
913
            else:
914
                edge_scores = (edge_scores - min(edge_scores)) / (max(edge_scores) - min(edge_scores)) * 3
915
916
            value_range = np.maximum(np.diff(ax.get_xlim())[0], np.diff(ax.get_ylim())[0])
917
            y_range = np.min(embed_z[:, 1]), np.max(embed_z[:, 1], axis=0)
918
            for i in range(len(edges)):
919
                points = embed_z[np.sum(self.cell_position_projected[:, edges[i]] > 0, axis=-1) == 2, :]
920
                points = points[points[:, 0].argsort()]
921
                try:
922
                    x_smooth, y_smooth = _get_smooth_curve(
923
                        points,
924
                        embed_mu[edges[i], :],
925
                        y_range
926
                    )
927
                except:
928
                    x_smooth, y_smooth = embed_mu[edges[i], 0], embed_mu[edges[i], 1]
929
                ax.plot(x_smooth, y_smooth,
930
                        '-',
931
                        linewidth=1 + edge_scores[i],
932
                        color="black",
933
                        alpha=0.8,
934
                        path_effects=[pe.Stroke(linewidth=1 + edge_scores[i] + 1.5,
935
                                                foreground='white'), pe.Normal()],
936
                        zorder=1
937
                        )
938
939
                if add_direct:
940
                    delta_x = embed_mu[edges[i][1], 0] - x_smooth[-2]
941
                    delta_y = embed_mu[edges[i][1], 1] - y_smooth[-2]
942
                    length = np.sqrt(delta_x ** 2 + delta_y ** 2) / 50 * value_range
943
                    ax.arrow(
944
                        embed_mu[edges[i][1], 0] - delta_x / length,
945
                        embed_mu[edges[i][1], 1] - delta_y / length,
946
                        delta_x / length,
947
                        delta_y / length,
948
                        color='black', alpha=1.0,
949
                        shape='full', lw=0, length_includes_head=True,
950
                        head_width=np.maximum(0.01 * (1 + edge_scores[i]), 0.03) * value_range,
951
                        zorder=2)
952
        self.ax = ax
953
        self.ax.figure.show()
954
        return None
955
956
957
    def infer_trajectory(self, root: Union[int,str], digraph = None, color = "pseudotime",
958
                         visualize: bool = True, path_to_fig = None,  **kwargs):
959
        '''Infer the trajectory.
960
961
        Parameters
962
        ----------
963
        root : int or string
964
            The root of the inferred trajectory. Can provide either an int (vertex index) or string (label name)
965
        digraph : nx.DiGraph, optional
966
            The directed graph to be used for trajectory inference. If None, the minimum spanning tree of the estimated trajectory backbone will be used.
967
        cutoff : string, optional
968
            The threshold for filtering edges with scores less than cutoff.
969
        visualize: boolean
970
            Whether plot the current trajectory backbone (directed graph)
971
        path_to_fig : string, optional  
972
            The path to save figure, or don't save if it is None.
973
        **kwargs : dict, optional
974
            Other keywords arguments for plotting.
975
        '''
976
        if isinstance(root,str):
977
            if root not in self.labels_map.values:
978
                raise ValueError("Root {} is not in the label names!".format(root))
979
            root = self.labels_map[self.labels_map['label_names']==root].index[0]
980
981
        if digraph is None:
982
            connected_comps = nx.node_connected_component(self.backbone, root)
983
            subG = self.backbone.subgraph(connected_comps)
984
985
            ## generate directed backbone which contains no loops
986
            DG = nx.DiGraph(nx.to_directed(self.backbone))
987
            temp = DG.subgraph(connected_comps)
988
            DG.remove_edges_from(temp.edges - nx.dfs_edges(DG, root))
989
            self.directed_backbone = DG
990
        else:
991
            if not nx.is_directed_acyclic_graph(digraph):
992
                raise ValueError("The graph 'digraph' should be a directed acyclic graph.")
993
            if set(digraph.nodes) != set(self.backbone.nodes):
994
                raise ValueError("The nodes in 'digraph' do not match the nodes in 'self.backbone'.")
995
            self.directed_backbone = digraph
996
997
            connected_comps = nx.node_connected_component(digraph, root)
998
            subG = self.backbone.subgraph(connected_comps)
999
1000
1001
        if len(subG.edges)>0:
1002
            milestone_net = self.inferer.build_milestone_net(subG, root)
1003
            if self.inferer.no_loop is False and milestone_net.shape[0]<len(self.backbone.edges):
1004
                warnings.warn("The directed graph shown is a minimum spanning tree of the estimated trajectory backbone to avoid arbitrary assignment of the directions.")
1005
            self.pseudotime = self.inferer.comp_pseudotime(milestone_net, root, self.cell_position_projected)
1006
        else:
1007
            warnings.warn("There are no connected states for starting from the giving root.")
1008
            self.pseudotime = -np.ones(self._adata.shape[0])
1009
1010
        self.adata.obs['pseudotime'] = self.pseudotime
1011
        print("Cell projection uncertainties stored as 'pseudotime' in self.adata.obs")
1012
1013
        if visualize:
1014
            self._adata.obs['pseudotime'] = self.pseudotime
1015
            self.ax = self.plot_backbone(directed = True, color = color, **kwargs)
1016
            if path_to_fig is not None:
1017
                self.ax.figure.savefig(path_to_fig)
1018
            self.ax.figure.show()
1019
1020
        return None
1021
1022
1023
1024
    def differential_expression_test(self, alpha: float = 0.05, cell_subset = None, order: int = 1):
1025
        '''Differentially gene expression test. All (selected and unselected) genes will be tested 
1026
        Only cells in `selected_cell_subset` will be used, which is useful when one need to
1027
        test differentially expressed genes on a branch of the inferred trajectory.
1028
1029
        Parameters
1030
        ----------
1031
        alpha : float, optional
1032
            The cutoff of p-values.
1033
        cell_subset : np.array, optional
1034
            The subset of cells to be used for testing. If None, all cells will be used.
1035
        order : int, optional
1036
            The maxium order we used for pseudotime in regression.
1037
1038
        Returns
1039
        ----------
1040
        res_df : pandas.DataFrame
1041
            The test results of expressed genes with two columns,
1042
            the estimated coefficients and the adjusted p-values.
1043
        '''
1044
        if not hasattr(self, 'pseudotime'):
1045
            raise ReferenceError("Pseudotime does not exist! Please run 'infer_trajectory' first.")
1046
        if cell_subset is None:
1047
            cell_subset = np.arange(self.X_input.shape[0])
1048
            print("All cells are selected.")
1049
        if order < 1:
1050
            raise  ValueError("Maximal order of pseudotime in regression must be at least 1.")
1051
1052
        # Prepare X and Y for regression expression ~ rank(PDT) + covariates
1053
        Y = self.adata.X[cell_subset,:]
1054
#        std_Y = np.std(Y, ddof=1, axis=0, keepdims=True)
1055
#        Y = np.divide(Y-np.mean(Y, axis=0, keepdims=True), std_Y, out=np.empty_like(Y)*np.nan, where=std_Y!=0)
1056
        X = stats.rankdata(self.pseudotime[cell_subset])        
1057
        if order > 1:
1058
            for _order in range(2, order+1):
1059
                X = np.c_[X, X**_order]
1060
        X = ((X-np.mean(X,axis=0, keepdims=True))/np.std(X, ddof=1, axis=0, keepdims=True))
1061
        X = np.c_[np.ones((X.shape[0],1)), X]
1062
        if self.covariates is not None:
1063
            X = np.c_[X, self.covariates[cell_subset, :]]
1064
1065
        res_df = DE_test(Y, X, self.adata.var_names, i_test = np.array(list(range(1,order+1))), alpha = alpha)
1066
        return res_df[res_df.pvalue_adjusted_1 != 0]
1067
1068
1069
 
1070
1071
    def evaluate(self, milestone_net, begin_node_true, grouping = None,
1072
                thres: float = 0.5, no_loop: bool = True, cutoff: Optional[float] = None,
1073
                method: str = 'mean', path: Optional[str] = None):
1074
        ''' Evaluate the model.
1075
1076
        Parameters
1077
        ----------
1078
        milestone_net : pd.DataFrame
1079
            The true milestone network. For real data, milestone_net will be a DataFrame of the graph of nodes.
1080
            Eg.
1081
1082
            from|to
1083
            ---|---
1084
            cluster 1 | cluster 1
1085
            cluster 1 | cluster 2
1086
1087
            For synthetic data, milestone_net will be a DataFrame of the (projected)
1088
            positions of cells. The indexes are the orders of cells in the dataset.
1089
            Eg.
1090
1091
            from|to|w
1092
            ---|---|---
1093
            cluster 1 | cluster 1 | 1
1094
            cluster 1 | cluster 2 | 0.1
1095
        begin_node_true : str or int
1096
            The true begin node of the milestone.
1097
        grouping : np.array, optional
1098
            \([N,]\) The labels. For real data, grouping must be provided.
1099
1100
        Returns
1101
        ----------
1102
        res : pd.DataFrame
1103
            The evaluation result.
1104
        '''
1105
        if not hasattr(self, 'labels_map'):
1106
            raise ValueError("No given labels for training.")
1107
1108
        '''
1109
        # Evaluate for the whole dataset will ignore selected_cell_subset.
1110
        if len(self.selected_cell_subset)!=len(self.cell_names):
1111
            warnings.warn("Evaluate for the whole dataset.")
1112
        '''
1113
        
1114
        # If the begin_node_true, need to encode it by self.le.
1115
        # this dict is for milestone net cause their labels are not merged
1116
        # all keys of label_map_dict are str
1117
        label_map_dict = dict()
1118
        for i in range(self.labels_map.shape[0]):
1119
            label_mapped = self.labels_map.loc[i]
1120
            ## merged cluster index is connected by comma
1121
            for each in label_mapped.values[0].split(","):
1122
                label_map_dict[each] = i
1123
        if isinstance(begin_node_true, str):
1124
            begin_node_true = label_map_dict[begin_node_true]
1125
            
1126
        # For generated data, grouping information is already in milestone_net
1127
        if 'w' in milestone_net.columns:
1128
            grouping = None
1129
            
1130
        # If milestone_net is provided, transform them to be numeric.
1131
        if milestone_net is not None:
1132
            milestone_net['from'] = [label_map_dict[x] for x in milestone_net["from"]]
1133
            milestone_net['to'] = [label_map_dict[x] for x in milestone_net["to"]]
1134
1135
        # this dict is for potentially merged clusters.
1136
        label_map_dict_for_merged_cluster = dict(zip(self.labels_map["label_names"],self.labels_map.index))
1137
        mapped_labels = np.array([label_map_dict_for_merged_cluster[x] for x in self.labels])
1138
        begin_node_pred = int(np.argmin(np.mean((
1139
            self.z[mapped_labels==begin_node_true,:,np.newaxis] -
1140
            self.mu[np.newaxis,:,:])**2, axis=(0,1))))
1141
1142
        if cutoff is None:
1143
            cutoff = 0.01
1144
1145
        G = self.backbone
1146
        w = self.cell_position_projected
1147
        pseudotime = self.pseudotime
1148
1149
        # 1. Topology
1150
        G_pred = nx.Graph()
1151
        G_pred.add_nodes_from(G.nodes)
1152
        G_pred.add_edges_from(G.edges)
1153
        nx.set_node_attributes(G_pred, False, 'is_init')
1154
        G_pred.nodes[begin_node_pred]['is_init'] = True
1155
1156
        G_true = nx.Graph()
1157
        G_true.add_nodes_from(G.nodes)
1158
        # if 'grouping' is not provided, assume 'milestone_net' contains proportions
1159
        if grouping is None:
1160
            G_true.add_edges_from(list(
1161
                milestone_net[~pd.isna(milestone_net['w'])].groupby(['from', 'to']).count().index))
1162
        # otherwise, 'milestone_net' indicates edges
1163
        else:
1164
            if milestone_net is not None:             
1165
                G_true.add_edges_from(list(
1166
                    milestone_net.groupby(['from', 'to']).count().index))
1167
            grouping = [label_map_dict[x] for x in grouping]
1168
            grouping = np.array(grouping)
1169
        G_true.remove_edges_from(nx.selfloop_edges(G_true))
1170
        nx.set_node_attributes(G_true, False, 'is_init')
1171
        G_true.nodes[begin_node_true]['is_init'] = True
1172
        res = topology(G_true, G_pred)
1173
            
1174
        # 2. Milestones assignment
1175
        if grouping is None:
1176
            milestones_true = milestone_net['from'].values.copy()
1177
            milestones_true[(milestone_net['from']!=milestone_net['to'])
1178
                           &(milestone_net['w']<0.5)] = milestone_net[(milestone_net['from']!=milestone_net['to'])
1179
                                                                      &(milestone_net['w']<0.5)]['to'].values
1180
        else:
1181
            milestones_true = grouping
1182
        milestones_true = milestones_true
1183
        milestones_pred = np.argmax(w, axis=1)
1184
        res['ARI'] = (adjusted_rand_score(milestones_true, milestones_pred) + 1)/2
1185
        
1186
        if grouping is None:
1187
            n_samples = len(milestone_net)
1188
            prop = np.zeros((n_samples,n_samples))
1189
            prop[np.arange(n_samples), milestone_net['to']] = 1-milestone_net['w']
1190
            prop[np.arange(n_samples), milestone_net['from']] = np.where(np.isnan(milestone_net['w']), 1, milestone_net['w'])
1191
            res['GRI'] = get_GRI(prop, w)
1192
        else:
1193
            res['GRI'] = get_GRI(grouping, w)
1194
        
1195
        # 3. Correlation between geodesic distances / Pseudotime
1196
        if no_loop:
1197
            if grouping is None:
1198
                pseudotime_true = milestone_net['from'].values + 1 - milestone_net['w'].values
1199
                pseudotime_true[np.isnan(pseudotime_true)] = milestone_net[pd.isna(milestone_net['w'])]['from'].values            
1200
            else:
1201
                pseudotime_true = - np.ones(len(grouping))
1202
                nx.set_edge_attributes(G_true, values = 1, name = 'weight')
1203
                connected_comps = nx.node_connected_component(G_true, begin_node_true)
1204
                subG = G_true.subgraph(connected_comps)
1205
                milestone_net_true = self.inferer.build_milestone_net(subG, begin_node_true)
1206
                if len(milestone_net_true)>0:
1207
                    pseudotime_true[grouping==int(milestone_net_true[0,0])] = 0
1208
                    for i in range(len(milestone_net_true)):
1209
                        pseudotime_true[grouping==int(milestone_net_true[i,1])] = milestone_net_true[i,-1]
1210
            pseudotime_true = pseudotime_true[pseudotime>-1]
1211
            pseudotime_pred = pseudotime[pseudotime>-1]
1212
            res['PDT score'] = (np.corrcoef(pseudotime_true,pseudotime_pred)[0,1]+1)/2
1213
        else:
1214
            res['PDT score'] = np.nan
1215
            
1216
        # 4. Shape
1217
        # score_cos_theta = 0
1218
        # for (_from,_to) in G.edges:
1219
        #     _z = self.z[(w[:,_from]>0) & (w[:,_to]>0),:]
1220
        #     v_1 = _z - self.mu[:,_from]
1221
        #     v_2 = _z - self.mu[:,_to]
1222
        #     cos_theta = np.sum(v_1*v_2, -1)/(np.linalg.norm(v_1,axis=-1)*np.linalg.norm(v_2,axis=-1)+1e-12)
1223
1224
        #     score_cos_theta += np.sum((1-cos_theta)/2)
1225
1226
        # res['score_cos_theta'] = score_cos_theta/(np.sum(np.sum(w>0, axis=-1)==2)+1e-12)
1227
        return res
1228
1229
1230
    def save_model(self, path_to_file: str = 'model.checkpoint',save_adata: bool = False):
1231
        '''Saving model weights.
1232
1233
        Parameters
1234
        ----------
1235
        path_to_file : str, optional
1236
            The path to weight files of pre-trained or trained model
1237
        save_adata : boolean, optional
1238
            Whether to save adata or not.
1239
        '''
1240
        self.vae.save_weights(path_to_file)
1241
        if hasattr(self, 'labels') and self.labels is not None:
1242
            with open(path_to_file + '.label', 'wb') as f:
1243
                np.save(f, self.labels)
1244
        with open(path_to_file + '.config', 'wb') as f:
1245
            self.dim_origin = self.X_input.shape[1]
1246
            np.save(f, np.array([
1247
                self.dim_origin, self.dimensions, self.dim_latent,
1248
                self.model_type, 0 if self.covariates is None else self.covariates.shape[1]], dtype=object))
1249
        if hasattr(self, 'inferer') and hasattr(self, 'uncertainty'):
1250
            with open(path_to_file + '.inference', 'wb') as f:
1251
                np.save(f, np.array([
1252
                    self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty,
1253
                    self.z,self.cell_position_variance], dtype=object))
1254
        if save_adata:
1255
            self.adata.write(path_to_file + '.adata.h5ad')
1256
1257
1258
    def load_model(self, path_to_file: str = 'model.checkpoint', load_labels: bool = False, load_adata: bool = False):
1259
        '''Load model weights.
1260
1261
        Parameters
1262
        ----------
1263
        path_to_file : str, optional
1264
            The path to weight files of pre trained or trained model
1265
        load_labels : boolean, optional
1266
            Whether to load clustering labels or not.
1267
            If load_labels is True, then the LatentSpace layer will be initialized basd on the model.
1268
            If load_labels is False, then the LatentSpace layer will not be initialized.
1269
        load_adata : boolean, optional
1270
            Whether to load adata or not.
1271
        '''
1272
        if not os.path.exists(path_to_file + '.config'):
1273
            raise AssertionError('Config file not exist!')
1274
        if load_labels and not os.path.exists(path_to_file + '.label'):
1275
            raise AssertionError('Label file not exist!')
1276
1277
        with open(path_to_file + '.config', 'rb') as f:
1278
            [self.dim_origin, self.dimensions,
1279
             self.dim_latent, self.model_type, cov_dim] = np.load(f, allow_pickle=True)
1280
        self.vae = model.VariationalAutoEncoder(
1281
            self.dim_origin, self.dimensions,
1282
            self.dim_latent, self.model_type, False if cov_dim == 0 else True
1283
        )
1284
1285
        if load_labels:
1286
            with open(path_to_file + '.label', 'rb') as f:
1287
                cluster_labels = np.load(f, allow_pickle=True)
1288
            self.init_latent_space(cluster_labels, dist_thres=0)
1289
            if os.path.exists(path_to_file + '.inference'):
1290
                with open(path_to_file + '.inference', 'rb') as f:
1291
                    arr = np.load(f, allow_pickle=True)
1292
                    if len(arr) == 8:
1293
                        [self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty,
1294
                         self.D_JS, self.z,self.cell_position_variance] = arr
1295
                    else:
1296
                        [self.pi, self.mu, self.pc_x, self.cell_position_posterior, self.uncertainty,
1297
                         self.z,self.cell_position_variance] = arr
1298
                self._adata_z = sc.AnnData(self.z)
1299
                sc.pp.neighbors(self._adata_z)
1300
        ## initialize the weight of encoder and decoder
1301
        self.vae.encoder(np.zeros((1, self.dim_origin + cov_dim)))
1302
        self.vae.decoder(np.expand_dims(np.zeros((1,self.dim_latent + cov_dim)),1))
1303
1304
        self.vae.load_weights(path_to_file)
1305
        self.update_z()
1306
        if load_adata:
1307
            if not os.path.exists(path_to_file + '.adata.h5ad'):
1308
                raise AssertionError('AnnData file not exist!')
1309
            self.adata = sc.read_h5ad(path_to_file + '.adata.h5ad')
1310
            self._adata.obs = self.adata.obs.copy()