a b/src/multivelo/dynamical_chrom_func.py
1
from multivelo import mv_logging as logg
2
from multivelo import settings
3
4
import os
5
import sys
6
import numpy as np
7
from numpy.linalg import norm
8
import matplotlib.pyplot as plt
9
from scipy import sparse
10
from scipy.sparse import coo_matrix
11
from scipy.optimize import minimize
12
from scipy.spatial import KDTree
13
from sklearn.metrics import pairwise_distances
14
from sklearn.mixture import GaussianMixture
15
from scanpy import Neighbors
16
import scvelo as scv
17
import pandas as pd
18
import seaborn as sns
19
from numba import njit
20
import numba
21
from numba.typed import List
22
from tqdm.auto import tqdm
23
from joblib import Parallel, delayed
24
import math
25
import torch
26
from torch import nn
27
28
current_path = os.path.dirname(__file__)
29
src_path = os.path.join(current_path, "..")
30
sys.path.append(src_path)
31
32
33
# a funciton to check for invalid values of different parameters
34
def check_params(alpha_c,
35
                 alpha,
36
                 beta,
37
                 gamma,
38
                 c0=None,
39
                 u0=None,
40
                 s0=None):
41
42
    new_alpha_c = alpha_c
43
    new_alpha = alpha
44
    new_beta = beta
45
    new_gamma = gamma
46
47
    new_c0 = c0
48
    new_u0 = u0
49
    new_s0 = s0
50
51
    inf_fix = 1e10
52
    zero_fix = 1e-10
53
54
    # check if any of our parameters are infinite
55
    if c0 is not None and math.isinf(c0):
56
        logg.error("c0 is infinite.", v=1)
57
        new_c0 = inf_fix
58
    if u0 is not None and math.isinf(u0):
59
        logg.error("u0 is infinite.", v=1)
60
        new_u0 = inf_fix
61
    if s0 is not None and math.isinf(s0):
62
        logg.error("s0 is infinite.", v=1)
63
        new_s0 = inf_fix
64
    if math.isinf(alpha_c):
65
        new_alpha_c = inf_fix
66
        logg.error("alpha_c is infinite.", v=1)
67
    if math.isinf(alpha):
68
        new_alpha = inf_fix
69
        logg.error("alpha is infinite.", v=1)
70
    if math.isinf(beta):
71
        new_beta = inf_fix
72
        logg.error("beta is infinite.", v=1)
73
    if math.isinf(gamma):
74
        new_gamma = inf_fix
75
        logg.error("gamma is infinite.", v=1)
76
77
    # check if any of our parameters are nan
78
    if c0 is not None and math.isnan(c0):
79
        logg.error("c0 is Nan.", v=1)
80
        new_c0 = zero_fix
81
    if u0 is not None and math.isnan(u0):
82
        logg.error("u0 is Nan.", v=1)
83
        new_u0 = zero_fix
84
    if s0 is not None and math.isnan(s0):
85
        logg.error("s0 is Nan.", v=1)
86
        new_s0 = zero_fix
87
    if math.isnan(alpha_c):
88
        new_alpha_c = zero_fix
89
        logg.error("alpha_c is Nan.", v=1)
90
    if math.isnan(alpha):
91
        new_alpha = zero_fix
92
        logg.error("alpha is Nan.", v=1)
93
    if math.isnan(beta):
94
        new_beta = zero_fix
95
        logg.error("beta is Nan.", v=1)
96
    if math.isnan(gamma):
97
        new_gamma = zero_fix
98
        logg.error("gamma is Nan.", v=1)
99
100
    # check if any of our rate parameters are 0
101
    if alpha_c < 1e-7:
102
        new_alpha_c = zero_fix
103
        logg.error("alpha_c is zero.", v=1)
104
    if alpha < 1e-7:
105
        new_alpha = zero_fix
106
        logg.error("alpha is zero.", v=1)
107
    if beta < 1e-7:
108
        new_beta = zero_fix
109
        logg.error("beta is zero.", v=1)
110
    if gamma < 1e-7:
111
        new_gamma = zero_fix
112
        logg.error("gamma is zero.", v=1)
113
114
    if beta == alpha_c:
115
        new_beta += zero_fix
116
        logg.error("alpha_c and beta are equal, leading to divide by zero",
117
                   v=1)
118
    if beta == gamma:
119
        new_gamma += zero_fix
120
        logg.error("gamma and beta are equal, leading to divide by zero",
121
                   v=1)
122
    if alpha_c == gamma:
123
        new_gamma += zero_fix
124
        logg.error("gamma and alpha_c are equal, leading to divide by zero",
125
                   v=1)
126
127
    if c0 is not None and u0 is not None and s0 is not None:
128
        return new_alpha_c, new_alpha, new_beta, new_gamma, new_c0, new_u0, \
129
               new_s0
130
131
    return new_alpha_c, new_alpha, new_beta, new_gamma
132
133
134
@njit(
135
    locals={
136
            "res": numba.types.float64[:, ::1],
137
            "eat": numba.types.float64[::1],
138
            "ebt": numba.types.float64[::1],
139
            "egt": numba.types.float64[::1],
140
    },
141
    fastmath=True)
142
def predict_exp(tau,
143
                c0,
144
                u0,
145
                s0,
146
                alpha_c,
147
                alpha,
148
                beta,
149
                gamma,
150
                scale_cc=1,
151
                pred_r=True,
152
                chrom_open=True,
153
                backward=False,
154
                rna_only=False):
155
156
    if len(tau) == 0:
157
        return np.empty((0, 3))
158
    if backward:
159
        tau = -tau
160
    res = np.empty((len(tau), 3))
161
    eat = np.exp(-alpha_c * tau)
162
    ebt = np.exp(-beta * tau)
163
    egt = np.exp(-gamma * tau)
164
    if rna_only:
165
        kc = 1
166
        c0 = 1
167
    else:
168
        if chrom_open:
169
            kc = 1
170
        else:
171
            kc = 0
172
            alpha_c *= scale_cc
173
174
    const = (kc - c0) * alpha / (beta - alpha_c)
175
176
    res[:, 0] = kc - (kc - c0) * eat
177
178
    if pred_r:
179
180
        res[:, 1] = u0 * ebt + (alpha * kc / beta) * (1 - ebt)
181
        res[:, 1] += const * (ebt - eat)
182
183
        res[:, 2] = s0 * egt + (alpha * kc / gamma) * (1 - egt)
184
        res[:, 2] += ((beta / (gamma - beta)) *
185
                      ((alpha * kc / beta) - u0 - const) * (egt - ebt))
186
        res[:, 2] += (beta / (gamma - alpha_c)) * const * (egt - eat)
187
188
    else:
189
        res[:, 1] = np.zeros(len(tau))
190
        res[:, 2] = np.zeros(len(tau))
191
    return res
192
193
194
@njit(locals={
195
            "exp_sw1": numba.types.float64[:, ::1],
196
            "exp_sw2": numba.types.float64[:, ::1],
197
            "exp_sw3": numba.types.float64[:, ::1],
198
            "exp1": numba.types.float64[:, ::1],
199
            "exp2": numba.types.float64[:, ::1],
200
            "exp3": numba.types.float64[:, ::1],
201
            "exp4": numba.types.float64[:, ::1],
202
            "tau_sw1": numba.types.float64[::1],
203
            "tau_sw2": numba.types.float64[::1],
204
            "tau_sw3": numba.types.float64[::1],
205
            "tau1": numba.types.float64[::1],
206
            "tau2": numba.types.float64[::1],
207
            "tau3": numba.types.float64[::1],
208
            "tau4": numba.types.float64[::1]
209
    },
210
    fastmath=True)
211
def generate_exp(tau_list,
212
                 t_sw_array,
213
                 alpha_c,
214
                 alpha,
215
                 beta,
216
                 gamma,
217
                 scale_cc=1,
218
                 model=1,
219
                 rna_only=False):
220
221
    if beta == alpha_c:
222
        beta += 1e-3
223
    if gamma == beta or gamma == alpha_c:
224
        gamma += 1e-3
225
    switch = len(t_sw_array)
226
    if switch >= 1:
227
        tau_sw1 = np.array([t_sw_array[0]])
228
        if switch >= 2:
229
            tau_sw2 = np.array([t_sw_array[1] - t_sw_array[0]])
230
            if switch == 3:
231
                tau_sw3 = np.array([t_sw_array[2] - t_sw_array[1]])
232
    exp_sw1, exp_sw2, exp_sw3 = (np.empty((0, 3)),
233
                                 np.empty((0, 3)),
234
                                 np.empty((0, 3)))
235
    if tau_list is None:
236
        if model == 0:
237
            if switch >= 1:
238
                exp_sw1 = predict_exp(tau_sw1, 0, 0, 0, alpha_c, alpha, beta,
239
                                      gamma, pred_r=False, scale_cc=scale_cc,
240
                                      rna_only=rna_only)
241
                if switch >= 2:
242
                    exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0],
243
                                          exp_sw1[0, 1], exp_sw1[0, 2],
244
                                          alpha_c, alpha, beta, gamma,
245
                                          pred_r=False, chrom_open=False,
246
                                          scale_cc=scale_cc, rna_only=rna_only)
247
                    if switch >= 3:
248
                        exp_sw3 = predict_exp(tau_sw3, exp_sw2[0, 0],
249
                                              exp_sw2[0, 1], exp_sw2[0, 2],
250
                                              alpha_c, alpha, beta, gamma,
251
                                              chrom_open=False,
252
                                              scale_cc=scale_cc,
253
                                              rna_only=rna_only)
254
        elif model == 1:
255
            if switch >= 1:
256
                exp_sw1 = predict_exp(tau_sw1, 0, 0, 0, alpha_c, alpha, beta,
257
                                      gamma, pred_r=False, scale_cc=scale_cc,
258
                                      rna_only=rna_only)
259
                if switch >= 2:
260
                    exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0],
261
                                          exp_sw1[0, 1], exp_sw1[0, 2],
262
                                          alpha_c, alpha, beta, gamma,
263
                                          scale_cc=scale_cc, rna_only=rna_only)
264
                    if switch >= 3:
265
                        exp_sw3 = predict_exp(tau_sw3, exp_sw2[0, 0],
266
                                              exp_sw2[0, 1], exp_sw2[0, 2],
267
                                              alpha_c, alpha, beta, gamma,
268
                                              chrom_open=False,
269
                                              scale_cc=scale_cc,
270
                                              rna_only=rna_only)
271
        elif model == 2:
272
            if switch >= 1:
273
                exp_sw1 = predict_exp(tau_sw1, 0, 0, 0, alpha_c, alpha, beta,
274
                                      gamma, pred_r=False, scale_cc=scale_cc,
275
                                      rna_only=rna_only)
276
                if switch >= 2:
277
                    exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0],
278
                                          exp_sw1[0, 1], exp_sw1[0, 2],
279
                                          alpha_c, alpha, beta, gamma,
280
                                          scale_cc=scale_cc, rna_only=rna_only)
281
                    if switch >= 3:
282
                        exp_sw3 = predict_exp(tau_sw3, exp_sw2[0, 0],
283
                                              exp_sw2[0, 1], exp_sw2[0, 2],
284
                                              alpha_c, 0, beta, gamma,
285
                                              scale_cc=scale_cc,
286
                                              rna_only=rna_only)
287
288
        return (np.empty((0, 3)), np.empty((0, 3)), np.empty((0, 3)),
289
                np.empty((0, 3))), (exp_sw1, exp_sw2, exp_sw3)
290
291
    tau1 = tau_list[0]
292
    if switch >= 1:
293
        tau2 = tau_list[1]
294
        if switch >= 2:
295
            tau3 = tau_list[2]
296
            if switch == 3:
297
                tau4 = tau_list[3]
298
    exp1, exp2, exp3, exp4 = (np.empty((0, 3)), np.empty((0, 3)),
299
                              np.empty((0, 3)), np.empty((0, 3)))
300
    if model == 0:
301
        exp1 = predict_exp(tau1, 0, 0, 0, alpha_c, alpha, beta, gamma,
302
                           pred_r=False, scale_cc=scale_cc, rna_only=rna_only)
303
        if switch >= 1:
304
            exp_sw1 = predict_exp(tau_sw1, 0, 0, 0, alpha_c, alpha, beta,
305
                                  gamma, pred_r=False, scale_cc=scale_cc,
306
                                  rna_only=rna_only)
307
            exp2 = predict_exp(tau2, exp_sw1[0, 0], exp_sw1[0, 1],
308
                               exp_sw1[0, 2], alpha_c, alpha, beta, gamma,
309
                               pred_r=False, chrom_open=False,
310
                               scale_cc=scale_cc, rna_only=rna_only)
311
            if switch >= 2:
312
                exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1],
313
                                      exp_sw1[0, 2], alpha_c, alpha, beta,
314
                                      gamma, pred_r=False, chrom_open=False,
315
                                      scale_cc=scale_cc, rna_only=rna_only)
316
                exp3 = predict_exp(tau3, exp_sw2[0, 0], exp_sw2[0, 1],
317
                                   exp_sw2[0, 2], alpha_c, alpha, beta, gamma,
318
                                   chrom_open=False, scale_cc=scale_cc,
319
                                   rna_only=rna_only)
320
                if switch == 3:
321
                    exp_sw3 = predict_exp(tau_sw3, exp_sw2[0, 0],
322
                                          exp_sw2[0, 1], exp_sw2[0, 2],
323
                                          alpha_c, alpha, beta, gamma,
324
                                          chrom_open=False, scale_cc=scale_cc,
325
                                          rna_only=rna_only)
326
                    exp4 = predict_exp(tau4, exp_sw3[0, 0], exp_sw3[0, 1],
327
                                       exp_sw3[0, 2], alpha_c, 0, beta, gamma,
328
                                       chrom_open=False, scale_cc=scale_cc,
329
                                       rna_only=rna_only)
330
    elif model == 1:
331
        exp1 = predict_exp(tau1, 0, 0, 0, alpha_c, alpha, beta, gamma,
332
                           pred_r=False, scale_cc=scale_cc, rna_only=rna_only)
333
        if switch >= 1:
334
            exp_sw1 = predict_exp(tau_sw1, 0, 0, 0, alpha_c, alpha, beta,
335
                                  gamma, pred_r=False, scale_cc=scale_cc,
336
                                  rna_only=rna_only)
337
            exp2 = predict_exp(tau2, exp_sw1[0, 0], exp_sw1[0, 1],
338
                               exp_sw1[0, 2], alpha_c, alpha, beta, gamma,
339
                               scale_cc=scale_cc, rna_only=rna_only)
340
            if switch >= 2:
341
                exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1],
342
                                      exp_sw1[0, 2], alpha_c, alpha, beta,
343
                                      gamma, scale_cc=scale_cc,
344
                                      rna_only=rna_only)
345
                exp3 = predict_exp(tau3, exp_sw2[0, 0], exp_sw2[0, 1],
346
                                   exp_sw2[0, 2], alpha_c, alpha, beta, gamma,
347
                                   chrom_open=False, scale_cc=scale_cc,
348
                                   rna_only=rna_only)
349
                if switch == 3:
350
                    exp_sw3 = predict_exp(tau_sw3, exp_sw2[0, 0],
351
                                          exp_sw2[0, 1], exp_sw2[0, 2],
352
                                          alpha_c, alpha, beta, gamma,
353
                                          chrom_open=False, scale_cc=scale_cc,
354
                                          rna_only=rna_only)
355
                    exp4 = predict_exp(tau4, exp_sw3[0, 0], exp_sw3[0, 1],
356
                                       exp_sw3[0, 2], alpha_c, 0, beta, gamma,
357
                                       chrom_open=False, scale_cc=scale_cc,
358
                                       rna_only=rna_only)
359
    elif model == 2:
360
        exp1 = predict_exp(tau1, 0, 0, 0, alpha_c, alpha, beta, gamma,
361
                           pred_r=False, scale_cc=scale_cc, rna_only=rna_only)
362
        if switch >= 1:
363
            exp_sw1 = predict_exp(tau_sw1, 0, 0, 0, alpha_c, alpha, beta,
364
                                  gamma, pred_r=False, scale_cc=scale_cc,
365
                                  rna_only=rna_only)
366
            exp2 = predict_exp(tau2, exp_sw1[0, 0], exp_sw1[0, 1],
367
                               exp_sw1[0, 2], alpha_c, alpha, beta, gamma,
368
                               scale_cc=scale_cc, rna_only=rna_only)
369
            if switch >= 2:
370
                exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1],
371
                                      exp_sw1[0, 2], alpha_c, alpha, beta,
372
                                      gamma, scale_cc=scale_cc,
373
                                      rna_only=rna_only)
374
                exp3 = predict_exp(tau3, exp_sw2[0, 0], exp_sw2[0, 1],
375
                                   exp_sw2[0, 2], alpha_c, 0, beta, gamma,
376
                                   scale_cc=scale_cc, rna_only=rna_only)
377
                if switch == 3:
378
                    exp_sw3 = predict_exp(tau_sw3, exp_sw2[0, 0],
379
                                          exp_sw2[0, 1], exp_sw2[0, 2],
380
                                          alpha_c, 0, beta, gamma,
381
                                          scale_cc=scale_cc, rna_only=rna_only)
382
                    exp4 = predict_exp(tau4, exp_sw3[0, 0], exp_sw3[0, 1],
383
                                       exp_sw3[0, 2], alpha_c, 0, beta, gamma,
384
                                       chrom_open=False, scale_cc=scale_cc,
385
                                       rna_only=rna_only)
386
    return (exp1, exp2, exp3, exp4), (exp_sw1, exp_sw2, exp_sw3)
387
388
389
@njit(locals={
390
            "exp_sw1": numba.types.float64[:, ::1],
391
            "exp_sw2": numba.types.float64[:, ::1],
392
            "exp_sw3": numba.types.float64[:, ::1],
393
            "exp1": numba.types.float64[:, ::1],
394
            "exp2": numba.types.float64[:, ::1],
395
            "exp3": numba.types.float64[:, ::1],
396
            "exp4": numba.types.float64[:, ::1],
397
            "tau_sw1": numba.types.float64[::1],
398
            "tau_sw2": numba.types.float64[::1],
399
            "tau_sw3": numba.types.float64[::1],
400
            "tau1": numba.types.float64[::1],
401
            "tau2": numba.types.float64[::1],
402
            "tau3": numba.types.float64[::1],
403
            "tau4": numba.types.float64[::1]
404
    },
405
    fastmath=True)
406
def generate_exp_backward(tau_list, t_sw_array, alpha_c, alpha, beta, gamma,
407
                          scale_cc=1, model=1):
408
    if beta == alpha_c:
409
        beta += 1e-3
410
    if gamma == beta or gamma == alpha_c:
411
        gamma += 1e-3
412
    switch = len(t_sw_array)
413
    if switch >= 1:
414
        tau_sw1 = np.array([t_sw_array[0]])
415
        if switch >= 2:
416
            tau_sw2 = np.array([t_sw_array[1] - t_sw_array[0]])
417
    if t is None:
418
        if model == 0:
419
            exp_sw1 = predict_exp(tau_sw1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta,
420
                                  gamma, scale_cc=scale_cc, chrom_open=False,
421
                                  backward=True)
422
            exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1],
423
                                  exp_sw1[0, 2], alpha_c, alpha, beta, gamma,
424
                                  scale_cc=scale_cc, chrom_open=False,
425
                                  backward=True)
426
        elif model == 1:
427
            exp_sw1 = predict_exp(tau_sw1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta,
428
                                  gamma, scale_cc=scale_cc, chrom_open=False,
429
                                  backward=True)
430
            exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1],
431
                                  exp_sw1[0, 2], alpha_c, alpha, beta, gamma,
432
                                  scale_cc=scale_cc, chrom_open=False,
433
                                  backward=True)
434
        elif model == 2:
435
            exp_sw1 = predict_exp(tau_sw1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta,
436
                                  gamma, scale_cc=scale_cc, chrom_open=False,
437
                                  backward=True)
438
            exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1],
439
                                  exp_sw1[0, 2], alpha_c, 0, beta, gamma,
440
                                  scale_cc=scale_cc, backward=True)
441
        return (np.empty((0, 0)),
442
                np.empty((0, 0)),
443
                np.empty((0, 0))), (exp_sw1, exp_sw2)
444
445
    tau1 = tau_list[0]
446
    if switch >= 1:
447
        tau2 = tau_list[1]
448
        if switch >= 2:
449
            tau3 = tau_list[2]
450
451
    exp1, exp2, exp3 = np.empty((0, 3)), np.empty((0, 3)), np.empty((0, 3))
452
    if model == 0:
453
        exp1 = predict_exp(tau1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, gamma,
454
                           scale_cc=scale_cc, chrom_open=False, backward=True)
455
        if switch >= 1:
456
            exp_sw1 = predict_exp(tau_sw1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta,
457
                                  gamma, scale_cc=scale_cc, chrom_open=False,
458
                                  backward=True)
459
            exp2 = predict_exp(tau2, exp_sw1[0, 0], exp_sw1[0, 1],
460
                               exp_sw1[0, 2], alpha_c, alpha, beta, gamma,
461
                               scale_cc=scale_cc, chrom_open=False,
462
                               backward=True)
463
            if switch >= 2:
464
                exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1],
465
                                      exp_sw1[0, 2], alpha_c, alpha, beta,
466
                                      gamma, scale_cc=scale_cc,
467
                                      chrom_open=False, backward=True)
468
                exp3 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1],
469
                                   exp_sw1[0, 2], alpha_c, alpha, beta, gamma,
470
                                   scale_cc=scale_cc, chrom_open=False,
471
                                   backward=True)
472
    elif model == 1:
473
        exp1 = predict_exp(tau1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, gamma,
474
                           scale_cc=scale_cc, chrom_open=False, backward=True)
475
        if switch >= 1:
476
            exp_sw1 = predict_exp(tau_sw1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta,
477
                                  gamma, scale_cc=scale_cc, chrom_open=False,
478
                                  backward=True)
479
            exp2 = predict_exp(tau2, exp_sw1[0, 0], exp_sw1[0, 1],
480
                               exp_sw1[0, 2], alpha_c, alpha, beta, gamma,
481
                               scale_cc=scale_cc, chrom_open=False,
482
                               backward=True)
483
            if switch >= 2:
484
                exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1],
485
                                      exp_sw1[0, 2], alpha_c, alpha, beta,
486
                                      gamma, scale_cc=scale_cc,
487
                                      chrom_open=False, backward=True)
488
                exp3 = predict_exp(tau3, exp_sw2[0, 0], exp_sw2[0, 1],
489
                                   exp_sw2[0, 2], alpha_c, alpha, beta, gamma,
490
                                   scale_cc=scale_cc, backward=True)
491
    elif model == 2:
492
        exp1 = predict_exp(tau1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, gamma,
493
                           scale_cc=scale_cc, chrom_open=False, backward=True)
494
        if switch >= 1:
495
            exp_sw1 = predict_exp(tau_sw1, 1e-3, 1e-3, 1e-3, alpha_c, alpha,
496
                                  beta, gamma, scale_cc=scale_cc,
497
                                  chrom_open=False, backward=True)
498
            exp2 = predict_exp(tau2, exp_sw1[0, 0], exp_sw1[0, 1],
499
                               exp_sw1[0, 2], alpha_c, 0, beta, gamma,
500
                               scale_cc=scale_cc, backward=True)
501
            if switch >= 2:
502
                exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1],
503
                                      exp_sw1[0, 2], alpha_c, 0, beta, gamma,
504
                                      scale_cc=scale_cc, backward=True)
505
                exp3 = predict_exp(tau3, exp_sw2[0, 0], exp_sw2[0, 1],
506
                                   exp_sw2[0, 2], alpha_c, alpha, beta, gamma,
507
                                   scale_cc=scale_cc, backward=True)
508
    return (exp1, exp2, exp3), (exp_sw1, exp_sw2)
509
510
511
@njit(locals={
512
            "res": numba.types.float64[:, ::1],
513
    },
514
    fastmath=True)
515
def ss_exp(alpha_c, alpha, beta, gamma, pred_r=True, chrom_open=True):
516
    res = np.empty((1, 3))
517
    if not chrom_open:
518
        res[0, 0] = 0
519
        res[0, 1] = 0
520
        res[0, 2] = 0
521
    else:
522
        res[0, 0] = 1
523
        if pred_r:
524
            res[0, 1] = alpha / beta
525
            res[0, 2] = alpha / gamma
526
        else:
527
            res[0, 1] = 0
528
            res[0, 2] = 0
529
    return res
530
531
532
@njit(locals={
533
            "ss1": numba.types.float64[:, ::1],
534
            "ss2": numba.types.float64[:, ::1],
535
            "ss3": numba.types.float64[:, ::1],
536
            "ss4": numba.types.float64[:, ::1]
537
    },
538
    fastmath=True)
539
def compute_ss_exp(alpha_c, alpha, beta, gamma, model=0):
540
    if model == 0:
541
        ss1 = ss_exp(alpha_c, alpha, beta, gamma, pred_r=False)
542
        ss2 = ss_exp(alpha_c, alpha, beta, gamma, pred_r=False,
543
                     chrom_open=False)
544
        ss3 = ss_exp(alpha_c, alpha, beta, gamma, chrom_open=False)
545
        ss4 = ss_exp(alpha_c, 0, beta, gamma, chrom_open=False)
546
    elif model == 1:
547
        ss1 = ss_exp(alpha_c, alpha, beta, gamma, pred_r=False)
548
        ss2 = ss_exp(alpha_c, alpha, beta, gamma)
549
        ss3 = ss_exp(alpha_c, alpha, beta, gamma, chrom_open=False)
550
        ss4 = ss_exp(alpha_c, 0, beta, gamma, chrom_open=False)
551
    elif model == 2:
552
        ss1 = ss_exp(alpha_c, alpha, beta, gamma, pred_r=False)
553
        ss2 = ss_exp(alpha_c, alpha, beta, gamma)
554
        ss3 = ss_exp(alpha_c, 0, beta, gamma)
555
        ss4 = ss_exp(alpha_c, 0, beta, gamma, chrom_open=False)
556
    return np.vstack((ss1, ss2, ss3, ss4))
557
558
559
@njit(fastmath=True)
560
def velocity_equations(c, u, s, alpha_c, alpha, beta, gamma, scale_cc=1,
561
                       pred_r=True, chrom_open=True, rna_only=False):
562
    if rna_only:
563
        c = np.full(len(u), 1.0)
564
    if not chrom_open:
565
        alpha_c *= scale_cc
566
        if pred_r:
567
            return -alpha_c * c, alpha * c - beta * u, beta * u - gamma * s
568
        else:
569
            return -alpha_c * c, np.zeros(len(u)), np.zeros(len(u))
570
    else:
571
        if pred_r:
572
            return (alpha_c - alpha_c * c), (alpha * c - beta * u), (beta * u
573
                                                                     - gamma
574
                                                                     * s)
575
        else:
576
            return alpha_c - alpha_c * c, np.zeros(len(u)), np.zeros(len(u))
577
578
579
@njit(locals={
580
            "state0": numba.types.boolean[::1],
581
            "state1": numba.types.boolean[::1],
582
            "state2": numba.types.boolean[::1],
583
            "state3": numba.types.boolean[::1],
584
            "tau1": numba.types.float64[::1],
585
            "tau2": numba.types.float64[::1],
586
            "tau3": numba.types.float64[::1],
587
            "tau4": numba.types.float64[::1],
588
            "exp_list": numba.types.Tuple((numba.types.float64[:, ::1],
589
                                           numba.types.float64[:, ::1],
590
                                           numba.types.float64[:, ::1],
591
                                           numba.types.float64[:, ::1])),
592
            "exp_sw_list": numba.types.Tuple((numba.types.float64[:, ::1],
593
                                              numba.types.float64[:, ::1],
594
                                              numba.types.float64[:, ::1])),
595
            "c": numba.types.float64[::1],
596
            "u": numba.types.float64[::1],
597
            "s": numba.types.float64[::1],
598
            "vc_vec": numba.types.float64[::1],
599
            "vu_vec": numba.types.float64[::1],
600
            "vs_vec": numba.types.float64[::1]
601
    },
602
    fastmath=True)
603
def compute_velocity(t,
604
                     t_sw_array,
605
                     state,
606
                     alpha_c,
607
                     alpha,
608
                     beta,
609
                     gamma,
610
                     rescale_c,
611
                     rescale_u,
612
                     scale_cc=1,
613
                     model=1,
614
                     total_h=20,
615
                     rna_only=False):
616
617
    if state is None:
618
        state0 = t <= t_sw_array[0]
619
        state1 = (t_sw_array[0] < t) & (t <= t_sw_array[1])
620
        state2 = (t_sw_array[1] < t) & (t <= t_sw_array[2])
621
        state3 = t_sw_array[2] < t
622
    else:
623
        state0 = np.equal(state, 0)
624
        state1 = np.equal(state, 1)
625
        state2 = np.equal(state, 2)
626
        state3 = np.equal(state, 3)
627
628
    tau1 = t[state0]
629
    tau2 = t[state1] - t_sw_array[0]
630
    tau3 = t[state2] - t_sw_array[1]
631
    tau4 = t[state3] - t_sw_array[2]
632
    tau_list = [tau1, tau2, tau3, tau4]
633
    switch = np.sum(t_sw_array < total_h)
634
    typed_tau_list = List()
635
    [typed_tau_list.append(x) for x in tau_list]
636
    exp_list, exp_sw_list = generate_exp(typed_tau_list,
637
                                         t_sw_array[:switch],
638
                                         alpha_c,
639
                                         alpha,
640
                                         beta,
641
                                         gamma,
642
                                         model=model,
643
                                         scale_cc=scale_cc,
644
                                         rna_only=rna_only)
645
646
    c = np.empty(len(t))
647
    u = np.empty(len(t))
648
    s = np.empty(len(t))
649
    for i, ii in enumerate([state0, state1, state2, state3]):
650
        if np.any(ii):
651
            c[ii] = exp_list[i][:, 0]
652
            u[ii] = exp_list[i][:, 1]
653
            s[ii] = exp_list[i][:, 2]
654
655
    vc_vec = np.zeros(len(u))
656
    vu_vec = np.zeros(len(u))
657
    vs_vec = np.zeros(len(u))
658
659
    if model == 0:
660
        if np.any(state0):
661
            vc_vec[state0], vu_vec[state0], vs_vec[state0] = \
662
                velocity_equations(c[state0], u[state0], s[state0], alpha_c,
663
                                   alpha, beta, gamma, pred_r=False,
664
                                   scale_cc=scale_cc, rna_only=rna_only)
665
        if np.any(state1):
666
            vc_vec[state1], vu_vec[state1], vs_vec[state1] = \
667
                velocity_equations(c[state1], u[state1], s[state1], alpha_c,
668
                                   alpha, beta, gamma, pred_r=False,
669
                                   chrom_open=False, scale_cc=scale_cc,
670
                                   rna_only=rna_only)
671
        if np.any(state2):
672
            vc_vec[state2], vu_vec[state2], vs_vec[state2] = \
673
                velocity_equations(c[state2], u[state2], s[state2], alpha_c,
674
                                   alpha, beta, gamma, chrom_open=False,
675
                                   scale_cc=scale_cc, rna_only=rna_only)
676
        if np.any(state3):
677
            vc_vec[state3], vu_vec[state3], vs_vec[state3] = \
678
                velocity_equations(c[state3], u[state3], s[state3], alpha_c, 0,
679
                                   beta, gamma, chrom_open=False,
680
                                   scale_cc=scale_cc, rna_only=rna_only)
681
    elif model == 1:
682
        if np.any(state0):
683
            vc_vec[state0], vu_vec[state0], vs_vec[state0] = \
684
                velocity_equations(c[state0], u[state0], s[state0], alpha_c,
685
                                   alpha, beta, gamma, pred_r=False,
686
                                   scale_cc=scale_cc, rna_only=rna_only)
687
        if np.any(state1):
688
            vc_vec[state1], vu_vec[state1], vs_vec[state1] = \
689
                velocity_equations(c[state1], u[state1], s[state1], alpha_c,
690
                                   alpha, beta, gamma, scale_cc=scale_cc,
691
                                   rna_only=rna_only)
692
        if np.any(state2):
693
            vc_vec[state2], vu_vec[state2], vs_vec[state2] = \
694
                velocity_equations(c[state2], u[state2], s[state2], alpha_c,
695
                                   alpha, beta, gamma, chrom_open=False,
696
                                   scale_cc=scale_cc, rna_only=rna_only)
697
        if np.any(state3):
698
            vc_vec[state3], vu_vec[state3], vs_vec[state3] = \
699
                velocity_equations(c[state3], u[state3], s[state3], alpha_c, 0,
700
                                   beta, gamma, chrom_open=False,
701
                                   scale_cc=scale_cc, rna_only=rna_only)
702
    elif model == 2:
703
        if np.any(state0):
704
            vc_vec[state0], vu_vec[state0], vs_vec[state0] = \
705
                velocity_equations(c[state0], u[state0], s[state0], alpha_c,
706
                                   alpha, beta, gamma, pred_r=False,
707
                                   scale_cc=scale_cc, rna_only=rna_only)
708
        if np.any(state1):
709
            vc_vec[state1], vu_vec[state1], vs_vec[state1] = \
710
                velocity_equations(c[state1], u[state1], s[state1], alpha_c,
711
                                   alpha, beta, gamma, scale_cc=scale_cc,
712
                                   rna_only=rna_only)
713
        if np.any(state2):
714
            vc_vec[state2], vu_vec[state2], vs_vec[state2] = \
715
                velocity_equations(c[state2], u[state2], s[state2], alpha_c,
716
                                   0, beta, gamma, scale_cc=scale_cc,
717
                                   rna_only=rna_only)
718
        if np.any(state3):
719
            vc_vec[state3], vu_vec[state3], vs_vec[state3] = \
720
                velocity_equations(c[state3], u[state3], s[state3], alpha_c, 0,
721
                                   beta, gamma, chrom_open=False,
722
                                   scale_cc=scale_cc, rna_only=rna_only)
723
    return vc_vec * rescale_c, vu_vec * rescale_u, vs_vec
724
725
726
def log_valid(x):
727
    return np.log(np.clip(x, 1e-3, 1 - 1e-3))
728
729
730
def approx_tau(u, s, u0, s0, alpha, beta, gamma):
731
    if gamma == beta:
732
        gamma -= 1e-3
733
    u_inf = alpha / beta
734
    if beta > gamma:
735
        b_new = beta / (gamma - beta)
736
        s_inf = alpha / gamma
737
        s_inf_new = s_inf - b_new * u_inf
738
        s_new = s - b_new * u
739
        s0_new = s0 - b_new * u0
740
        tau = -1.0 / gamma * log_valid((s_new - s_inf_new) /
741
                                       (s0_new - s_inf_new))
742
    else:
743
        tau = -1.0 / beta * log_valid((u - u_inf) / (u0 - u_inf))
744
    return tau
745
746
747
def anchor_points(t_sw_array, total_h=20, t=1000, mode='uniform',
748
                  return_time=False):
749
    t_ = np.linspace(0, total_h, t)
750
    tau1 = t_[t_ <= t_sw_array[0]]
751
    tau2 = t_[(t_sw_array[0] < t_) & (t_ <= t_sw_array[1])] - t_sw_array[0]
752
    tau3 = t_[(t_sw_array[1] < t_) & (t_ <= t_sw_array[2])] - t_sw_array[1]
753
    tau4 = t_[t_sw_array[2] < t_] - t_sw_array[2]
754
755
    if mode == 'log':
756
        if len(tau1) > 0:
757
            tau1 = np.expm1(tau1)
758
            tau1 = tau1 / np.max(tau1) * (t_sw_array[0])
759
        if len(tau2) > 0:
760
            tau2 = np.expm1(tau2)
761
            tau2 = tau2 / np.max(tau2) * (t_sw_array[1] - t_sw_array[0])
762
        if len(tau3) > 0:
763
            tau3 = np.expm1(tau3)
764
            tau3 = tau3 / np.max(tau3) * (t_sw_array[2] - t_sw_array[1])
765
        if len(tau4) > 0:
766
            tau4 = np.expm1(tau4)
767
            tau4 = tau4 / np.max(tau4) * (total_h - t_sw_array[2])
768
769
    tau_list = [tau1, tau2, tau3, tau4]
770
    if return_time:
771
        return t_, tau_list
772
    else:
773
        return tau_list
774
775
776
# @jit(nopython=True, fastmath=True, debug=True)
777
def pairwise_distance_square(X, Y):
778
    res = np.empty((X.shape[0], Y.shape[0]), dtype=X.dtype)
779
    for a in range(X.shape[0]):
780
        for b in range(Y.shape[0]):
781
            val = 0.0
782
            for i in range(X.shape[1]):
783
                tmp = X[a, i] - Y[b, i]
784
                val += tmp**2
785
            res[a, b] = val
786
    return res
787
788
789
def calculate_dist_and_time(c, u, s,
790
                            t_sw_array,
791
                            alpha_c, alpha, beta, gamma,
792
                            rescale_c, rescale_u,
793
                            scale_cc=1,
794
                            scale_factor=None,
795
                            model=1,
796
                            conn=None,
797
                            t=1000, k=1,
798
                            direction='complete',
799
                            total_h=20,
800
                            rna_only=False,
801
                            penalize_gap=True,
802
                            all_cells=True):
803
804
    n = len(u)
805
    if scale_factor is None:
806
        scale_factor = np.array([np.std(c), np.std(u), np.std(s)])
807
    tau_list = anchor_points(t_sw_array, total_h, t)
808
    switch = np.sum(t_sw_array < total_h)
809
    typed_tau_list = List()
810
    [typed_tau_list.append(x) for x in tau_list]
811
    alpha_c, alpha, beta, gamma = check_params(alpha_c, alpha, beta, gamma)
812
    exp_list, exp_sw_list = generate_exp(typed_tau_list,
813
                                         t_sw_array[:switch],
814
                                         alpha_c,
815
                                         alpha,
816
                                         beta,
817
                                         gamma,
818
                                         model=model,
819
                                         scale_cc=scale_cc,
820
                                         rna_only=rna_only)
821
    rescale_factor = np.array([rescale_c, rescale_u, 1.0])
822
    exp_list = [x*rescale_factor for x in exp_list]
823
    exp_sw_list = [x*rescale_factor for x in exp_sw_list]
824
    max_c = 0
825
    max_u = 0
826
    max_s = 0
827
    if rna_only:
828
        exp_mat = (np.hstack((np.reshape(u, (-1, 1)), np.reshape(s, (-1, 1))))
829
                   / scale_factor[1:])
830
    else:
831
        exp_mat = np.hstack((np.reshape(c, (-1, 1)), np.reshape(u, (-1, 1)),
832
                             np.reshape(s, (-1, 1)))) / scale_factor
833
834
    dists = np.full((n, 4), np.inf)
835
    taus = np.zeros((n, 4), dtype=u.dtype)
836
    ts = np.zeros((n, 4), dtype=u.dtype)
837
    anchor_exp, anchor_t = None, None
838
839
    for i in range(switch+1):
840
        if not all_cells:
841
            max_ci = (np.max(exp_list[i][:, 0]) if exp_list[i].shape[0] > 0
842
                      else 0)
843
            max_c = max_ci if max_ci > max_c else max_c
844
        max_ui = np.max(exp_list[i][:, 1]) if exp_list[i].shape[0] > 0 else 0
845
        max_u = max_ui if max_ui > max_u else max_u
846
        max_si = np.max(exp_list[i][:, 2]) if exp_list[i].shape[0] > 0 else 0
847
        max_s = max_si if max_si > max_s else max_s
848
849
        skip_phase = False
850
        if direction == 'off':
851
            if (model in [1, 2]) and (i < 2):
852
                skip_phase = True
853
        elif direction == 'on':
854
            if (model in [1, 2]) and (i >= 2):
855
                skip_phase = True
856
        if rna_only and i == 0:
857
            skip_phase = True
858
859
        if not skip_phase:
860
            if rna_only:
861
                tmp = exp_list[i][:, 1:] / scale_factor[1:]
862
            else:
863
                tmp = exp_list[i] / scale_factor
864
            if anchor_exp is None:
865
                anchor_exp = exp_list[i]
866
                anchor_t = (tau_list[i] + t_sw_array[i-1] if i >= 1
867
                            else tau_list[i])
868
            else:
869
                anchor_exp = np.vstack((anchor_exp, exp_list[i]))
870
                anchor_t = np.hstack((anchor_t, tau_list[i] + t_sw_array[i-1]
871
                                      if i >= 1 else tau_list[i]))
872
873
            if not all_cells:
874
                anchor_dist = np.diff(tmp, axis=0, prepend=np.zeros((1, 2))
875
                                      if rna_only else np.zeros((1, 3)))
876
                anchor_dist = np.sqrt((anchor_dist**2).sum(axis=1))
877
                remove_cand = anchor_dist < (0.01*np.max(exp_mat[1])
878
                                             if rna_only
879
                                             else 0.01*np.max(exp_mat[2]))
880
                step_idx = np.arange(0, len(anchor_dist), 1) % 3 > 0
881
                remove_cand &= step_idx
882
                keep_idx = np.where(~remove_cand)[0]
883
                tmp = tmp[keep_idx, :]
884
885
            tree = KDTree(tmp)
886
            dd, ii = tree.query(exp_mat, k=k)
887
            dd = dd**2
888
            if k > 1:
889
                dd = np.mean(dd, axis=1)
890
            if conn is not None:
891
                dd = conn.dot(dd)
892
            dists[:, i] = dd
893
894
            if not all_cells:
895
                ii = keep_idx[ii]
896
            if k == 1:
897
                taus[:, i] = tau_list[i][ii]
898
            else:
899
                for j in range(n):
900
                    taus[j, i] = tau_list[i][ii[j, :]]
901
            ts[:, i] = taus[:, i] + t_sw_array[i-1] if i >= 1 else taus[:, i]
902
903
    min_dist = np.min(dists, axis=1)
904
    state_pred = np.argmin(dists, axis=1)
905
    t_pred = ts[np.arange(n), state_pred]
906
907
    anchor_t1_list = []
908
    anchor_t2_list = []
909
    t_sw_adjust = np.zeros(3, dtype=u.dtype)
910
911
    if direction == 'complete':
912
        t_sorted = np.sort(t_pred)
913
        dt = np.diff(t_sorted, prepend=0)
914
        gap_thresh = 3*np.percentile(dt, 99)
915
        idx = np.where(dt > gap_thresh)[0]
916
        for i in idx:
917
            t1 = t_sorted[i-1] if i > 0 else 0
918
            t2 = t_sorted[i]
919
            anchor_t1 = anchor_exp[np.argmin(np.abs(anchor_t - t1)), :]
920
            anchor_t2 = anchor_exp[np.argmin(np.abs(anchor_t - t2)), :]
921
            if all_cells:
922
                anchor_t1_list.append(np.ravel(anchor_t1))
923
                anchor_t2_list.append(np.ravel(anchor_t2))
924
            if not all_cells:
925
                for j in range(1, switch):
926
                    crit1 = ((t1 > t_sw_array[j-1]) and (t2 > t_sw_array[j-1])
927
                             and (t1 <= t_sw_array[j])
928
                             and (t2 <= t_sw_array[j]))
929
                    crit2 = ((np.abs(anchor_t1[2] - exp_sw_list[j][0, 2])
930
                             < 0.02 * max_s) and
931
                             (np.abs(anchor_t2[2] - exp_sw_list[j][0, 2])
932
                             < 0.01 * max_s))
933
                    crit3 = ((np.abs(anchor_t1[1] - exp_sw_list[j][0, 1])
934
                             < 0.02 * max_u) and
935
                             (np.abs(anchor_t2[1] - exp_sw_list[j][0, 1])
936
                             < 0.01 * max_u))
937
                    crit4 = ((np.abs(anchor_t1[0] - exp_sw_list[j][0, 0])
938
                             < 0.02 * max_c) and
939
                             (np.abs(anchor_t2[0] - exp_sw_list[j][0, 0])
940
                             < 0.01 * max_c))
941
                    if crit1 and crit2 and crit3 and crit4:
942
                        t_sw_adjust[j] += t2 - t1
943
            if penalize_gap:
944
                dist_gap = np.sum(((anchor_t1[1:] - anchor_t2[1:]) /
945
                                   scale_factor[1:])**2)
946
                idx_to_adjust = t_pred >= t2
947
                t_sw_array_ = np.append(t_sw_array, total_h)
948
                state_to_adjust = np.where(t_sw_array_ > t2)[0]
949
                dists[np.ix_(idx_to_adjust, state_to_adjust)] += dist_gap
950
        min_dist = np.min(dists, axis=1)
951
        state_pred = np.argmin(dists, axis=1)
952
        if all_cells:
953
            t_pred = ts[np.arange(n), state_pred]
954
955
    if all_cells:
956
        exp_ss_mat = compute_ss_exp(alpha_c, alpha, beta, gamma, model=model)
957
        if rna_only:
958
            exp_ss_mat[:, 0] = 1
959
        dists_ss = pairwise_distance_square(exp_mat, exp_ss_mat *
960
                                            rescale_factor / scale_factor)
961
962
        reach_ss = np.full((n, 4), False)
963
        for i in range(n):
964
            for j in range(4):
965
                if min_dist[i] > dists_ss[i, j]:
966
                    reach_ss[i, j] = True
967
        late_phase = np.full(n, -1)
968
        for i in range(3):
969
            late_phase[np.abs(t_pred - t_sw_array[i]) < 0.1] = i
970
        return min_dist, t_pred, state_pred, reach_ss, late_phase, max_u, \
971
            max_s, anchor_t1_list, anchor_t2_list
972
    else:
973
        return min_dist, state_pred, max_u, max_s, t_sw_adjust
974
975
976
def t_of_c(alpha_c, k_c, c_o, c, rescale_factor, sw_t):
977
978
    coef = -float(1)/alpha_c
979
980
    c_val = np.clip(c / rescale_factor, a_min=0, a_max=1)
981
982
    in_log = (float(k_c) - c_val) / float((k_c) - (c_o))
983
984
    epsilon = 1e-9
985
986
    return_val = coef * np.log(in_log + epsilon)
987
988
    if k_c == 0:
989
        return_val += sw_t
990
991
    return return_val
992
993
994
def make_X(c, u, s,
995
           max_u,
996
           max_s,
997
           alpha_c, alpha, beta, gamma,
998
           gene_sw_t,
999
           c0, c_sw1, c_sw2, c_sw3,
1000
           u0, u_sw1, u_sw2, u_sw3,
1001
           s0, s_sw1, s_sw2, s_sw3,
1002
           model, direction, state):
1003
1004
    if direction == "complete":
1005
        dire = 0
1006
    elif direction == "on":
1007
        dire = 1
1008
    elif direction == "off":
1009
        dire = 2
1010
1011
    n = c.shape[0]
1012
1013
    epsilon = 1e-5
1014
1015
    if dire == 0:
1016
        x = np.concatenate((np.array([c,
1017
                                      np.log(u + epsilon),
1018
                                      np.log(s + epsilon)]),
1019
                            np.full((n, 17), [np.log(alpha_c + epsilon),
1020
                                              np.log(alpha + epsilon),
1021
                                              np.log(beta + epsilon),
1022
                                              np.log(gamma + epsilon),
1023
                                              c_sw1, c_sw2, c_sw3,
1024
                                              np.log(u_sw2 + epsilon),
1025
                                              np.log(u_sw3 + epsilon),
1026
                                              np.log(s_sw2 + epsilon),
1027
                                              np.log(s_sw3 + epsilon),
1028
                                              np.log(max_u),
1029
                                              np.log(max_s),
1030
                                              gene_sw_t[0],
1031
                                              gene_sw_t[1],
1032
                                              gene_sw_t[2],
1033
                                              model]).T,
1034
                            np.full((n, 1), state).T
1035
                            )).T.astype(np.float32)
1036
1037
    elif dire == 1:
1038
        x = np.concatenate((np.array([c,
1039
                                      np.log(u + epsilon),
1040
                                      np.log(s + epsilon)]),
1041
                            np.full((n, 12), [np.log(alpha_c + epsilon),
1042
                                              np.log(alpha + epsilon),
1043
                                              np.log(beta + epsilon),
1044
                                              np.log(gamma + epsilon),
1045
                                              c_sw1, c_sw2,
1046
                                              np.log(u_sw1 + epsilon),
1047
                                              np.log(u_sw2 + epsilon),
1048
                                              np.log(s_sw1 + epsilon),
1049
                                              np.log(s_sw2 + epsilon),
1050
                                              gene_sw_t[0],
1051
                                              model]).T,
1052
                            np.full((n, 1), state).T
1053
                            )).T.astype(np.float32)
1054
1055
    elif dire == 2:
1056
        if model == 1:
1057
1058
            max_u_t = -(float(1)/alpha_c)*np.log((max_u*beta)
1059
                                                 / (alpha*c0[2]))
1060
1061
            x = np.concatenate((np.array([np.log(c + epsilon),
1062
                                          np.log(u + epsilon),
1063
                                          np.log(s + epsilon)]),
1064
                                np.full((n, 14), [np.log(alpha_c + epsilon),
1065
                                                  np.log(alpha + epsilon),
1066
                                                  np.log(beta + epsilon),
1067
                                                  np.log(gamma + epsilon),
1068
                                                  c_sw2, c_sw3,
1069
                                                  np.log(u_sw2 + epsilon),
1070
                                                  np.log(u_sw3 + epsilon),
1071
                                                  np.log(s_sw2 + epsilon),
1072
                                                  np.log(s_sw3 + epsilon),
1073
                                                  max_u_t,
1074
                                                  np.log(max_u),
1075
                                                  np.log(max_s),
1076
                                                  gene_sw_t[2]]).T,
1077
                                np.full((n, 1), state).T
1078
                                )).T.astype(np.float32)
1079
        elif model == 2:
1080
            x = np.concatenate((np.array([c,
1081
                                          np.log(u + epsilon),
1082
                                          np.log(s + epsilon)]),
1083
                                np.full((n, 12), [np.log(alpha_c + epsilon),
1084
                                                  np.log(alpha + epsilon),
1085
                                                  np.log(beta + epsilon),
1086
                                                  np.log(gamma + epsilon),
1087
                                                  c_sw2, c_sw3,
1088
                                                  np.log(u_sw2 + epsilon),
1089
                                                  np.log(u_sw3 + epsilon),
1090
                                                  np.log(s_sw2 + epsilon),
1091
                                                  np.log(s_sw3 + epsilon),
1092
                                                  np.log(max_u),
1093
                                                  gene_sw_t[2]]).T,
1094
                                np.full((n, 1), state).T
1095
                                )).T.astype(np.float32)
1096
1097
    return x
1098
1099
1100
def calculate_dist_and_time_nn(c, u, s,
1101
                               max_u, max_s,
1102
                               t_sw_array,
1103
                               alpha_c, alpha, beta, gamma,
1104
                               rescale_c, rescale_u,
1105
                               ode_model_0, ode_model_1,
1106
                               ode_model_2_m1, ode_model_2_m2,
1107
                               device,
1108
                               scale_cc=1,
1109
                               scale_factor=None,
1110
                               model=1,
1111
                               conn=None,
1112
                               t=1000, k=1,
1113
                               direction='complete',
1114
                               total_h=20,
1115
                               rna_only=False,
1116
                               penalize_gap=True,
1117
                               all_cells=True):
1118
1119
    rescale_factor = np.array([rescale_c, rescale_u, 1.0])
1120
1121
    exp_list_net, exp_sw_list_net = generate_exp(None,
1122
                                                 t_sw_array,
1123
                                                 alpha_c,
1124
                                                 alpha,
1125
                                                 beta,
1126
                                                 gamma,
1127
                                                 model=model,
1128
                                                 scale_cc=scale_cc,
1129
                                                 rna_only=rna_only)
1130
1131
    N = len(c)
1132
    N_list = np.arange(N)
1133
1134
    if scale_factor is None:
1135
        cur_scale_factor = np.array([np.std(c),
1136
                                     np.std(u),
1137
                                     np.std(s)])
1138
    else:
1139
        cur_scale_factor = scale_factor
1140
1141
    t_pred_per_state = []
1142
    dists_per_state = []
1143
1144
    dire = 0
1145
1146
    if direction == "on":
1147
        states = [0, 1]
1148
        dire = 1
1149
1150
    elif direction == "off":
1151
        states = [2, 3]
1152
        dire = 2
1153
1154
    else:
1155
        states = [0, 1, 2, 3]
1156
        dire = 0
1157
1158
    dists_per_state = np.zeros((N, len(states)))
1159
    t_pred_per_state = np.zeros((N, len(states)))
1160
    u_pred_per_state = np.zeros((N, len(states)))
1161
    s_pred_per_state = np.zeros((N, len(states)))
1162
1163
    increment = 0
1164
1165
    # determine when we can consider u and s close to zero
1166
    zero_us = np.logical_and((u < 0.1 * max_u), (s < 0.1 * max_s))
1167
1168
    t_pred = np.zeros(N)
1169
    dists = None
1170
1171
    # pass all the data through the neural net as each valid state
1172
    for state in states:
1173
1174
        # when u and s = 0, it's better to use the inverse c equation
1175
        # instead of the neural network, which happens for part of
1176
        # state 3 and all of state 0
1177
        inverse_c = np.logical_or(state == 0,
1178
                                  np.logical_and(state == 3, zero_us))
1179
1180
        not_inverse_c = np.logical_not(inverse_c)
1181
1182
        # if we want to use the inverse c equation...
1183
        if np.any(inverse_c):
1184
1185
            # find out at what switch time chromatin closes
1186
            c_sw_t = t_sw_array[int(model)]
1187
1188
            # figure out whether chromatin is opening/closing and what
1189
            # the initial c value is
1190
            if state <= model:
1191
                k_c = 1
1192
                c_0_for_t_guess = 0
1193
            elif state > model:
1194
                k_c = 0
1195
                c_0_for_t_guess = exp_sw_list_net[int(model)][0, 0]
1196
1197
            # calculate predicted time from the inverse c equation
1198
            t_pred[inverse_c] = t_of_c(alpha_c,
1199
                                       k_c, c_0_for_t_guess,
1200
                                       c[inverse_c],
1201
                                       rescale_factor[0],
1202
                                       c_sw_t)
1203
1204
        # if there are points where we want to use the neural network...
1205
        if np.any(not_inverse_c):
1206
1207
            # create an input matrix from the data
1208
            x = make_X(c[not_inverse_c] / rescale_factor[0],
1209
                       u[not_inverse_c] / rescale_factor[1],
1210
                       s[not_inverse_c] / rescale_factor[2],
1211
                       max_u,
1212
                       max_s,
1213
                       alpha_c*(scale_cc if state > model else 1),
1214
                       alpha, beta, gamma,
1215
                       t_sw_array,
1216
                       0,
1217
                       exp_sw_list_net[0][0, 0],
1218
                       exp_sw_list_net[1][0, 0],
1219
                       exp_sw_list_net[2][0, 0],
1220
                       0,
1221
                       exp_sw_list_net[0][0, 1],
1222
                       exp_sw_list_net[1][0, 1],
1223
                       exp_sw_list_net[2][0, 1],
1224
                       0,
1225
                       exp_sw_list_net[0][0, 2],
1226
                       exp_sw_list_net[1][0, 2],
1227
                       exp_sw_list_net[2][0, 2],
1228
                       model, direction, state)
1229
1230
            # do a forward pass
1231
            if dire == 0:
1232
                t_pred_ten = ode_model_0(torch.tensor(x,
1233
                                                      dtype=torch.float,
1234
                                                      device=device)
1235
                                         .reshape(-1, x.shape[1]))
1236
1237
            elif dire == 1:
1238
                t_pred_ten = ode_model_1(torch.tensor(x,
1239
                                                      dtype=torch.float,
1240
                                                      device=device)
1241
                                         .reshape(-1, x.shape[1]))
1242
1243
            elif dire == 2:
1244
                if model == 1:
1245
                    t_pred_ten = ode_model_2_m1(torch.tensor(x,
1246
                                                             dtype=torch.float,
1247
                                                             device=device)
1248
                                                .reshape(-1, x.shape[1]))
1249
                elif model == 2:
1250
                    t_pred_ten = ode_model_2_m2(torch.tensor(x,
1251
                                                             dtype=torch.float,
1252
                                                             device=device)
1253
                                                .reshape(-1, x.shape[1]))
1254
1255
            # make a numpy array out of our tensor of predicted time points
1256
            t_pred[not_inverse_c] = (t_pred_ten.cpu().detach().numpy()
1257
                                     .flatten()*21) - 1
1258
1259
        # calculate tau values from our predicted time points
1260
        if state == 0:
1261
            t_pred = np.clip(t_pred, a_min=0, a_max=t_sw_array[0])
1262
            tau1 = t_pred
1263
            tau2 = []
1264
            tau3 = []
1265
            tau4 = []
1266
        elif state == 1:
1267
            tau1 = []
1268
            t_pred = np.clip(t_pred, a_min=t_sw_array[0], a_max=t_sw_array[1])
1269
            tau2 = t_pred - t_sw_array[0]
1270
            tau3 = []
1271
            tau4 = []
1272
        elif state == 2:
1273
            tau1 = []
1274
            tau2 = []
1275
            t_pred = np.clip(t_pred, a_min=t_sw_array[1], a_max=t_sw_array[2])
1276
            tau3 = t_pred - t_sw_array[1]
1277
            tau4 = []
1278
        elif state == 3:
1279
            tau1 = []
1280
            tau2 = []
1281
            tau3 = []
1282
            t_pred = np.clip(t_pred, a_min=t_sw_array[2], a_max=20)
1283
            tau4 = t_pred - t_sw_array[2]
1284
1285
        tau_list = [tau1, tau2, tau3, tau4]
1286
1287
        valid_vals = []
1288
1289
        for i in range(len(tau_list)):
1290
            if len(tau_list[i]) == 0:
1291
                tau_list[i] = np.array([0.0])
1292
            else:
1293
                valid_vals.append(i)
1294
1295
        # take the time points and get predicted c/u/s values from them
1296
        exp_list, exp_sw_list_2 = generate_exp(tau_list,
1297
                                               t_sw_array,
1298
                                               alpha_c,
1299
                                               alpha,
1300
                                               beta,
1301
                                               gamma,
1302
                                               model=model,
1303
                                               scale_cc=scale_cc,
1304
                                               rna_only=rna_only)
1305
1306
        pred_c = np.concatenate([exp_list[x][:, 0] * rescale_factor[0]
1307
                                 for x in valid_vals])
1308
        pred_u = np.concatenate([exp_list[x][:, 1] * rescale_factor[1]
1309
                                 for x in valid_vals])
1310
        pred_s = np.concatenate([exp_list[x][:, 2] * rescale_factor[2]
1311
                                 for x in valid_vals])
1312
1313
        # calculate distance between predicted and real values
1314
        c_diff = (c - pred_c) / cur_scale_factor[0]
1315
        u_diff = (u - pred_u) / cur_scale_factor[1]
1316
        s_diff = (s - pred_s) / cur_scale_factor[2]
1317
1318
        dists = (c_diff*c_diff) + (u_diff*u_diff) + (s_diff*s_diff)
1319
1320
        if conn is not None:
1321
            dists = conn.dot(dists)
1322
1323
        # store the distances, times, and predicted u and s values for
1324
        # each state
1325
        dists_per_state[:, increment] = dists
1326
        t_pred_per_state[:, increment] = t_pred
1327
        u_pred_per_state[:, increment] = pred_u
1328
        s_pred_per_state[:, increment] = pred_s
1329
1330
        increment += 1
1331
1332
    # whichever state has the smallest distance for a given data point
1333
    # is our predicted state
1334
    state_pred = np.argmin(dists_per_state, axis=1)
1335
1336
    # slice dists and predicted time over the correct state
1337
    dists = dists_per_state[N_list, state_pred]
1338
    t_pred = t_pred_per_state[N_list, state_pred]
1339
1340
    max_t = t_pred.max()
1341
    min_t = t_pred.min()
1342
1343
    penalty = 0
1344
1345
    # for induction and complete genes, add a penalty to ensure that not
1346
    # all points are in state 0
1347
    if direction == "on" or direction == "complete":
1348
1349
        if t_sw_array[0] >= max_t:
1350
            penalty += (t_sw_array[0] - max_t) + 10
1351
1352
    # for induction genes, add a penalty to ensure that predicted time
1353
    # points are not "out of bounds" by being greater than the
1354
    # second switch time
1355
    if direction == "on":
1356
1357
        if min_t > t_sw_array[1]:
1358
            penalty += (min_t - t_sw_array[1]) + 10
1359
1360
    # for repression genes, add a penalty to ensure that predicted time
1361
    # points are not "out of bounds" by being smaller than the
1362
    # second switch time
1363
    if direction == "off":
1364
1365
        if t_sw_array[1] >= max_t:
1366
            penalty += (t_sw_array[1] - max_t) + 10
1367
1368
    # add penalty to ensure that the time points aren't concentrated to
1369
    # one spot
1370
    if np.abs(max_t - min_t) <= 1e-2:
1371
        penalty += np.abs(max_t - min_t) + 10
1372
1373
    # because the indices chosen by np.argmin are just indices,
1374
    # we need to increment by two to get the true state number for
1375
    # our "off" genes (e.g. so that they're in the domain of [2,3] instead
1376
    # of [0,1])
1377
    if direction == "off":
1378
        state_pred += 2
1379
1380
    if all_cells:
1381
        return dists, t_pred, state_pred, max_u, max_s, penalty
1382
    else:
1383
        return dists, state_pred, max_u, max_s, penalty
1384
1385
1386
# @jit(nopython=True, fastmath=True)
1387
def compute_likelihood(c, u, s,
1388
                       t_sw_array,
1389
                       alpha_c, alpha, beta, gamma,
1390
                       rescale_c, rescale_u,
1391
                       t_pred,
1392
                       state_pred,
1393
                       scale_cc=1,
1394
                       scale_factor=None,
1395
                       model=1,
1396
                       weight=None,
1397
                       total_h=20,
1398
                       rna_only=False):
1399
1400
    if weight is None:
1401
        weight = np.full(c.shape, True)
1402
    c_ = c[weight]
1403
    u_ = u[weight]
1404
    s_ = s[weight]
1405
    t_pred_ = t_pred[weight]
1406
    state_pred_ = state_pred[weight]
1407
1408
    n = len(u_)
1409
    if scale_factor is None:
1410
        scale_factor = np.ones(3)
1411
    tau1 = t_pred_[state_pred_ == 0]
1412
    tau2 = t_pred_[state_pred_ == 1] - t_sw_array[0]
1413
    tau3 = t_pred_[state_pred_ == 2] - t_sw_array[1]
1414
    tau4 = t_pred_[state_pred_ == 3] - t_sw_array[2]
1415
    tau_list = [tau1, tau2, tau3, tau4]
1416
    switch = np.sum(t_sw_array < total_h)
1417
    typed_tau_list = List()
1418
    [typed_tau_list.append(x) for x in tau_list]
1419
    alpha_c, alpha, beta, gamma = check_params(alpha_c, alpha, beta, gamma)
1420
    exp_list, _ = generate_exp(typed_tau_list,
1421
                               t_sw_array[:switch],
1422
                               alpha_c,
1423
                               alpha,
1424
                               beta,
1425
                               gamma,
1426
                               model=model,
1427
                               scale_cc=scale_cc,
1428
                               rna_only=rna_only)
1429
    rescale_factor = np.array([rescale_c, rescale_u, 1.0])
1430
    exp_list = [x*rescale_factor*scale_factor for x in exp_list]
1431
    exp_mat = np.hstack((np.reshape(c_, (-1, 1)), np.reshape(u_, (-1, 1)),
1432
                         np.reshape(s_, (-1, 1)))) * scale_factor
1433
    diffs = np.empty((n, 3), dtype=u.dtype)
1434
    likelihood_c = 0
1435
    likelihood_u = 0
1436
    likelihood_s = 0
1437
    ssd_c, var_c = 0, 0
1438
    for i in range(switch+1):
1439
        index = state_pred_ == i
1440
        if np.sum(index) > 0:
1441
            diff = exp_mat[index, :] - exp_list[i]
1442
            diffs[index, :] = diff
1443
    if rna_only:
1444
        diff_u = np.ravel(diffs[:, 0])
1445
        diff_s = np.ravel(diffs[:, 1])
1446
        dist_us = diff_u ** 2 + diff_s ** 2
1447
        var_us = np.var(np.sign(diff_s) * np.sqrt(dist_us))
1448
        nll = (0.5 * np.log(2 * np.pi * var_us) + 0.5 / n /
1449
               var_us * np.sum(dist_us))
1450
    else:
1451
        diff_c = np.ravel(diffs[:, 0])
1452
        diff_u = np.ravel(diffs[:, 1])
1453
        diff_s = np.ravel(diffs[:, 2])
1454
        dist_c = diff_c ** 2
1455
        dist_u = diff_u ** 2
1456
        dist_s = diff_s ** 2
1457
        var_c = np.var(diff_c)
1458
        var_u = np.var(diff_u)
1459
        var_s = np.var(diff_s)
1460
        ssd_c = np.sum(dist_c)
1461
        nll_c = (0.5 * np.log(2 * np.pi * var_c) + 0.5 / n /
1462
                 var_c * np.sum(dist_c))
1463
        nll_u = (0.5 * np.log(2 * np.pi * var_u) + 0.5 / n /
1464
                 var_u * np.sum(dist_u))
1465
        nll_s = (0.5 * np.log(2 * np.pi * var_s) + 0.5 / n /
1466
                 var_s * np.sum(dist_s))
1467
        nll = nll_c + nll_u + nll_s
1468
        likelihood_c = np.exp(-nll_c)
1469
        likelihood_u = np.exp(-nll_u)
1470
        likelihood_s = np.exp(-nll_s)
1471
    likelihood = np.exp(-nll)
1472
    return likelihood, likelihood_c, ssd_c, var_c, likelihood_u, likelihood_s
1473
1474
1475
class ChromatinDynamical:
1476
    def __init__(self, c, u, s,
1477
                 gene=None,
1478
                 model=None,
1479
                 max_iter=10,
1480
                 init_mode="grid",
1481
                 device="cpu",
1482
                 neural_net=False,
1483
                 adam=False,
1484
                 adam_lr=None,
1485
                 adam_beta1=None,
1486
                 adam_beta2=None,
1487
                 batch_size=None,
1488
                 local_std=None,
1489
                 embed_coord=None,
1490
                 connectivities=None,
1491
                 plot=False,
1492
                 save_plot=False,
1493
                 plot_dir=None,
1494
                 fit_args=None,
1495
                 partial=None,
1496
                 direction=None,
1497
                 rna_only=False,
1498
                 fit_decoupling=True,
1499
                 extra_color=None,
1500
                 rescale_u=None,
1501
                 alpha=None,
1502
                 beta=None,
1503
                 gamma=None,
1504
                 t_=None
1505
                 ):
1506
1507
        self.device = device
1508
        self.gene = gene
1509
        self.local_std = local_std
1510
        self.conn = connectivities
1511
1512
        self.neural_net = neural_net
1513
        self.adam = adam
1514
        self.adam_lr = adam_lr
1515
        self.adam_beta1 = adam_beta1
1516
        self.adam_beta2 = adam_beta2
1517
        self.batch_size = batch_size
1518
1519
        self.torch_type = type(u[0].item())
1520
1521
        # fitting arguments
1522
        self.init_mode = init_mode
1523
        self.rna_only = rna_only
1524
        self.fit_decoupling = fit_decoupling
1525
        self.max_iter = max_iter
1526
        self.n_anchors = np.clip(int(fit_args['t']), 201, 2000)
1527
        self.k_dist = np.clip(int(fit_args['k']), 1, 20)
1528
        self.tm = np.clip(fit_args['thresh_multiplier'], 0.4, 2)
1529
        self.weight_c = np.clip(fit_args['weight_c'], 0.1, 5)
1530
        self.outlier = np.clip(fit_args['outlier'], 80, 100)
1531
        self.model = int(model) if isinstance(model, float) else model
1532
        self.model_ = None
1533
        if self.model == 0 and self.init_mode == 'invert':
1534
            self.init_mode = 'grid'
1535
1536
        # plot parameters
1537
        self.plot = plot
1538
        self.save_plot = save_plot
1539
        self.extra_color = extra_color
1540
        self.fig_size = fit_args['fig_size']
1541
        self.point_size = fit_args['point_size']
1542
        if plot_dir is None:
1543
            self.plot_path = 'rna_plots' if self.rna_only else 'plots'
1544
        else:
1545
            self.plot_path = plot_dir
1546
        self.color = ['tab:red', 'tab:orange', 'tab:green', 'tab:blue']
1547
        self.fig = None
1548
        self.ax = None
1549
1550
        # input
1551
        self.total_n = len(u)
1552
        if sparse.issparse(c):
1553
            c = c.A
1554
        if sparse.issparse(u):
1555
            u = u.A
1556
        if sparse.issparse(s):
1557
            s = s.A
1558
        self.c_all = np.ravel(np.array(c, dtype=np.float64))
1559
        self.u_all = np.ravel(np.array(u, dtype=np.float64))
1560
        self.s_all = np.ravel(np.array(s, dtype=np.float64))
1561
1562
        # adjust offset
1563
        self.offset_c, self.offset_u, self.offset_s = np.min(self.c_all), \
1564
            np.min(self.u_all), np.min(self.s_all)
1565
        self.offset_c = 0 if self.rna_only else self.offset_c
1566
        self.c_all -= self.offset_c
1567
        self.u_all -= self.offset_u
1568
        self.s_all -= self.offset_s
1569
        # remove zero counts
1570
        self.non_zero = (np.ravel(self.c_all > 0) | np.ravel(self.u_all > 0) |
1571
                         np.ravel(self.s_all > 0))
1572
        # remove outliers
1573
        self.non_outlier = np.ravel(self.c_all <= np.percentile(self.c_all,
1574
                                                                self.outlier))
1575
        self.non_outlier &= np.ravel(self.u_all <= np.percentile(self.u_all,
1576
                                                                 self.outlier))
1577
        self.non_outlier &= np.ravel(self.s_all <= np.percentile(self.s_all,
1578
                                                                 self.outlier))
1579
        self.c = self.c_all[self.non_zero & self.non_outlier]
1580
        self.u = self.u_all[self.non_zero & self.non_outlier]
1581
        self.s = self.s_all[self.non_zero & self.non_outlier]
1582
        self.low_quality = len(self.u) < 10
1583
        # scale modalities
1584
        self.std_c, self.std_u, self.std_s = (np.std(self.c_all)
1585
                                              if not self.rna_only
1586
                                              else 1.0, np.std(self.u_all),
1587
                                              np.std(self.s_all))
1588
        if self.std_u == 0 or self.std_s == 0:
1589
            self.low_quality = True
1590
        self.scale_c, self.scale_u, self.scale_s = np.max(self.c_all) \
1591
            if not self.rna_only else 1.0, self.std_u/self.std_s, 1.0
1592
1593
        # if we're on neural net mode, check to see if c is way bigger than
1594
        # u or s, which would be very hard for the neural net to fit
1595
        if not self.low_quality and neural_net:
1596
            max_c_orig = np.max(self.c)
1597
            if max_c_orig / np.max(self.u) > 500:
1598
                self.low_quality = True
1599
1600
            if not self.low_quality:
1601
                if max_c_orig / np.max(self.s) > 500:
1602
                    self.low_quality = True
1603
1604
        self.c_all /= self.scale_c
1605
        self.u_all /= self.scale_u
1606
        self.s_all /= self.scale_s
1607
        self.c /= self.scale_c
1608
        self.u /= self.scale_u
1609
        self.s /= self.scale_s
1610
        self.scale_factor = np.array([np.std(self.c_all) / self.std_s /
1611
                                      self.weight_c, 1.0, 1.0])
1612
        self.scale_factor[0] = 1 if self.rna_only else self.scale_factor[0]
1613
        self.max_u, self.max_s = np.max(self.u), np.max(self.s)
1614
        self.max_u_all, self.max_s_all = np.max(self.u_all), np.max(self.s_all)
1615
        if self.conn is not None:
1616
            self.conn_sub = self.conn[np.ix_(self.non_zero & self.non_outlier,
1617
                                             self.non_zero & self.non_outlier)]
1618
        else:
1619
            self.conn_sub = None
1620
1621
        logg.update(f'{len(self.u)} cells passed filter and will be used to '
1622
                    'compute trajectories.', v=2)
1623
        self.known_pars = (True
1624
                           if None not in [rescale_u, alpha, beta, gamma, t_]
1625
                           else False)
1626
        if self.known_pars:
1627
            logg.update(f'known parameters for gene {self.gene} are '
1628
                        f'scaling={rescale_u}, alpha={alpha}, beta={beta},'
1629
                        f' gamma={gamma}, t_={t_}.', v=1)
1630
1631
        # define neural networks
1632
        self.ode_model_0 = nn.Sequential(
1633
            nn.Linear(21, 150),
1634
            nn.ReLU(),
1635
            nn.Linear(150, 112),
1636
            nn.ReLU(),
1637
            nn.Linear(112, 75),
1638
            nn.ReLU(),
1639
            nn.Linear(75, 1),
1640
            nn.Sigmoid()
1641
        )
1642
1643
        self.ode_model_1 = nn.Sequential(
1644
            nn.Linear(16, 64),
1645
            nn.ReLU(),
1646
            nn.Linear(64, 48),
1647
            nn.ReLU(),
1648
            nn.Linear(48, 32),
1649
            nn.ReLU(),
1650
            nn.Linear(32, 1),
1651
            nn.Sigmoid()
1652
        )
1653
1654
        self.ode_model_2_m1 = nn.Sequential(
1655
            nn.Linear(18, 220),
1656
            nn.ReLU(),
1657
            nn.Linear(220, 165),
1658
            nn.ReLU(),
1659
            nn.Linear(165, 110),
1660
            nn.ReLU(),
1661
            nn.Linear(110, 1),
1662
            nn.Sigmoid()
1663
        )
1664
1665
        self.ode_model_2_m2 = nn.Sequential(
1666
            nn.Linear(16, 150),
1667
            nn.ReLU(),
1668
            nn.Linear(150, 112),
1669
            nn.ReLU(),
1670
            nn.Linear(112, 75),
1671
            nn.ReLU(),
1672
            nn.Linear(75, 1),
1673
            nn.Sigmoid()
1674
        )
1675
1676
        self.ode_model_0.to(torch.device(self.device))
1677
        self.ode_model_1.to(torch.device(self.device))
1678
        self.ode_model_2_m1.to(torch.device(self.device))
1679
        self.ode_model_2_m2.to(torch.device(self.device))
1680
1681
        # load in neural network
1682
        net_path = os.path.dirname(os.path.abspath(__file__)) + \
1683
            "/neural_nets/"
1684
1685
        self.ode_model_0.load_state_dict(torch.load(net_path+"dir0.pt"))
1686
        self.ode_model_1.load_state_dict(torch.load(net_path+"dir1.pt"))
1687
        self.ode_model_2_m1.load_state_dict(torch.load(net_path+"dir2_m1.pt"))
1688
        self.ode_model_2_m2.load_state_dict(torch.load(net_path+"dir2_m2.pt"))
1689
1690
        # 4 rate parameters
1691
        self.alpha_c = 0.1
1692
        self.alpha = alpha if alpha is not None else 0.0
1693
        self.beta = beta if beta is not None else 0.0
1694
        self.gamma = gamma if gamma is not None else 0.0
1695
        # 3 possible switch time points
1696
        self.t_sw_1 = 0.1 if t_ is not None else 0.0
1697
        self.t_sw_2 = t_+0.1 if t_ is not None else 0.0
1698
        self.t_sw_3 = 20.0 if t_ is not None else 0.0
1699
        # 2 rescale factors
1700
        self.rescale_c = 1.0
1701
        self.rescale_u = rescale_u if rescale_u is not None else 1.0
1702
        self.rates = None
1703
        self.t_sw_array = None
1704
        self.fit_rescale = True if rescale_u is None else False
1705
        self.params = None
1706
1707
        # other parameters or results
1708
        self.t = None
1709
        self.state = None
1710
        self.loss = [np.inf]
1711
        self.likelihood = -1.0
1712
        self.l_c = 0
1713
        self.ssd_c, self.var_c = 0, 0
1714
        self.scale_cc = 1.0
1715
        self.fitting_flag_ = 0
1716
        self.velocity = None
1717
        self.anchor_t1_list, self.anchor_t2_list = None, None
1718
        self.anchor_exp = None
1719
        self.anchor_exp_sw = None
1720
        self.anchor_min_idx, self.anchor_max_idx, self.anchor_velo_min_idx, \
1721
            self.anchor_velo_max_idx = None, None, None, None
1722
        self.anchor_velo = None
1723
        self.c0 = self.u0 = self.s0 = 0.0
1724
        self.realign_ratio = 1.0
1725
        self.partial = False
1726
        self.direction = 'complete'
1727
        self.steady_state_func = None
1728
1729
        # for fit and update
1730
        self.cur_iter = 0
1731
        self.cur_loss = None
1732
        self.cur_state_pred = None
1733
        self.cur_t_sw_adjust = None
1734
1735
        # partial checking and model examination
1736
        determine_model = model is None
1737
        if partial is None and direction is None:
1738
            if embed_coord is not None:
1739
                self.embed_coord = embed_coord[self.non_zero &
1740
                                               self.non_outlier]
1741
            else:
1742
                self.embed_coord = None
1743
            self.check_partial_trajectory(determine_model=determine_model)
1744
        elif direction is not None:
1745
            self.direction = direction
1746
            if direction in ['on', 'off']:
1747
                self.partial = True
1748
            else:
1749
                self.partial = False
1750
            self.check_partial_trajectory(fit_gmm=False, fit_slope=False,
1751
                                          determine_model=determine_model)
1752
        elif partial is not None:
1753
            self.partial = partial
1754
            self.check_partial_trajectory(fit_gmm=False,
1755
                                          determine_model=determine_model)
1756
        else:
1757
            self.check_partial_trajectory(fit_gmm=False, fit_slope=False,
1758
                                          determine_model=determine_model)
1759
1760
        # intialize steady state parameters
1761
        if not self.known_pars and not self.low_quality:
1762
            self.initialize_steady_state_params(model_mismatch=self.model
1763
                                                != self.model_)
1764
        if self.known_pars:
1765
            self.params = np.array([self.t_sw_1,
1766
                                    self.t_sw_2-self.t_sw_1,
1767
                                    self.t_sw_3-self.t_sw_2,
1768
                                    self.alpha_c,
1769
                                    self.alpha,
1770
                                    self.beta,
1771
                                    self.gamma,
1772
                                    self.scale_cc,
1773
                                    self.rescale_c,
1774
                                    self.rescale_u])
1775
1776
    # the torch tensor version of the anchor points function
1777
    def anchor_points_ten(self, t_sw_array, total_h=20, t=1000, mode='uniform',
1778
                          return_time=False):
1779
1780
        t_ = torch.linspace(0, total_h, t, device=self.device,
1781
                            dtype=self.torch_type)
1782
        tau1 = t_[t_ <= t_sw_array[0]]
1783
        tau2 = t_[(t_sw_array[0] < t_) & (t_ <= t_sw_array[1])] - t_sw_array[0]
1784
        tau3 = t_[(t_sw_array[1] < t_) & (t_ <= t_sw_array[2])] - t_sw_array[1]
1785
        tau4 = t_[t_sw_array[2] < t_] - t_sw_array[2]
1786
1787
        if mode == 'log':
1788
            if len(tau1) > 0:
1789
                tau1 = torch.expm1(tau1)
1790
                tau1 = tau1 / torch.max(tau1) * (t_sw_array[0])
1791
            if len(tau2) > 0:
1792
                tau2 = torch.expm1(tau2)
1793
                tau2 = tau2 / torch.max(tau2) * (t_sw_array[1] - t_sw_array[0])
1794
            if len(tau3) > 0:
1795
                tau3 = torch.expm1(tau3)
1796
                tau3 = tau3 / torch.max(tau3) * (t_sw_array[2] - t_sw_array[1])
1797
            if len(tau4) > 0:
1798
                tau4 = torch.expm1(tau4)
1799
                tau4 = tau4 / torch.max(tau4) * (total_h - t_sw_array[2])
1800
1801
        tau_list = [tau1, tau2, tau3, tau4]
1802
        if return_time:
1803
            return t_, tau_list
1804
        else:
1805
            return tau_list
1806
1807
    # the torch version of the predict_exp function
1808
    def predict_exp_ten(self,
1809
                        tau,
1810
                        c0,
1811
                        u0,
1812
                        s0,
1813
                        alpha_c,
1814
                        alpha,
1815
                        beta,
1816
                        gamma,
1817
                        scale_cc=None,
1818
                        pred_r=True,
1819
                        chrom_open=True,
1820
                        backward=False,
1821
                        rna_only=False):
1822
1823
        if scale_cc is None:
1824
            scale_cc = torch.tensor(1.0, requires_grad=True,
1825
                                    device=self.device,
1826
                                    dtype=self.torch_type)
1827
1828
        if len(tau) == 0:
1829
            return torch.empty((0, 3),
1830
                               requires_grad=True,
1831
                               device=self.device,
1832
                               dtype=self.torch_type)
1833
        if backward:
1834
            tau = -tau
1835
1836
        eat = torch.exp(-alpha_c * tau)
1837
        ebt = torch.exp(-beta * tau)
1838
        egt = torch.exp(-gamma * tau)
1839
        if rna_only:
1840
            kc = 1
1841
            c0 = 1
1842
        else:
1843
            if chrom_open:
1844
                kc = 1
1845
            else:
1846
                kc = 0
1847
                alpha_c = alpha_c * scale_cc
1848
1849
        const = (kc - c0) * alpha / (beta - alpha_c)
1850
1851
        res0 = kc - (kc - c0) * eat
1852
1853
        if pred_r:
1854
1855
            res1 = u0 * ebt + (alpha * kc / beta) * (1 - ebt)
1856
            res1 += const * (ebt - eat)
1857
1858
            res2 = s0 * egt + (alpha * kc / gamma) * (1 - egt)
1859
            res2 += ((beta / (gamma - beta)) *
1860
                     ((alpha * kc / beta) - u0 - const) * (egt - ebt))
1861
            res2 += (beta / (gamma - alpha_c)) * const * (egt - eat)
1862
1863
        else:
1864
            res1 = torch.zeros(len(tau), device=self.device,
1865
                               requires_grad=True,
1866
                               dtype=self.torch_type)
1867
            res2 = torch.zeros(len(tau), device=self.device,
1868
                               requires_grad=True,
1869
                               dtype=self.torch_type)
1870
1871
        res = torch.stack((res0, res1, res2), 1)
1872
1873
        return res
1874
1875
    # the torch tensor version of the generate_exp function
1876
    def generate_exp_tens(self,
1877
                          tau_list,
1878
                          t_sw_array,
1879
                          alpha_c,
1880
                          alpha,
1881
                          beta,
1882
                          gamma,
1883
                          scale_cc=None,
1884
                          model=1,
1885
                          rna_only=False):
1886
1887
        if scale_cc is None:
1888
            scale_cc = torch.tensor(1.0, requires_grad=True,
1889
                                    device=self.device,
1890
                                    dtype=self.torch_type)
1891
1892
        if beta == alpha_c:
1893
            beta += 1e-3
1894
        if gamma == beta or gamma == alpha_c:
1895
            gamma += 1e-3
1896
        switch = int(t_sw_array.size(dim=0))
1897
        if switch >= 1:
1898
            tau_sw1 = torch.tensor([t_sw_array[0]], requires_grad=True,
1899
                                   device=self.device,
1900
                                   dtype=self.torch_type)
1901
            if switch >= 2:
1902
                tau_sw2 = torch.tensor([t_sw_array[1] - t_sw_array[0]],
1903
                                       requires_grad=True,
1904
                                       device=self.device,
1905
                                       dtype=self.torch_type)
1906
                if switch == 3:
1907
                    tau_sw3 = torch.tensor([t_sw_array[2] - t_sw_array[1]],
1908
                                           requires_grad=True,
1909
                                           device=self.device,
1910
                                           dtype=self.torch_type)
1911
        exp_sw1, exp_sw2, exp_sw3 = (torch.empty((0, 3),
1912
                                                 requires_grad=True,
1913
                                                 device=self.device,
1914
                                                 dtype=self.torch_type),
1915
                                     torch.empty((0, 3),
1916
                                                 requires_grad=True,
1917
                                                 device=self.device,
1918
                                                 dtype=self.torch_type),
1919
                                     torch.empty((0, 3),
1920
                                                 requires_grad=True,
1921
                                                 device=self.device,
1922
                                                 dtype=self.torch_type))
1923
        if tau_list is None:
1924
            if model == 0:
1925
                if switch >= 1:
1926
                    exp_sw1 = self.predict_exp_ten(tau_sw1, 0, 0, 0, alpha_c,
1927
                                                   alpha, beta, gamma,
1928
                                                   pred_r=False,
1929
                                                   scale_cc=scale_cc,
1930
                                                   rna_only=rna_only)
1931
                    if switch >= 2:
1932
                        exp_sw2 = self.predict_exp_ten(tau_sw2, exp_sw1[0, 0],
1933
                                                       exp_sw1[0, 1],
1934
                                                       exp_sw1[0, 2],
1935
                                                       alpha_c, alpha, beta,
1936
                                                       gamma, pred_r=False,
1937
                                                       chrom_open=False,
1938
                                                       scale_cc=scale_cc,
1939
                                                       rna_only=rna_only)
1940
                        if switch >= 3:
1941
                            exp_sw3 = self.predict_exp_ten(tau_sw3,
1942
                                                           exp_sw2[0, 0],
1943
                                                           exp_sw2[0, 1],
1944
                                                           exp_sw2[0, 2],
1945
                                                           alpha_c, alpha,
1946
                                                           beta, gamma,
1947
                                                           chrom_open=False,
1948
                                                           scale_cc=scale_cc,
1949
                                                           rna_only=rna_only)
1950
            elif model == 1:
1951
                if switch >= 1:
1952
                    exp_sw1 = self.predict_exp_ten(tau_sw1, 0, 0, 0, alpha_c,
1953
                                                   alpha, beta, gamma,
1954
                                                   pred_r=False,
1955
                                                   scale_cc=scale_cc,
1956
                                                   rna_only=rna_only)
1957
                    if switch >= 2:
1958
                        exp_sw2 = self.predict_exp_ten(tau_sw2, exp_sw1[0, 0],
1959
                                                       exp_sw1[0, 1],
1960
                                                       exp_sw1[0, 2],
1961
                                                       alpha_c, alpha,
1962
                                                       beta, gamma,
1963
                                                       scale_cc=scale_cc,
1964
                                                       rna_only=rna_only)
1965
                        if switch >= 3:
1966
                            exp_sw3 = self.predict_exp_ten(tau_sw3,
1967
                                                           exp_sw2[0, 0],
1968
                                                           exp_sw2[0, 1],
1969
                                                           exp_sw2[0, 2],
1970
                                                           alpha_c, alpha,
1971
                                                           beta, gamma,
1972
                                                           chrom_open=False,
1973
                                                           scale_cc=scale_cc,
1974
                                                           rna_only=rna_only)
1975
            elif model == 2:
1976
                if switch >= 1:
1977
                    exp_sw1 = self.predict_exp_ten(tau_sw1, 0, 0, 0, alpha_c,
1978
                                                   alpha, beta, gamma,
1979
                                                   pred_r=False,
1980
                                                   scale_cc=scale_cc,
1981
                                                   rna_only=rna_only)
1982
                    if switch >= 2:
1983
                        exp_sw2 = self.predict_exp_ten(tau_sw2, exp_sw1[0, 0],
1984
                                                       exp_sw1[0, 1],
1985
                                                       exp_sw1[0, 2], alpha_c,
1986
                                                       alpha, beta, gamma,
1987
                                                       scale_cc=scale_cc,
1988
                                                       rna_only=rna_only)
1989
                        if switch >= 3:
1990
                            exp_sw3 = self.predict_exp_ten(tau_sw3,
1991
                                                           exp_sw2[0, 0],
1992
                                                           exp_sw2[0, 1],
1993
                                                           exp_sw2[0, 2],
1994
                                                           alpha_c, 0, beta,
1995
                                                           gamma,
1996
                                                           scale_cc=scale_cc,
1997
                                                           rna_only=rna_only)
1998
1999
            return [torch.empty((0, 3), requires_grad=True,
2000
                                device=self.device,
2001
                                dtype=self.torch_type),
2002
                    torch.empty((0, 3), requires_grad=True,
2003
                                device=self.device,
2004
                                dtype=self.torch_type),
2005
                    torch.empty((0, 3), requires_grad=True,
2006
                                device=self.device,
2007
                                dtype=self.torch_type),
2008
                    torch.empty((0, 3), requires_grad=True,
2009
                                device=self.device,
2010
                                dtype=self.torch_type)], \
2011
                   [exp_sw1, exp_sw2, exp_sw3]
2012
2013
        tau1 = tau_list[0]
2014
        if switch >= 1:
2015
            tau2 = tau_list[1]
2016
            if switch >= 2:
2017
                tau3 = tau_list[2]
2018
                if switch == 3:
2019
                    tau4 = tau_list[3]
2020
        exp1, exp2, exp3, exp4 = (torch.empty((0, 3), requires_grad=True,
2021
                                              device=self.device,
2022
                                              dtype=self.torch_type),
2023
                                  torch.empty((0, 3), requires_grad=True,
2024
                                              device=self.device,
2025
                                              dtype=self.torch_type),
2026
                                  torch.empty((0, 3), requires_grad=True,
2027
                                              device=self.device,
2028
                                              dtype=self.torch_type),
2029
                                  torch.empty((0, 3), requires_grad=True,
2030
                                              device=self.device,
2031
                                              dtype=self.torch_type))
2032
        if model == 0:
2033
            exp1 = self.predict_exp_ten(tau1, 0, 0, 0, alpha_c, alpha, beta,
2034
                                        gamma, pred_r=False, scale_cc=scale_cc,
2035
                                        rna_only=rna_only)
2036
            if switch >= 1:
2037
                exp_sw1 = self.predict_exp_ten(tau_sw1, 0, 0, 0, alpha_c,
2038
                                               alpha, beta, gamma,
2039
                                               pred_r=False, scale_cc=scale_cc,
2040
                                               rna_only=rna_only)
2041
                exp2 = self.predict_exp_ten(tau2, exp_sw1[0, 0], exp_sw1[0, 1],
2042
                                            exp_sw1[0, 2], alpha_c, alpha,
2043
                                            beta, gamma, pred_r=False,
2044
                                            chrom_open=False,
2045
                                            scale_cc=scale_cc,
2046
                                            rna_only=rna_only)
2047
                if switch >= 2:
2048
                    exp_sw2 = self.predict_exp_ten(tau_sw2, exp_sw1[0, 0],
2049
                                                   exp_sw1[0, 1],
2050
                                                   exp_sw1[0, 2],
2051
                                                   alpha_c, alpha, beta, gamma,
2052
                                                   pred_r=False,
2053
                                                   chrom_open=False,
2054
                                                   scale_cc=scale_cc,
2055
                                                   rna_only=rna_only)
2056
                    exp3 = self.predict_exp_ten(tau3, exp_sw2[0, 0],
2057
                                                exp_sw2[0, 1], exp_sw2[0, 2],
2058
                                                alpha_c, alpha, beta, gamma,
2059
                                                chrom_open=False,
2060
                                                scale_cc=scale_cc,
2061
                                                rna_only=rna_only)
2062
                    if switch == 3:
2063
                        exp_sw3 = self.predict_exp_ten(tau_sw3, exp_sw2[0, 0],
2064
                                                       exp_sw2[0, 1],
2065
                                                       exp_sw2[0, 2],
2066
                                                       alpha_c, alpha, beta,
2067
                                                       gamma,
2068
                                                       chrom_open=False,
2069
                                                       scale_cc=scale_cc,
2070
                                                       rna_only=rna_only)
2071
                        exp4 = self.predict_exp_ten(tau4, exp_sw3[0, 0],
2072
                                                    exp_sw3[0, 1],
2073
                                                    exp_sw3[0, 2],
2074
                                                    alpha_c, 0, beta, gamma,
2075
                                                    chrom_open=False,
2076
                                                    scale_cc=scale_cc,
2077
                                                    rna_only=rna_only)
2078
        elif model == 1:
2079
            exp1 = self.predict_exp_ten(tau1, 0, 0, 0, alpha_c, alpha, beta,
2080
                                        gamma, pred_r=False, scale_cc=scale_cc,
2081
                                        rna_only=rna_only)
2082
            if switch >= 1:
2083
                exp_sw1 = self.predict_exp_ten(tau_sw1, 0, 0, 0, alpha_c,
2084
                                               alpha, beta, gamma,
2085
                                               pred_r=False, scale_cc=scale_cc,
2086
                                               rna_only=rna_only)
2087
                exp2 = self.predict_exp_ten(tau2, exp_sw1[0, 0], exp_sw1[0, 1],
2088
                                            exp_sw1[0, 2], alpha_c, alpha,
2089
                                            beta, gamma, scale_cc=scale_cc,
2090
                                            rna_only=rna_only)
2091
                if switch >= 2:
2092
                    exp_sw2 = self.predict_exp_ten(tau_sw2, exp_sw1[0, 0],
2093
                                                   exp_sw1[0, 1],
2094
                                                   exp_sw1[0, 2], alpha_c,
2095
                                                   alpha, beta, gamma,
2096
                                                   scale_cc=scale_cc,
2097
                                                   rna_only=rna_only)
2098
                    exp3 = self.predict_exp_ten(tau3, exp_sw2[0, 0],
2099
                                                exp_sw2[0, 1], exp_sw2[0, 2],
2100
                                                alpha_c, alpha, beta, gamma,
2101
                                                chrom_open=False,
2102
                                                scale_cc=scale_cc,
2103
                                                rna_only=rna_only)
2104
                    if switch == 3:
2105
                        exp_sw3 = self.predict_exp_ten(tau_sw3, exp_sw2[0, 0],
2106
                                                       exp_sw2[0, 1],
2107
                                                       exp_sw2[0, 2],
2108
                                                       alpha_c, alpha, beta,
2109
                                                       gamma,
2110
                                                       chrom_open=False,
2111
                                                       scale_cc=scale_cc,
2112
                                                       rna_only=rna_only)
2113
                        exp4 = self.predict_exp_ten(tau4, exp_sw3[0, 0],
2114
                                                    exp_sw3[0, 1],
2115
                                                    exp_sw3[0, 2], alpha_c, 0,
2116
                                                    beta, gamma,
2117
                                                    chrom_open=False,
2118
                                                    scale_cc=scale_cc,
2119
                                                    rna_only=rna_only)
2120
        elif model == 2:
2121
            exp1 = self.predict_exp_ten(tau1, 0, 0, 0, alpha_c, alpha, beta,
2122
                                        gamma, pred_r=False, scale_cc=scale_cc,
2123
                                        rna_only=rna_only)
2124
            if switch >= 1:
2125
                exp_sw1 = self.predict_exp_ten(tau_sw1, 0, 0, 0, alpha_c,
2126
                                               alpha, beta, gamma,
2127
                                               pred_r=False, scale_cc=scale_cc,
2128
                                               rna_only=rna_only)
2129
                exp2 = self.predict_exp_ten(tau2, exp_sw1[0, 0], exp_sw1[0, 1],
2130
                                            exp_sw1[0, 2], alpha_c, alpha,
2131
                                            beta, gamma, scale_cc=scale_cc,
2132
                                            rna_only=rna_only)
2133
                if switch >= 2:
2134
                    exp_sw2 = self.predict_exp_ten(tau_sw2, exp_sw1[0, 0],
2135
                                                   exp_sw1[0, 1],
2136
                                                   exp_sw1[0, 2], alpha_c,
2137
                                                   alpha, beta, gamma,
2138
                                                   scale_cc=scale_cc,
2139
                                                   rna_only=rna_only)
2140
                    exp3 = self.predict_exp_ten(tau3, exp_sw2[0, 0],
2141
                                                exp_sw2[0, 1],
2142
                                                exp_sw2[0, 2], alpha_c, 0,
2143
                                                beta, gamma, scale_cc=scale_cc,
2144
                                                rna_only=rna_only)
2145
                    if switch == 3:
2146
                        exp_sw3 = self.predict_exp_ten(tau_sw3, exp_sw2[0, 0],
2147
                                                       exp_sw2[0, 1],
2148
                                                       exp_sw2[0, 2],
2149
                                                       alpha_c, 0, beta, gamma,
2150
                                                       scale_cc=scale_cc,
2151
                                                       rna_only=rna_only)
2152
                        exp4 = self.predict_exp_ten(tau4, exp_sw3[0, 0],
2153
                                                    exp_sw3[0, 1],
2154
                                                    exp_sw3[0, 2],
2155
                                                    alpha_c, 0, beta, gamma,
2156
                                                    chrom_open=False,
2157
                                                    scale_cc=scale_cc,
2158
                                                    rna_only=rna_only)
2159
        return [exp1, exp2, exp3, exp4], [exp_sw1, exp_sw2, exp_sw3]
2160
2161
    def check_partial_trajectory(self, fit_gmm=True, fit_slope=True,
2162
                                 determine_model=True):
2163
        w_non_zero = ((self.c >= 0.1 * np.max(self.c)) &
2164
                      (self.u >= 0.1 * np.max(self.u)) &
2165
                      (self.s >= 0.1 * np.max(self.s)))
2166
        u_non_zero = self.u[w_non_zero]
2167
        s_non_zero = self.s[w_non_zero]
2168
        if len(u_non_zero) < 10:
2169
            self.low_quality = True
2170
            return
2171
2172
        # GMM
2173
        w_low = ((np.percentile(s_non_zero, 30) <= s_non_zero) &
2174
                 (s_non_zero <= np.percentile(s_non_zero, 40)))
2175
        if np.sum(w_low) < 10:
2176
            fit_gmm = False
2177
            self.partial = True
2178
        if self.local_std is None:
2179
            logg.warn('local standard deviation not provided. '
2180
                      'Skipping GMM..', v=2)
2181
        if self.embed_coord is None:
2182
            logg.warn('Warning: embedded coordinates not provided. '
2183
                      'Skipping GMM..')
2184
        if (fit_gmm and self.local_std is not None and self.embed_coord
2185
                is not None):
2186
2187
            pdist = pairwise_distances(
2188
                self.embed_coord[w_non_zero, :][w_low, :])
2189
            dists = (np.ravel(pdist[np.triu_indices_from(pdist, k=1)])
2190
                     .reshape(-1, 1))
2191
            model = GaussianMixture(n_components=2, covariance_type='tied',
2192
                                    random_state=2021).fit(dists)
2193
            mean_diff = np.abs(model.means_[1][0] - model.means_[0][0])
2194
            criterion1 = mean_diff > self.local_std / self.tm
2195
            logg.update(f'GMM: difference between means = {mean_diff}, '
2196
                        f'threshold = {self.local_std / self.tm}.', v=2)
2197
            criterion2 = np.all(model.weights_[1] > 0.2 / self.tm)
2198
            logg.update('GMM: weight of the second Gaussian ='
2199
                        f' {model.weights_[1]}.', v=2)
2200
            if criterion1 and criterion2:
2201
                self.partial = False
2202
            else:
2203
                self.partial = True
2204
            logg.update(f'GMM decides {"" if self.partial else "not "}'
2205
                        'partial.', v=2)
2206
2207
        # steady-state slope
2208
        wu = self.u >= np.percentile(u_non_zero, 95)
2209
        ws = self.s >= np.percentile(s_non_zero, 95)
2210
        ss_u = self.u[wu | ws]
2211
        ss_s = self.s[wu | ws]
2212
        if np.all(ss_u == 0) or np.all(ss_s == 0):
2213
            self.low_quality = True
2214
            return
2215
        gamma = np.dot(ss_u, ss_s) / np.dot(ss_s, ss_s)
2216
        self.steady_state_func = lambda x: gamma*x
2217
2218
        # thickness of phase portrait
2219
        u_norm = u_non_zero / np.max(self.u)
2220
        s_norm = s_non_zero / np.max(self.s)
2221
        exp = np.hstack((np.reshape(u_norm, (-1, 1)),
2222
                         np.reshape(s_norm, (-1, 1))))
2223
        U, S, Vh = np.linalg.svd(exp)
2224
        self.thickness = S[1]
2225
2226
        # slope-based direction decision
2227
        with np.errstate(divide='ignore', invalid='ignore'):
2228
            slope = self.u / self.s
2229
        non_nan = ~np.isnan(slope)
2230
        slope = slope[non_nan]
2231
        on = slope >= gamma
2232
        off = slope < gamma
2233
        if len(ss_u) < 10 or len(u_non_zero) < 10:
2234
            fit_slope = False
2235
            self.direction = 'complete'
2236
        if fit_slope:
2237
            slope_ = u_non_zero / s_non_zero
2238
            on_ = slope_ >= gamma
2239
            off_ = slope_ < gamma
2240
            on_dist = np.sum((u_non_zero[on_] - gamma * s_non_zero[on_])**2)
2241
            off_dist = np.sum((gamma * s_non_zero[off_] - u_non_zero[off_])**2)
2242
            logg.update(f'Slope: SSE on induction phase = {on_dist},'
2243
                        f' SSE on repression phase = {off_dist}.', v=2)
2244
            if self.thickness < 1.5 / np.sqrt(self.tm):
2245
                narrow = True
2246
            else:
2247
                narrow = False
2248
            logg.update(f'Thickness of trajectory = {self.thickness}. '
2249
                        f'Trajectory is {"narrow" if narrow else "normal"}.',
2250
                        v=2)
2251
            if on_dist > 10 * self.tm**2 * off_dist:
2252
                self.direction = 'on'
2253
                self.partial = True
2254
            elif off_dist > 10 * self.tm**2 * on_dist:
2255
                self.direction = 'off'
2256
                self.partial = True
2257
            else:
2258
                if self.partial is True:
2259
                    if on_dist > 3 * self.tm * off_dist:
2260
                        self.direction = 'on'
2261
                    elif off_dist > 3 * self.tm * on_dist:
2262
                        self.direction = 'off'
2263
                    else:
2264
                        if narrow:
2265
                            self.direction = 'on'
2266
                        else:
2267
                            self.direction = 'complete'
2268
                            self.partial = False
2269
                else:
2270
                    if narrow:
2271
                        self.direction = ('off'
2272
                                          if off_dist > 2 * self.tm * on_dist
2273
                                          else 'on')
2274
                        self.partial = True
2275
                    else:
2276
                        self.direction = 'complete'
2277
2278
        # model pre-determination
2279
        if self.direction == 'on':
2280
            self.model_ = 1
2281
        elif self.direction == 'off':
2282
            self.model_ = 2
2283
        else:
2284
            c_high = self.c >= np.mean(self.c) + 2 * np.std(self.c)
2285
            c_high = c_high[non_nan]
2286
            if np.sum(c_high) < 10:
2287
                c_high = self.c >= np.mean(self.c) + np.std(self.c)
2288
                c_high = c_high[non_nan]
2289
            if np.sum(c_high) < 10:
2290
                c_high = self.c >= np.percentile(self.c, 90)
2291
                c_high = c_high[non_nan]
2292
            if np.sum(self.c[non_nan][c_high] == 0) > 0.5*np.sum(c_high):
2293
                self.low_quality = True
2294
                return
2295
            c_high_on = np.sum(c_high & on)
2296
            c_high_off = np.sum(c_high & off)
2297
            if c_high_on > c_high_off:
2298
                self.model_ = 1
2299
            else:
2300
                self.model_ = 2
2301
        if determine_model:
2302
            self.model = self.model_
2303
2304
        if not self.known_pars:
2305
            if fit_gmm or fit_slope:
2306
                logg.update(f'predicted partial trajectory: {self.partial}',
2307
                            v=1)
2308
                logg.update('predicted trajectory direction:'
2309
                            f'{self.direction}', v=1)
2310
            if determine_model:
2311
                logg.update(f'predicted model: {self.model}', v=1)
2312
2313
    def initialize_steady_state_params(self, model_mismatch=False):
2314
        self.scale_cc = 1.0
2315
        self.rescale_c = 1.0
2316
        # estimate rescale factor for u
2317
        s_norm = self.s / self.max_s
2318
        u_mid = (self.u >= 0.4 * self.max_u) & (self.u <= 0.6 * self.max_u)
2319
        if np.sum(u_mid) < 10:
2320
            self.rescale_u = self.thickness / 5
2321
        else:
2322
            s_low, s_high = np.percentile(s_norm[u_mid], [2, 98])
2323
            s_dist = s_high - s_low
2324
            self.rescale_u = s_dist
2325
        if self.rescale_u == 0:
2326
            self.low_quality = True
2327
            return
2328
2329
        c = self.c / self.rescale_c
2330
        u = self.u / self.rescale_u
2331
        s = self.s
2332
2333
        # some extreme values
2334
        wu = u >= np.percentile(u, 97)
2335
        ws = s >= np.percentile(s, 97)
2336
        ss_u = u[wu | ws]
2337
        ss_s = s[wu | ws]
2338
        c_upper = np.mean(c[wu | ws])
2339
2340
        c_high = c >= np.mean(c)
2341
        # _r stands for repressed state
2342
        c0_r = np.mean(c[c_high])
2343
        u0_r = np.mean(ss_u)
2344
        s0_r = np.mean(ss_s)
2345
        if c0_r < c_upper:
2346
            c0_r = c_upper + 0.1
2347
2348
        # adjust chromatin level for reasonable initialization
2349
        if model_mismatch or not self.fit_decoupling:
2350
            c_indu = np.mean(c[self.u > self.steady_state_func(self.s)])
2351
            c_repr = np.mean(c[self.u < self.steady_state_func(self.s)])
2352
            if c_indu == np.nan or c_repr == np.nan:
2353
                self.low_quality = True
2354
                return
2355
            c0_r = np.mean(c[c >= np.min([c_indu, c_repr])])
2356
2357
        # initialize rates
2358
        self.alpha_c = 0.1
2359
        self.beta = 1.0
2360
        self.gamma = np.dot(ss_u, ss_s) / np.dot(ss_s, ss_s)
2361
        alpha = u0_r
2362
        self.alpha = u0_r
2363
        self.rates = np.array([self.alpha_c, self.alpha, self.beta,
2364
                               self.gamma])
2365
2366
        # RNA-only
2367
        if self.rna_only:
2368
            t_sw_1 = 0.1
2369
            t_sw_3 = 20.0
2370
            if self.init_mode == 'grid':
2371
                # arange returns sequence [2,6,10,14,18]
2372
                for t_sw_2 in np.arange(2, 20, 4, dtype=np.float64):
2373
                    self.update(params, initialize=True, adjust_time=False,
2374
                                plot=False)
2375
2376
            elif self.init_mode == 'simple':
2377
                t_sw_2 = 10
2378
                self.params = np.array([t_sw_1,
2379
                                        t_sw_2-t_sw_1,
2380
                                        t_sw_3-t_sw_2,
2381
                                        self.alpha_c,
2382
                                        self.alpha,
2383
                                        self.beta,
2384
                                        self.gamma,
2385
                                        self.scale_cc,
2386
                                        self.rescale_c,
2387
                                        self.rescale_u])
2388
2389
            elif self.init_mode == 'invert':
2390
                t_sw_2 = approx_tau(u0_r, s0_r, 0, 0, alpha, self.beta,
2391
                                    self.gamma)
2392
                if t_sw_2 <= 0.2:
2393
                    t_sw_2 = 1.0
2394
                elif t_sw_2 >= 19.9:
2395
                    t_sw_2 = 19.0
2396
                self.params = np.array([t_sw_1,
2397
                                        t_sw_2-t_sw_1,
2398
                                        t_sw_3-t_sw_2,
2399
                                        self.alpha_c,
2400
                                        self.alpha,
2401
                                        self.beta,
2402
                                        self.gamma,
2403
                                        self.scale_cc,
2404
                                        self.rescale_c,
2405
                                        self.rescale_u])
2406
2407
        # chromatin-RNA
2408
        else:
2409
            if self.init_mode == 'grid':
2410
                # arange returns sequence [1,5,9,13,17]
2411
                for t_sw_1 in np.arange(1, 18, 4, dtype=np.float64):
2412
                    # arange returns sequence 2,6,10,14,18
2413
                    for t_sw_2 in np.arange(t_sw_1+1, 19, 4, dtype=np.float64):
2414
                        # arange returns sequence [3,7,11,15,19]
2415
                        for t_sw_3 in np.arange(t_sw_2+1, 20, 4,
2416
                                                dtype=np.float64):
2417
                            if not self.fit_decoupling:
2418
                                t_sw_3 = t_sw_2 + 30 / self.n_anchors
2419
                            params = np.array([t_sw_1,
2420
                                               t_sw_2-t_sw_1,
2421
                                               t_sw_3-t_sw_2,
2422
                                               self.alpha_c,
2423
                                               self.alpha,
2424
                                               self.beta,
2425
                                               self.gamma,
2426
                                               self.scale_cc,
2427
                                               self.rescale_c,
2428
                                               self.rescale_u])
2429
                            self.update(params, initialize=True,
2430
                                        adjust_time=False, plot=False)
2431
                            if not self.fit_decoupling:
2432
                                break
2433
2434
            elif self.init_mode == 'simple':
2435
                t_sw_1, t_sw_2, t_sw_3 = 5, 10, 15 \
2436
                    if not self.fit_decoupling \
2437
                    else 10.1
2438
                self.params = np.array([t_sw_1,
2439
                                        t_sw_2-t_sw_1,
2440
                                        t_sw_3-t_sw_2,
2441
                                        self.alpha_c,
2442
                                        self.alpha,
2443
                                        self.beta,
2444
                                        self.gamma,
2445
                                        self.scale_cc,
2446
                                        self.rescale_c,
2447
                                        self.rescale_u])
2448
2449
            elif self.init_mode == 'invert':
2450
                self.alpha = u0_r / c_upper
2451
                if model_mismatch or not self.fit_decoupling:
2452
                    self.alpha = u0_r / c0_r
2453
                rna_interval = approx_tau(u0_r, s0_r, 0, 0, alpha, self.beta,
2454
                                          self.gamma)
2455
                rna_interval = np.clip(rna_interval, 3, 12)
2456
                if self.model == 1:
2457
                    for t_sw_1 in np.arange(1, rna_interval-1, 2,
2458
                                            dtype=np.float64):
2459
                        t_sw_3 = rna_interval + t_sw_1
2460
                        for t_sw_2 in np.arange(t_sw_1+1, rna_interval, 2,
2461
                                                dtype=np.float64):
2462
                            if not self.fit_decoupling:
2463
                                t_sw_2 = t_sw_3 - 30 / self.n_anchors
2464
2465
                            alpha_c = -np.log(1 - c0_r) / t_sw_2
2466
                            params = np.array([t_sw_1,
2467
                                               t_sw_2-t_sw_1,
2468
                                               t_sw_3-t_sw_2,
2469
                                               alpha_c,
2470
                                               self.alpha,
2471
                                               self.beta,
2472
                                               self.gamma,
2473
                                               self.scale_cc,
2474
                                               self.rescale_c,
2475
                                               self.rescale_u])
2476
                            self.update(params, initialize=True,
2477
                                        adjust_time=False, plot=False)
2478
                            if not self.fit_decoupling:
2479
                                break
2480
2481
                elif self.model == 2:
2482
                    for t_sw_1 in np.arange(1, rna_interval, 2,
2483
                                            dtype=np.float64):
2484
                        t_sw_2 = rna_interval + t_sw_1
2485
                        for t_sw_3 in np.arange(t_sw_2+1, t_sw_2+6, 2,
2486
                                                dtype=np.float64):
2487
                            if not self.fit_decoupling:
2488
                                t_sw_3 = t_sw_2 + 30 / self.n_anchors
2489
2490
                            alpha_c = -np.log(1 - c0_r) / t_sw_3
2491
                            params = np.array([t_sw_1,
2492
                                               t_sw_2-t_sw_1,
2493
                                               t_sw_3-t_sw_2,
2494
                                               alpha_c,
2495
                                               self.alpha,
2496
                                               self.beta,
2497
                                               self.gamma,
2498
                                               self.scale_cc,
2499
                                               self.rescale_c,
2500
                                               self.rescale_u])
2501
                            self.update(params, initialize=True,
2502
                                        adjust_time=False, plot=False)
2503
                            if not self.fit_decoupling:
2504
                                break
2505
2506
        self.loss = [self.mse(self.params)]
2507
        self.t_sw_array = np.array([self.params[0],
2508
                                    self.params[0]+self.params[1],
2509
                                    self.params[0]+self.params[1]
2510
                                    + self.params[2]])
2511
        self.t_sw_1, self.t_sw_2, self.t_sw_3 = self.t_sw_array
2512
2513
        logg.update(f'initial params:\nswitch time array = {self.t_sw_array},'
2514
                    '\n'
2515
                    f'rates = {self.rates},\ncc scale = {self.scale_cc},\n'
2516
                    f'c rescale factor = {self.rescale_c},\n'
2517
                    f'u rescale factor = {self.rescale_u}', v=1)
2518
        logg.update(f'initial loss: {self.loss[-1]}', v=1)
2519
2520
    def fit(self):
2521
        if self.low_quality:
2522
            return self.loss
2523
2524
        if self.plot:
2525
            plt.ion()
2526
            self.fig = plt.figure(figsize=self.fig_size)
2527
            if self.rna_only:
2528
                self.ax = self.fig.add_subplot(111)
2529
            else:
2530
                self.ax = self.fig.add_subplot(111, projection='3d')
2531
2532
        if not self.known_pars:
2533
            self.fit_dyn()
2534
2535
        self.update(self.params, perform_update=True, fit_outlier=True,
2536
                    plot=True)
2537
2538
        # remove long gaps in the last observed state
2539
        t_sorted = np.sort(self.t)
2540
        dt = np.diff(t_sorted, prepend=0)
2541
        mean_dt = np.mean(dt)
2542
        std_dt = np.std(dt)
2543
        gap_thresh = np.clip(mean_dt+3*std_dt, 3*20/self.n_anchors, None)
2544
        if gap_thresh > 0:
2545
            idx = np.where(dt > gap_thresh)[0]
2546
            gap_sum = 0
2547
            last_t_sw = np.max(self.t_sw_array[self.t_sw_array < 20])
2548
            for i in idx:
2549
                t1 = t_sorted[i-1] if i > 0 else 0
2550
                t2 = t_sorted[i]
2551
                if t1 > last_t_sw and t2 <= 20:
2552
                    gap_sum += np.clip(t2 - t1 - mean_dt, 0, None)
2553
            if last_t_sw > np.max(self.t):
2554
                gap_sum += 20 - last_t_sw
2555
            realign_ratio = np.clip(20/(20 - gap_sum), None, 20/last_t_sw)
2556
            logg.update(f'removing gaps and realigning by {realign_ratio}..',
2557
                        v=1)
2558
            self.rates /= realign_ratio
2559
            self.alpha_c, self.alpha, self.beta, self.gamma = self.rates
2560
            self.params[:3] *= realign_ratio
2561
            self.params[3:7] = self.rates
2562
            self.t_sw_array = np.array([self.params[0],
2563
                                        self.params[0]+self.params[1],
2564
                                        self.params[0]+self.params[1]
2565
                                        + self.params[2]])
2566
            self.t_sw_1, self.t_sw_2, self.t_sw_3 = self.t_sw_array
2567
            self.update(self.params, perform_update=True, fit_outlier=True,
2568
                        plot=True)
2569
2570
        if self.plot:
2571
            plt.ioff()
2572
            plt.show(block=True)
2573
2574
        # likelihood
2575
        logg.update('computing likelihood..', v=1)
2576
        keep = self.non_zero & self.non_outlier & \
2577
            (self.u_all > 0.2 * np.percentile(self.u_all, 99.5)) & \
2578
            (self.s_all > 0.2 * np.percentile(self.s_all, 99.5))
2579
        scale_factor = np.array([self.scale_c / self.std_c,
2580
                                 self.scale_u / self.std_u,
2581
                                 self.scale_s / self.std_s])
2582
        if np.sum(keep) >= 10:
2583
            self.likelihood, self.l_c, self.ssd_c, self.var_c, l_u, l_s = \
2584
                compute_likelihood(self.c_all,
2585
                                   self.u_all,
2586
                                   self.s_all,
2587
                                   self.t_sw_array,
2588
                                   self.alpha_c,
2589
                                   self.alpha,
2590
                                   self.beta,
2591
                                   self.gamma,
2592
                                   self.rescale_c,
2593
                                   self.rescale_u,
2594
                                   self.t,
2595
                                   self.state,
2596
                                   scale_cc=self.scale_cc,
2597
                                   scale_factor=scale_factor,
2598
                                   model=self.model,
2599
                                   weight=keep,
2600
                                   rna_only=self.rna_only)
2601
        else:
2602
            self.likelihood, self.l_c, self.ssd_c, self.var_c, l_u = \
2603
                0, 0, 0, 0, 0
2604
            # TODO: Keep? Remove??
2605
            l_s = 0
2606
2607
        if not self.rna_only:
2608
            logg.update(f'likelihood of c: {self.l_c}, likelihood of u: {l_u},'
2609
                        f' likelihood of s: {l_s}', v=1)
2610
2611
        # velocity
2612
        logg.update('computing velocities..', v=1)
2613
        self.velocity = np.empty((len(self.u_all), 3))
2614
        if self.conn is not None:
2615
            new_time = self.conn.dot(self.t)
2616
            new_time[new_time > 20] = 20
2617
            new_state = self.state.copy()
2618
            new_state[new_time <= self.t_sw_1] = 0
2619
            new_state[(self.t_sw_1 < new_time) & (new_time <= self.t_sw_2)] = 1
2620
            new_state[(self.t_sw_2 < new_time) & (new_time <= self.t_sw_3)] = 2
2621
            new_state[self.t_sw_3 < new_time] = 3
2622
2623
        else:
2624
            new_time = self.t
2625
            new_state = self.state
2626
2627
        self.alpha_c, self.alpha, self.beta, self.gamma = \
2628
            check_params(self.alpha_c, self.alpha, self.beta, self.gamma)
2629
        vc, vu, vs = compute_velocity(new_time,
2630
                                      self.t_sw_array,
2631
                                      new_state,
2632
                                      self.alpha_c,
2633
                                      self.alpha,
2634
                                      self.beta,
2635
                                      self.gamma,
2636
                                      self.rescale_c,
2637
                                      self.rescale_u,
2638
                                      scale_cc=self.scale_cc,
2639
                                      model=self.model,
2640
                                      rna_only=self.rna_only)
2641
2642
        self.velocity[:, 0] = vc * self.scale_c
2643
        self.velocity[:, 1] = vu * self.scale_u
2644
        self.velocity[:, 2] = vs * self.scale_s
2645
2646
        # anchor expression and velocity
2647
        anchor_time, tau_list = anchor_points(self.t_sw_array, 20,
2648
                                              self.n_anchors, return_time=True)
2649
        switch = np.sum(self.t_sw_array < 20)
2650
        typed_tau_list = List()
2651
        [typed_tau_list.append(x) for x in tau_list]
2652
        self.alpha_c, self.alpha, self.beta, self.gamma, \
2653
            self.c0, self.u0, self.s0 = \
2654
            check_params(self.alpha_c, self.alpha, self.beta, self.gamma,
2655
                         c0=self.c0, u0=self.u0, s0=self.s0)
2656
        exp_list, exp_sw_list = generate_exp(typed_tau_list,
2657
                                             self.t_sw_array[:switch],
2658
                                             self.alpha_c,
2659
                                             self.alpha,
2660
                                             self.beta,
2661
                                             self.gamma,
2662
                                             scale_cc=self.scale_cc,
2663
                                             model=self.model,
2664
                                             rna_only=self.rna_only)
2665
        rescale_factor = np.array([self.rescale_c, self.rescale_u, 1.0])
2666
        exp_list = [x*rescale_factor for x in exp_list]
2667
        exp_sw_list = [x*rescale_factor for x in exp_sw_list]
2668
        c = np.ravel(np.concatenate([exp_list[x][:, 0]
2669
                                     for x in range(switch+1)]))
2670
        u = np.ravel(np.concatenate([exp_list[x][:, 1]
2671
                                     for x in range(switch+1)]))
2672
        s = np.ravel(np.concatenate([exp_list[x][:, 2]
2673
                                     for x in range(switch+1)]))
2674
        c_sw = np.ravel(np.concatenate([exp_sw_list[x][:, 0]
2675
                                        for x in range(switch)]))
2676
        u_sw = np.ravel(np.concatenate([exp_sw_list[x][:, 1]
2677
                                        for x in range(switch)]))
2678
        s_sw = np.ravel(np.concatenate([exp_sw_list[x][:, 2]
2679
                                        for x in range(switch)]))
2680
        self.alpha_c, self.alpha, self.beta, self.gamma = \
2681
            check_params(self.alpha_c, self.alpha, self.beta, self.gamma)
2682
        vc, vu, vs = compute_velocity(anchor_time,
2683
                                      self.t_sw_array,
2684
                                      None,
2685
                                      self.alpha_c,
2686
                                      self.alpha,
2687
                                      self.beta,
2688
                                      self.gamma,
2689
                                      self.rescale_c,
2690
                                      self.rescale_u,
2691
                                      scale_cc=self.scale_cc,
2692
                                      model=self.model,
2693
                                      rna_only=self.rna_only)
2694
2695
        # scale and shift back to original scale
2696
        c_ = c * self.scale_c + self.offset_c
2697
        u_ = u * self.scale_u + self.offset_u
2698
        s_ = s * self.scale_s + self.offset_s
2699
        c_sw_ = c_sw * self.scale_c + self.offset_c
2700
        u_sw_ = u_sw * self.scale_u + self.offset_u
2701
        s_sw_ = s_sw * self.scale_s + self.offset_s
2702
        vc = vc * self.scale_c
2703
        vu = vu * self.scale_u
2704
        vs = vs * self.scale_s
2705
2706
        self.anchor_exp = np.empty((len(u_), 3))
2707
        self.anchor_exp[:, 0], self.anchor_exp[:, 1], self.anchor_exp[:, 2] = \
2708
            c_, u_, s_
2709
        self.anchor_exp_sw = np.empty((len(u_sw_), 3))
2710
        self.anchor_exp_sw[:, 0], self.anchor_exp_sw[:, 1], \
2711
            self.anchor_exp_sw[:, 2] = c_sw_, u_sw_, s_sw_
2712
        self.anchor_velo = np.empty((len(u_), 3))
2713
        self.anchor_velo[:, 0] = vc
2714
        self.anchor_velo[:, 1] = vu
2715
        self.anchor_velo[:, 2] = vs
2716
        self.anchor_velo_min_idx = np.sum(anchor_time < np.min(new_time))
2717
        self.anchor_velo_max_idx = np.sum(anchor_time < np.max(new_time)) - 1
2718
2719
        if self.save_plot:
2720
            logg.update('saving plots..', v=1)
2721
            self.save_dyn_plot(c_, u_, s_, c_sw_, u_sw_, s_sw_, tau_list)
2722
2723
        self.realign_time_and_velocity(c, u, s, anchor_time)
2724
2725
        logg.update(f'final params:\nswitch time array = {self.t_sw_array},\n'
2726
                    f'rates = {self.rates},\ncc scale = {self.scale_cc},\n'
2727
                    f'c rescale factor = {self.rescale_c},\n'
2728
                    f'u rescale factor = {self.rescale_u}',
2729
                    v=1)
2730
        logg.update(f'final loss: {self.loss[-1]}', v=1)
2731
        logg.update(f'final likelihood: {self.likelihood}', v=1)
2732
2733
        return self.loss
2734
2735
    # the adam algorithm
2736
    # NOTE: The starting point for this function was an excample on the
2737
    # GeeksForGeeks website. The particular article is linked below:
2738
    # www.geeksforgeeks.org/how-to-implement-adam-gradient-descent-from-scratch-using-python/
2739
    def AdamMin(self, x, n_iter, tol, eps=1e-8):
2740
2741
        n = len(x)
2742
2743
        x_ten = torch.tensor(x, requires_grad=True, device=self.device,
2744
                             dtype=self.torch_type)
2745
2746
        # record lowest loss as a benchmark
2747
        # (right now the lowest loss is the current loss)
2748
        lowest_loss = torch.tensor(np.array(self.loss[-1], dtype=self.u.dtype),
2749
                                   device=self.device,
2750
                                   dtype=self.torch_type)
2751
2752
        # record the tensor of the parameters that cause the lowest loss
2753
        lowest_x_ten = x_ten
2754
2755
        # the m and v variables used in the adam calculations
2756
        m = torch.zeros(n, device=self.device, requires_grad=True,
2757
                        dtype=self.torch_type)
2758
        v = torch.zeros(n, device=self.device, requires_grad=True,
2759
                        dtype=self.torch_type)
2760
2761
        # the update amount to add to the x tensor after the appropriate
2762
        # calculations are made
2763
        u = torch.ones(n, device=self.device, requires_grad=True,
2764
                       dtype=self.torch_type) * float("inf")
2765
2766
        # how many times the new loss is lower than the lowest loss
2767
        update_count = 0
2768
2769
        iterations = 0
2770
2771
        # run the gradient descent updates
2772
        for t in range(n_iter):
2773
2774
            iterations += 1
2775
2776
            # calculate the loss
2777
            loss = self.mse_ten(x_ten)
2778
2779
            # if the loss is lower than the lowest loss...
2780
            if loss < lowest_loss:
2781
2782
                # record the new best tensor
2783
                lowest_x_ten = x_ten
2784
                update_count += 1
2785
2786
                # if the percentage difference in x tensors and loss values
2787
                # is less than the tolerance parameter and we've update the
2788
                # loss 3 times by now...
2789
                if torch.all((torch.abs(u) / lowest_x_ten) < tol) and \
2790
                    (torch.abs(loss - lowest_loss) / lowest_loss) < tol and \
2791
                        update_count >= 3:
2792
2793
                    # ...we've updated enough. Break!
2794
                    break
2795
2796
                # record the new lowest loss
2797
                lowest_loss = loss
2798
2799
            # take the gradient of mse w/r/t our current parameter values
2800
            loss.backward(inputs=x_ten)
2801
            g = x_ten.grad
2802
2803
            # calculate the new update value using the Adam formula
2804
            m = (self.adam_beta1 * m) + ((1.0 - self.adam_beta1) * g)
2805
            v = (self.adam_beta2 * v) + ((1.0 - self.adam_beta2) * g * g)
2806
2807
            mhat = m / (1.0 - (self.adam_beta1**(t+1)))
2808
            vhat = v / (1.0 - (self.adam_beta2**(t+1)))
2809
2810
            u = -(self.adam_lr * mhat) / (torch.sqrt(vhat) + eps)
2811
2812
            # update the x tensor
2813
            x_ten = x_ten + u
2814
2815
        # as long as we've found at least one better x tensor...
2816
        if update_count > 1:
2817
2818
            # record the final lowest loss
2819
            if loss < lowest_loss:
2820
                lowest_loss = loss
2821
2822
            # set the new loss for the gene to the new lowest loss
2823
            self.cur_loss = lowest_loss.item()
2824
2825
            # use the update() function so the gene's parameters
2826
            # are the new best one we found
2827
            updated = self.update(lowest_x_ten.cpu().detach().numpy())
2828
2829
        # if we never found a better x tensor, then the return value should
2830
        # state that we did not update it
2831
        else:
2832
            updated = False
2833
2834
        # return whether we updated the x tensor or not
2835
        return updated
2836
2837
    def fit_dyn(self):
2838
2839
        while self.cur_iter < self.max_iter:
2840
            self.cur_iter += 1
2841
2842
            # RNA-only
2843
            if self.rna_only:
2844
                logg.update('Nelder Mead on t_sw_2 and alpha..', v=2)
2845
                self.fitting_flag_ = 0
2846
                if self.cur_iter == 1:
2847
                    var_test = (self.alpha +
2848
                                np.array([-2, -1, -0.5, 0.5, 1, 2]) * 0.1
2849
                                * self.alpha)
2850
                    new_params = self.params.copy()
2851
                    for var in var_test:
2852
                        new_params[4] = var
2853
                        self.update(new_params, adjust_time=False,
2854
                                    penalize_gap=False)
2855
                res = minimize(self.mse, x0=[self.params[1], self.params[4]],
2856
                               method='Nelder-Mead', tol=1e-2,
2857
                               callback=self.update, options={'maxiter': 3})
2858
2859
                if self.fit_rescale:
2860
                    logg.update('Nelder Mead on t_sw_2, beta, and rescale u..',
2861
                                v=2)
2862
                    res = minimize(self.mse, x0=[self.params[1],
2863
                                                 self.params[5],
2864
                                                 self.params[9]],
2865
                                   method='Nelder-Mead', tol=1e-2,
2866
                                   callback=self.update,
2867
                                   options={'maxiter': 5})
2868
2869
                logg.update('Nelder Mead on alpha and gamma..', v=2)
2870
                self.fitting_flag_ = 1
2871
                res = minimize(self.mse, x0=[self.params[4], self.params[6]],
2872
                               method='Nelder-Mead', tol=1e-2,
2873
                               callback=self.update, options={'maxiter': 3})
2874
2875
                logg.update('Nelder Mead on t_sw_2..', v=2)
2876
                res = minimize(self.mse, x0=[self.params[1]],
2877
                               method='Nelder-Mead', tol=1e-2,
2878
                               callback=self.update, options={'maxiter': 2})
2879
2880
                logg.update('Full Nelder Mead..', v=2)
2881
                res = minimize(self.mse, x0=[self.params[1], self.params[4],
2882
                                             self.params[5], self.params[6]],
2883
                               method='Nelder-Mead', tol=1e-2,
2884
                               callback=self.update, options={'maxiter': 5})
2885
2886
            # chromatin-RNA
2887
            else:
2888
2889
                if not self.adam:
2890
                    logg.update('Nelder Mead on t_sw_1, chromatin switch time,'
2891
                                'and alpha_c..', v=2)
2892
                    self.fitting_flag_ = 1
2893
                    if self.cur_iter == 1:
2894
                        var_test = (self.gamma + np.array([-1, -0.5, 0.5, 1])
2895
                                    * 0.1 * self.gamma)
2896
                        new_params = self.params.copy()
2897
                        for var in var_test:
2898
                            new_params[6] = var
2899
                            self.update(new_params, adjust_time=False)
2900
                    if self.model == 0 or self.model == 1:
2901
                        res = minimize(self.mse, x0=[self.params[0],
2902
                                                     self.params[1],
2903
                                                     self.params[3]],
2904
                                       method='Nelder-Mead', tol=1e-2,
2905
                                       callback=self.update,
2906
                                       options={'maxiter': 20})
2907
                    elif self.model == 2:
2908
                        res = minimize(self.mse, x0=[self.params[0],
2909
                                                     self.params[2],
2910
                                                     self.params[3]],
2911
                                       method='Nelder-Mead', tol=1e-2,
2912
                                       callback=self.update,
2913
                                       options={'maxiter': 20})
2914
2915
                    logg.update('Nelder Mead on chromatin switch time,'
2916
                                'chromatin closing rate scaling, and rescale'
2917
                                'c..', v=2)
2918
                    self.fitting_flag_ = 2
2919
                    if self.model == 0 or self.model == 1:
2920
                        res = minimize(self.mse, x0=[self.params[1],
2921
                                                     self.params[7],
2922
                                                     self.params[8]],
2923
                                       method='Nelder-Mead', tol=1e-2,
2924
                                       callback=self.update,
2925
                                       options={'maxiter': 20})
2926
                    elif self.model == 2:
2927
                        res = minimize(self.mse, x0=[self.params[2],
2928
                                                     self.params[7],
2929
                                                     self.params[8]],
2930
                                       method='Nelder-Mead', tol=1e-2,
2931
                                       callback=self.update,
2932
                                       options={'maxiter': 20})
2933
2934
                    logg.update('Nelder Mead on rna switch time and alpha..',
2935
                                v=2)
2936
                    self.fitting_flag_ = 1
2937
                    if self.model == 0 or self.model == 1:
2938
                        res = minimize(self.mse, x0=[self.params[2],
2939
                                                     self.params[4]],
2940
                                       method='Nelder-Mead', tol=1e-2,
2941
                                       callback=self.update,
2942
                                       options={'maxiter': 10})
2943
                    elif self.model == 2:
2944
                        res = minimize(self.mse, x0=[self.params[1],
2945
                                                     self.params[4]],
2946
                                       method='Nelder-Mead', tol=1e-2,
2947
                                       callback=self.update,
2948
                                       options={'maxiter': 10})
2949
2950
                    logg.update('Nelder Mead on rna switch time, beta, and '
2951
                                'rescale u..', v=2)
2952
                    self.fitting_flag_ = 3
2953
                    if self.model == 0 or self.model == 1:
2954
                        res = minimize(self.mse, x0=[self.params[2],
2955
                                                     self.params[5],
2956
                                                     self.params[9]],
2957
                                       method='Nelder-Mead', tol=1e-2,
2958
                                       callback=self.update,
2959
                                       options={'maxiter': 20})
2960
                    elif self.model == 2:
2961
                        res = minimize(self.mse, x0=[self.params[1],
2962
                                                     self.params[5],
2963
                                                     self.params[9]],
2964
                                       method='Nelder-Mead', tol=1e-2,
2965
                                       callback=self.update,
2966
                                       options={'maxiter': 20})
2967
2968
                    logg.update('Nelder Mead on alpha and gamma..', v=2)
2969
                    self.fitting_flag_ = 2
2970
                    res = minimize(self.mse, x0=[self.params[4],
2971
                                                 self.params[6]],
2972
                                   method='Nelder-Mead', tol=1e-2,
2973
                                   callback=self.update,
2974
                                   options={'maxiter': 10})
2975
2976
                    logg.update('Nelder Mead on t_sw..', v=2)
2977
                    self.fitting_flag_ = 4
2978
                    res = minimize(self.mse, x0=self.params[:3],
2979
                                   method='Nelder-Mead', tol=1e-2,
2980
                                   callback=self.update,
2981
                                   options={'maxiter': 20})
2982
2983
                else:
2984
2985
                    logg.update('Adam on all parameters', v=2)
2986
                    self.AdamMin(np.array(self.params, dtype=self.u.dtype), 20,
2987
                                 tol=1e-2)
2988
2989
                    logg.update('Nelder Mead on t_sw..', v=2)
2990
                    self.fitting_flag_ = 4
2991
                    res = minimize(self.mse, x0=self.params[:3],
2992
                                   method='Nelder-Mead', tol=1e-2,
2993
                                   callback=self.update,
2994
                                   options={'maxiter': 15})
2995
2996
            logg.update(f'iteration {self.cur_iter} finished', v=2)
2997
2998
    def _variables(self, x):
2999
        scale_cc = self.scale_cc
3000
        rescale_c = self.rescale_c
3001
        rescale_u = self.rescale_u
3002
3003
        # RNA-only
3004
        if self.rna_only:
3005
            if len(x) == 1:  # fit t_sw_2
3006
                t3 = np.array([self.t_sw_1, x[0],
3007
                               self.t_sw_3 - self.t_sw_1 - x[0]])
3008
                r4 = self.rates
3009
3010
            elif len(x) == 2:
3011
                if self.fitting_flag_:  # fit alpha and gamma
3012
                    t3 = self.params[:3]
3013
                    r4 = np.array([self.alpha_c, x[0], self.beta, x[1]])
3014
                else:  # fit t_sw_2 and alpha
3015
                    t3 = np.array([self.t_sw_1, x[0],
3016
                                   self.t_sw_3 - self.t_sw_1 - x[0]])
3017
                    r4 = np.array([self.alpha_c, x[1], self.beta, self.gamma])
3018
3019
            elif len(x) == 3:  # fit t_sw_2, beta, and rescale u
3020
                t3 = np.array([self.t_sw_1,
3021
                               x[0], self.t_sw_3 - self.t_sw_1 - x[0]])
3022
                r4 = np.array([self.alpha_c, self.alpha, x[1], self.gamma])
3023
                rescale_u = x[2]
3024
3025
            elif len(x) == 4:  # fit all
3026
                t3 = np.array([self.t_sw_1, x[0], self.t_sw_3 - self.t_sw_1
3027
                               - x[0]])
3028
                r4 = np.array([self.alpha_c, x[1], x[2], x[3]])
3029
3030
            elif len(x) == 10:  # all available
3031
                t3 = x[:3]
3032
                r4 = x[3:7]
3033
                scale_cc = x[7]
3034
                rescale_c = x[8]
3035
                rescale_u = x[9]
3036
3037
            else:
3038
                return
3039
3040
        # chromatin-RNA
3041
        else:
3042
3043
            if len(x) == 2:
3044
                if self.fitting_flag_ == 1:  # fit rna switch time and alpha
3045
                    if self.model == 0 or self.model == 1:
3046
                        t3 = np.array([self.t_sw_1, self.params[1], x[0]])
3047
                    elif self.model == 2:
3048
                        t3 = np.array([self.t_sw_1, x[0],
3049
                                       self.t_sw_3 - self.t_sw_1 - x[0]])
3050
                    r4 = np.array([self.alpha_c, x[1], self.beta, self.gamma])
3051
                elif self.fitting_flag_ == 2:  # fit alpha and gamma
3052
                    t3 = self.params[:3]
3053
                    r4 = np.array([self.alpha_c, x[0], self.beta, x[1]])
3054
3055
            elif len(x) == 3:
3056
                # fit t_sw_1, chromatin switch time, and alpha_c
3057
                if self.fitting_flag_ == 1:
3058
                    if self.model == 0 or self.model == 1:
3059
                        t3 = np.array([x[0], x[1], self.t_sw_3 - x[0] - x[1]])
3060
                    elif self.model == 2:
3061
                        t3 = np.array([x[0], self.t_sw_2 - x[0], x[1]])
3062
                    r4 = np.array([x[2], self.alpha, self.beta, self.gamma])
3063
                # fit chromatin switch time, chromatin closing rate scaling,
3064
                # and rescale c
3065
                elif self.fitting_flag_ == 2:
3066
                    if self.model == 0 or self.model == 1:
3067
                        t3 = np.array([self.t_sw_1, x[0],
3068
                                       self.t_sw_3 - self.t_sw_1 - x[0]])
3069
                    elif self.model == 2:
3070
                        t3 = np.array([self.t_sw_1, self.params[1], x[0]])
3071
                    r4 = self.rates
3072
                    scale_cc = x[1]
3073
                    rescale_c = x[2]
3074
                # fit rna switch time, beta, and rescale u
3075
                elif self.fitting_flag_ == 3:
3076
                    if self.model == 0 or self.model == 1:
3077
                        t3 = np.array([self.t_sw_1, self.params[1], x[0]])
3078
                    elif self.model == 2:
3079
                        t3 = np.array([self.t_sw_1, x[0],
3080
                                       self.t_sw_3 - self.t_sw_1 - x[0]])
3081
                    r4 = np.array([self.alpha_c, self.alpha, x[1], self.gamma])
3082
                    rescale_u = x[2]
3083
                # fit three switch times
3084
                elif self.fitting_flag_ == 4:
3085
                    t3 = x
3086
                    r4 = self.rates
3087
3088
            elif len(x) == 7:
3089
                t3 = x[:3]
3090
                r4 = x[3:]
3091
3092
            elif len(x) == 10:
3093
                t3 = x[:3]
3094
                r4 = x[3:7]
3095
                scale_cc = x[7]
3096
                rescale_c = x[8]
3097
                rescale_u = x[9]
3098
3099
            else:
3100
                return
3101
3102
        # clip to meaningful values
3103
        if self.fitting_flag_ and not self.adam:
3104
            scale_cc = np.clip(scale_cc,
3105
                               np.max([0.5*self.scale_cc, 0.25]),
3106
                               np.min([2*self.scale_cc, 4]))
3107
3108
        if not self.known_pars:
3109
            if self.fit_decoupling:
3110
                t3 = np.clip(t3, 0.1, None)
3111
            else:
3112
                t3[2] = 30 / self.n_anchors
3113
                t3[:2] = np.clip(t3[:2], 0.1, None)
3114
            r4 = np.clip(r4, 0.001, 1000)
3115
            rescale_c = np.clip(rescale_c, 0.75, 1.5)
3116
            rescale_u = np.clip(rescale_u, 0.2, 3)
3117
3118
        return t3, r4, scale_cc, rescale_c, rescale_u
3119
3120
    # the tensor version of the calculate_dist_and_time function
3121
    def calculate_dist_and_time_ten(self,
3122
                                    c, u, s,
3123
                                    t_sw_array,
3124
                                    alpha_c, alpha, beta, gamma,
3125
                                    rescale_c, rescale_u,
3126
                                    scale_cc=1,
3127
                                    scale_factor=None,
3128
                                    model=1,
3129
                                    conn=None,
3130
                                    t=1000, k=1,
3131
                                    direction='complete',
3132
                                    total_h=20,
3133
                                    rna_only=False,
3134
                                    penalize_gap=True,
3135
                                    all_cells=True):
3136
3137
        conn = torch.tensor(conn.todense(),
3138
                            device=self.device,
3139
                            dtype=self.torch_type)
3140
3141
        c_ten = torch.tensor(c, device=self.device, dtype=self.torch_type)
3142
        u_ten = torch.tensor(u, device=self.device, dtype=self.torch_type)
3143
        s_ten = torch.tensor(s, device=self.device, dtype=self.torch_type)
3144
3145
        n = len(u)
3146
        if scale_factor is None:
3147
            scale_factor_ten = torch.stack((torch.std(c_ten), torch.std(u_ten),
3148
                                            torch.std(s_ten)))
3149
        else:
3150
            scale_factor_ten = torch.tensor(scale_factor, device=self.device,
3151
                                            dtype=self.torch_type)
3152
3153
        tau_list = self.anchor_points_ten(t_sw_array, total_h, t)
3154
3155
        switch = torch.sum(t_sw_array < total_h)
3156
3157
        exp_list, exp_sw_list = self.generate_exp_tens(tau_list,
3158
                                                       t_sw_array[:switch],
3159
                                                       alpha_c,
3160
                                                       alpha,
3161
                                                       beta,
3162
                                                       gamma,
3163
                                                       model=model,
3164
                                                       scale_cc=scale_cc,
3165
                                                       rna_only=rna_only)
3166
3167
        rescale_factor = torch.stack((rescale_c, rescale_u,
3168
                                     torch.tensor(1.0, device=self.device,
3169
                                                  requires_grad=True,
3170
                                                  dtype=self.torch_type)))
3171
3172
        for i in range(len(exp_list)):
3173
            exp_list[i] = exp_list[i]*rescale_factor
3174
3175
            if i < len(exp_list)-1:
3176
                exp_sw_list[i] = exp_sw_list[i]*rescale_factor
3177
3178
        max_c = 0
3179
        max_u = 0
3180
        max_s = 0
3181
3182
        if rna_only:
3183
            exp_mat = (torch.hstack((torch.reshape(u_ten, (-1, 1)),
3184
                                     torch.reshape(s_ten, (-1, 1))))
3185
                       / scale_factor_ten[1:])
3186
        else:
3187
            exp_mat = torch.hstack((torch.reshape(c_ten, (-1, 1)),
3188
                                    torch.reshape(u_ten, (-1, 1)),
3189
                                    torch.reshape(s_ten, (-1, 1))))\
3190
                                    / scale_factor_ten
3191
3192
        taus = torch.zeros((1, n), device=self.device,
3193
                           requires_grad=True,
3194
                           dtype=self.torch_type)
3195
        anchor_exp, anchor_t = None, None
3196
3197
        dists0 = torch.full((1, n), 0.0 if direction == "on"
3198
                            or direction == "complete" else np.inf,
3199
                            device=self.device,
3200
                            requires_grad=True,
3201
                            dtype=self.torch_type)
3202
        dists1 = torch.full((1, n), 0.0 if direction == "on"
3203
                            or direction == "complete" else np.inf,
3204
                            device=self.device,
3205
                            requires_grad=True,
3206
                            dtype=self.torch_type)
3207
        dists2 = torch.full((1, n), 0.0 if direction == "off"
3208
                            or direction == "complete" else np.inf,
3209
                            device=self.device,
3210
                            requires_grad=True,
3211
                            dtype=self.torch_type)
3212
        dists3 = torch.full((1, n), 0.0 if direction == "off"
3213
                            or direction == "complete" else np.inf,
3214
                            device=self.device,
3215
                            requires_grad=True,
3216
                            dtype=self.torch_type)
3217
3218
        ts0 = torch.zeros((1, n), device=self.device,
3219
                          requires_grad=True,
3220
                          dtype=self.torch_type)
3221
        ts1 = torch.zeros((1, n), device=self.device,
3222
                          requires_grad=True,
3223
                          dtype=self.torch_type)
3224
        ts2 = torch.zeros((1, n), device=self.device,
3225
                          requires_grad=True,
3226
                          dtype=self.torch_type)
3227
        ts3 = torch.zeros((1, n), device=self.device,
3228
                          requires_grad=True,
3229
                          dtype=self.torch_type)
3230
3231
        for i in range(switch+1):
3232
3233
            if not all_cells:
3234
                max_ci = (torch.max(exp_list[i][:, 0])
3235
                          if exp_list[i].shape[0] > 0
3236
                          else 0)
3237
                max_c = max_ci if max_ci > max_c else max_c
3238
            max_ui = torch.max(exp_list[i][:, 1]) if exp_list[i].shape[0] > 0 \
3239
                else 0
3240
            max_u = max_ui if max_ui > max_u else max_u
3241
            max_si = torch.max(exp_list[i][:, 2]) if exp_list[i].shape[0] > 0 \
3242
                else 0
3243
            max_s = max_si if max_si > max_s else max_s
3244
3245
            skip_phase = False
3246
            if direction == 'off':
3247
                if (model in [1, 2]) and (i < 2):
3248
                    skip_phase = True
3249
            elif direction == 'on':
3250
                if (model in [1, 2]) and (i >= 2):
3251
                    skip_phase = True
3252
            if rna_only and i == 0:
3253
                skip_phase = True
3254
3255
            if not skip_phase:
3256
                if rna_only:
3257
                    tmp = exp_list[i][:, 1:] / scale_factor_ten[1:]
3258
                else:
3259
                    tmp = exp_list[i] / scale_factor_ten
3260
                if anchor_exp is None:
3261
                    anchor_exp = exp_list[i]
3262
                    anchor_t = (tau_list[i] + t_sw_array[i-1] if i >= 1
3263
                                else tau_list[i])
3264
                else:
3265
                    anchor_exp = torch.vstack((anchor_exp, exp_list[i]))
3266
                    anchor_t = torch.hstack((anchor_t,
3267
                                             tau_list[i] + t_sw_array[i-1]
3268
                                             if i >= 1 else tau_list[i]))
3269
3270
                if not all_cells:
3271
                    anchor_prepend_rna = torch.zeros((1, 2),
3272
                                                     device=self.device,
3273
                                                     dtype=self.torch_type)
3274
                    anchor_prepend_chrom = torch.zeros((1, 3),
3275
                                                       device=self.device,
3276
                                                       dtype=self.torch_type)
3277
                    anchor_dist = torch.diff(tmp, dim=0,
3278
                                             prepend=anchor_prepend_rna
3279
                                             if rna_only
3280
                                             else anchor_prepend_chrom)
3281
3282
                    anchor_dist = torch.sqrt((anchor_dist*anchor_dist)
3283
                                             .sum(axis=1))
3284
                    remove_cand = anchor_dist < (0.01*torch.max(exp_mat[1])
3285
                                                 if rna_only
3286
                                                 else
3287
                                                 0.01*torch.max(exp_mat[2]))
3288
                    step_idx = torch.arange(0, anchor_dist.size()[0], 1,
3289
                                            device=self.device,
3290
                                            dtype=self.torch_type) % 3 > 0
3291
                    remove_cand &= step_idx
3292
                    keep_idx = torch.where(~remove_cand)[0]
3293
3294
                    tmp = tmp[keep_idx, :]
3295
3296
                model = NearestNeighbors(n_neighbors=k, output_type="numpy")
3297
                model.fit(tmp.detach())
3298
                dd, ii = model.kneighbors(exp_mat.detach())
3299
                ii = ii.T[0]
3300
3301
                new_dd = ((exp_mat[:, 0] - tmp[ii, 0])
3302
                          * (exp_mat[:, 0] - tmp[ii, 0])
3303
                          + (exp_mat[:, 1] - tmp[ii, 1])
3304
                          * (exp_mat[:, 1] - tmp[ii, 1])
3305
                          + (exp_mat[:, 2] - tmp[ii, 2])
3306
                          * (exp_mat[:, 2] - tmp[ii, 2]))
3307
3308
                if k > 1:
3309
                    new_dd = torch.mean(new_dd, dim=1)
3310
                if conn is not None:
3311
                    new_dd = torch.matmul(conn, new_dd)
3312
3313
                if i == 0:
3314
                    dists0 = dists0 + new_dd
3315
                elif i == 1:
3316
                    dists1 = dists1 + new_dd
3317
                elif i == 2:
3318
                    dists2 = dists2 + new_dd
3319
                elif i == 3:
3320
                    dists3 = dists3 + new_dd
3321
3322
                if not all_cells:
3323
                    ii = keep_idx[ii]
3324
                if k == 1:
3325
                    taus = tau_list[i][ii]
3326
                else:
3327
                    for j in range(n):
3328
                        taus[j] = tau_list[i][ii[j, :]]
3329
3330
                if i == 0:
3331
                    ts0 = ts0 + taus
3332
                elif i == 1:
3333
                    ts1 = ts1 + taus + t_sw_array[0]
3334
                elif i == 2:
3335
                    ts2 = ts2 + taus + t_sw_array[1]
3336
                elif i == 3:
3337
                    ts3 = ts3 + taus + t_sw_array[2]
3338
3339
        dists = torch.cat((dists0, dists1, dists2, dists3), 0)
3340
3341
        ts = torch.cat((ts0, ts1, ts2, ts3), 0)
3342
3343
        state_pred = torch.argmin(dists, axis=0)
3344
3345
        t_pred = ts[state_pred, torch.arange(n, device=self.device)]
3346
3347
        anchor_t1_list = []
3348
        anchor_t2_list = []
3349
3350
        t_sw_adjust = torch.zeros(3, device=self.device, dtype=self.torch_type)
3351
3352
        if direction == 'complete':
3353
3354
            dist_gap_add = torch.zeros((1, n), device=self.device,
3355
                                       dtype=self.torch_type)
3356
3357
            t_sorted = torch.clone(t_pred)
3358
            t_sorted, t_sorted_indices = torch.sort(t_sorted)
3359
3360
            dt = torch.diff(t_sorted, dim=0,
3361
                            prepend=torch.zeros(1, device=self.device,
3362
                                                dtype=self.torch_type))
3363
3364
            gap_thresh = 3*torch.quantile(dt, 0.99)
3365
3366
            idx = torch.where(dt > gap_thresh)[0]
3367
3368
            if len(idx) > 0 and penalize_gap:
3369
                h_tens = torch.tensor([total_h], device=self.device,
3370
                                      dtype=self.torch_type)
3371
3372
            for i in idx:
3373
3374
                t1 = t_sorted[i-1] if i > 0 else 0
3375
                t2 = t_sorted[i]
3376
                anchor_t1 = anchor_exp[torch.argmin(torch.abs(anchor_t - t1)),
3377
                                       :]
3378
                anchor_t2 = anchor_exp[torch.argmin(torch.abs(anchor_t - t2)),
3379
                                       :]
3380
                if all_cells:
3381
                    anchor_t1_list.append(torch.ravel(anchor_t1))
3382
                    anchor_t2_list.append(torch.ravel(anchor_t2))
3383
                if not all_cells:
3384
                    for j in range(1, switch):
3385
                        crit1 = ((t1 > t_sw_array[j-1])
3386
                                 and (t2 > t_sw_array[j-1])
3387
                                 and (t1 <= t_sw_array[j])
3388
                                 and (t2 <= t_sw_array[j]))
3389
                        crit2 = ((torch.abs(anchor_t1[2]
3390
                                            - exp_sw_list[j][0, 2])
3391
                                  < 0.02 * max_s) and
3392
                                 (torch.abs(anchor_t2[2]
3393
                                            - exp_sw_list[j][0, 2])
3394
                                 < 0.01 * max_s))
3395
                        crit3 = ((torch.abs(anchor_t1[1]
3396
                                            - exp_sw_list[j][0, 1])
3397
                                 < 0.02 * max_u) and
3398
                                 (torch.abs(anchor_t2[1]
3399
                                            - exp_sw_list[j][0, 1])
3400
                                 < 0.01 * max_u))
3401
                        crit4 = ((torch.abs(anchor_t1[0]
3402
                                            - exp_sw_list[j][0, 0])
3403
                                 < 0.02 * max_c) and
3404
                                 (torch.abs(anchor_t2[0]
3405
                                            - exp_sw_list[j][0, 0])
3406
                                 < 0.01 * max_c))
3407
                        if crit1 and crit2 and crit3 and crit4:
3408
                            t_sw_adjust[j] += t2 - t1
3409
                if penalize_gap:
3410
                    dist_gap = torch.sum(((anchor_t1[1:] - anchor_t2[1:]) /
3411
                                          scale_factor_ten[1:])**2)
3412
3413
                    idx_to_adjust = torch.tensor(t_pred >= t2,
3414
                                                 device=self.device)
3415
3416
                    idx_to_adjust = torch.reshape(idx_to_adjust,
3417
                                                  (1, idx_to_adjust.size()[0]))
3418
3419
                    true_tensor = torch.tensor([True], device=self.device)
3420
                    false_tensor = torch.tensor([False], device=self.device)
3421
3422
                    t_sw_array_ = torch.cat((t_sw_array, h_tens), dim=0)
3423
                    state_to_adjust = torch.where(t_sw_array_ > t2,
3424
                                                  true_tensor, false_tensor)
3425
3426
                    dist_gap_add[idx_to_adjust] += dist_gap
3427
3428
                    if state_to_adjust[0].item():
3429
                        dists0 += dist_gap_add
3430
                    if state_to_adjust[1].item():
3431
                        dists1 += dist_gap_add
3432
                    if state_to_adjust[2].item():
3433
                        dists2 += dist_gap_add
3434
                    if state_to_adjust[3].item():
3435
                        dists3 += dist_gap_add
3436
3437
                    dist_gap_add[idx_to_adjust] -= dist_gap
3438
3439
            dists = torch.cat((dists0, dists1, dists2, dists3), 0)
3440
3441
            state_pred = torch.argmin(dists, dim=0)
3442
3443
            if all_cells:
3444
                t_pred = ts[torch.arange(n, device=self.device), state_pred]
3445
3446
        min_dist = torch.min(dists, dim=0).values
3447
3448
        if all_cells:
3449
            exp_ss_mat = compute_ss_exp(alpha_c, alpha, beta, gamma,
3450
                                        model=model)
3451
            if rna_only:
3452
                exp_ss_mat[:, 0] = 1
3453
            dists_ss = pairwise_distance_square(exp_mat, exp_ss_mat *
3454
                                                rescale_factor / scale_factor)
3455
3456
            reach_ss = np.full((n, 4), False)
3457
            for i in range(n):
3458
                for j in range(4):
3459
                    if min_dist[i] > dists_ss[i, j]:
3460
                        reach_ss[i, j] = True
3461
            late_phase = np.full(n, -1)
3462
            for i in range(3):
3463
                late_phase[torch.abs(t_pred - t_sw_array[i]) < 0.1] = i
3464
3465
            return min_dist, t_pred, state_pred.cpu().detach().numpy(), \
3466
                reach_ss, late_phase, max_u, max_s, anchor_t1_list, \
3467
                anchor_t2_list
3468
3469
        else:
3470
            return min_dist, state_pred.cpu().detach().numpy(), max_u, max_s, \
3471
                   t_sw_adjust.cpu().detach().numpy()
3472
3473
    # the torch tensor version of the mse function
3474
    def mse_ten(self, x, fit_outlier=False,
3475
                penalize_gap=True):
3476
3477
        t3 = x[:3]
3478
        r4 = x[3:7]
3479
        scale_cc = x[7]
3480
        rescale_c = x[8]
3481
        rescale_u = x[9]
3482
3483
        if not self.known_pars:
3484
            if self.fit_decoupling:
3485
                t3 = torch.clip(t3, 0.1, None)
3486
            else:
3487
                t3[2] = 30 / self.n_anchors
3488
                t3[:2] = torch.clip(t3[:2], 0.1, None)
3489
            r4 = torch.clip(r4, 0.001, 1000)
3490
            rescale_c = torch.clip(rescale_c, 0.75, 1.5)
3491
            rescale_u = torch.clip(rescale_u, 0.2, 3)
3492
3493
        t_sw_array = torch.cumsum(t3, dim=0)
3494
3495
        if self.rna_only:
3496
            t_sw_array[2] = 20
3497
3498
        # conditions for minimum switch time and rate params
3499
        penalty = 0
3500
        if any(t3 < 0.2) or any(r4 < 0.005):
3501
            penalty = (torch.sum(0.2 - t3[t3 < 0.2]) if self.fit_decoupling
3502
                       else torch.sum(0.2 - t3[:2][t3[:2] < 0.2]))
3503
            penalty += torch.sum(0.005 - r4[r4 < 0.005]) * 1e2
3504
3505
        # condition for all params
3506
        if any(x > 500):
3507
            penalty = torch.sum(x[x > 500] - 500) * 1e-2
3508
3509
        c_array = self.c_all if fit_outlier else self.c
3510
        u_array = self.u_all if fit_outlier else self.u
3511
        s_array = self.s_all if fit_outlier else self.s
3512
3513
        if self.batch_size is not None and self.batch_size < len(c_array):
3514
3515
            subset_choice = np.random.choice(len(c_array), self.batch_size,
3516
                                             replace=False)
3517
3518
            c_array = c_array[subset_choice]
3519
            u_array = u_array[subset_choice]
3520
            s_array = s_array[subset_choice]
3521
3522
            if fit_outlier:
3523
                conn_for_calc = self.conn[subset_choice]
3524
            if not fit_outlier:
3525
                conn_for_calc = self.conn_sub[subset_choice]
3526
3527
            conn_for_calc = ((conn_for_calc.T)[subset_choice]).T
3528
3529
        else:
3530
3531
            if fit_outlier:
3532
                conn_for_calc = self.conn
3533
            if not fit_outlier:
3534
                conn_for_calc = self.conn_sub
3535
3536
        scale_factor_func = np.array(self.scale_factor, dtype=self.u.dtype)
3537
3538
        # distances and time assignments
3539
        res = self.calculate_dist_and_time_ten(c_array,
3540
                                               u_array,
3541
                                               s_array,
3542
                                               t_sw_array,
3543
                                               r4[0],
3544
                                               r4[1],
3545
                                               r4[2],
3546
                                               r4[3],
3547
                                               rescale_c,
3548
                                               rescale_u,
3549
                                               scale_cc=scale_cc,
3550
                                               scale_factor=scale_factor_func,
3551
                                               model=self.model,
3552
                                               direction=self.direction,
3553
                                               conn=conn_for_calc,
3554
                                               k=self.k_dist,
3555
                                               t=self.n_anchors,
3556
                                               rna_only=self.rna_only,
3557
                                               penalize_gap=penalize_gap,
3558
                                               all_cells=fit_outlier)
3559
3560
        if fit_outlier:
3561
            min_dist, t_pred, state_pred, reach_ss, late_phase, max_u, max_s, \
3562
                self.anchor_t1_list, self.anchor_t2_list = res
3563
        else:
3564
            min_dist, state_pred, max_u, max_s, t_sw_adjust = res
3565
3566
        loss = torch.mean(min_dist)
3567
3568
        # avoid exceeding maximum expressions
3569
        reg = torch.max(torch.tensor([0, max_s - torch.tensor(self.max_s)],
3570
                                     requires_grad=True,
3571
                                     dtype=self.torch_type))\
3572
            + torch.max(torch.tensor([0, max_u - torch.tensor(self.max_u)],
3573
                                     requires_grad=True,
3574
                                     dtype=self.torch_type))
3575
3576
        loss += reg
3577
3578
        loss += 1e-1 * penalty
3579
3580
        self.cur_loss = loss.item()
3581
        self.cur_state_pred = state_pred
3582
3583
        if fit_outlier:
3584
            return loss, t_pred
3585
        else:
3586
            self.cur_t_sw_adjust = t_sw_adjust
3587
3588
        return loss
3589
3590
    def mse(self, x, fit_outlier=False, penalize_gap=True):
3591
        x = np.array(x)
3592
3593
        t3, r4, scale_cc, rescale_c, rescale_u = self._variables(x)
3594
3595
        t_sw_array = np.array([t3[0], t3[0]+t3[1], t3[0]+t3[1]+t3[2]])
3596
        if self.rna_only:
3597
            t_sw_array[2] = 20
3598
3599
        # conditions for minimum switch time and rate params
3600
        penalty = 0
3601
        if any(t3 < 0.2) or any(r4 < 0.005):
3602
            penalty = (np.sum(0.2 - t3[t3 < 0.2]) if self.fit_decoupling
3603
                       else np.sum(0.2 - t3[:2][t3[:2] < 0.2]))
3604
            penalty += np.sum(0.005 - r4[r4 < 0.005]) * 1e2
3605
3606
        # condition for all params
3607
        if any(x > 500):
3608
            penalty = np.sum(x[x > 500] - 500) * 1e-2
3609
3610
        c_array = self.c_all if fit_outlier else self.c
3611
        u_array = self.u_all if fit_outlier else self.u
3612
        s_array = self.s_all if fit_outlier else self.s
3613
3614
        if self.neural_net:
3615
3616
            res = calculate_dist_and_time_nn(c_array,
3617
                                             u_array,
3618
                                             s_array,
3619
                                             self.max_u_all if fit_outlier
3620
                                             else self.max_u,
3621
                                             self.max_s_all if fit_outlier
3622
                                             else self.max_s,
3623
                                             t_sw_array,
3624
                                             r4[0],
3625
                                             r4[1],
3626
                                             r4[2],
3627
                                             r4[3],
3628
                                             rescale_c,
3629
                                             rescale_u,
3630
                                             self.ode_model_0,
3631
                                             self.ode_model_1,
3632
                                             self.ode_model_2_m1,
3633
                                             self.ode_model_2_m2,
3634
                                             self.device,
3635
                                             scale_cc=scale_cc,
3636
                                             scale_factor=self.scale_factor,
3637
                                             model=self.model,
3638
                                             direction=self.direction,
3639
                                             conn=self.conn if fit_outlier
3640
                                             else self.conn_sub,
3641
                                             k=self.k_dist,
3642
                                             t=self.n_anchors,
3643
                                             rna_only=self.rna_only,
3644
                                             penalize_gap=penalize_gap,
3645
                                             all_cells=fit_outlier)
3646
3647
            if fit_outlier:
3648
                min_dist, t_pred, state_pred, max_u, max_s, nn_penalty = res
3649
            else:
3650
                min_dist, state_pred, max_u, max_s, nn_penalty = res
3651
3652
            penalty += nn_penalty
3653
3654
            t_sw_adjust = [0, 0, 0]
3655
3656
        else:
3657
3658
            # distances and time assignments
3659
            res = calculate_dist_and_time(c_array,
3660
                                          u_array,
3661
                                          s_array,
3662
                                          t_sw_array,
3663
                                          r4[0],
3664
                                          r4[1],
3665
                                          r4[2],
3666
                                          r4[3],
3667
                                          rescale_c,
3668
                                          rescale_u,
3669
                                          scale_cc=scale_cc,
3670
                                          scale_factor=self.scale_factor,
3671
                                          model=self.model,
3672
                                          direction=self.direction,
3673
                                          conn=self.conn if fit_outlier
3674
                                          else self.conn_sub,
3675
                                          k=self.k_dist,
3676
                                          t=self.n_anchors,
3677
                                          rna_only=self.rna_only,
3678
                                          penalize_gap=penalize_gap,
3679
                                          all_cells=fit_outlier)
3680
3681
            if fit_outlier:
3682
                min_dist, t_pred, state_pred, reach_ss, late_phase, max_u, \
3683
                    max_s, self.anchor_t1_list, self.anchor_t2_list = res
3684
            else:
3685
                min_dist, state_pred, max_u, max_s, t_sw_adjust = res
3686
3687
        loss = np.mean(min_dist)
3688
3689
        # avoid exceeding maximum expressions
3690
        reg = np.max([0, max_s - self.max_s]) + np.max([0, max_u - self.max_u])
3691
        loss += reg
3692
3693
        loss += 1e-1 * penalty
3694
        self.cur_loss = loss
3695
        self.cur_state_pred = state_pred
3696
3697
        if fit_outlier:
3698
            return loss, t_pred
3699
        else:
3700
            self.cur_t_sw_adjust = t_sw_adjust
3701
3702
        return loss
3703
3704
    def update(self, x, perform_update=False, initialize=False,
3705
               fit_outlier=False, adjust_time=True, penalize_gap=True,
3706
               plot=True):
3707
        t3, r4, scale_cc, rescale_c, rescale_u = self._variables(x)
3708
        t_sw_array = np.array([t3[0], t3[0]+t3[1], t3[0]+t3[1]+t3[2]])
3709
3710
        # read results
3711
        if initialize:
3712
            new_loss = self.mse(x, penalize_gap=penalize_gap)
3713
        elif fit_outlier:
3714
            new_loss, t_pred = self.mse(x, fit_outlier=True,
3715
                                        penalize_gap=penalize_gap)
3716
        else:
3717
            new_loss = self.cur_loss
3718
            t_sw_adjust = self.cur_t_sw_adjust
3719
        state_pred = self.cur_state_pred
3720
3721
        if new_loss < self.loss[-1] or perform_update:
3722
            perform_update = True
3723
3724
            self.loss.append(new_loss)
3725
            self.alpha_c, self.alpha, self.beta, self.gamma = r4
3726
            self.rates = r4
3727
            self.scale_cc = scale_cc
3728
            self.rescale_c = rescale_c
3729
            self.rescale_u = rescale_u
3730
3731
            # adjust overcrowded anchors
3732
            if not fit_outlier and adjust_time:
3733
                t_sw_array -= np.cumsum(t_sw_adjust)
3734
                if self.rna_only:
3735
                    t_sw_array[2] = 20
3736
3737
            self.t_sw_1, self.t_sw_2, self.t_sw_3 = t_sw_array
3738
            self.t_sw_array = t_sw_array
3739
            self.params = np.array([self.t_sw_1,
3740
                                    self.t_sw_2-self.t_sw_1,
3741
                                    self.t_sw_3-self.t_sw_2,
3742
                                    self.alpha_c,
3743
                                    self.alpha,
3744
                                    self.beta,
3745
                                    self.gamma,
3746
                                    self.scale_cc,
3747
                                    self.rescale_c,
3748
                                    self.rescale_u])
3749
            if not initialize:
3750
                self.state = state_pred
3751
            if fit_outlier:
3752
                self.t = t_pred
3753
3754
            logg.update(f'params updated as: {self.t_sw_array} {self.rates} '
3755
                        f'{self.scale_cc} {self.rescale_c} {self.rescale_u}',
3756
                        v=2)
3757
3758
            # interactive plot
3759
            if self.plot and plot:
3760
                tau_list = anchor_points(self.t_sw_array, 20, self.n_anchors)
3761
                switch = np.sum(self.t_sw_array < 20)
3762
                typed_tau_list = List()
3763
                [typed_tau_list.append(x) for x in tau_list]
3764
                self.alpha_c, self.alpha, self.beta, self.gamma, \
3765
                    self.c0, self.u0, self.s0 = \
3766
                    check_params(self.alpha_c, self.alpha, self.beta,
3767
                                 self.gamma, c0=self.c0, u0=self.u0,
3768
                                 s0=self.s0)
3769
                exp_list, exp_sw_list = generate_exp(typed_tau_list,
3770
                                                     self.t_sw_array[:switch],
3771
                                                     self.alpha_c,
3772
                                                     self.alpha,
3773
                                                     self.beta,
3774
                                                     self.gamma,
3775
                                                     scale_cc=self.scale_cc,
3776
                                                     model=self.model,
3777
                                                     rna_only=self.rna_only)
3778
                rescale_factor = np.array([self.rescale_c,
3779
                                           self.rescale_u,
3780
                                           1.0])
3781
                exp_list = [x*rescale_factor for x in exp_list]
3782
                exp_sw_list = [x*rescale_factor for x in exp_sw_list]
3783
                c = np.ravel(np.concatenate([exp_list[x][:, 0] for x in
3784
                                             range(switch+1)]))
3785
                u = np.ravel(np.concatenate([exp_list[x][:, 1] for x in
3786
                                             range(switch+1)]))
3787
                s = np.ravel(np.concatenate([exp_list[x][:, 2] for x in
3788
                                             range(switch+1)]))
3789
                c_ = self.c_all if fit_outlier else self.c
3790
                u_ = self.u_all if fit_outlier else self.u
3791
                s_ = self.s_all if fit_outlier else self.s
3792
                self.ax.clear()
3793
                plt.pause(0.1)
3794
                if self.rna_only:
3795
                    self.ax.scatter(s, u, s=self.point_size*1.5, c='black',
3796
                                    alpha=0.6, zorder=2)
3797
                    if switch >= 1:
3798
                        c_sw1, u_sw1, s_sw1 = exp_sw_list[0][0]
3799
                        self.ax.plot([s_sw1], [u_sw1], "om",
3800
                                     markersize=self.point_size, zorder=5)
3801
                    if switch >= 2:
3802
                        c_sw2, u_sw2, s_sw2 = exp_sw_list[1][0]
3803
                        self.ax.plot([s_sw2], [u_sw2], "Xm",
3804
                                     markersize=self.point_size, zorder=5)
3805
                    if switch == 3:
3806
                        c_sw3, u_sw3, s_sw3 = exp_sw_list[2][0]
3807
                        self.ax.plot([s_sw3], [u_sw3], "Dm",
3808
                                     markersize=self.point_size, zorder=5)
3809
                    if np.max(self.t) == 20:
3810
                        self.ax.plot([s[-1]], [u[-1]], "*m",
3811
                                     markersize=self.point_size, zorder=5)
3812
                    for i in range(4):
3813
                        if any(self.state == i):
3814
                            self.ax.scatter(s_[(self.state == i)],
3815
                                            u_[(self.state == i)],
3816
                                            s=self.point_size, c=self.color[i])
3817
                    self.ax.set_xlabel('s')
3818
                    self.ax.set_ylabel('u')
3819
3820
                else:
3821
                    self.ax.scatter(s, u, c, s=self.point_size*1.5,
3822
                                    c='black', alpha=0.6, zorder=2)
3823
                    if switch >= 1:
3824
                        c_sw1, u_sw1, s_sw1 = exp_sw_list[0][0]
3825
                        self.ax.plot([s_sw1], [u_sw1], [c_sw1], "om",
3826
                                     markersize=self.point_size, zorder=5)
3827
                    if switch >= 2:
3828
                        c_sw2, u_sw2, s_sw2 = exp_sw_list[1][0]
3829
                        self.ax.plot([s_sw2], [u_sw2], [c_sw2], "Xm",
3830
                                     markersize=self.point_size, zorder=5)
3831
                    if switch == 3:
3832
                        c_sw3, u_sw3, s_sw3 = exp_sw_list[2][0]
3833
                        self.ax.plot([s_sw3], [u_sw3], [c_sw3], "Dm",
3834
                                     markersize=self.point_size, zorder=5)
3835
                    if np.max(self.t) == 20:
3836
                        self.ax.plot([s[-1]], [u[-1]], [c[-1]], "*m",
3837
                                     markersize=self.point_size, zorder=5)
3838
                    for i in range(4):
3839
                        if any(self.state == i):
3840
                            self.ax.scatter(s_[(self.state == i)],
3841
                                            u_[(self.state == i)],
3842
                                            c_[(self.state == i)],
3843
                                            s=self.point_size, c=self.color[i])
3844
                    self.ax.set_xlabel('s')
3845
                    self.ax.set_ylabel('u')
3846
                    self.ax.set_zlabel('c')
3847
                self.fig.canvas.draw()
3848
                plt.pause(0.1)
3849
        return perform_update
3850
3851
    def save_dyn_plot(self, c, u, s, c_sw, u_sw, s_sw, tau_list,
3852
                      show_all=False):
3853
        if not os.path.exists(self.plot_path):
3854
            os.makedirs(self.plot_path)
3855
            logg.update(f'{self.plot_path} directory created.', v=2)
3856
3857
        switch = np.sum(self.t_sw_array < 20)
3858
        scale_back = np.array([self.scale_c, self.scale_u, self.scale_s])
3859
        shift_back = np.array([self.offset_c, self.offset_u, self.offset_s])
3860
        if switch >= 1:
3861
            c_sw1, u_sw1, s_sw1 = c_sw[0], u_sw[0], s_sw[0]
3862
        if switch >= 2:
3863
            c_sw2, u_sw2, s_sw2 = c_sw[1], u_sw[1], s_sw[1]
3864
        if switch == 3:
3865
            c_sw3, u_sw3, s_sw3 = c_sw[2], u_sw[2], s_sw[2]
3866
3867
        if not show_all:
3868
            n_anchors = len(u)
3869
            t_lower = np.min(self.t)
3870
            t_upper = np.max(self.t)
3871
            t_ = np.concatenate((tau_list[0], tau_list[1] + self.t_sw_array[0],
3872
                                 tau_list[2] + self.t_sw_array[1],
3873
                                 tau_list[3] + self.t_sw_array[2]))
3874
            c_pre = c[t_[:n_anchors] <= t_lower]
3875
            u_pre = u[t_[:n_anchors] <= t_lower]
3876
            s_pre = s[t_[:n_anchors] <= t_lower]
3877
            c = c[(t_lower < t_[:n_anchors]) & (t_[:n_anchors] < t_upper)]
3878
            u = u[(t_lower < t_[:n_anchors]) & (t_[:n_anchors] < t_upper)]
3879
            s = s[(t_lower < t_[:n_anchors]) & (t_[:n_anchors] < t_upper)]
3880
3881
        c_all = self.c_all * self.scale_c + self.offset_c
3882
        u_all = self.u_all * self.scale_u + self.offset_u
3883
        s_all = self.s_all * self.scale_s + self.offset_s
3884
3885
        fig = plt.figure(figsize=self.fig_size)
3886
        fig.patch.set_facecolor('white')
3887
        ax = fig.add_subplot(111, facecolor='white')
3888
        if not show_all and len(u_pre) > 0:
3889
            ax.scatter(s_pre, u_pre, s=self.point_size/2, c='black',
3890
                       alpha=0.4, zorder=2)
3891
        ax.scatter(s, u, s=self.point_size*1.5, c='black', alpha=0.6, zorder=2)
3892
        for i in range(4):
3893
            if any(self.state == i):
3894
                ax.scatter(s_all[(self.state == i) & (self.non_outlier)],
3895
                           u_all[(self.state == i) & (self.non_outlier)],
3896
                           s=self.point_size, c=self.color[i])
3897
        ax.scatter(s_all[~self.non_outlier], u_all[~self.non_outlier],
3898
                   s=self.point_size/2, c='grey')
3899
        if show_all or t_lower <= self.t_sw_array[0]:
3900
            ax.plot([s_sw1], [u_sw1], "om", markersize=self.point_size,
3901
                    zorder=5)
3902
        if switch >= 2 and (show_all or (t_lower <= self.t_sw_array[1] and
3903
                                         t_upper >= self.t_sw_array[1])):
3904
            ax.plot([s_sw2], [u_sw2], "Xm", markersize=self.point_size,
3905
                    zorder=5)
3906
        if switch >= 3 and (show_all or (t_lower <= self.t_sw_array[2] and
3907
                                         t_upper >= self.t_sw_array[2])):
3908
            ax.plot([s_sw3], [u_sw3], "Dm", markersize=self.point_size,
3909
                    zorder=5)
3910
        if np.max(self.t) == 20:
3911
            ax.plot([s[-1]], [u[-1]], "*m", markersize=self.point_size,
3912
                    zorder=5)
3913
        if (self.anchor_t1_list is not None and len(self.anchor_t1_list) > 0
3914
                and show_all):
3915
            for i in range(len(self.anchor_t1_list)):
3916
                exp_t1 = self.anchor_t1_list[i] * scale_back + shift_back
3917
                exp_t2 = self.anchor_t2_list[i] * scale_back + shift_back
3918
                ax.plot([exp_t1[2]], [exp_t1[1]], "|y",
3919
                        markersize=self.point_size*1.5)
3920
                ax.plot([exp_t2[2]], [exp_t2[1]], "|c",
3921
                        markersize=self.point_size*1.5)
3922
        ax.plot(s_all,
3923
                self.steady_state_func(self.s_all) * self.scale_u
3924
                + self.offset_u, c='grey', ls=':', lw=self.point_size/4,
3925
                alpha=0.7)
3926
        ax.set_xlabel('s')
3927
        ax.set_ylabel('u')
3928
        ax.set_title(f'{self.gene}-{self.model}')
3929
        plt.tight_layout()
3930
        fig.savefig(f'{self.plot_path}/{self.gene}-{self.model}-us.png',
3931
                    dpi=fig.dpi, facecolor=fig.get_facecolor(),
3932
                    transparent=False, edgecolor='none')
3933
        plt.close(fig)
3934
        plt.pause(0.2)
3935
3936
        if self.extra_color is not None:
3937
            fig = plt.figure(figsize=self.fig_size)
3938
            fig.patch.set_facecolor('white')
3939
            ax = fig.add_subplot(111, facecolor='white')
3940
            if not show_all and len(u_pre) > 0:
3941
                ax.scatter(s_pre, u_pre, s=self.point_size/2, c='black',
3942
                           alpha=0.4, zorder=2)
3943
            ax.scatter(s, u, s=self.point_size*1.5, c='black', alpha=0.6,
3944
                       zorder=2)
3945
            ax.scatter(s_all, u_all, s=self.point_size, c=self.extra_color)
3946
            if show_all or t_lower <= self.t_sw_array[0]:
3947
                ax.plot([s_sw1], [u_sw1], "om", markersize=self.point_size,
3948
                        zorder=5)
3949
            if switch >= 2 and (show_all or (t_lower <= self.t_sw_array[1] and
3950
                                             t_upper >= self.t_sw_array[1])):
3951
                ax.plot([s_sw2], [u_sw2], "Xm", markersize=self.point_size,
3952
                        zorder=5)
3953
            if switch >= 3 and (show_all or (t_lower <= self.t_sw_array[2] and
3954
                                             t_upper >= self.t_sw_array[2])):
3955
                ax.plot([s_sw3], [u_sw3], "Dm", markersize=self.point_size,
3956
                        zorder=5)
3957
            if np.max(self.t) == 20:
3958
                ax.plot([s[-1]], [u[-1]], "*m", markersize=self.point_size,
3959
                        zorder=5)
3960
            if (self.anchor_t1_list is not None and
3961
                    len(self.anchor_t1_list) > 0 and show_all):
3962
                for i in range(len(self.anchor_t1_list)):
3963
                    exp_t1 = self.anchor_t1_list[i] * scale_back + shift_back
3964
                    exp_t2 = self.anchor_t2_list[i] * scale_back + shift_back
3965
                    ax.plot([exp_t1[2]], [exp_t1[1]], "|y",
3966
                            markersize=self.point_size*1.5)
3967
                    ax.plot([exp_t2[2]], [exp_t2[1]], "|c",
3968
                            markersize=self.point_size*1.5)
3969
            ax.plot(s_all, self.steady_state_func(self.s_all) * self.scale_u
3970
                    + self.offset_u, c='grey', ls=':', lw=self.point_size/4,
3971
                    alpha=0.7)
3972
            ax.set_xlabel('s')
3973
            ax.set_ylabel('u')
3974
            ax.set_title(f'{self.gene}-{self.model}')
3975
            plt.tight_layout()
3976
            fig.savefig(f'{self.plot_path}/{self.gene}-{self.model}-'
3977
                        'us_colorby_extra.png', dpi=fig.dpi,
3978
                        facecolor=fig.get_facecolor(), transparent=False,
3979
                        edgecolor='none')
3980
            plt.close(fig)
3981
            plt.pause(0.2)
3982
3983
            if not self.rna_only:
3984
                fig = plt.figure(figsize=self.fig_size)
3985
                fig.patch.set_facecolor('white')
3986
                ax = fig.add_subplot(111, facecolor='white')
3987
                if not show_all and len(u_pre) > 0:
3988
                    ax.scatter(u_pre, c_pre, s=self.point_size/2, c='black',
3989
                               alpha=0.4, zorder=2)
3990
                ax.scatter(u, c, s=self.point_size*1.5, c='black', alpha=0.6,
3991
                           zorder=2)
3992
                ax.scatter(u_all, c_all, s=self.point_size, c=self.extra_color)
3993
                if show_all or t_lower <= self.t_sw_array[0]:
3994
                    ax.plot([u_sw1], [c_sw1], "om", markersize=self.point_size,
3995
                            zorder=5)
3996
                if switch >= 2 and (show_all or (t_lower <= self.t_sw_array[1]
3997
                                                 and t_upper >=
3998
                                                 self.t_sw_array[1])):
3999
                    ax.plot([u_sw2], [c_sw2], "Xm", markersize=self.point_size,
4000
                            zorder=5)
4001
                if switch >= 3 and (show_all or (t_lower <= self.t_sw_array[2]
4002
                                                 and t_upper >=
4003
                                                 self.t_sw_array[2])):
4004
                    ax.plot([u_sw3], [c_sw3], "Dm", markersize=self.point_size,
4005
                            zorder=5)
4006
                if np.max(self.t) == 20:
4007
                    ax.plot([u[-1]], [c[-1]], "*m", markersize=self.point_size,
4008
                            zorder=5)
4009
                ax.set_xlabel('u')
4010
                ax.set_ylabel('c')
4011
                ax.set_title(f'{self.gene}-{self.model}')
4012
                plt.tight_layout()
4013
                fig.savefig(f'{self.plot_path}/{self.gene}-{self.model}-'
4014
                            'cu_colorby_extra.png', dpi=fig.dpi,
4015
                            facecolor=fig.get_facecolor(), transparent=False,
4016
                            edgecolor='none')
4017
                plt.close(fig)
4018
                plt.pause(0.2)
4019
4020
        if not self.rna_only:
4021
            fig = plt.figure(figsize=self.fig_size)
4022
            fig.patch.set_facecolor('white')
4023
            ax = fig.add_subplot(111, projection='3d', facecolor='white')
4024
            if not show_all and len(u_pre) > 0:
4025
                ax.scatter(s_pre, u_pre, c_pre, s=self.point_size/2, c='black',
4026
                           alpha=0.4, zorder=2)
4027
            ax.scatter(s, u, c, s=self.point_size*1.5, c='black', alpha=0.6,
4028
                       zorder=2)
4029
            for i in range(4):
4030
                if any(self.state == i):
4031
                    ax.scatter(s_all[(self.state == i) & (self.non_outlier)],
4032
                               u_all[(self.state == i) & (self.non_outlier)],
4033
                               c_all[(self.state == i) & (self.non_outlier)],
4034
                               s=self.point_size, c=self.color[i])
4035
            ax.scatter(s_all[~self.non_outlier], u_all[~self.non_outlier],
4036
                       c_all[~self.non_outlier], s=self.point_size/2, c='grey')
4037
            if show_all or t_lower <= self.t_sw_array[0]:
4038
                ax.plot([s_sw1], [u_sw1], [c_sw1], "om",
4039
                        markersize=self.point_size, zorder=5)
4040
            if switch >= 2 and (show_all or (t_lower <= self.t_sw_array[1] and
4041
                                             t_upper >= self.t_sw_array[1])):
4042
                ax.plot([s_sw2], [u_sw2], [c_sw2], "Xm",
4043
                        markersize=self.point_size, zorder=5)
4044
            if switch >= 3 and (show_all or (t_lower <= self.t_sw_array[2] and
4045
                                             t_upper >= self.t_sw_array[2])):
4046
                ax.plot([s_sw3], [u_sw3], [c_sw3], "Dm",
4047
                        markersize=self.point_size, zorder=5)
4048
            if np.max(self.t) == 20:
4049
                ax.plot([s[-1]], [u[-1]], [c[-1]], "*m",
4050
                        markersize=self.point_size, zorder=5)
4051
            ax.set_xlabel('s')
4052
            ax.set_ylabel('u')
4053
            ax.set_zlabel('c')
4054
            ax.set_title(f'{self.gene}-{self.model}')
4055
            plt.tight_layout()
4056
            fig.savefig(f'{self.plot_path}/{self.gene}-{self.model}-cus.png',
4057
                        dpi=fig.dpi, facecolor=fig.get_facecolor(),
4058
                        transparent=False, edgecolor='none')
4059
            plt.close(fig)
4060
            plt.pause(0.2)
4061
4062
            fig = plt.figure(figsize=self.fig_size)
4063
            fig.patch.set_facecolor('white')
4064
            ax = fig.add_subplot(111, facecolor='white')
4065
            if not show_all and len(u_pre) > 0:
4066
                ax.scatter(s_pre, u_pre, s=self.point_size/2, c='black',
4067
                           alpha=0.4, zorder=2)
4068
            ax.scatter(s, u, s=self.point_size*1.5, c='black', alpha=0.6,
4069
                       zorder=2)
4070
            ax.scatter(s_all, u_all, s=self.point_size, c=np.log1p(self.c_all),
4071
                       cmap='coolwarm')
4072
            if show_all or t_lower <= self.t_sw_array[0]:
4073
                ax.plot([s_sw1], [u_sw1], "om", markersize=self.point_size,
4074
                        zorder=5)
4075
            if switch >= 2 and (show_all or (t_lower <= self.t_sw_array[1] and
4076
                                             t_upper >= self.t_sw_array[1])):
4077
                ax.plot([s_sw2], [u_sw2], "Xm", markersize=self.point_size,
4078
                        zorder=5)
4079
            if switch >= 3 and (show_all or (t_lower <= self.t_sw_array[2] and
4080
                                             t_upper >= self.t_sw_array[2])):
4081
                ax.plot([s_sw3], [u_sw3], "Dm", markersize=self.point_size,
4082
                        zorder=5)
4083
            if np.max(self.t) == 20:
4084
                ax.plot([s[-1]], [u[-1]], "*m", markersize=self.point_size,
4085
                        zorder=5)
4086
            ax.plot(s_all, self.steady_state_func(self.s_all) * self.scale_u +
4087
                    self.offset_u, c='grey', ls=':', lw=self.point_size/4,
4088
                    alpha=0.7)
4089
            ax.set_xlabel('s')
4090
            ax.set_ylabel('u')
4091
            ax.set_title(f'{self.gene}-{self.model}')
4092
            plt.tight_layout()
4093
            fig.savefig(f'{self.plot_path}/{self.gene}-{self.model}-'
4094
                        'us_colorby_c.png', dpi=fig.dpi,
4095
                        facecolor=fig.get_facecolor(), transparent=False,
4096
                        edgecolor='none')
4097
            plt.close(fig)
4098
            plt.pause(0.2)
4099
4100
            fig = plt.figure(figsize=self.fig_size)
4101
            fig.patch.set_facecolor('white')
4102
            ax = fig.add_subplot(111, facecolor='white')
4103
            if not show_all and len(u_pre) > 0:
4104
                ax.scatter(u_pre, c_pre, s=self.point_size/2, c='black',
4105
                           alpha=0.4, zorder=2)
4106
            ax.scatter(u, c, s=self.point_size*1.5, c='black', alpha=0.6,
4107
                       zorder=2)
4108
            for i in range(4):
4109
                if any(self.state == i):
4110
                    ax.scatter(u_all[(self.state == i) & (self.non_outlier)],
4111
                               c_all[(self.state == i) & (self.non_outlier)],
4112
                               s=self.point_size, c=self.color[i])
4113
            ax.scatter(u_all[~self.non_outlier], c_all[~self.non_outlier],
4114
                       s=self.point_size/2, c='grey')
4115
            if show_all or t_lower <= self.t_sw_array[0]:
4116
                ax.plot([u_sw1], [c_sw1], "om", markersize=self.point_size,
4117
                        zorder=5)
4118
            if switch >= 2 and (show_all or (t_lower <= self.t_sw_array[1] and
4119
                                             t_upper >= self.t_sw_array[1])):
4120
                ax.plot([u_sw2], [c_sw2], "Xm", markersize=self.point_size,
4121
                        zorder=5)
4122
            if switch >= 3 and (show_all or (t_lower <= self.t_sw_array[2] and
4123
                                             t_upper >= self.t_sw_array[2])):
4124
                ax.plot([u_sw3], [c_sw3], "Dm", markersize=self.point_size,
4125
                        zorder=5)
4126
            if np.max(self.t) == 20:
4127
                ax.plot([u[-1]], [c[-1]], "*m", markersize=self.point_size,
4128
                        zorder=5)
4129
            ax.set_xlabel('u')
4130
            ax.set_ylabel('c')
4131
            ax.set_title(f'{self.gene}-{self.model}')
4132
            plt.tight_layout()
4133
            fig.savefig(f'{self.plot_path}/{self.gene}-{self.model}-cu.png',
4134
                        dpi=fig.dpi, facecolor=fig.get_facecolor(),
4135
                        transparent=False, edgecolor='none')
4136
            plt.close(fig)
4137
            plt.pause(0.2)
4138
4139
    def get_loss(self):
4140
        return self.loss
4141
4142
    def get_model(self):
4143
        return self.model
4144
4145
    def get_params(self):
4146
        return self.t_sw_array, self.rates, self.scale_cc, self.rescale_c, \
4147
            self.rescale_u, self.realign_ratio
4148
4149
    def is_partial(self):
4150
        return self.partial
4151
4152
    def get_direction(self):
4153
        return self.direction
4154
4155
    def realign_time_and_velocity(self, c, u, s, anchor_time):
4156
        # realign time to range (0,20)
4157
        self.anchor_min_idx = np.sum(anchor_time < (np.min(self.t)-1e-5))
4158
        self.anchor_max_idx = np.sum(anchor_time < (np.max(self.t)-1e-5))
4159
        self.c0 = c[self.anchor_min_idx]
4160
        self.u0 = u[self.anchor_min_idx]
4161
        self.s0 = s[self.anchor_min_idx]
4162
        self.realign_ratio = 20 / (np.max(self.t) - np.min(self.t))
4163
        logg.update(f'fitted params:\nswitch time array = {self.t_sw_array},\n'
4164
                    f'rates = {self.rates},\ncc scale = {self.scale_cc},\n'
4165
                    f'c rescale factor = {self.rescale_c},\n'
4166
                    f'u rescale factor = {self.rescale_u}',
4167
                    v=1)
4168
        logg.update(f'aligning to range (0,20) by {self.realign_ratio}..',
4169
                    v=1)
4170
        self.rates /= self.realign_ratio
4171
        self.alpha_c, self.alpha, self.beta, self.gamma = self.rates
4172
        self.params[3:7] = self.rates
4173
        self.t_sw_array = ((self.t_sw_array - np.min(self.t))
4174
                           * self.realign_ratio)
4175
        self.t_sw_1, self.t_sw_2, self.t_sw_3 = self.t_sw_array
4176
        self.params[:3] = np.array([self.t_sw_1, self.t_sw_2 - self.t_sw_1,
4177
                                    self.t_sw_3 - self.t_sw_2])
4178
        self.t -= np.min(self.t)
4179
        self.t = self.t * 20 / np.max(self.t)
4180
        self.velocity /= self.realign_ratio
4181
        self.velocity[:, 0] = np.clip(self.velocity[:, 0], -self.c_all
4182
                                      * self.scale_c, None)
4183
        self.velocity[:, 1] = np.clip(self.velocity[:, 1], -self.u_all
4184
                                      * self.scale_u, None)
4185
        self.velocity[:, 2] = np.clip(self.velocity[:, 2], -self.s_all
4186
                                      * self.scale_s, None)
4187
        self.anchor_velo /= self.realign_ratio
4188
        self.anchor_velo[:, 0] = np.clip(self.anchor_velo[:, 0],
4189
                                         -np.max(self.c_all * self.scale_c),
4190
                                         None)
4191
        self.anchor_velo[:, 1] = np.clip(self.anchor_velo[:, 1],
4192
                                         -np.max(self.u_all * self.scale_u),
4193
                                         None)
4194
        self.anchor_velo[:, 2] = np.clip(self.anchor_velo[:, 2],
4195
                                         -np.max(self.s_all * self.scale_s),
4196
                                         None)
4197
4198
    def get_initial_exp(self):
4199
        return np.array([self.c0, self.u0, self.s0])
4200
4201
    def get_time_assignment(self):
4202
        if self.low_quality:
4203
            return np.zeros(len(self.u_all))
4204
        return self.t
4205
4206
    def get_state_assignment(self):
4207
        if self.low_quality:
4208
            return np.zeros(len(self.u_all))
4209
        return self.state
4210
4211
    def get_velocity(self):
4212
        if self.low_quality:
4213
            return np.zeros((len(self.u_all), 3))
4214
        return self.velocity
4215
4216
    def get_likelihood(self):
4217
        return self.likelihood, self.l_c, self.ssd_c, self.var_c
4218
4219
    def get_anchors(self):
4220
        if self.low_quality:
4221
            return (np.zeros((1, 3)), np.zeros((1, 3)), np.zeros((1, 3)),
4222
                    0, 0, 0, 0)
4223
        return self.anchor_exp, self.anchor_exp_sw, self.anchor_velo, \
4224
            self.anchor_min_idx, self.anchor_max_idx, \
4225
            self.anchor_velo_min_idx, self.anchor_velo_max_idx
4226
4227
4228
def regress_func(c, u, s, m, mi, im, dev, nn, ad, lr, b1, b2, bs, gpdist,
4229
                 embed, conn, pl, sp, pdir, fa, gene, pa, di, ro, fit, fd,
4230
                 extra, ru, alpha, beta, gamma, t_, verbosity, log_folder,
4231
                 log_filename):
4232
4233
    settings.VERBOSITY = verbosity
4234
    settings.LOG_FOLDER = log_folder
4235
    settings.LOG_FILENAME = log_filename
4236
    settings.GENE = gene
4237
4238
    if m is not None:
4239
        logg.update('#########################################################'
4240
                    '######################################', v=1)
4241
        logg.update(f'testing model {m}', v=1)
4242
4243
    c_90 = np.percentile(c, 90)
4244
    u_90 = np.percentile(u, 90)
4245
    s_90 = np.percentile(s, 90)
4246
    low_quality = (u_90 == 0 or s_90 == 0) if ro else (c_90 == 0 or u_90 == 0
4247
                                                       or s_90 == 0)
4248
    if low_quality:
4249
        logg.update(f'low quality gene {gene}, skipping', v=1)
4250
        return (np.inf, np.nan, '', (np.zeros(3), np.zeros(4), 0, 0, 0, 0),
4251
                np.zeros(3), np.zeros(len(u)), np.zeros(len(u)),
4252
                np.zeros((len(u), 3)), (-1.0, 0, 0, 0),
4253
                (np.zeros((1, 3)), np.zeros((1, 3)), np.zeros((1, 3)), 0, 0,
4254
                 0, 0))
4255
4256
    if gpdist is not None:
4257
        subset_cells = s > 0.1 * np.percentile(s, 99)
4258
        subset_cells = np.where(subset_cells)[0]
4259
        if len(subset_cells) > 3000:
4260
            rng = np.random.default_rng(2021)
4261
            subset_cells = rng.choice(subset_cells, 3000, replace=False)
4262
        local_pdist = gpdist[np.ix_(subset_cells, subset_cells)]
4263
        dists = (np.ravel(local_pdist[np.triu_indices_from(local_pdist, k=1)])
4264
                 .reshape(-1, 1))
4265
        local_std = np.std(dists)
4266
    else:
4267
        local_std = None
4268
4269
    cdc = ChromatinDynamical(c,
4270
                             u,
4271
                             s,
4272
                             model=m,
4273
                             max_iter=mi,
4274
                             init_mode=im,
4275
                             device=dev,
4276
                             neural_net=nn,
4277
                             adam=ad,
4278
                             adam_lr=lr,
4279
                             adam_beta1=b1,
4280
                             adam_beta2=b2,
4281
                             batch_size=bs,
4282
                             local_std=local_std,
4283
                             embed_coord=embed,
4284
                             connectivities=conn,
4285
                             plot=pl,
4286
                             save_plot=sp,
4287
                             plot_dir=pdir,
4288
                             fit_args=fa,
4289
                             gene=gene,
4290
                             partial=pa,
4291
                             direction=di,
4292
                             rna_only=ro,
4293
                             fit_decoupling=fd,
4294
                             extra_color=extra,
4295
                             rescale_u=ru,
4296
                             alpha=alpha,
4297
                             beta=beta,
4298
                             gamma=gamma,
4299
                             t_=t_)
4300
    if fit:
4301
        loss = cdc.fit()
4302
        if loss[-1] == np.inf:
4303
            logg.update(f'low quality gene {gene}, skipping..', v=1)
4304
    loss = cdc.get_loss()
4305
    model = cdc.get_model()
4306
    direction = cdc.get_direction()
4307
    parameters = cdc.get_params()
4308
    initial_exp = cdc.get_initial_exp()
4309
    velocity = cdc.get_velocity()
4310
    likelihood = cdc.get_likelihood()
4311
    time = cdc.get_time_assignment()
4312
    state = cdc.get_state_assignment()
4313
    anchors = cdc.get_anchors()
4314
    return loss[-1], model, direction, parameters, initial_exp, time, state, \
4315
        velocity, likelihood, anchors
4316
4317
4318
def multimodel_helper(c, u, s,
4319
                      model_to_run,
4320
                      max_iter,
4321
                      init_mode,
4322
                      device,
4323
                      neural_net,
4324
                      adam,
4325
                      adam_lr,
4326
                      adam_beta1,
4327
                      adam_beta2,
4328
                      batch_size,
4329
                      global_pdist,
4330
                      embed_coord,
4331
                      conn,
4332
                      plot,
4333
                      save_plot,
4334
                      plot_dir,
4335
                      fit_args,
4336
                      gene,
4337
                      partial,
4338
                      direction,
4339
                      rna_only,
4340
                      fit,
4341
                      fit_decoupling,
4342
                      extra_color,
4343
                      rescale_u,
4344
                      alpha,
4345
                      beta,
4346
                      gamma,
4347
                      t_,
4348
                      verbosity, log_folder, log_filename):
4349
4350
    loss, param_cand, initial_cand, time_cand = [], [], [], []
4351
    state_cand, velo_cand, likelihood_cand, anch_cand = [], [], [], []
4352
4353
    for model in model_to_run:
4354
        (loss_m, _, direction_, parameters, initial_exp,
4355
         time, state, velocity, likelihood, anchors) = \
4356
         regress_func(c, u, s, model, max_iter, init_mode, device, neural_net,
4357
                      adam, adam_lr, adam_beta1, adam_beta2, batch_size,
4358
                      global_pdist, embed_coord, conn, plot, save_plot,
4359
                      plot_dir, fit_args, gene, partial, direction, rna_only,
4360
                      fit, fit_decoupling, extra_color, rescale_u, alpha, beta,
4361
                      gamma, t_)
4362
        loss.append(loss_m)
4363
        param_cand.append(parameters)
4364
        initial_cand.append(initial_exp)
4365
        time_cand.append(time)
4366
        state_cand.append(state)
4367
        velo_cand.append(velocity)
4368
        likelihood_cand.append(likelihood)
4369
        anch_cand.append(anchors)
4370
4371
    best_model = np.argmin(loss)
4372
    model = np.nan if rna_only else model_to_run[best_model]
4373
    parameters = param_cand[best_model]
4374
    initial_exp = initial_cand[best_model]
4375
    time = time_cand[best_model]
4376
    state = state_cand[best_model]
4377
    velocity = velo_cand[best_model]
4378
    likelihood = likelihood_cand[best_model]
4379
    anchors = anch_cand[best_model]
4380
    return loss, model, direction_, parameters, initial_exp, time, state, \
4381
        velocity, likelihood, anchors
4382
4383
4384
def recover_dynamics_chrom(adata_rna,
4385
                           adata_atac=None,
4386
                           gene_list=None,
4387
                           max_iter=5,
4388
                           init_mode='invert',
4389
                           device="cpu",
4390
                           neural_net=False,
4391
                           adam=False,
4392
                           adam_lr=None,
4393
                           adam_beta1=None,
4394
                           adam_beta2=None,
4395
                           batch_size=None,
4396
                           model_to_run=None,
4397
                           plot=False,
4398
                           parallel=True,
4399
                           n_jobs=None,
4400
                           save_plot=False,
4401
                           plot_dir=None,
4402
                           rna_only=False,
4403
                           fit=True,
4404
                           fit_decoupling=True,
4405
                           extra_color_key=None,
4406
                           embedding='X_umap',
4407
                           n_anchors=500,
4408
                           k_dist=1,
4409
                           thresh_multiplier=1.0,
4410
                           weight_c=0.6,
4411
                           outlier=99.8,
4412
                           n_pcs=30,
4413
                           n_neighbors=30,
4414
                           fig_size=(8, 6),
4415
                           point_size=7,
4416
                           partial=None,
4417
                           direction=None,
4418
                           rescale_u=None,
4419
                           alpha=None,
4420
                           beta=None,
4421
                           gamma=None,
4422
                           t_sw=None
4423
                           ):
4424
4425
    """Multi-omic dynamics recovery.
4426
4427
    This function optimizes the joint chromatin and RNA model parameters in
4428
    ODE solutions.
4429
4430
    Parameters
4431
    ----------
4432
    adata_rna: :class:`~anndata.AnnData`
4433
        RNA anndata object. Required fields: `Mu`, `Ms`, and `connectivities`.
4434
    adata_atac: :class:`~anndata.AnnData` (default: `None`)
4435
        ATAC anndata object. Required fields: `Mc`.
4436
    gene_list: `str`,  list of `str` (default: highly variable genes)
4437
        Genes to use for model fitting.
4438
    max_iter: `int` (default: `5`)
4439
        Iterations to run for parameter optimization.
4440
    init_mode: `str` (default: `'invert'`)
4441
        Initialization method for switch times.
4442
        `'invert'`: initial RNA switch time will be computed with scVelo time
4443
        inversion method.
4444
        `'grid'`: grid search the best set of switch times.
4445
        `'simple'`: simply initialize switch times to be 5, 10, and 15.
4446
    device: `str` (default: `'cpu'`)
4447
        The CUDA device that pytorch tensor calculations will be run on. Only
4448
        to be used with Adam or Neural Network mode.
4449
    neural_net: `bool` (default: `False`)
4450
        Whether to run time predictions with a neural network or not. Shortens
4451
        runtime at the expense of accuracy. If False, uses the usual method of
4452
        assigning each data point to an anchor time point as outlined in the
4453
        Multivelo paper.
4454
    adam: `bool` (default: `False`)
4455
        Whether MSE minimization is handled by the Adam algorithm or not. When
4456
        set to the default of False, function uses Nelder-Mead instead.
4457
    adam_lr: `float` (default: `None`)
4458
        The learning rate to use the Adam algorithm. If adam is False, this
4459
        value is ignored.
4460
    adam_beta1: `float` (default: `None`)
4461
        The beta1 parameter for the Adam algorithm. If adam is False, this
4462
        value is ignored.
4463
    adam_beta2: `float` (default: `None`)
4464
        The beta2 parameter for the Adam algorithm. If adam is False, this
4465
        value is ignored.
4466
    batch_size: `int` (default: `None`)
4467
        Speeds up performance using minibatch training. Specifies number of
4468
        cells to use per run of MSE when running the Adam algorithm. Ignored
4469
        if Adam is set to False.
4470
    model_to_run: `int` or list of `int` (default: `None`)
4471
        User specified models for each genes. Possible values are 1 are 2. If
4472
        `None`, the model
4473
        for each gene will be inferred based on expression patterns. If more
4474
        than one value is given,
4475
        the best model will be decided based on loss of fit.
4476
    plot: `bool` or `None` (default: `False`)
4477
        Whether to interactively plot the 3D gene portraits. Ignored if
4478
        parallel is True.
4479
    parallel: `bool` (default: `True`)
4480
        Whether to fit genes in a parallel fashion (recommended).
4481
    n_jobs: `int` (default: available threads)
4482
        Number of parallel jobs.
4483
    save_plot: `bool` (default: `False`)
4484
        Whether to save the fitted gene portrait figures as files. This will
4485
        take some disk space.
4486
    plot_dir: `str` (default: `plots` for multiome and `rna_plots` for
4487
    RNA-only)
4488
        Directory to save the plots.
4489
    rna_only: `bool` (default: `False`)
4490
        Whether to only use RNA for fitting (RNA velocity).
4491
    fit: `bool` (default: `True`)
4492
        Whether to fit the models. If False, only pre-determination and
4493
        initialization will be run.
4494
    fit_decoupling: `bool` (default: `True`)
4495
        Whether to fit decoupling phase (Model 1 vs Model 2 distinction).
4496
    n_anchors: `int` (default: 500)
4497
        Number of anchor time-points to generate as a representation of the
4498
        trajectory.
4499
    k_dist: `int` (default: 1)
4500
        Number of anchors to use to determine a cell's gene time. If more than
4501
        1, time will be averaged.
4502
    thresh_multiplier: `float` (default: 1.0)
4503
        Multiplier for the heuristic threshold of partial versus complete
4504
        trajectory pre-determination.
4505
    weight_c: `float` (default: 0.6)
4506
        Weighting of scaled chromatin distances when performing 3D residual
4507
        calculation.
4508
    outlier: `float` (default: 99.8)
4509
        The percentile to mark as outlier that will be excluded when fitting
4510
        the model.
4511
    n_pcs: `int` (default: 30)
4512
        Number of principal components to compute distance smoothing neighbors.
4513
        This can be different from the one used for expression smoothing.
4514
    n_neighbors: `int` (default: 30)
4515
        Number of nearest neighbors for distance smoothing.
4516
        This can be different from the one used for expression smoothing.
4517
    fig_size: `tuple` (default: (8,6))
4518
        Size of each figure when saved.
4519
    point_size: `float` (default: 7)
4520
        Marker point size for plotting.
4521
    extra_color_key: `str` (default: `None`)
4522
        Extra color key used for plotting. Common choices are `leiden`,
4523
        `celltype`, etc.
4524
        The colors for each category must be present in one of anndatas, which
4525
        can be pre-computed
4526
        with `scanpy.pl.scatter` function.
4527
    embedding: `str` (default: `X_umap`)
4528
        2D coordinates of the low-dimensional embedding of cells.
4529
    partial: `bool` or list of `bool` (default: `None`)
4530
        User specified trajectory completeness for each gene.
4531
    direction: `str` or list of `str` (default: `None`)
4532
        User specified trajectory directionality for each gene.
4533
    rescale_u: `float` or list of `float` (default: `None`)
4534
        Known scaling factors for unspliced. Can be computed from scVelo
4535
        `fit_scaling` values
4536
        as `rescale_u = fit_scaling / std(u) * std(s)`.
4537
    alpha: `float` or list of `float` (default: `None`)
4538
        Known trascription rates. Can be computed from scVelo `fit_alpha`
4539
        values
4540
        as `alpha = fit_alpha * fit_alignment_scaling`.
4541
    beta: `float` or list of `float` (default: `None`)
4542
        Known splicing rates. Can be computed from scVelo `fit_alpha` values
4543
        as `beta = fit_beta * fit_alignment_scaling`.
4544
    gamma: `float` or list of `float` (default: `None`)
4545
        Known degradation rates. Can be computed from scVelo `fit_gamma` values
4546
        as `gamma = fit_gamma * fit_alignment_scaling`.
4547
    t_sw: `float` or list of `float` (default: `None`)
4548
        Known RNA switch time. Can be computed from scVelo `fit_t_` values
4549
        as `t_sw = fit_t_ / fit_alignment_scaling`.
4550
4551
    Returns
4552
    -------
4553
    fit_alpha_c, fit_alpha, fit_beta, fit_gamma: `.var`
4554
        inferred chromatin opening, transcription, splicing, and degradation
4555
        (nuclear export) rates
4556
    fit_t_sw1, fit_t_sw2, fit_t_sw3: `.var`
4557
        inferred switching time points
4558
    fit_rescale_c, fit_rescale_u: `.var`
4559
        inferred scaling factor for chromatin and unspliced counts
4560
    fit_scale_cc: `.var`
4561
        inferred scaling value for chromatin closing rate compared to opening
4562
        rate
4563
    fit_alignment_scaling: `.var`
4564
        ratio used to realign observed time range to 0-20
4565
    fit_c0, fit_u0, fit_s0: `.var`
4566
        initial expression values at earliest observed time
4567
    fit_model: `.var`
4568
        inferred gene model
4569
    fit_direction: `.var`
4570
        inferred gene direction
4571
    fit_loss: `.var`
4572
        loss of model fit
4573
    fit_likelihood: `.var`
4574
        likelihood of model fit
4575
    fit_likelihood_c: `.var`
4576
        likelihood of chromatin fit
4577
    fit_anchor_c, fit_anchor_u, fit_anchor_s: `.varm`
4578
        anchor expressions
4579
    fit_anchor_c_sw, fit_anchor_u_sw, fit_anchor_s_sw: `.varm`
4580
        switch time-point expressions
4581
    fit_anchor_c_velo, fit_anchor_u_velo, fit_anchor_s_velo: `.varm`
4582
        velocities of anchors
4583
    fit_anchor_min_idx: `.var`
4584
        first anchor mapped to observations
4585
    fit_anchor_max_idx: `.var`
4586
        last anchor mapped to observations
4587
    fit_anchor_velo_min_idx: `.var`
4588
        first velocity anchor mapped to observations
4589
    fit_anchor_velo_max_idx: `.var`
4590
        last velocity anchor mapped to observations
4591
    fit_t: `.layers`
4592
        inferred gene time
4593
    fit_state: `.layers`
4594
        inferred state assignments
4595
    velo_s, velo_u, velo_chrom: `.layers`
4596
        velocities in spliced, unspliced, and chromatin space
4597
    velo_s_genes, velo_u_genes, velo_chrom_genes: `.var`
4598
        velocity genes
4599
    velo_s_params, velo_u_params, velo_chrom_params: `.var`
4600
        fitting arguments used
4601
    ATAC: `.layers`
4602
        KNN smoothed chromatin accessibilities copied from adata_atac
4603
    """
4604
4605
    fit_args = {}
4606
    fit_args['max_iter'] = max_iter
4607
    fit_args['init_mode'] = init_mode
4608
    fit_args['fit_decoupling'] = fit_decoupling
4609
    n_anchors = np.clip(int(n_anchors), 201, 2000)
4610
    fit_args['t'] = n_anchors
4611
    fit_args['k'] = k_dist
4612
    fit_args['thresh_multiplier'] = thresh_multiplier
4613
    fit_args['weight_c'] = weight_c
4614
    fit_args['outlier'] = outlier
4615
    fit_args['n_pcs'] = n_pcs
4616
    fit_args['n_neighbors'] = n_neighbors
4617
    fit_args['fig_size'] = list(fig_size)
4618
    fit_args['point_size'] = point_size
4619
4620
    if adam and neural_net:
4621
        raise Exception("ADAM and Neural Net mode can not be run concurently."
4622
                        " Please choose one to run on.")
4623
4624
    if not adam and not neural_net and not device == "cpu":
4625
        raise Exception("Multivelo only uses non-CPU devices for Adam or"
4626
                        " Neural Network mode. Please use one of those or"
4627
                        "set the device to \"cpu\"")
4628
4629
    if adam and not device[0:5] == "cuda:":
4630
        raise Exception("ADAM and Neural Net mode are only possible on a cuda "
4631
                        "device. Please try again.")
4632
    if not adam and batch_size is not None:
4633
        raise Exception("Batch training is for ADAM only, please set "
4634
                        "batch_size to None")
4635
4636
    if adam:
4637
        from cuml.neighbors import NearestNeighbors
4638
4639
    all_genes = adata_rna.var_names
4640
    if adata_atac is None:
4641
        import anndata as ad
4642
        rna_only = True
4643
        adata_atac = ad.AnnData(X=np.ones(adata_rna.shape), obs=adata_rna.obs,
4644
                                var=adata_rna.var)
4645
        adata_atac.layers['Mc'] = np.ones(adata_rna.shape)
4646
    if adata_rna.shape != adata_atac.shape:
4647
        raise ValueError('Shape of RNA and ATAC adata objects do not match: '
4648
                         f'{adata_rna.shape} {adata_atac.shape}')
4649
    if not np.all(adata_rna.obs_names == adata_atac.obs_names):
4650
        raise ValueError('obs_names of RNA and ATAC adata objects do not '
4651
                         'match, please check if they are consistent')
4652
    if not np.all(all_genes == adata_atac.var_names):
4653
        raise ValueError('var_names of RNA and ATAC adata objects do not '
4654
                         'match, please check if they are consistent')
4655
    if 'connectivities' not in adata_rna.obsp.keys():
4656
        raise ValueError('Missing connectivities entry in RNA adata object')
4657
    if extra_color_key is None:
4658
        extra_color = None
4659
    elif (isinstance(extra_color_key, str) and extra_color_key in adata_rna.obs
4660
            and adata_rna.obs[extra_color_key].dtype.name == 'category'):
4661
        ngroups = len(adata_rna.obs[extra_color_key].cat.categories)
4662
        extra_color = adata_rna.obs[extra_color_key].cat.rename_categories(
4663
            adata_rna.uns[extra_color_key+'_colors'][:ngroups]).to_numpy()
4664
    elif (isinstance(extra_color_key, str) and extra_color_key in
4665
          adata_atac.obs and
4666
          adata_rna.obs[extra_color_key].dtype.name == 'category'):
4667
        ngroups = len(adata_atac.obs[extra_color_key].cat.categories)
4668
        extra_color = adata_atac.obs[extra_color_key].cat.rename_categories(
4669
            adata_atac.uns[extra_color_key+'_colors'][:ngroups]).to_numpy()
4670
    else:
4671
        raise ValueError('Currently, extra_color_key must be a single string '
4672
                         'of categories and available in adata obs, and its '
4673
                         'colors can be found in adata uns')
4674
    if ('connectivities' not in adata_rna.obsp.keys() or
4675
            (adata_rna.obsp['connectivities'] > 0).sum(1).min()
4676
            > (n_neighbors-1)):
4677
        neighbors = Neighbors(adata_rna)
4678
        neighbors.compute_neighbors(n_neighbors=n_neighbors, knn=True,
4679
                                    n_pcs=n_pcs)
4680
        rna_conn = neighbors.connectivities
4681
    else:
4682
        rna_conn = adata_rna.obsp['connectivities'].copy()
4683
    rna_conn.setdiag(1)
4684
    rna_conn = rna_conn.multiply(1.0 / rna_conn.sum(1)).tocsr()
4685
    if not rna_only:
4686
        if 'connectivities' not in adata_atac.obsp.keys():
4687
            logg.update('Missing connectivities in ATAC adata object, using '
4688
                        'RNA connectivities instead', v=1)
4689
            atac_conn = rna_conn
4690
        else:
4691
            atac_conn = adata_atac.obsp['connectivities'].copy()
4692
            atac_conn.setdiag(1)
4693
        atac_conn = atac_conn.multiply(1.0 / atac_conn.sum(1)).tocsr()
4694
    if gene_list is None:
4695
        if 'highly_variable' in adata_rna.var:
4696
            gene_list = adata_rna.var_names[adata_rna.var['highly_variable']]\
4697
                .values
4698
        else:
4699
            gene_list = adata_rna.var_names.values[
4700
                (~np.isnan(np.asarray(adata_rna.layers['Mu'].sum(0))
4701
                             .reshape(-1)
4702
                           if sparse.issparse(adata_rna.layers['Mu'])
4703
                           else np.sum(adata_rna.layers['Mu'], axis=0)))
4704
                & (~np.isnan(np.asarray(adata_rna.layers['Ms'].sum(0))
4705
                             .reshape(-1)
4706
                             if sparse.issparse(adata_rna.layers['Ms'])
4707
                             else np.sum(adata_rna.layers['Ms'], axis=0)))
4708
                & (~np.isnan(np.asarray(adata_atac.layers['Mc'].sum(0))
4709
                             .reshape(-1)
4710
                             if sparse.issparse(adata_atac.layers['Mc'])
4711
                             else np.sum(adata_atac.layers['Mc'], axis=0)))]
4712
    elif isinstance(gene_list, (list, np.ndarray, pd.Index, pd.Series)):
4713
        gene_list = np.array([x for x in gene_list if x in all_genes])
4714
    elif isinstance(gene_list, str):
4715
        gene_list = np.array([gene_list]) if gene_list in all_genes else []
4716
    else:
4717
        raise ValueError('Invalid gene list, must be one of (str, np.ndarray,'
4718
                         'pd.Index, pd.Series)')
4719
    gn = len(gene_list)
4720
    if gn == 0:
4721
        raise ValueError('None of the genes specified are in the adata object')
4722
    logg.update(f'{gn} genes will be fitted', v=1)
4723
4724
    models = np.zeros(gn)
4725
    t_sws = np.zeros((gn, 3))
4726
    rates = np.zeros((gn, 4))
4727
    scale_ccs = np.zeros(gn)
4728
    rescale_cs = np.zeros(gn)
4729
    rescale_us = np.zeros(gn)
4730
    realign_ratios = np.zeros(gn)
4731
    initial_exps = np.zeros((gn, 3))
4732
    times = np.zeros((adata_rna.n_obs, gn))
4733
    states = np.zeros((adata_rna.n_obs, gn))
4734
    if not rna_only:
4735
        velo_c = np.zeros((adata_rna.n_obs, gn))
4736
    velo_u = np.zeros((adata_rna.n_obs, gn))
4737
    velo_s = np.zeros((adata_rna.n_obs, gn))
4738
    likelihoods = np.zeros(gn)
4739
    l_cs = np.zeros(gn)
4740
    ssd_cs = np.zeros(gn)
4741
    var_cs = np.zeros(gn)
4742
    directions = []
4743
    anchor_c = np.zeros((n_anchors, gn))
4744
    anchor_u = np.zeros((n_anchors, gn))
4745
    anchor_s = np.zeros((n_anchors, gn))
4746
    anchor_c_sw = np.zeros((3, gn))
4747
    anchor_u_sw = np.zeros((3, gn))
4748
    anchor_s_sw = np.zeros((3, gn))
4749
    anchor_vc = np.zeros((n_anchors, gn))
4750
    anchor_vu = np.zeros((n_anchors, gn))
4751
    anchor_vs = np.zeros((n_anchors, gn))
4752
    anchor_min_idx = np.zeros(gn)
4753
    anchor_max_idx = np.zeros(gn)
4754
    anchor_velo_min_idx = np.zeros(gn)
4755
    anchor_velo_max_idx = np.zeros(gn)
4756
4757
    if rna_only:
4758
        model_to_run = [2]
4759
        logg.update('Skipping model checking for RNA-only, running model 2',
4760
                    v=1)
4761
4762
    m_per_g = False
4763
    if model_to_run is not None:
4764
        if isinstance(model_to_run, (list, np.ndarray, pd.Index, pd.Series)):
4765
            model_to_run = [int(x) for x in model_to_run]
4766
            if np.any(~np.isin(model_to_run, [0, 1, 2])):
4767
                raise ValueError('Invalid model number (must be values in'
4768
                                 ' [0,1,2])')
4769
            if len(model_to_run) == gn:
4770
                losses = np.zeros((gn, 1))
4771
                m_per_g = True
4772
                func_to_call = regress_func
4773
            else:
4774
                losses = np.zeros((gn, len(model_to_run)))
4775
                func_to_call = multimodel_helper
4776
        elif isinstance(model_to_run, (int, float)):
4777
            model_to_run = int(model_to_run)
4778
            if not np.isin(model_to_run, [0, 1, 2]):
4779
                raise ValueError('Invalid model number (must be values in '
4780
                                 '[0,1,2])')
4781
            model_to_run = [model_to_run]
4782
            losses = np.zeros((gn, 1))
4783
            func_to_call = multimodel_helper
4784
        else:
4785
            raise ValueError('Invalid model number (must be values in '
4786
                             '[0,1,2])')
4787
    else:
4788
        losses = np.zeros((gn, 1))
4789
        func_to_call = regress_func
4790
4791
    p_per_g = False
4792
    if partial is not None:
4793
        if isinstance(partial, (list, np.ndarray, pd.Index, pd.Series)):
4794
            if np.any(~np.isin(partial, [True, False])):
4795
                raise ValueError('Invalid partial argument (must be values in'
4796
                                 ' [True,False])')
4797
            if len(partial) == gn:
4798
                p_per_g = True
4799
            else:
4800
                raise ValueError('Incorrect partial argument length')
4801
        elif isinstance(partial, bool):
4802
            if not np.isin(partial, [True, False]):
4803
                raise ValueError('Invalid partial argument (must be values in'
4804
                                 ' [True,False])')
4805
        else:
4806
            raise ValueError('Invalid partial argument (must be values in'
4807
                             ' [True,False])')
4808
4809
    d_per_g = False
4810
    if direction is not None:
4811
        if isinstance(direction, (list, np.ndarray, pd.Index, pd.Series)):
4812
            if np.any(~np.isin(direction, ['on', 'off', 'complete'])):
4813
                raise ValueError('Invalid direction argument (must be values'
4814
                                 ' in ["on","off","complete"])')
4815
            if len(direction) == gn:
4816
                d_per_g = True
4817
            else:
4818
                raise ValueError('Incorrect direction argument length')
4819
        elif isinstance(direction, str):
4820
            if not np.isin(direction, ['on', 'off', 'complete']):
4821
                raise ValueError('Invalid direction argument (must be values'
4822
                                 ' in ["on","off","complete"])')
4823
        else:
4824
            raise ValueError('Invalid direction argument (must be values in'
4825
                             ' ["on","off","complete"])')
4826
4827
    known_pars = [rescale_u, alpha, beta, gamma, t_sw]
4828
    for x in known_pars:
4829
        if x is not None:
4830
            if isinstance(x, (list, np.ndarray)):
4831
                if np.sum(np.isnan(x)) + np.sum(np.isinf(x)) > 0:
4832
                    raise ValueError('Known parameters cannot contain NaN or'
4833
                                     ' Inf')
4834
            elif isinstance(x, (int, float)):
4835
                if x == np.nan or x == np.inf:
4836
                    raise ValueError('Known parameters cannot contain NaN or'
4837
                                     ' Inf')
4838
            else:
4839
                raise ValueError('Invalid known parameters type')
4840
4841
    if ((embedding not in adata_rna.obsm) and
4842
            (embedding not in adata_atac.obsm)):
4843
        raise ValueError(f'{embedding} is not found in obsm')
4844
    embed_coord = adata_rna.obsm[embedding] if embedding in adata_rna.obsm \
4845
        else adata_atac.obsm[embedding]
4846
    global_pdist = pairwise_distances(embed_coord)
4847
4848
    u_mat = adata_rna[:, gene_list].layers['Mu'].A \
4849
        if sparse.issparse(adata_rna.layers['Mu']) \
4850
        else adata_rna[:, gene_list].layers['Mu']
4851
    s_mat = adata_rna[:, gene_list].layers['Ms'].A \
4852
        if sparse.issparse(adata_rna.layers['Ms']) \
4853
        else adata_rna[:, gene_list].layers['Ms']
4854
    c_mat = adata_atac[:, gene_list].layers['Mc'].A \
4855
        if sparse.issparse(adata_atac.layers['Mc']) \
4856
        else adata_atac[:, gene_list].layers['Mc']
4857
4858
    ru = rescale_u if rescale_u is not None else None
4859
4860
    if parallel:
4861
        if (n_jobs is None or not isinstance(n_jobs, int) or n_jobs < 0 or
4862
                n_jobs > os.cpu_count()):
4863
            n_jobs = os.cpu_count()
4864
        if n_jobs > gn:
4865
            n_jobs = gn
4866
        batches = -(-gn // n_jobs)
4867
        if n_jobs > 1:
4868
            logg.update(f'running {n_jobs} jobs in parallel', v=1)
4869
    else:
4870
        n_jobs = 1
4871
        batches = gn
4872
    if n_jobs == 1:
4873
        parallel = False
4874
4875
    pbar = tqdm(total=gn)
4876
    for group in range(batches):
4877
        gene_indices = range(group * n_jobs, np.min([gn, (group+1) * n_jobs]))
4878
        if parallel:
4879
            verb = 51 if settings.VERBOSITY >= 2 else 0
4880
            plot = False
4881
4882
            # clear the settings file if it exists
4883
            open("settings.txt", "w").close()
4884
4885
            # write our current settings to the file
4886
            with open("settings.txt", "a") as sfile:
4887
                sfile.write(str(settings.VERBOSITY) + "\n")
4888
                sfile.write(str(settings.CWD) + "\n")
4889
                sfile.write(str(settings.LOG_FOLDER) + "\n")
4890
                sfile.write(str(settings.LOG_FILENAME) + "\n")
4891
4892
            res = Parallel(n_jobs=n_jobs, backend='loky', verbose=verb)(
4893
                delayed(func_to_call)(
4894
                    c_mat[:, i],
4895
                    u_mat[:, i],
4896
                    s_mat[:, i],
4897
                    model_to_run[i] if m_per_g else model_to_run,
4898
                    max_iter,
4899
                    init_mode,
4900
                    device,
4901
                    neural_net,
4902
                    adam,
4903
                    adam_lr,
4904
                    adam_beta1,
4905
                    adam_beta2,
4906
                    batch_size,
4907
                    global_pdist,
4908
                    embed_coord,
4909
                    rna_conn,
4910
                    plot,
4911
                    save_plot,
4912
                    plot_dir,
4913
                    fit_args,
4914
                    gene_list[i],
4915
                    partial[i] if p_per_g else partial,
4916
                    direction[i] if d_per_g else direction,
4917
                    rna_only,
4918
                    fit,
4919
                    fit_decoupling,
4920
                    extra_color,
4921
                    ru[i] if isinstance(ru, (list, np.ndarray)) else ru,
4922
                    alpha[i] if isinstance(alpha, (list, np.ndarray))
4923
                    else alpha,
4924
                    beta[i] if isinstance(beta, (list, np.ndarray))
4925
                    else beta,
4926
                    gamma[i] if isinstance(gamma, (list, np.ndarray))
4927
                    else gamma,
4928
                    t_sw[i] if isinstance(t_sw, (list, np.ndarray)) else t_sw,
4929
                    settings.VERBOSITY,
4930
                    settings.LOG_FOLDER,
4931
                    settings.LOG_FILENAME)
4932
                for i in gene_indices)
4933
4934
            for i, r in zip(gene_indices, res):
4935
                (loss, model, direct_out, parameters, initial_exp,
4936
                 time, state, velocity, likelihood, anchors) = r
4937
                switch, rate, scale_cc, rescale_c, rescale_u, realign_ratio = \
4938
                    parameters
4939
                likelihood, l_c, ssd_c, var_c = likelihood
4940
                losses[i, :] = loss
4941
                models[i] = model
4942
                directions.append(direct_out)
4943
                t_sws[i, :] = switch
4944
                rates[i, :] = rate
4945
                scale_ccs[i] = scale_cc
4946
                rescale_cs[i] = rescale_c
4947
                rescale_us[i] = rescale_u
4948
                realign_ratios[i] = realign_ratio
4949
                likelihoods[i] = likelihood
4950
                l_cs[i] = l_c
4951
                ssd_cs[i] = ssd_c
4952
                var_cs[i] = var_c
4953
                if fit:
4954
                    initial_exps[i, :] = initial_exp
4955
                    times[:, i] = time
4956
                    states[:, i] = state
4957
                    n_anchors_ = anchors[0].shape[0]
4958
                    n_switch = anchors[1].shape[0]
4959
                    if not rna_only:
4960
                        velo_c[:, i] = smooth_scale(atac_conn, velocity[:, 0])
4961
                        anchor_c[:n_anchors_, i] = anchors[0][:, 0]
4962
                        anchor_c_sw[:n_switch, i] = anchors[1][:, 0]
4963
                        anchor_vc[:n_anchors_, i] = anchors[2][:, 0]
4964
                    velo_u[:, i] = smooth_scale(rna_conn, velocity[:, 1])
4965
                    velo_s[:, i] = smooth_scale(rna_conn, velocity[:, 2])
4966
                    anchor_u[:n_anchors_, i] = anchors[0][:, 1]
4967
                    anchor_s[:n_anchors_, i] = anchors[0][:, 2]
4968
                    anchor_u_sw[:n_switch, i] = anchors[1][:, 1]
4969
                    anchor_s_sw[:n_switch, i] = anchors[1][:, 2]
4970
                    anchor_vu[:n_anchors_, i] = anchors[2][:, 1]
4971
                    anchor_vs[:n_anchors_, i] = anchors[2][:, 2]
4972
                    anchor_min_idx[i] = anchors[3]
4973
                    anchor_max_idx[i] = anchors[4]
4974
                    anchor_velo_min_idx[i] = anchors[5]
4975
                    anchor_velo_max_idx[i] = anchors[6]
4976
        else:
4977
            i = group
4978
            gene = gene_list[i]
4979
            logg.update(f'@@@@@fitting {gene}', v=1)
4980
            (loss, model, direct_out,
4981
             parameters, initial_exp,
4982
             time, state, velocity,
4983
             likelihood, anchors) = \
4984
                func_to_call(c_mat[:, i], u_mat[:, i], s_mat[:, i],
4985
                             model_to_run[i] if m_per_g else model_to_run,
4986
                             max_iter, init_mode,
4987
                             device,
4988
                             neural_net,
4989
                             adam,
4990
                             adam_lr,
4991
                             adam_beta1,
4992
                             adam_beta2,
4993
                             batch_size,
4994
                             global_pdist, embed_coord,
4995
                             rna_conn, plot, save_plot, plot_dir,
4996
                             fit_args, gene,
4997
                             partial[i] if p_per_g else partial,
4998
                             direction[i] if d_per_g else direction,
4999
                             rna_only, fit, fit_decoupling, extra_color,
5000
                             ru[i] if isinstance(ru, (list, np.ndarray))
5001
                             else ru,
5002
                             alpha[i] if isinstance(alpha, (list, np.ndarray))
5003
                             else alpha,
5004
                             beta[i] if isinstance(beta, (list, np.ndarray))
5005
                             else beta,
5006
                             gamma[i] if isinstance(gamma, (list, np.ndarray))
5007
                             else gamma,
5008
                             t_sw[i] if isinstance(t_sw, (list, np.ndarray))
5009
                             else t_sw,
5010
                             settings.VERBOSITY,
5011
                             settings.LOG_FOLDER,
5012
                             settings.LOG_FILENAME)
5013
            switch, rate, scale_cc, rescale_c, rescale_u, realign_ratio = \
5014
                parameters
5015
            likelihood, l_c, ssd_c, var_c = likelihood
5016
            losses[i, :] = loss
5017
            models[i] = model
5018
            directions.append(direct_out)
5019
            t_sws[i, :] = switch
5020
            rates[i, :] = rate
5021
            scale_ccs[i] = scale_cc
5022
            rescale_cs[i] = rescale_c
5023
            rescale_us[i] = rescale_u
5024
            realign_ratios[i] = realign_ratio
5025
            likelihoods[i] = likelihood
5026
            l_cs[i] = l_c
5027
            ssd_cs[i] = ssd_c
5028
            var_cs[i] = var_c
5029
            if fit:
5030
                initial_exps[i, :] = initial_exp
5031
                times[:, i] = time
5032
                states[:, i] = state
5033
                n_anchors_ = anchors[0].shape[0]
5034
                n_switch = anchors[1].shape[0]
5035
                if not rna_only:
5036
                    velo_c[:, i] = smooth_scale(atac_conn, velocity[:, 0])
5037
                    anchor_c[:n_anchors_, i] = anchors[0][:, 0]
5038
                    anchor_c_sw[:n_switch, i] = anchors[1][:, 0]
5039
                    anchor_vc[:n_anchors_, i] = anchors[2][:, 0]
5040
                velo_u[:, i] = smooth_scale(rna_conn, velocity[:, 1])
5041
                velo_s[:, i] = smooth_scale(rna_conn, velocity[:, 2])
5042
                anchor_u[:n_anchors_, i] = anchors[0][:, 1]
5043
                anchor_s[:n_anchors_, i] = anchors[0][:, 2]
5044
                anchor_u_sw[:n_switch, i] = anchors[1][:, 1]
5045
                anchor_s_sw[:n_switch, i] = anchors[1][:, 2]
5046
                anchor_vu[:n_anchors_, i] = anchors[2][:, 1]
5047
                anchor_vs[:n_anchors_, i] = anchors[2][:, 2]
5048
                anchor_min_idx[i] = anchors[3]
5049
                anchor_max_idx[i] = anchors[4]
5050
                anchor_velo_min_idx[i] = anchors[5]
5051
                anchor_velo_max_idx[i] = anchors[6]
5052
        pbar.update(len(gene_indices))
5053
    pbar.close()
5054
    directions = np.array(directions)
5055
5056
    filt = np.sum(losses != np.inf, 1) >= 1
5057
    if np.sum(filt) == 0:
5058
        raise ValueError('None of the genes were fitted due to low quality,'
5059
                         ' not returning')
5060
    adata_copy = adata_rna[:, gene_list[filt]].copy()
5061
    adata_copy.layers['ATAC'] = c_mat[:, filt]
5062
    adata_copy.var['fit_alpha_c'] = rates[filt, 0]
5063
    adata_copy.var['fit_alpha'] = rates[filt, 1]
5064
    adata_copy.var['fit_beta'] = rates[filt, 2]
5065
    adata_copy.var['fit_gamma'] = rates[filt, 3]
5066
    adata_copy.var['fit_t_sw1'] = t_sws[filt, 0]
5067
    adata_copy.var['fit_t_sw2'] = t_sws[filt, 1]
5068
    adata_copy.var['fit_t_sw3'] = t_sws[filt, 2]
5069
    adata_copy.var['fit_scale_cc'] = scale_ccs[filt]
5070
    adata_copy.var['fit_rescale_c'] = rescale_cs[filt]
5071
    adata_copy.var['fit_rescale_u'] = rescale_us[filt]
5072
    adata_copy.var['fit_alignment_scaling'] = realign_ratios[filt]
5073
    adata_copy.var['fit_model'] = models[filt]
5074
    adata_copy.var['fit_direction'] = directions[filt]
5075
    if model_to_run is not None and not m_per_g and not rna_only:
5076
        for i, m in enumerate(model_to_run):
5077
            adata_copy.var[f'fit_loss_M{m}'] = losses[filt, i]
5078
    else:
5079
        adata_copy.var['fit_loss'] = losses[filt, 0]
5080
    adata_copy.var['fit_likelihood'] = likelihoods[filt]
5081
    adata_copy.var['fit_likelihood_c'] = l_cs[filt]
5082
    adata_copy.var['fit_ssd_c'] = ssd_cs[filt]
5083
    adata_copy.var['fit_var_c'] = var_cs[filt]
5084
    if fit:
5085
        adata_copy.layers['fit_t'] = times[:, filt]
5086
        adata_copy.layers['fit_state'] = states[:, filt]
5087
        adata_copy.layers['velo_s'] = velo_s[:, filt]
5088
        adata_copy.layers['velo_u'] = velo_u[:, filt]
5089
        if not rna_only:
5090
            adata_copy.layers['velo_chrom'] = velo_c[:, filt]
5091
        adata_copy.var['fit_c0'] = initial_exps[filt, 0]
5092
        adata_copy.var['fit_u0'] = initial_exps[filt, 1]
5093
        adata_copy.var['fit_s0'] = initial_exps[filt, 2]
5094
        adata_copy.var['fit_anchor_min_idx'] = anchor_min_idx[filt]
5095
        adata_copy.var['fit_anchor_max_idx'] = anchor_max_idx[filt]
5096
        adata_copy.var['fit_anchor_velo_min_idx'] = anchor_velo_min_idx[filt]
5097
        adata_copy.var['fit_anchor_velo_max_idx'] = anchor_velo_max_idx[filt]
5098
        adata_copy.varm['fit_anchor_c'] = np.transpose(anchor_c[:, filt])
5099
        adata_copy.varm['fit_anchor_u'] = np.transpose(anchor_u[:, filt])
5100
        adata_copy.varm['fit_anchor_s'] = np.transpose(anchor_s[:, filt])
5101
        adata_copy.varm['fit_anchor_c_sw'] = np.transpose(anchor_c_sw[:, filt])
5102
        adata_copy.varm['fit_anchor_u_sw'] = np.transpose(anchor_u_sw[:, filt])
5103
        adata_copy.varm['fit_anchor_s_sw'] = np.transpose(anchor_s_sw[:, filt])
5104
        adata_copy.varm['fit_anchor_c_velo'] = np.transpose(anchor_vc[:, filt])
5105
        adata_copy.varm['fit_anchor_u_velo'] = np.transpose(anchor_vu[:, filt])
5106
        adata_copy.varm['fit_anchor_s_velo'] = np.transpose(anchor_vs[:, filt])
5107
    v_genes = adata_copy.var['fit_likelihood'] >= 0.05
5108
    adata_copy.var['velo_s_genes'] = adata_copy.var['velo_u_genes'] = \
5109
        adata_copy.var['velo_chrom_genes'] = v_genes
5110
    adata_copy.uns['velo_s_params'] = adata_copy.uns['velo_u_params'] = \
5111
        adata_copy.uns['velo_chrom_params'] = {'mode': 'dynamical'}
5112
    adata_copy.uns['velo_s_params'].update(fit_args)
5113
    adata_copy.uns['velo_u_params'].update(fit_args)
5114
    adata_copy.uns['velo_chrom_params'].update(fit_args)
5115
    adata_copy.obsp['_RNA_conn'] = rna_conn
5116
    if not rna_only:
5117
        adata_copy.obsp['_ATAC_conn'] = atac_conn
5118
    return adata_copy
5119
5120
5121
def smooth_scale(conn, vector):
5122
    max_to = np.max(vector)
5123
    min_to = np.min(vector)
5124
    v = conn.dot(vector.T).T
5125
    max_from = np.max(v)
5126
    min_from = np.min(v)
5127
    res = ((v - min_from) * (max_to - min_to) / (max_from - min_from)) + min_to
5128
    return res
5129
5130
5131
def top_n_sparse(conn, n):
5132
    conn_ll = conn.tolil()
5133
    for i in range(conn_ll.shape[0]):
5134
        row_data = np.array(conn_ll.data[i])
5135
        row_idx = np.array(conn_ll.rows[i])
5136
        new_idx = row_data.argsort()[-n:]
5137
        top_val = row_data[new_idx]
5138
        top_idx = row_idx[new_idx]
5139
        conn_ll.data[i] = top_val.tolist()
5140
        conn_ll.rows[i] = top_idx.tolist()
5141
    conn = conn_ll.tocsr()
5142
    idx1 = conn > 0
5143
    idx2 = conn > 0.25
5144
    idx3 = conn > 0.5
5145
    conn[idx1] = 0.25
5146
    conn[idx2] = 0.5
5147
    conn[idx3] = 1
5148
    conn.eliminate_zeros()
5149
    return conn
5150
5151
5152
def set_velocity_genes(adata,
5153
                       likelihood_lower=0.05,
5154
                       rescale_u_upper=None,
5155
                       rescale_u_lower=None,
5156
                       rescale_c_upper=None,
5157
                       rescale_c_lower=None,
5158
                       primed_upper=None,
5159
                       primed_lower=None,
5160
                       decoupled_upper=None,
5161
                       decoupled_lower=None,
5162
                       alpha_c_upper=None,
5163
                       alpha_c_lower=None,
5164
                       alpha_upper=None,
5165
                       alpha_lower=None,
5166
                       beta_upper=None,
5167
                       beta_lower=None,
5168
                       gamma_upper=None,
5169
                       gamma_lower=None,
5170
                       scale_cc_upper=None,
5171
                       scale_cc_lower=None
5172
                       ):
5173
    """Reset velocity genes.
5174
5175
    This function resets velocity genes based on criteria of variables.
5176
5177
    Parameters
5178
    ----------
5179
    adata: :class:`~anndata.AnnData`
5180
        Anndata result from dynamics recovery.
5181
    likelihood_lower: `float` (default: 0.05)
5182
        Minimum ikelihood.
5183
    rescale_u_upper: `float` (default: `None`)
5184
        Maximum rescale_u.
5185
    rescale_u_lower: `float` (default: `None`)
5186
        Minimum rescale_u.
5187
    rescale_c_upper: `float` (default: `None`)
5188
        Maximum rescale_c.
5189
    rescale_c_lower: `float` (default: `None`)
5190
        Minimum rescale_c.
5191
    primed_upper: `float` (default: `None`)
5192
        Maximum primed interval.
5193
    primed_lower: `float` (default: `None`)
5194
        Minimum primed interval.
5195
    decoupled_upper: `float` (default: `None`)
5196
        Maximum decoupled interval.
5197
    decoupled_lower: `float` (default: `None`)
5198
        Minimum decoupled interval.
5199
    alpha_c_upper: `float` (default: `None`)
5200
        Maximum alpha_c.
5201
    alpha_c_lower: `float` (default: `None`)
5202
        Minimum alpha_c.
5203
    alpha_upper: `float` (default: `None`)
5204
        Maximum alpha.
5205
    alpha_lower: `float` (default: `None`)
5206
        Minimum alpha.
5207
    beta_upper: `float` (default: `None`)
5208
        Maximum beta.
5209
    beta_lower: `float` (default: `None`)
5210
        Minimum beta.
5211
    gamma_upper: `float` (default: `None`)
5212
        Maximum gamma.
5213
    gamma_lower: `float` (default: `None`)
5214
        Minimum gamma.
5215
    scale_cc_upper: `float` (default: `None`)
5216
        Maximum scale_cc.
5217
    scale_cc_lower: `float` (default: `None`)
5218
        Minimum scale_cc.
5219
5220
    Returns
5221
    -------
5222
    velo_s_genes, velo_u_genes, velo_chrom_genes: `.var`
5223
        new velocity genes for each modalities.
5224
    """
5225
5226
    v_genes = (adata.var['fit_likelihood'] >= likelihood_lower)
5227
    if rescale_u_upper is not None:
5228
        v_genes &= adata.var['fit_rescale_u'] <= rescale_u_upper
5229
    if rescale_u_lower is not None:
5230
        v_genes &= adata.var['fit_rescale_u'] >= rescale_u_lower
5231
    if rescale_c_upper is not None:
5232
        v_genes &= adata.var['fit_rescale_c'] <= rescale_c_upper
5233
    if rescale_c_lower is not None:
5234
        v_genes &= adata.var['fit_rescale_c'] >= rescale_c_lower
5235
    t_sw1 = adata.var['fit_t_sw1'] + 20 / adata.uns['velo_s_params']['t'] * \
5236
        adata.var['fit_anchor_min_idx'] * adata.var['fit_alignment_scaling']
5237
    if primed_upper is not None:
5238
        v_genes &= t_sw1 <= primed_upper
5239
    if primed_lower is not None:
5240
        v_genes &= t_sw1 >= primed_lower
5241
    t_sw2 = np.clip(adata.var['fit_t_sw2'], None, 20)
5242
    t_sw3 = np.clip(adata.var['fit_t_sw3'], None, 20)
5243
    t_interval3 = t_sw3 - t_sw2
5244
    if decoupled_upper is not None:
5245
        v_genes &= t_interval3 <= decoupled_upper
5246
    if decoupled_lower is not None:
5247
        v_genes &= t_interval3 >= decoupled_lower
5248
    if alpha_c_upper is not None:
5249
        v_genes &= adata.var['fit_alpha_c'] <= alpha_c_upper
5250
    if alpha_c_lower is not None:
5251
        v_genes &= adata.var['fit_alpha_c'] >= alpha_c_lower
5252
    if alpha_upper is not None:
5253
        v_genes &= adata.var['fit_alpha'] <= alpha_upper
5254
    if alpha_lower is not None:
5255
        v_genes &= adata.var['fit_alpha'] >= alpha_lower
5256
    if beta_upper is not None:
5257
        v_genes &= adata.var['fit_beta'] <= beta_upper
5258
    if beta_lower is not None:
5259
        v_genes &= adata.var['fit_beta'] >= beta_lower
5260
    if gamma_upper is not None:
5261
        v_genes &= adata.var['fit_gamma'] <= gamma_upper
5262
    if gamma_lower is not None:
5263
        v_genes &= adata.var['fit_gamma'] >= gamma_lower
5264
    if scale_cc_upper is not None:
5265
        v_genes &= adata.var['fit_scale_cc'] <= scale_cc_upper
5266
    if scale_cc_lower is not None:
5267
        v_genes &= adata.var['fit_scale_cc'] >= scale_cc_lower
5268
    logg.update(f'{np.sum(v_genes)} velocity genes were selected', v=1)
5269
    adata.var['velo_s_genes'] = adata.var['velo_u_genes'] = \
5270
        adata.var['velo_chrom_genes'] = v_genes
5271
5272
5273
def velocity_graph(adata, vkey='velo_s', xkey='Ms', **kwargs):
5274
    """Computes velocity graph.
5275
5276
    This function normalizes the velocity matrix and computes velocity graph
5277
    with `scvelo.tl.velocity_graph`.
5278
5279
    Parameters
5280
    ----------
5281
    adata: :class:`~anndata.AnnData`
5282
        Anndata result from dynamics recovery.
5283
    vkey: `str` (default: `velo_s`)
5284
        Default to use spliced velocities.
5285
    xkey: `str` (default: `Ms`)
5286
        Default to use smoothed spliced counts.
5287
    Additional parameters passed to `scvelo.tl.velocity_graph`.
5288
5289
    Returns
5290
    -------
5291
    Normalized velocity matrix and associated velocity genes and params.
5292
    Outputs of `scvelo.tl.velocity_graph`.
5293
    """
5294
    if vkey not in adata.layers.keys():
5295
        raise ValueError('Velocity matrix is not found. Please run multivelo'
5296
                         '.recover_dynamics_chrom function first.')
5297
    if vkey+'_norm' not in adata.layers.keys():
5298
        adata.layers[vkey+'_norm'] = adata.layers[vkey] / np.sum(
5299
            np.abs(adata.layers[vkey]), 0)
5300
        adata.layers[vkey+'_norm'] /= np.mean(adata.layers[vkey+'_norm'])
5301
        adata.uns[vkey+'_norm_params'] = adata.uns[vkey+'_params']
5302
    if vkey+'_norm_genes' not in adata.var.columns:
5303
        adata.var[vkey+'_norm_genes'] = adata.var[vkey+'_genes']
5304
    scv.tl.velocity_graph(adata, vkey=vkey+'_norm', xkey=xkey, **kwargs)
5305
5306
5307
def velocity_embedding_stream(adata, vkey='velo_s', show=True, **kwargs):
5308
    """Plots velocity stream.
5309
5310
    This function plots velocity streamplot with
5311
    `scvelo.pl.velocity_embedding_stream`.
5312
5313
    Parameters
5314
    ----------
5315
    adata: :class:`~anndata.AnnData`
5316
        Anndata result from dynamics recovery.
5317
    vkey: `str` (default: `velo_s`)
5318
        Default to use spliced velocities. The normalized matrix will be used.
5319
    show: `bool` (default: `True`)
5320
        Whether to show the plot.
5321
    Additional parameters passed to `scvelo.tl.velocity_graph`.
5322
5323
    Returns
5324
    -------
5325
    If `show==False`, a matplotlib axis object.
5326
    """
5327
    if vkey not in adata.layers:
5328
        raise ValueError('Velocity matrix is not found. Please run multivelo.'
5329
                         'recover_dynamics_chrom function first.')
5330
    if vkey+'_norm' not in adata.layers.keys():
5331
        adata.layers[vkey+'_norm'] = adata.layers[vkey] / np.sum(
5332
            np.abs(adata.layers[vkey]), 0)
5333
        adata.uns[vkey+'_norm_params'] = adata.uns[vkey+'_params']
5334
    if vkey+'_norm_genes' not in adata.var.columns:
5335
        adata.var[vkey+'_norm_genes'] = adata.var[vkey+'_genes']
5336
    if vkey+'_norm_graph' not in adata.uns.keys():
5337
        velocity_graph(adata, vkey=vkey, **kwargs)
5338
    out = scv.pl.velocity_embedding_stream(adata, vkey=vkey+'_norm', show=show,
5339
                                           **kwargs)
5340
    if not show:
5341
        return out
5342
5343
5344
def latent_time(adata, vkey='velo_s', **kwargs):
5345
    """Computes latent time.
5346
5347
    This function computes latent time with `scvelo.tl.latent_time`.
5348
5349
    Parameters
5350
    ----------
5351
    adata: :class:`~anndata.AnnData`
5352
        Anndata result from dynamics recovery.
5353
    vkey: `str` (default: `velo_s`)
5354
        Default to use spliced velocities. The normalized matrix will be used.
5355
    Additional parameters passed to `scvelo.tl.velocity_graph`.
5356
5357
    Returns
5358
    -------
5359
    Outputs of `scvelo.tl.latent_time`.
5360
    """
5361
    if vkey not in adata.layers.keys() or 'fit_t' not in adata.layers.keys():
5362
        raise ValueError('Velocity or time matrix is not found. Please run '
5363
                         'multivelo.recover_dynamics_chrom function first.')
5364
    if vkey+'_norm' not in adata.layers.keys():
5365
        raise ValueError('Normalized velocity matrix is not found. Please '
5366
                         'run multivelo.velocity_graph function first.')
5367
    if vkey+'_norm_graph' not in adata.uns.keys():
5368
        velocity_graph(adata, vkey=vkey, **kwargs)
5369
    scv.tl.latent_time(adata, vkey=vkey+'_norm', **kwargs)
5370
5371
5372
def LRT_decoupling(adata_rna, adata_atac, **kwargs):
5373
    """Computes likelihood ratio test for decoupling state.
5374
5375
    This function computes whether keeping decoupling state improves fit
5376
    Likelihood.
5377
5378
    Parameters
5379
    ----------
5380
    adata_rna: :class:`~anndata.AnnData`
5381
        RNA anndata object
5382
    adata_atac: :class:`~anndata.AnnData`
5383
        ATAC anndata object.
5384
    Additional parameters passed to `recover_dynamics_chrom`.
5385
5386
    Returns
5387
    -------
5388
    adata_result_w_decoupled: class:`~anndata.AnnData`
5389
        fit result with decoupling state
5390
    adata_result_w_decoupled: class:`~anndata.AnnData`
5391
        fit result without decoupling state
5392
    res: `pandas.DataFrame`
5393
        LRT statistics
5394
    """
5395
    from scipy.stats.distributions import chi2
5396
    logg.update('fitting models with decoupling intervals', v=0)
5397
    adata_result_w_decoupled = recover_dynamics_chrom(adata_rna, adata_atac,
5398
                                                      fit_decoupling=True,
5399
                                                      **kwargs)
5400
    logg.update('fitting models without decoupling intervals', v=0)
5401
    adata_result_wo_decoupled = recover_dynamics_chrom(adata_rna, adata_atac,
5402
                                                       fit_decoupling=False,
5403
                                                       **kwargs)
5404
    logg.update('testing likelihood ratio', v=0)
5405
    shared_genes = pd.Index(np.intersect1d(adata_result_w_decoupled.var_names,
5406
                                           adata_result_wo_decoupled.var_names)
5407
                            )
5408
    l_c_w_decoupled = adata_result_w_decoupled[:, shared_genes].\
5409
        var['fit_likelihood_c'].values
5410
    l_c_wo_decoupled = adata_result_wo_decoupled[:, shared_genes].\
5411
        var['fit_likelihood_c'].values
5412
    n_obs = adata_rna.n_obs
5413
    LRT_c = -2 * n_obs * (np.log(l_c_wo_decoupled) - np.log(l_c_w_decoupled))
5414
    p_c = chi2.sf(LRT_c, 1)
5415
    l_w_decoupled = adata_result_w_decoupled[:, shared_genes].\
5416
        var['fit_likelihood'].values
5417
    l_wo_decoupled = adata_result_wo_decoupled[:, shared_genes].\
5418
        var['fit_likelihood'].values
5419
    LRT = -2 * n_obs * (np.log(l_wo_decoupled) - np.log(l_w_decoupled))
5420
    p = chi2.sf(LRT, 1)
5421
    res = pd.DataFrame({'likelihood_c_w_decoupled': l_c_w_decoupled,
5422
                        'likelihood_c_wo_decoupled': l_c_wo_decoupled,
5423
                        'LRT_c': LRT_c,
5424
                        'pval_c': p_c,
5425
                        'likelihood_w_decoupled': l_w_decoupled,
5426
                        'likelihood_wo_decoupled': l_wo_decoupled,
5427
                        'LRT': LRT,
5428
                        'pval': p,
5429
                        }, index=shared_genes)
5430
    return adata_result_w_decoupled, adata_result_wo_decoupled, res
5431
5432
5433
def transition_matrix_s(s_mat, velo_s, knn):
5434
    knn = knn.astype(int)
5435
    tm_val, tm_col, tm_row = [], [], []
5436
    for i in range(knn.shape[0]):
5437
        two_step_knn = knn[i, :]
5438
        for j in knn[i, :]:
5439
            two_step_knn = np.append(two_step_knn, knn[j, :])
5440
        two_step_knn = np.unique(two_step_knn)
5441
        for j in two_step_knn:
5442
            s = s_mat[i, :]
5443
            sn = s_mat[j, :]
5444
            ds = s - sn
5445
            dx = np.ravel(ds.A)
5446
            velo = velo_s[i, :]
5447
            cos_sim = np.dot(dx, velo)/(norm(dx)*norm(velo))
5448
            tm_val.append(cos_sim)
5449
            tm_col.append(j)
5450
            tm_row.append(i)
5451
    tm = coo_matrix((tm_val, (tm_row, tm_col)), shape=(s_mat.shape[0],
5452
                    s_mat.shape[0])).tocsr()
5453
    tm.setdiag(0)
5454
    tm_neg = tm.copy()
5455
    tm.data = np.clip(tm.data, 0, 1)
5456
    tm_neg.data = np.clip(tm_neg.data, -1, 0)
5457
    tm.eliminate_zeros()
5458
    tm_neg.eliminate_zeros()
5459
    return tm, tm_neg
5460
5461
5462
def transition_matrix_chrom(c_mat, u_mat, s_mat, velo_c, velo_u, velo_s, knn):
5463
    knn = knn.astype(int)
5464
    tm_val, tm_col, tm_row = [], [], []
5465
    for i in range(knn.shape[0]):
5466
        two_step_knn = knn[i, :]
5467
        for j in knn[i, :]:
5468
            two_step_knn = np.append(two_step_knn, knn[j, :])
5469
        two_step_knn = np.unique(two_step_knn)
5470
        for j in two_step_knn:
5471
            u = u_mat[i, :].A
5472
            s = s_mat[i, :].A
5473
            c = c_mat[i, :].A
5474
            un = u_mat[j, :]
5475
            sn = s_mat[j, :]
5476
            cn = c_mat[j, :]
5477
            dc = (c - cn) / np.std(c)
5478
            du = (u - un) / np.std(u)
5479
            ds = (s - sn) / np.std(s)
5480
            dx = np.ravel(np.hstack((dc.A, du.A, ds.A)))
5481
            velo = np.hstack((velo_c[i, :], velo_u[i, :], velo_s[i, :]))
5482
            cos_sim = np.dot(dx, velo)/(norm(dx)*norm(velo))
5483
            tm_val.append(cos_sim)
5484
            tm_col.append(j)
5485
            tm_row.append(i)
5486
    tm = coo_matrix((tm_val, (tm_row, tm_col)), shape=(c_mat.shape[0],
5487
                    c_mat.shape[0])).tocsr()
5488
    tm.setdiag(0)
5489
    tm_neg = tm.copy()
5490
    tm.data = np.clip(tm.data, 0, 1)
5491
    tm_neg.data = np.clip(tm_neg.data, -1, 0)
5492
    tm.eliminate_zeros()
5493
    tm_neg.eliminate_zeros()
5494
    return tm, tm_neg
5495
5496
5497
def likelihood_plot(adata,
5498
                    genes=None,
5499
                    figsize=(14, 10),
5500
                    bins=50,
5501
                    pointsize=4
5502
                    ):
5503
    """Likelihood plots.
5504
5505
    This function plots likelihood and variable distributions.
5506
5507
    Parameters
5508
    ----------
5509
    adata: :class:`~anndata.AnnData`
5510
        Anndata result from dynamics recovery.
5511
    genes: `str`,  list of `str` (default: `None`)
5512
        If `None`, will use all fitted genes.
5513
    figsize: `tuple` (default: (14,10))
5514
        Figure size.
5515
    bins: `int` (default: 50)
5516
        Number of bins for histograms.
5517
    pointsize: `float` (default: 4)
5518
        Point size for scatter plots.
5519
    """
5520
    if genes is None:
5521
        var = adata.var
5522
    else:
5523
        genes = np.array(genes)
5524
        var = adata[:, genes].var
5525
    likelihood = var[['fit_likelihood']].values
5526
    rescale_u = var[['fit_rescale_u']].values
5527
    rescale_c = var[['fit_rescale_c']].values
5528
    t_interval1 = var['fit_t_sw1'] + 20 / adata.uns['velo_s_params']['t'] \
5529
        * var['fit_anchor_min_idx'] * var['fit_alignment_scaling']
5530
    t_sw2 = np.clip(var['fit_t_sw2'], None, 20)
5531
    t_sw3 = np.clip(var['fit_t_sw3'], None, 20)
5532
    t_interval3 = t_sw3 - t_sw2
5533
    log_s = np.log1p(np.sum(adata.layers['Ms'], axis=0))
5534
    alpha_c = var[['fit_alpha_c']].values
5535
    alpha = var[['fit_alpha']].values
5536
    beta = var[['fit_beta']].values
5537
    gamma = var[['fit_gamma']].values
5538
    scale_cc = var[['fit_scale_cc']].values
5539
5540
    fig, axes = plt.subplots(4, 5, figsize=figsize)
5541
    axes[0, 0].hist(likelihood, bins=bins)
5542
    axes[0, 0].set_title('likelihood')
5543
    axes[0, 1].hist(rescale_u, bins=bins)
5544
    axes[0, 1].set_title('rescale u')
5545
    axes[0, 2].hist(rescale_c, bins=bins)
5546
    axes[0, 2].set_title('rescale c')
5547
    axes[0, 3].hist(t_interval1.values, bins=bins)
5548
    axes[0, 3].set_title('primed interval')
5549
    axes[0, 4].hist(t_interval3, bins=bins)
5550
    axes[0, 4].set_title('decoupled interval')
5551
5552
    axes[1, 0].scatter(log_s, likelihood, s=pointsize)
5553
    axes[1, 0].set_xlabel('log spliced')
5554
    axes[1, 0].set_ylabel('likelihood')
5555
    axes[1, 1].scatter(rescale_u, likelihood, s=pointsize)
5556
    axes[1, 1].set_xlabel('rescale u')
5557
    axes[1, 2].scatter(rescale_c, likelihood, s=pointsize)
5558
    axes[1, 2].set_xlabel('rescale c')
5559
    axes[1, 3].scatter(t_interval1.values, likelihood, s=pointsize)
5560
    axes[1, 3].set_xlabel('primed interval')
5561
    axes[1, 4].scatter(t_interval3, likelihood, s=pointsize)
5562
    axes[1, 4].set_xlabel('decoupled interval')
5563
5564
    axes[2, 0].hist(alpha_c, bins=bins)
5565
    axes[2, 0].set_title('alpha c')
5566
    axes[2, 1].hist(alpha, bins=bins)
5567
    axes[2, 1].set_title('alpha')
5568
    axes[2, 2].hist(beta, bins=bins)
5569
    axes[2, 2].set_title('beta')
5570
    axes[2, 3].hist(gamma, bins=bins)
5571
    axes[2, 3].set_title('gamma')
5572
    axes[2, 4].hist(scale_cc, bins=bins)
5573
    axes[2, 4].set_title('scale cc')
5574
5575
    axes[3, 0].scatter(alpha_c, likelihood, s=pointsize)
5576
    axes[3, 0].set_xlabel('alpha c')
5577
    axes[3, 0].set_ylabel('likelihood')
5578
    axes[3, 1].scatter(alpha, likelihood, s=pointsize)
5579
    axes[3, 1].set_xlabel('alpha')
5580
    axes[3, 2].scatter(beta, likelihood, s=pointsize)
5581
    axes[3, 2].set_xlabel('beta')
5582
    axes[3, 3].scatter(gamma, likelihood, s=pointsize)
5583
    axes[3, 3].set_xlabel('gamma')
5584
    axes[3, 4].scatter(scale_cc, likelihood, s=pointsize)
5585
    axes[3, 4].set_xlabel('scale cc')
5586
    fig.tight_layout()
5587
5588
5589
def pie_summary(adata, genes=None):
5590
    """Summary of directions and models.
5591
5592
    This function plots a pie chart for (pre-determined or specified)
5593
    directions and models.
5594
    `induction`: induction-only genes.
5595
    `repression`: repression-only genes.
5596
    `Model 1`: model 1 complete genes.
5597
    `Model 2`: model 2 complete genes.
5598
5599
    Parameters
5600
    ----------
5601
    adata: :class:`~anndata.AnnData`
5602
        Anndata result from dynamics recovery.
5603
    genes: `str`,  list of `str` (default: `None`)
5604
        If `None`, will use all fitted genes.
5605
    """
5606
    if genes is None:
5607
        genes = adata.var_names
5608
    fit_model = adata[:, (adata.var['fit_direction'] == 'complete') &
5609
                      np.isin(adata.var_names, genes)].var['fit_model'].values
5610
    fit_direction = adata[:, genes].var['fit_direction'].values
5611
    data = [np.sum(fit_direction == 'on'), np.sum(fit_direction == 'off'),
5612
            np.sum(fit_model == 1), np.sum(fit_model == 2)]
5613
    index = ['induction', 'repression', 'Model 1', 'Model 2']
5614
    index = [x for i, x in enumerate(index) if data[i] > 0]
5615
    data = [x for x in data if x > 0]
5616
    df = pd.DataFrame({'data': data}, index=index)
5617
    df.plot.pie(y='data', autopct='%1.1f%%', legend=False, startangle=30,
5618
                ylabel='')
5619
    circle = plt.Circle((0, 0), 0.8, fc='white')
5620
    fig = plt.gcf()
5621
    fig.gca().add_artist(circle)
5622
5623
5624
def switch_time_summary(adata, genes=None):
5625
    """Summary of switch times.
5626
5627
    This function plots a box plot for observed switch times.
5628
    `primed`: primed intervals.
5629
    `coupled-on`: coupled induction intervals.
5630
    `decoupled`: decoupled intervals.
5631
    `coupled-off`: coupled repression intervals.
5632
5633
    Parameters
5634
    ----------
5635
    adata: :class:`~anndata.AnnData`
5636
        Anndata result from dynamics recovery.
5637
    genes: `str`,  list of `str` (default: `None`)
5638
        If `None`, will use velocity genes.
5639
    """
5640
    t_sw = adata[:, adata.var['velo_s_genes']
5641
                 if genes is None
5642
                 else genes] \
5643
        .var[['fit_t_sw1', 'fit_t_sw2', 'fit_t_sw3']].copy()
5644
    t_sw = t_sw.mask(t_sw > 20, 20)
5645
    t_sw = t_sw.mask(t_sw < 0)
5646
    t_sw['interval 1'] = t_sw['fit_t_sw1']
5647
    t_sw['t_sw2 - t_sw1'] = t_sw['fit_t_sw2'] - t_sw['fit_t_sw1']
5648
    t_sw['t_sw3 - t_sw2'] = t_sw['fit_t_sw3'] - t_sw['fit_t_sw2']
5649
    t_sw['20 - t_sw3'] = 20 - t_sw['fit_t_sw3']
5650
    t_sw = t_sw.mask(t_sw <= 0)
5651
    t_sw = t_sw.mask(t_sw > 20)
5652
    t_sw.columns = pd.Index(['time 1', 'time 2', 'time 3', 'primed',
5653
                             'coupled-on', 'decoupled', 'coupled-off'])
5654
    t_sw = t_sw[['primed', 'coupled-on', 'decoupled', 'coupled-off']]
5655
    t_sw = t_sw / 20
5656
    fig, ax = plt.subplots(figsize=(4, 5))
5657
    ax = sns.boxplot(data=t_sw, width=0.5, palette='Set2', ax=ax)
5658
    ax.set_yticks(np.linspace(0, 1, 5))
5659
    ax.set_title('Switch Intervals')
5660
5661
5662
def dynamic_plot(adata,
5663
                 genes,
5664
                 by='expression',
5665
                 color_by='state',
5666
                 gene_time=True,
5667
                 axis_on=True,
5668
                 frame_on=True,
5669
                 show_anchors=True,
5670
                 show_switches=True,
5671
                 downsample=1,
5672
                 full_range=False,
5673
                 figsize=None,
5674
                 pointsize=2,
5675
                 linewidth=1.5,
5676
                 cmap='coolwarm'
5677
                 ):
5678
    """Gene dynamics plot.
5679
5680
    This function plots accessibility, expression, or velocity by time.
5681
5682
    Parameters
5683
    ----------
5684
    adata: :class:`~anndata.AnnData`
5685
        Anndata result from dynamics recovery.
5686
    genes: `str`,  list of `str`
5687
        List of genes to plot.
5688
    by: `str` (default: `expression`)
5689
        Plot accessibilities and expressions if `expression`. Plot velocities
5690
        if `velocity`.
5691
    color_by: `str` (default: `state`)
5692
        Color by the four potential states if `state`. Other common values are
5693
        leiden, louvain, celltype, etc.
5694
        If not `state`, the color field must be present in `.uns`, which can
5695
        be pre-computed with `scanpy.pl.scatter`.
5696
        For `state`, red, orange, green, and blue represent state 1, 2, 3, and
5697
        4, respectively.
5698
    gene_time: `bool` (default: `True`)
5699
        Whether to use individual gene fitted time, or shared global latent
5700
        time.
5701
        Mean values of 20 equal sized windows will be connected and shown if
5702
        `gene_time==False`.
5703
    axis_on: `bool` (default: `True`)
5704
        Whether to show axis labels.
5705
    frame_on: `bool` (default: `True`)
5706
        Whether to show plot frames.
5707
    show_anchors: `bool` (default: `True`)
5708
        Whether to display anchors.
5709
    show_switches: `bool` (default: `True`)
5710
        Whether to show switch times. The switch times are indicated by
5711
        vertical dotted line.
5712
    downsample: `int` (default: 1)
5713
        How much to downsample the cells. The remaining number will be
5714
        `1/downsample` of original.
5715
    full_range: `bool` (default: `False`)
5716
        Whether to show the full time range of velocities before smoothing or
5717
        subset to only smoothed range.
5718
    figsize: `tuple` (default: `None`)
5719
        Total figure size.
5720
    pointsize: `float` (default: 2)
5721
        Point size for scatter plots.
5722
    linewidth: `float` (default: 1.5)
5723
        Line width for anchor line or mean line.
5724
    cmap: `str` (default: `coolwarm`)
5725
        Color map for continuous color key.
5726
    """
5727
    from pandas.api.types import is_numeric_dtype, is_categorical_dtype
5728
    if by not in ['expression', 'velocity']:
5729
        raise ValueError('"by" must be either "expression" or "velocity".')
5730
    if by == 'velocity':
5731
        show_switches = False
5732
    if color_by == 'state':
5733
        types = [0, 1, 2, 3]
5734
        colors = ['tab:red', 'tab:orange', 'tab:green', 'tab:blue']
5735
    elif color_by in adata.obs and is_numeric_dtype(adata.obs[color_by]):
5736
        types = None
5737
        colors = adata.obs[color_by].values
5738
    elif color_by in adata.obs and is_categorical_dtype(adata.obs[color_by]) \
5739
            and color_by+'_colors' in adata.uns.keys():
5740
        types = adata.obs[color_by].cat.categories
5741
        colors = adata.uns[f'{color_by}_colors']
5742
    else:
5743
        raise ValueError('Currently, color key must be a single string of '
5744
                         'either numerical or categorical available in adata '
5745
                         'obs, and the colors of categories can be found in '
5746
                         'adata uns.')
5747
5748
    downsample = np.clip(int(downsample), 1, 10)
5749
    genes = np.array(genes)
5750
    missing_genes = genes[~np.isin(genes, adata.var_names)]
5751
    if len(missing_genes) > 0:
5752
        logg.update(f'{missing_genes} not found', v=0)
5753
    genes = genes[np.isin(genes, adata.var_names)]
5754
    gn = len(genes)
5755
    if gn == 0:
5756
        return
5757
    if not gene_time:
5758
        show_anchors = False
5759
        latent_time = np.array(adata.obs['latent_time'])
5760
        time_window = latent_time // 0.05
5761
        time_window = time_window.astype(int)
5762
        time_window[time_window == 20] = 19
5763
    if 'velo_s_params' in adata.uns.keys() and 'outlier' \
5764
            in adata.uns['velo_s_params']:
5765
        outlier = adata.uns['velo_s_params']['outlier']
5766
    else:
5767
        outlier = 99
5768
5769
    fig, axs = plt.subplots(gn, 3, squeeze=False, figsize=(10, 2.3*gn)
5770
                            if figsize is None else figsize)
5771
    fig.patch.set_facecolor('white')
5772
    for row, gene in enumerate(genes):
5773
        u = adata[:, gene].layers['Mu' if by == 'expression' else 'velo_u']
5774
        s = adata[:, gene].layers['Ms' if by == 'expression' else 'velo_s']
5775
        c = adata[:, gene].layers['ATAC' if by == 'expression'
5776
                                  else 'velo_chrom']
5777
        c = c.A if sparse.issparse(c) else c
5778
        u = u.A if sparse.issparse(u) else u
5779
        s = s.A if sparse.issparse(s) else s
5780
        c, u, s = np.ravel(c), np.ravel(u), np.ravel(s)
5781
        non_outlier = c <= np.percentile(c, outlier)
5782
        non_outlier &= u <= np.percentile(u, outlier)
5783
        non_outlier &= s <= np.percentile(s, outlier)
5784
        c, u, s = c[non_outlier], u[non_outlier], s[non_outlier]
5785
        time = np.array(adata[:, gene].layers['fit_t'] if gene_time
5786
                        else latent_time)
5787
        if by == 'velocity':
5788
            time = np.reshape(time, (-1, 1))
5789
            time = np.ravel(adata.obsp['_RNA_conn'].dot(time))
5790
        time = time[non_outlier]
5791
        if types is not None:
5792
            for i in range(len(types)):
5793
                if color_by == 'state':
5794
                    filt = adata[non_outlier, gene].layers['fit_state'] \
5795
                           == types[i]
5796
                else:
5797
                    filt = adata[non_outlier, :].obs[color_by] == types[i]
5798
                filt = np.ravel(filt)
5799
                if np.sum(filt) > 0:
5800
                    axs[row, 0].scatter(time[filt][::downsample],
5801
                                        c[filt][::downsample], s=pointsize,
5802
                                        c=colors[i], alpha=0.6)
5803
                    axs[row, 1].scatter(time[filt][::downsample],
5804
                                        u[filt][::downsample],
5805
                                        s=pointsize, c=colors[i], alpha=0.6)
5806
                    axs[row, 2].scatter(time[filt][::downsample],
5807
                                        s[filt][::downsample], s=pointsize,
5808
                                        c=colors[i], alpha=0.6)
5809
        else:
5810
            axs[row, 0].scatter(time[::downsample], c[::downsample],
5811
                                s=pointsize,
5812
                                c=colors[non_outlier][::downsample],
5813
                                alpha=0.6, cmap=cmap)
5814
            axs[row, 1].scatter(time[::downsample], u[::downsample],
5815
                                s=pointsize,
5816
                                c=colors[non_outlier][::downsample],
5817
                                alpha=0.6, cmap=cmap)
5818
            axs[row, 2].scatter(time[::downsample], s[::downsample],
5819
                                s=pointsize,
5820
                                c=colors[non_outlier][::downsample],
5821
                                alpha=0.6, cmap=cmap)
5822
5823
        if not gene_time:
5824
            window_count = np.zeros(20)
5825
            window_mean_c = np.zeros(20)
5826
            window_mean_u = np.zeros(20)
5827
            window_mean_s = np.zeros(20)
5828
            for i in np.unique(time_window[non_outlier]):
5829
                idx = time_window[non_outlier] == i
5830
                window_count[i] = np.sum(idx)
5831
                window_mean_c[i] = np.mean(c[idx])
5832
                window_mean_u[i] = np.mean(u[idx])
5833
                window_mean_s[i] = np.mean(s[idx])
5834
            window_idx = np.where(window_count > 20)[0]
5835
            axs[row, 0].plot(window_idx*0.05+0.025, window_mean_c[window_idx],
5836
                             linewidth=linewidth, color='black', alpha=0.5)
5837
            axs[row, 1].plot(window_idx*0.05+0.025, window_mean_u[window_idx],
5838
                             linewidth=linewidth, color='black', alpha=0.5)
5839
            axs[row, 2].plot(window_idx*0.05+0.025, window_mean_s[window_idx],
5840
                             linewidth=linewidth, color='black', alpha=0.5)
5841
5842
        if show_anchors:
5843
            n_anchors = adata.uns['velo_s_params']['t']
5844
            t_sw_array = np.array([adata[:, gene].var['fit_t_sw1'],
5845
                                   adata[:, gene].var['fit_t_sw2'],
5846
                                   adata[:, gene].var['fit_t_sw3']])
5847
            t_sw_array = t_sw_array[t_sw_array < 20]
5848
            min_idx = int(adata[:, gene].var['fit_anchor_min_idx'])
5849
            max_idx = int(adata[:, gene].var['fit_anchor_max_idx'])
5850
            old_t = np.linspace(0, 20, n_anchors)[min_idx:max_idx+1]
5851
            new_t = old_t - np.min(old_t)
5852
            new_t = new_t * 20 / np.max(new_t)
5853
            if by == 'velocity' and not full_range:
5854
                anchor_interval = 20 / (max_idx + 1 - min_idx)
5855
                min_idx = int(adata[:, gene].var['fit_anchor_velo_min_idx'])
5856
                max_idx = int(adata[:, gene].var['fit_anchor_velo_max_idx'])
5857
                start = 0 + (min_idx -
5858
                             adata[:, gene].var['fit_anchor_min_idx']) \
5859
                    * anchor_interval
5860
                end = 20 + (max_idx -
5861
                            adata[:, gene].var['fit_anchor_max_idx']) \
5862
                    * anchor_interval
5863
                new_t = np.linspace(start, end, max_idx + 1 - min_idx)
5864
            ax = axs[row, 0]
5865
            a_c = adata[:, gene].varm['fit_anchor_c' if by == 'expression'
5866
                                      else 'fit_anchor_c_velo']\
5867
                                .ravel()[min_idx:max_idx+1]
5868
            if show_switches:
5869
                for t_sw in t_sw_array:
5870
                    if t_sw > 0:
5871
                        ax.vlines(t_sw, np.min(c), np.max(c), colors='black',
5872
                                  linestyles='dashed', alpha=0.5)
5873
            ax.plot(new_t[0:new_t.shape[0]], a_c, linewidth=linewidth,
5874
                    color='black', alpha=0.5)
5875
            ax = axs[row, 1]
5876
            a_u = adata[:, gene].varm['fit_anchor_u' if by == 'expression'
5877
                                      else 'fit_anchor_u_velo']\
5878
                                .ravel()[min_idx:max_idx+1]
5879
            if show_switches:
5880
                for t_sw in t_sw_array:
5881
                    if t_sw > 0:
5882
                        ax.vlines(t_sw, np.min(u), np.max(u), colors='black',
5883
                                  linestyles='dashed', alpha=0.5)
5884
            ax.plot(new_t[0:new_t.shape[0]], a_u, linewidth=linewidth,
5885
                    color='black', alpha=0.5)
5886
            ax = axs[row, 2]
5887
            a_s = adata[:, gene].varm['fit_anchor_s' if by == 'expression'
5888
                                      else 'fit_anchor_s_velo']\
5889
                                .ravel()[min_idx:max_idx+1]
5890
            if show_switches:
5891
                for t_sw in t_sw_array:
5892
                    if t_sw > 0:
5893
                        ax.vlines(t_sw, np.min(s), np.max(s), colors='black',
5894
                                  linestyles='dashed', alpha=0.5)
5895
            ax.plot(new_t[0:new_t.shape[0]], a_s, linewidth=linewidth,
5896
                    color='black', alpha=0.5)
5897
5898
        axs[row, 0].set_title(f'{gene} ATAC' if by == 'expression'
5899
                              else f'{gene} chromatin velocity')
5900
        axs[row, 0].set_xlabel('t' if by == 'expression' else '~t')
5901
        axs[row, 0].set_ylabel('c' if by == 'expression' else 'dc/dt')
5902
        axs[row, 1].set_title(f'{gene} unspliced' + ('' if by == 'expression'
5903
                              else ' velocity'))
5904
        axs[row, 1].set_xlabel('t' if by == 'expression' else '~t')
5905
        axs[row, 1].set_ylabel('u' if by == 'expression' else 'du/dt')
5906
        axs[row, 2].set_title(f'{gene} spliced' + ('' if by == 'expression'
5907
                              else ' velocity'))
5908
        axs[row, 2].set_xlabel('t' if by == 'expression' else '~t')
5909
        axs[row, 2].set_ylabel('s' if by == 'expression' else 'ds/dt')
5910
5911
        for j in range(3):
5912
            ax = axs[row, j]
5913
            if not axis_on:
5914
                ax.xaxis.set_ticks_position('none')
5915
                ax.yaxis.set_ticks_position('none')
5916
                ax.get_xaxis().set_visible(False)
5917
                ax.get_yaxis().set_visible(False)
5918
            if not frame_on:
5919
                ax.xaxis.set_ticks_position('none')
5920
                ax.yaxis.set_ticks_position('none')
5921
                ax.set_frame_on(False)
5922
    fig.tight_layout()
5923
5924
5925
def scatter_plot(adata,
5926
                 genes,
5927
                 by='us',
5928
                 color_by='state',
5929
                 n_cols=5,
5930
                 axis_on=True,
5931
                 frame_on=True,
5932
                 show_anchors=True,
5933
                 show_switches=True,
5934
                 show_all_anchors=False,
5935
                 title_more_info=False,
5936
                 velocity_arrows=False,
5937
                 downsample=1,
5938
                 figsize=None,
5939
                 pointsize=2,
5940
                 markersize=5,
5941
                 linewidth=2,
5942
                 cmap='coolwarm',
5943
                 view_3d_elev=None,
5944
                 view_3d_azim=None,
5945
                 full_name=False
5946
                 ):
5947
    """Gene scatter plot.
5948
5949
    This function plots phase portraits of the specified plane.
5950
5951
    Parameters
5952
    ----------
5953
    adata: :class:`~anndata.AnnData`
5954
        Anndata result from dynamics recovery.
5955
    genes: `str`,  list of `str`
5956
        List of genes to plot.
5957
    by: `str` (default: `us`)
5958
        Plot unspliced-spliced plane if `us`. Plot chromatin-unspliced plane
5959
        if `cu`.
5960
        Plot 3D phase portraits if `cus`.
5961
    color_by: `str` (default: `state`)
5962
        Color by the four potential states if `state`. Other common values are
5963
        leiden, louvain, celltype, etc.
5964
        If not `state`, the color field must be present in `.uns`, which can be
5965
        pre-computed with `scanpy.pl.scatter`.
5966
        For `state`, red, orange, green, and blue represent state 1, 2, 3, and
5967
        4, respectively.
5968
        When `by=='us'`, `color_by` can also be `c`, which displays the log
5969
        accessibility on U-S phase portraits.
5970
    n_cols: `int` (default: 5)
5971
        Number of columns to plot on each row.
5972
    axis_on: `bool` (default: `True`)
5973
        Whether to show axis labels.
5974
    frame_on: `bool` (default: `True`)
5975
        Whether to show plot frames.
5976
    show_anchors: `bool` (default: `True`)
5977
        Whether to display anchors.
5978
    show_switches: `bool` (default: `True`)
5979
        Whether to show switch times. The three switch times and the end of
5980
        trajectory are indicated by
5981
        circle, cross, dismond, and star, respectively.
5982
    show_all_anchors: `bool` (default: `False`)
5983
        Whether to display full range of (predicted) anchors even for
5984
        repression-only genes.
5985
    title_more_info: `bool` (default: `False`)
5986
        Whether to display model, direction, and likelihood information for
5987
        the gene in title.
5988
    velocity_arrows: `bool` (default: `False`)
5989
        Whether to show velocity arrows of cells on the phase portraits.
5990
    downsample: `int` (default: 1)
5991
        How much to downsample the cells. The remaining number will be
5992
        `1/downsample` of original.
5993
    figsize: `tuple` (default: `None`)
5994
        Total figure size.
5995
    pointsize: `float` (default: 2)
5996
        Point size for scatter plots.
5997
    markersize: `float` (default: 5)
5998
        Point size for switch time points.
5999
    linewidth: `float` (default: 2)
6000
        Line width for connected anchors.
6001
    cmap: `str` (default: `coolwarm`)
6002
        Color map for log accessibilities or other continuous color keys when
6003
        plotting on U-S plane.
6004
    view_3d_elev: `float` (default: `None`)
6005
        Matplotlib 3D plot `elev` argument. `elev=90` is the same as U-S plane,
6006
        and `elev=0` is the same as C-U plane.
6007
    view_3d_azim: `float` (default: `None`)
6008
        Matplotlib 3D plot `azim` argument. `azim=270` is the same as U-S
6009
        plane, and `azim=0` is the same as C-U plane.
6010
    full_name: `bool` (default: `False`)
6011
        Show full names for chromatin, unspliced, and spliced rather than
6012
        using abbreviated terms c, u, and s.
6013
    """
6014
    from pandas.api.types import is_numeric_dtype, is_categorical_dtype
6015
    if by not in ['us', 'cu', 'cus']:
6016
        raise ValueError("'by' argument must be one of ['us', 'cu', 'cus']")
6017
    if color_by == 'state':
6018
        types = [0, 1, 2, 3]
6019
        colors = ['tab:red', 'tab:orange', 'tab:green', 'tab:blue']
6020
    elif by == 'us' and color_by == 'c':
6021
        types = None
6022
    elif color_by in adata.obs and is_numeric_dtype(adata.obs[color_by]):
6023
        types = None
6024
        colors = adata.obs[color_by].values
6025
    elif color_by in adata.obs and is_categorical_dtype(adata.obs[color_by]) \
6026
            and color_by+'_colors' in adata.uns.keys():
6027
        types = adata.obs[color_by].cat.categories
6028
        colors = adata.uns[f'{color_by}_colors']
6029
    else:
6030
        raise ValueError('Currently, color key must be a single string of '
6031
                         'either numerical or categorical available in adata'
6032
                         ' obs, and the colors of categories can be found in'
6033
                         ' adata uns.')
6034
6035
    if 'velo_s_params' not in adata.uns.keys() \
6036
            or 'fit_anchor_s' not in adata.varm.keys():
6037
        show_anchors = False
6038
    if color_by == 'state' and 'fit_state' not in adata.layers.keys():
6039
        raise ValueError('fit_state is not found. Please run '
6040
                         'recover_dynamics_chrom function first or provide a '
6041
                         'valid color key.')
6042
6043
    downsample = np.clip(int(downsample), 1, 10)
6044
    genes = np.array(genes)
6045
    missing_genes = genes[~np.isin(genes, adata.var_names)]
6046
    if len(missing_genes) > 0:
6047
        logg.update(f'{missing_genes} not found', v=0)
6048
    genes = genes[np.isin(genes, adata.var_names)]
6049
    gn = len(genes)
6050
    if gn == 0:
6051
        return
6052
    if gn < n_cols:
6053
        n_cols = gn
6054
    if by == 'cus':
6055
        fig, axs = plt.subplots(-(-gn // n_cols), n_cols, squeeze=False,
6056
                                figsize=(3.2*n_cols, 2.7*(-(-gn // n_cols)))
6057
                                if figsize is None else figsize,
6058
                                subplot_kw={'projection': '3d'})
6059
    else:
6060
        fig, axs = plt.subplots(-(-gn // n_cols), n_cols, squeeze=False,
6061
                                figsize=(2.7*n_cols, 2.4*(-(-gn // n_cols)))
6062
                                if figsize is None else figsize)
6063
    fig.patch.set_facecolor('white')
6064
    count = 0
6065
    for gene in genes:
6066
        u = adata[:, gene].layers['Mu'].copy() if 'Mu' in adata.layers \
6067
            else adata[:, gene].layers['unspliced'].copy()
6068
        s = adata[:, gene].layers['Ms'].copy() if 'Ms' in adata.layers \
6069
            else adata[:, gene].layers['spliced'].copy()
6070
        u = u.A if sparse.issparse(u) else u
6071
        s = s.A if sparse.issparse(s) else s
6072
        u, s = np.ravel(u), np.ravel(s)
6073
        if 'ATAC' not in adata.layers.keys() and \
6074
                'Mc' not in adata.layers.keys():
6075
            show_anchors = False
6076
        elif 'ATAC' in adata.layers.keys():
6077
            c = adata[:, gene].layers['ATAC'].copy()
6078
            c = c.A if sparse.issparse(c) else c
6079
            c = np.ravel(c)
6080
        elif 'Mc' in adata.layers.keys():
6081
            c = adata[:, gene].layers['Mc'].copy()
6082
            c = c.A if sparse.issparse(c) else c
6083
            c = np.ravel(c)
6084
6085
        if velocity_arrows:
6086
            if 'velo_u' in adata.layers.keys():
6087
                vu = adata[:, gene].layers['velo_u'].copy()
6088
            elif 'velocity_u' in adata.layers.keys():
6089
                vu = adata[:, gene].layers['velocity_u'].copy()
6090
            else:
6091
                vu = np.zeros(adata.n_obs)
6092
            max_u = np.max([np.max(u), 1e-6])
6093
            u /= max_u
6094
            vu = np.ravel(vu)
6095
            vu /= np.max([np.max(np.abs(vu)), 1e-6])
6096
            if 'velo_s' in adata.layers.keys():
6097
                vs = adata[:, gene].layers['velo_s'].copy()
6098
            elif 'velocity' in adata.layers.keys():
6099
                vs = adata[:, gene].layers['velocity'].copy()
6100
            max_s = np.max([np.max(s), 1e-6])
6101
            s /= max_s
6102
            vs = np.ravel(vs)
6103
            vs /= np.max([np.max(np.abs(vs)), 1e-6])
6104
            if 'velo_chrom' in adata.layers.keys():
6105
                vc = adata[:, gene].layers['velo_chrom'].copy()
6106
                max_c = np.max([np.max(c), 1e-6])
6107
                c /= max_c
6108
                vc = np.ravel(vc)
6109
                vc /= np.max([np.max(np.abs(vc)), 1e-6])
6110
6111
        row = count // n_cols
6112
        col = count % n_cols
6113
        ax = axs[row, col]
6114
        if types is not None:
6115
            for i in range(len(types)):
6116
                if color_by == 'state':
6117
                    filt = adata[:, gene].layers['fit_state'] == types[i]
6118
                else:
6119
                    filt = adata.obs[color_by] == types[i]
6120
                filt = np.ravel(filt)
6121
                if by == 'us':
6122
                    if velocity_arrows:
6123
                        ax.quiver(s[filt][::downsample], u[filt][::downsample],
6124
                                  vs[filt][::downsample],
6125
                                  vu[filt][::downsample], color=colors[i],
6126
                                  alpha=0.5, scale_units='xy', scale=10,
6127
                                  width=0.005, headwidth=4, headaxislength=5.5)
6128
                    else:
6129
                        ax.scatter(s[filt][::downsample],
6130
                                   u[filt][::downsample], s=pointsize,
6131
                                   c=colors[i], alpha=0.7)
6132
                elif by == 'cu':
6133
                    if velocity_arrows:
6134
                        ax.quiver(u[filt][::downsample],
6135
                                  c[filt][::downsample],
6136
                                  vu[filt][::downsample],
6137
                                  vc[filt][::downsample], color=colors[i],
6138
                                  alpha=0.5, scale_units='xy', scale=10,
6139
                                  width=0.005, headwidth=4, headaxislength=5.5)
6140
                    else:
6141
                        ax.scatter(u[filt][::downsample],
6142
                                   c[filt][::downsample], s=pointsize,
6143
                                   c=colors[i], alpha=0.7)
6144
                else:
6145
                    if velocity_arrows:
6146
                        ax.quiver(s[filt][::downsample],
6147
                                  u[filt][::downsample], c[filt][::downsample],
6148
                                  vs[filt][::downsample],
6149
                                  vu[filt][::downsample],
6150
                                  vc[filt][::downsample],
6151
                                  color=colors[i], alpha=0.4, length=0.1,
6152
                                  arrow_length_ratio=0.5, normalize=True)
6153
                    else:
6154
                        ax.scatter(s[filt][::downsample],
6155
                                   u[filt][::downsample],
6156
                                   c[filt][::downsample], s=pointsize,
6157
                                   c=colors[i], alpha=0.7)
6158
        elif color_by == 'c':
6159
            if 'velo_s_params' in adata.uns.keys() and \
6160
                    'outlier' in adata.uns['velo_s_params']:
6161
                outlier = adata.uns['velo_s_params']['outlier']
6162
            else:
6163
                outlier = 99.8
6164
            non_zero = (u > 0) & (s > 0) & (c > 0)
6165
            non_outlier = u < np.percentile(u, outlier)
6166
            non_outlier &= s < np.percentile(s, outlier)
6167
            non_outlier &= c < np.percentile(c, outlier)
6168
            c -= np.min(c)
6169
            c /= np.max(c)
6170
            if velocity_arrows:
6171
                ax.quiver(s[non_zero & non_outlier][::downsample],
6172
                          u[non_zero & non_outlier][::downsample],
6173
                          vs[non_zero & non_outlier][::downsample],
6174
                          vu[non_zero & non_outlier][::downsample],
6175
                          np.log1p(c[non_zero & non_outlier][::downsample]),
6176
                          alpha=0.5,
6177
                          scale_units='xy', scale=10, width=0.005,
6178
                          headwidth=4, headaxislength=5.5, cmap=cmap)
6179
            else:
6180
                ax.scatter(s[non_zero & non_outlier][::downsample],
6181
                           u[non_zero & non_outlier][::downsample],
6182
                           s=pointsize,
6183
                           c=np.log1p(c[non_zero & non_outlier][::downsample]),
6184
                           alpha=0.8, cmap=cmap)
6185
        else:
6186
            if by == 'us':
6187
                if velocity_arrows:
6188
                    ax.quiver(s[::downsample], u[::downsample],
6189
                              vs[::downsample], vu[::downsample],
6190
                              colors[::downsample], alpha=0.5,
6191
                              scale_units='xy', scale=10, width=0.005,
6192
                              headwidth=4, headaxislength=5.5, cmap=cmap)
6193
                else:
6194
                    ax.scatter(s[::downsample], u[::downsample], s=pointsize,
6195
                               c=colors[::downsample], alpha=0.7, cmap=cmap)
6196
            elif by == 'cu':
6197
                if velocity_arrows:
6198
                    ax.quiver(u[::downsample], c[::downsample],
6199
                              vu[::downsample], vc[::downsample],
6200
                              colors[::downsample], alpha=0.5,
6201
                              scale_units='xy', scale=10, width=0.005,
6202
                              headwidth=4, headaxislength=5.5, cmap=cmap)
6203
                else:
6204
                    ax.scatter(u[::downsample], c[::downsample], s=pointsize,
6205
                               c=colors[::downsample], alpha=0.7, cmap=cmap)
6206
            else:
6207
                if velocity_arrows:
6208
                    ax.quiver(s[::downsample], u[::downsample],
6209
                              c[::downsample], vs[::downsample],
6210
                              vu[::downsample], vc[::downsample],
6211
                              colors[::downsample], alpha=0.4, length=0.1,
6212
                              arrow_length_ratio=0.5, normalize=True,
6213
                              cmap=cmap)
6214
                else:
6215
                    ax.scatter(s[::downsample], u[::downsample],
6216
                               c[::downsample], s=pointsize,
6217
                               c=colors[::downsample], alpha=0.7, cmap=cmap)
6218
6219
        if show_anchors:
6220
            min_idx = int(adata[:, gene].var['fit_anchor_min_idx'])
6221
            max_idx = int(adata[:, gene].var['fit_anchor_max_idx'])
6222
            a_c = adata[:, gene].varm['fit_anchor_c']\
6223
                .ravel()[min_idx:max_idx+1].copy()
6224
            a_u = adata[:, gene].varm['fit_anchor_u']\
6225
                .ravel()[min_idx:max_idx+1].copy()
6226
            a_s = adata[:, gene].varm['fit_anchor_s']\
6227
                .ravel()[min_idx:max_idx+1].copy()
6228
            if velocity_arrows:
6229
                a_c /= max_c
6230
                a_u /= max_u
6231
                a_s /= max_s
6232
            if by == 'us':
6233
                ax.plot(a_s, a_u, linewidth=linewidth, color='black',
6234
                        alpha=0.7, zorder=1000)
6235
            elif by == 'cu':
6236
                ax.plot(a_u, a_c, linewidth=linewidth, color='black',
6237
                        alpha=0.7, zorder=1000)
6238
            else:
6239
                ax.plot(a_s, a_u, a_c, linewidth=linewidth, color='black',
6240
                        alpha=0.7, zorder=1000)
6241
            if show_all_anchors:
6242
                a_c_pre = adata[:, gene].varm['fit_anchor_c']\
6243
                    .ravel()[:min_idx].copy()
6244
                a_u_pre = adata[:, gene].varm['fit_anchor_u']\
6245
                    .ravel()[:min_idx].copy()
6246
                a_s_pre = adata[:, gene].varm['fit_anchor_s']\
6247
                    .ravel()[:min_idx].copy()
6248
                if velocity_arrows:
6249
                    a_c_pre /= max_c
6250
                    a_u_pre /= max_u
6251
                    a_s_pre /= max_s
6252
                if len(a_c_pre) > 0:
6253
                    if by == 'us':
6254
                        ax.plot(a_s_pre, a_u_pre, linewidth=linewidth/1.3,
6255
                                color='black', alpha=0.6, zorder=1000)
6256
                    elif by == 'cu':
6257
                        ax.plot(a_u_pre, a_c_pre, linewidth=linewidth/1.3,
6258
                                color='black', alpha=0.6, zorder=1000)
6259
                    else:
6260
                        ax.plot(a_s_pre, a_u_pre, a_c_pre,
6261
                                linewidth=linewidth/1.3, color='black',
6262
                                alpha=0.6, zorder=1000)
6263
            if show_switches:
6264
                t_sw_array = np.array([adata[:, gene].var['fit_t_sw1']
6265
                                      .values[0],
6266
                                      adata[:, gene].var['fit_t_sw2']
6267
                                      .values[0],
6268
                                      adata[:, gene].var['fit_t_sw3']
6269
                                      .values[0]])
6270
                in_range = (t_sw_array > 0) & (t_sw_array < 20)
6271
                a_c_sw = adata[:, gene].varm['fit_anchor_c_sw'].ravel().copy()
6272
                a_u_sw = adata[:, gene].varm['fit_anchor_u_sw'].ravel().copy()
6273
                a_s_sw = adata[:, gene].varm['fit_anchor_s_sw'].ravel().copy()
6274
                if velocity_arrows:
6275
                    a_c_sw /= max_c
6276
                    a_u_sw /= max_u
6277
                    a_s_sw /= max_s
6278
                if in_range[0]:
6279
                    c_sw1, u_sw1, s_sw1 = a_c_sw[0], a_u_sw[0], a_s_sw[0]
6280
                    if by == 'us':
6281
                        ax.plot([s_sw1], [u_sw1], "om", markersize=markersize,
6282
                                zorder=2000)
6283
                    elif by == 'cu':
6284
                        ax.plot([u_sw1], [c_sw1], "om", markersize=markersize,
6285
                                zorder=2000)
6286
                    else:
6287
                        ax.plot([s_sw1], [u_sw1], [c_sw1], "om",
6288
                                markersize=markersize, zorder=2000)
6289
                if in_range[1]:
6290
                    c_sw2, u_sw2, s_sw2 = a_c_sw[1], a_u_sw[1], a_s_sw[1]
6291
                    if by == 'us':
6292
                        ax.plot([s_sw2], [u_sw2], "Xm", markersize=markersize,
6293
                                zorder=2000)
6294
                    elif by == 'cu':
6295
                        ax.plot([u_sw2], [c_sw2], "Xm", markersize=markersize,
6296
                                zorder=2000)
6297
                    else:
6298
                        ax.plot([s_sw2], [u_sw2], [c_sw2], "Xm",
6299
                                markersize=markersize, zorder=2000)
6300
                if in_range[2]:
6301
                    c_sw3, u_sw3, s_sw3 = a_c_sw[2], a_u_sw[2], a_s_sw[2]
6302
                    if by == 'us':
6303
                        ax.plot([s_sw3], [u_sw3], "Dm", markersize=markersize,
6304
                                zorder=2000)
6305
                    elif by == 'cu':
6306
                        ax.plot([u_sw3], [c_sw3], "Dm", markersize=markersize,
6307
                                zorder=2000)
6308
                    else:
6309
                        ax.plot([s_sw3], [u_sw3], [c_sw3], "Dm",
6310
                                markersize=markersize, zorder=2000)
6311
                if max_idx > adata.uns['velo_s_params']['t'] - 4:
6312
                    if by == 'us':
6313
                        ax.plot([a_s[-1]], [a_u[-1]], "*m",
6314
                                markersize=markersize, zorder=2000)
6315
                    elif by == 'cu':
6316
                        ax.plot([a_u[-1]], [a_c[-1]], "*m",
6317
                                markersize=markersize, zorder=2000)
6318
                    else:
6319
                        ax.plot([a_s[-1]], [a_u[-1]], [a_c[-1]], "*m",
6320
                                markersize=markersize, zorder=2000)
6321
6322
        if by == 'cus' and \
6323
                (view_3d_elev is not None or view_3d_azim is not None):
6324
            # US: elev=90, azim=270. CU: elev=0, azim=0.
6325
            ax.view_init(elev=view_3d_elev, azim=view_3d_azim)
6326
        title = gene
6327
        if title_more_info:
6328
            if 'fit_model' in adata.var:
6329
                title += f" M{int(adata[:,gene].var['fit_model'].values[0])}"
6330
            if 'fit_direction' in adata.var:
6331
                title += f" {adata[:,gene].var['fit_direction'].values[0]}"
6332
            if 'fit_likelihood' in adata.var \
6333
                    and not np.all(adata.var['fit_likelihood'].values == -1):
6334
                title += " "
6335
                f"{adata[:,gene].var['fit_likelihood'].values[0]:.3g}"
6336
        ax.set_title(f'{title}', fontsize=11)
6337
        if by == 'us':
6338
            ax.set_xlabel('spliced' if full_name else 's')
6339
            ax.set_ylabel('unspliced' if full_name else 'u')
6340
        elif by == 'cu':
6341
            ax.set_xlabel('unspliced' if full_name else 'u')
6342
            ax.set_ylabel('chromatin' if full_name else 'c')
6343
        elif by == 'cus':
6344
            ax.set_xlabel('spliced' if full_name else 's')
6345
            ax.set_ylabel('unspliced' if full_name else 'u')
6346
            ax.set_zlabel('chromatin' if full_name else 'c')
6347
        if by in ['us', 'cu']:
6348
            if not axis_on:
6349
                ax.xaxis.set_ticks_position('none')
6350
                ax.yaxis.set_ticks_position('none')
6351
                ax.get_xaxis().set_visible(False)
6352
                ax.get_yaxis().set_visible(False)
6353
            if not frame_on:
6354
                ax.xaxis.set_ticks_position('none')
6355
                ax.yaxis.set_ticks_position('none')
6356
                ax.set_frame_on(False)
6357
        elif by == 'cus':
6358
            if not axis_on:
6359
                ax.set_xlabel('')
6360
                ax.set_ylabel('')
6361
                ax.set_zlabel('')
6362
                ax.xaxis.set_ticklabels([])
6363
                ax.yaxis.set_ticklabels([])
6364
                ax.zaxis.set_ticklabels([])
6365
            if not frame_on:
6366
                ax.xaxis._axinfo['grid']['color'] = (1, 1, 1, 0)
6367
                ax.yaxis._axinfo['grid']['color'] = (1, 1, 1, 0)
6368
                ax.zaxis._axinfo['grid']['color'] = (1, 1, 1, 0)
6369
                ax.xaxis._axinfo['tick']['inward_factor'] = 0
6370
                ax.xaxis._axinfo['tick']['outward_factor'] = 0
6371
                ax.yaxis._axinfo['tick']['inward_factor'] = 0
6372
                ax.yaxis._axinfo['tick']['outward_factor'] = 0
6373
                ax.zaxis._axinfo['tick']['inward_factor'] = 0
6374
                ax.zaxis._axinfo['tick']['outward_factor'] = 0
6375
        count += 1
6376
    for i in range(col+1, n_cols):
6377
        fig.delaxes(axs[row, i])
6378
    fig.tight_layout()