|
a |
|
b/src/multivelo/dynamical_chrom_func.py |
|
|
1 |
from multivelo import mv_logging as logg |
|
|
2 |
from multivelo import settings |
|
|
3 |
|
|
|
4 |
import os |
|
|
5 |
import sys |
|
|
6 |
import numpy as np |
|
|
7 |
from numpy.linalg import norm |
|
|
8 |
import matplotlib.pyplot as plt |
|
|
9 |
from scipy import sparse |
|
|
10 |
from scipy.sparse import coo_matrix |
|
|
11 |
from scipy.optimize import minimize |
|
|
12 |
from scipy.spatial import KDTree |
|
|
13 |
from sklearn.metrics import pairwise_distances |
|
|
14 |
from sklearn.mixture import GaussianMixture |
|
|
15 |
from scanpy import Neighbors |
|
|
16 |
import scvelo as scv |
|
|
17 |
import pandas as pd |
|
|
18 |
import seaborn as sns |
|
|
19 |
from numba import njit |
|
|
20 |
import numba |
|
|
21 |
from numba.typed import List |
|
|
22 |
from tqdm.auto import tqdm |
|
|
23 |
from joblib import Parallel, delayed |
|
|
24 |
import math |
|
|
25 |
import torch |
|
|
26 |
from torch import nn |
|
|
27 |
|
|
|
28 |
current_path = os.path.dirname(__file__) |
|
|
29 |
src_path = os.path.join(current_path, "..") |
|
|
30 |
sys.path.append(src_path) |
|
|
31 |
|
|
|
32 |
|
|
|
33 |
# a funciton to check for invalid values of different parameters |
|
|
34 |
def check_params(alpha_c, |
|
|
35 |
alpha, |
|
|
36 |
beta, |
|
|
37 |
gamma, |
|
|
38 |
c0=None, |
|
|
39 |
u0=None, |
|
|
40 |
s0=None): |
|
|
41 |
|
|
|
42 |
new_alpha_c = alpha_c |
|
|
43 |
new_alpha = alpha |
|
|
44 |
new_beta = beta |
|
|
45 |
new_gamma = gamma |
|
|
46 |
|
|
|
47 |
new_c0 = c0 |
|
|
48 |
new_u0 = u0 |
|
|
49 |
new_s0 = s0 |
|
|
50 |
|
|
|
51 |
inf_fix = 1e10 |
|
|
52 |
zero_fix = 1e-10 |
|
|
53 |
|
|
|
54 |
# check if any of our parameters are infinite |
|
|
55 |
if c0 is not None and math.isinf(c0): |
|
|
56 |
logg.error("c0 is infinite.", v=1) |
|
|
57 |
new_c0 = inf_fix |
|
|
58 |
if u0 is not None and math.isinf(u0): |
|
|
59 |
logg.error("u0 is infinite.", v=1) |
|
|
60 |
new_u0 = inf_fix |
|
|
61 |
if s0 is not None and math.isinf(s0): |
|
|
62 |
logg.error("s0 is infinite.", v=1) |
|
|
63 |
new_s0 = inf_fix |
|
|
64 |
if math.isinf(alpha_c): |
|
|
65 |
new_alpha_c = inf_fix |
|
|
66 |
logg.error("alpha_c is infinite.", v=1) |
|
|
67 |
if math.isinf(alpha): |
|
|
68 |
new_alpha = inf_fix |
|
|
69 |
logg.error("alpha is infinite.", v=1) |
|
|
70 |
if math.isinf(beta): |
|
|
71 |
new_beta = inf_fix |
|
|
72 |
logg.error("beta is infinite.", v=1) |
|
|
73 |
if math.isinf(gamma): |
|
|
74 |
new_gamma = inf_fix |
|
|
75 |
logg.error("gamma is infinite.", v=1) |
|
|
76 |
|
|
|
77 |
# check if any of our parameters are nan |
|
|
78 |
if c0 is not None and math.isnan(c0): |
|
|
79 |
logg.error("c0 is Nan.", v=1) |
|
|
80 |
new_c0 = zero_fix |
|
|
81 |
if u0 is not None and math.isnan(u0): |
|
|
82 |
logg.error("u0 is Nan.", v=1) |
|
|
83 |
new_u0 = zero_fix |
|
|
84 |
if s0 is not None and math.isnan(s0): |
|
|
85 |
logg.error("s0 is Nan.", v=1) |
|
|
86 |
new_s0 = zero_fix |
|
|
87 |
if math.isnan(alpha_c): |
|
|
88 |
new_alpha_c = zero_fix |
|
|
89 |
logg.error("alpha_c is Nan.", v=1) |
|
|
90 |
if math.isnan(alpha): |
|
|
91 |
new_alpha = zero_fix |
|
|
92 |
logg.error("alpha is Nan.", v=1) |
|
|
93 |
if math.isnan(beta): |
|
|
94 |
new_beta = zero_fix |
|
|
95 |
logg.error("beta is Nan.", v=1) |
|
|
96 |
if math.isnan(gamma): |
|
|
97 |
new_gamma = zero_fix |
|
|
98 |
logg.error("gamma is Nan.", v=1) |
|
|
99 |
|
|
|
100 |
# check if any of our rate parameters are 0 |
|
|
101 |
if alpha_c < 1e-7: |
|
|
102 |
new_alpha_c = zero_fix |
|
|
103 |
logg.error("alpha_c is zero.", v=1) |
|
|
104 |
if alpha < 1e-7: |
|
|
105 |
new_alpha = zero_fix |
|
|
106 |
logg.error("alpha is zero.", v=1) |
|
|
107 |
if beta < 1e-7: |
|
|
108 |
new_beta = zero_fix |
|
|
109 |
logg.error("beta is zero.", v=1) |
|
|
110 |
if gamma < 1e-7: |
|
|
111 |
new_gamma = zero_fix |
|
|
112 |
logg.error("gamma is zero.", v=1) |
|
|
113 |
|
|
|
114 |
if beta == alpha_c: |
|
|
115 |
new_beta += zero_fix |
|
|
116 |
logg.error("alpha_c and beta are equal, leading to divide by zero", |
|
|
117 |
v=1) |
|
|
118 |
if beta == gamma: |
|
|
119 |
new_gamma += zero_fix |
|
|
120 |
logg.error("gamma and beta are equal, leading to divide by zero", |
|
|
121 |
v=1) |
|
|
122 |
if alpha_c == gamma: |
|
|
123 |
new_gamma += zero_fix |
|
|
124 |
logg.error("gamma and alpha_c are equal, leading to divide by zero", |
|
|
125 |
v=1) |
|
|
126 |
|
|
|
127 |
if c0 is not None and u0 is not None and s0 is not None: |
|
|
128 |
return new_alpha_c, new_alpha, new_beta, new_gamma, new_c0, new_u0, \ |
|
|
129 |
new_s0 |
|
|
130 |
|
|
|
131 |
return new_alpha_c, new_alpha, new_beta, new_gamma |
|
|
132 |
|
|
|
133 |
|
|
|
134 |
@njit( |
|
|
135 |
locals={ |
|
|
136 |
"res": numba.types.float64[:, ::1], |
|
|
137 |
"eat": numba.types.float64[::1], |
|
|
138 |
"ebt": numba.types.float64[::1], |
|
|
139 |
"egt": numba.types.float64[::1], |
|
|
140 |
}, |
|
|
141 |
fastmath=True) |
|
|
142 |
def predict_exp(tau, |
|
|
143 |
c0, |
|
|
144 |
u0, |
|
|
145 |
s0, |
|
|
146 |
alpha_c, |
|
|
147 |
alpha, |
|
|
148 |
beta, |
|
|
149 |
gamma, |
|
|
150 |
scale_cc=1, |
|
|
151 |
pred_r=True, |
|
|
152 |
chrom_open=True, |
|
|
153 |
backward=False, |
|
|
154 |
rna_only=False): |
|
|
155 |
|
|
|
156 |
if len(tau) == 0: |
|
|
157 |
return np.empty((0, 3)) |
|
|
158 |
if backward: |
|
|
159 |
tau = -tau |
|
|
160 |
res = np.empty((len(tau), 3)) |
|
|
161 |
eat = np.exp(-alpha_c * tau) |
|
|
162 |
ebt = np.exp(-beta * tau) |
|
|
163 |
egt = np.exp(-gamma * tau) |
|
|
164 |
if rna_only: |
|
|
165 |
kc = 1 |
|
|
166 |
c0 = 1 |
|
|
167 |
else: |
|
|
168 |
if chrom_open: |
|
|
169 |
kc = 1 |
|
|
170 |
else: |
|
|
171 |
kc = 0 |
|
|
172 |
alpha_c *= scale_cc |
|
|
173 |
|
|
|
174 |
const = (kc - c0) * alpha / (beta - alpha_c) |
|
|
175 |
|
|
|
176 |
res[:, 0] = kc - (kc - c0) * eat |
|
|
177 |
|
|
|
178 |
if pred_r: |
|
|
179 |
|
|
|
180 |
res[:, 1] = u0 * ebt + (alpha * kc / beta) * (1 - ebt) |
|
|
181 |
res[:, 1] += const * (ebt - eat) |
|
|
182 |
|
|
|
183 |
res[:, 2] = s0 * egt + (alpha * kc / gamma) * (1 - egt) |
|
|
184 |
res[:, 2] += ((beta / (gamma - beta)) * |
|
|
185 |
((alpha * kc / beta) - u0 - const) * (egt - ebt)) |
|
|
186 |
res[:, 2] += (beta / (gamma - alpha_c)) * const * (egt - eat) |
|
|
187 |
|
|
|
188 |
else: |
|
|
189 |
res[:, 1] = np.zeros(len(tau)) |
|
|
190 |
res[:, 2] = np.zeros(len(tau)) |
|
|
191 |
return res |
|
|
192 |
|
|
|
193 |
|
|
|
194 |
@njit(locals={ |
|
|
195 |
"exp_sw1": numba.types.float64[:, ::1], |
|
|
196 |
"exp_sw2": numba.types.float64[:, ::1], |
|
|
197 |
"exp_sw3": numba.types.float64[:, ::1], |
|
|
198 |
"exp1": numba.types.float64[:, ::1], |
|
|
199 |
"exp2": numba.types.float64[:, ::1], |
|
|
200 |
"exp3": numba.types.float64[:, ::1], |
|
|
201 |
"exp4": numba.types.float64[:, ::1], |
|
|
202 |
"tau_sw1": numba.types.float64[::1], |
|
|
203 |
"tau_sw2": numba.types.float64[::1], |
|
|
204 |
"tau_sw3": numba.types.float64[::1], |
|
|
205 |
"tau1": numba.types.float64[::1], |
|
|
206 |
"tau2": numba.types.float64[::1], |
|
|
207 |
"tau3": numba.types.float64[::1], |
|
|
208 |
"tau4": numba.types.float64[::1] |
|
|
209 |
}, |
|
|
210 |
fastmath=True) |
|
|
211 |
def generate_exp(tau_list, |
|
|
212 |
t_sw_array, |
|
|
213 |
alpha_c, |
|
|
214 |
alpha, |
|
|
215 |
beta, |
|
|
216 |
gamma, |
|
|
217 |
scale_cc=1, |
|
|
218 |
model=1, |
|
|
219 |
rna_only=False): |
|
|
220 |
|
|
|
221 |
if beta == alpha_c: |
|
|
222 |
beta += 1e-3 |
|
|
223 |
if gamma == beta or gamma == alpha_c: |
|
|
224 |
gamma += 1e-3 |
|
|
225 |
switch = len(t_sw_array) |
|
|
226 |
if switch >= 1: |
|
|
227 |
tau_sw1 = np.array([t_sw_array[0]]) |
|
|
228 |
if switch >= 2: |
|
|
229 |
tau_sw2 = np.array([t_sw_array[1] - t_sw_array[0]]) |
|
|
230 |
if switch == 3: |
|
|
231 |
tau_sw3 = np.array([t_sw_array[2] - t_sw_array[1]]) |
|
|
232 |
exp_sw1, exp_sw2, exp_sw3 = (np.empty((0, 3)), |
|
|
233 |
np.empty((0, 3)), |
|
|
234 |
np.empty((0, 3))) |
|
|
235 |
if tau_list is None: |
|
|
236 |
if model == 0: |
|
|
237 |
if switch >= 1: |
|
|
238 |
exp_sw1 = predict_exp(tau_sw1, 0, 0, 0, alpha_c, alpha, beta, |
|
|
239 |
gamma, pred_r=False, scale_cc=scale_cc, |
|
|
240 |
rna_only=rna_only) |
|
|
241 |
if switch >= 2: |
|
|
242 |
exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], |
|
|
243 |
exp_sw1[0, 1], exp_sw1[0, 2], |
|
|
244 |
alpha_c, alpha, beta, gamma, |
|
|
245 |
pred_r=False, chrom_open=False, |
|
|
246 |
scale_cc=scale_cc, rna_only=rna_only) |
|
|
247 |
if switch >= 3: |
|
|
248 |
exp_sw3 = predict_exp(tau_sw3, exp_sw2[0, 0], |
|
|
249 |
exp_sw2[0, 1], exp_sw2[0, 2], |
|
|
250 |
alpha_c, alpha, beta, gamma, |
|
|
251 |
chrom_open=False, |
|
|
252 |
scale_cc=scale_cc, |
|
|
253 |
rna_only=rna_only) |
|
|
254 |
elif model == 1: |
|
|
255 |
if switch >= 1: |
|
|
256 |
exp_sw1 = predict_exp(tau_sw1, 0, 0, 0, alpha_c, alpha, beta, |
|
|
257 |
gamma, pred_r=False, scale_cc=scale_cc, |
|
|
258 |
rna_only=rna_only) |
|
|
259 |
if switch >= 2: |
|
|
260 |
exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], |
|
|
261 |
exp_sw1[0, 1], exp_sw1[0, 2], |
|
|
262 |
alpha_c, alpha, beta, gamma, |
|
|
263 |
scale_cc=scale_cc, rna_only=rna_only) |
|
|
264 |
if switch >= 3: |
|
|
265 |
exp_sw3 = predict_exp(tau_sw3, exp_sw2[0, 0], |
|
|
266 |
exp_sw2[0, 1], exp_sw2[0, 2], |
|
|
267 |
alpha_c, alpha, beta, gamma, |
|
|
268 |
chrom_open=False, |
|
|
269 |
scale_cc=scale_cc, |
|
|
270 |
rna_only=rna_only) |
|
|
271 |
elif model == 2: |
|
|
272 |
if switch >= 1: |
|
|
273 |
exp_sw1 = predict_exp(tau_sw1, 0, 0, 0, alpha_c, alpha, beta, |
|
|
274 |
gamma, pred_r=False, scale_cc=scale_cc, |
|
|
275 |
rna_only=rna_only) |
|
|
276 |
if switch >= 2: |
|
|
277 |
exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], |
|
|
278 |
exp_sw1[0, 1], exp_sw1[0, 2], |
|
|
279 |
alpha_c, alpha, beta, gamma, |
|
|
280 |
scale_cc=scale_cc, rna_only=rna_only) |
|
|
281 |
if switch >= 3: |
|
|
282 |
exp_sw3 = predict_exp(tau_sw3, exp_sw2[0, 0], |
|
|
283 |
exp_sw2[0, 1], exp_sw2[0, 2], |
|
|
284 |
alpha_c, 0, beta, gamma, |
|
|
285 |
scale_cc=scale_cc, |
|
|
286 |
rna_only=rna_only) |
|
|
287 |
|
|
|
288 |
return (np.empty((0, 3)), np.empty((0, 3)), np.empty((0, 3)), |
|
|
289 |
np.empty((0, 3))), (exp_sw1, exp_sw2, exp_sw3) |
|
|
290 |
|
|
|
291 |
tau1 = tau_list[0] |
|
|
292 |
if switch >= 1: |
|
|
293 |
tau2 = tau_list[1] |
|
|
294 |
if switch >= 2: |
|
|
295 |
tau3 = tau_list[2] |
|
|
296 |
if switch == 3: |
|
|
297 |
tau4 = tau_list[3] |
|
|
298 |
exp1, exp2, exp3, exp4 = (np.empty((0, 3)), np.empty((0, 3)), |
|
|
299 |
np.empty((0, 3)), np.empty((0, 3))) |
|
|
300 |
if model == 0: |
|
|
301 |
exp1 = predict_exp(tau1, 0, 0, 0, alpha_c, alpha, beta, gamma, |
|
|
302 |
pred_r=False, scale_cc=scale_cc, rna_only=rna_only) |
|
|
303 |
if switch >= 1: |
|
|
304 |
exp_sw1 = predict_exp(tau_sw1, 0, 0, 0, alpha_c, alpha, beta, |
|
|
305 |
gamma, pred_r=False, scale_cc=scale_cc, |
|
|
306 |
rna_only=rna_only) |
|
|
307 |
exp2 = predict_exp(tau2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
308 |
exp_sw1[0, 2], alpha_c, alpha, beta, gamma, |
|
|
309 |
pred_r=False, chrom_open=False, |
|
|
310 |
scale_cc=scale_cc, rna_only=rna_only) |
|
|
311 |
if switch >= 2: |
|
|
312 |
exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
313 |
exp_sw1[0, 2], alpha_c, alpha, beta, |
|
|
314 |
gamma, pred_r=False, chrom_open=False, |
|
|
315 |
scale_cc=scale_cc, rna_only=rna_only) |
|
|
316 |
exp3 = predict_exp(tau3, exp_sw2[0, 0], exp_sw2[0, 1], |
|
|
317 |
exp_sw2[0, 2], alpha_c, alpha, beta, gamma, |
|
|
318 |
chrom_open=False, scale_cc=scale_cc, |
|
|
319 |
rna_only=rna_only) |
|
|
320 |
if switch == 3: |
|
|
321 |
exp_sw3 = predict_exp(tau_sw3, exp_sw2[0, 0], |
|
|
322 |
exp_sw2[0, 1], exp_sw2[0, 2], |
|
|
323 |
alpha_c, alpha, beta, gamma, |
|
|
324 |
chrom_open=False, scale_cc=scale_cc, |
|
|
325 |
rna_only=rna_only) |
|
|
326 |
exp4 = predict_exp(tau4, exp_sw3[0, 0], exp_sw3[0, 1], |
|
|
327 |
exp_sw3[0, 2], alpha_c, 0, beta, gamma, |
|
|
328 |
chrom_open=False, scale_cc=scale_cc, |
|
|
329 |
rna_only=rna_only) |
|
|
330 |
elif model == 1: |
|
|
331 |
exp1 = predict_exp(tau1, 0, 0, 0, alpha_c, alpha, beta, gamma, |
|
|
332 |
pred_r=False, scale_cc=scale_cc, rna_only=rna_only) |
|
|
333 |
if switch >= 1: |
|
|
334 |
exp_sw1 = predict_exp(tau_sw1, 0, 0, 0, alpha_c, alpha, beta, |
|
|
335 |
gamma, pred_r=False, scale_cc=scale_cc, |
|
|
336 |
rna_only=rna_only) |
|
|
337 |
exp2 = predict_exp(tau2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
338 |
exp_sw1[0, 2], alpha_c, alpha, beta, gamma, |
|
|
339 |
scale_cc=scale_cc, rna_only=rna_only) |
|
|
340 |
if switch >= 2: |
|
|
341 |
exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
342 |
exp_sw1[0, 2], alpha_c, alpha, beta, |
|
|
343 |
gamma, scale_cc=scale_cc, |
|
|
344 |
rna_only=rna_only) |
|
|
345 |
exp3 = predict_exp(tau3, exp_sw2[0, 0], exp_sw2[0, 1], |
|
|
346 |
exp_sw2[0, 2], alpha_c, alpha, beta, gamma, |
|
|
347 |
chrom_open=False, scale_cc=scale_cc, |
|
|
348 |
rna_only=rna_only) |
|
|
349 |
if switch == 3: |
|
|
350 |
exp_sw3 = predict_exp(tau_sw3, exp_sw2[0, 0], |
|
|
351 |
exp_sw2[0, 1], exp_sw2[0, 2], |
|
|
352 |
alpha_c, alpha, beta, gamma, |
|
|
353 |
chrom_open=False, scale_cc=scale_cc, |
|
|
354 |
rna_only=rna_only) |
|
|
355 |
exp4 = predict_exp(tau4, exp_sw3[0, 0], exp_sw3[0, 1], |
|
|
356 |
exp_sw3[0, 2], alpha_c, 0, beta, gamma, |
|
|
357 |
chrom_open=False, scale_cc=scale_cc, |
|
|
358 |
rna_only=rna_only) |
|
|
359 |
elif model == 2: |
|
|
360 |
exp1 = predict_exp(tau1, 0, 0, 0, alpha_c, alpha, beta, gamma, |
|
|
361 |
pred_r=False, scale_cc=scale_cc, rna_only=rna_only) |
|
|
362 |
if switch >= 1: |
|
|
363 |
exp_sw1 = predict_exp(tau_sw1, 0, 0, 0, alpha_c, alpha, beta, |
|
|
364 |
gamma, pred_r=False, scale_cc=scale_cc, |
|
|
365 |
rna_only=rna_only) |
|
|
366 |
exp2 = predict_exp(tau2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
367 |
exp_sw1[0, 2], alpha_c, alpha, beta, gamma, |
|
|
368 |
scale_cc=scale_cc, rna_only=rna_only) |
|
|
369 |
if switch >= 2: |
|
|
370 |
exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
371 |
exp_sw1[0, 2], alpha_c, alpha, beta, |
|
|
372 |
gamma, scale_cc=scale_cc, |
|
|
373 |
rna_only=rna_only) |
|
|
374 |
exp3 = predict_exp(tau3, exp_sw2[0, 0], exp_sw2[0, 1], |
|
|
375 |
exp_sw2[0, 2], alpha_c, 0, beta, gamma, |
|
|
376 |
scale_cc=scale_cc, rna_only=rna_only) |
|
|
377 |
if switch == 3: |
|
|
378 |
exp_sw3 = predict_exp(tau_sw3, exp_sw2[0, 0], |
|
|
379 |
exp_sw2[0, 1], exp_sw2[0, 2], |
|
|
380 |
alpha_c, 0, beta, gamma, |
|
|
381 |
scale_cc=scale_cc, rna_only=rna_only) |
|
|
382 |
exp4 = predict_exp(tau4, exp_sw3[0, 0], exp_sw3[0, 1], |
|
|
383 |
exp_sw3[0, 2], alpha_c, 0, beta, gamma, |
|
|
384 |
chrom_open=False, scale_cc=scale_cc, |
|
|
385 |
rna_only=rna_only) |
|
|
386 |
return (exp1, exp2, exp3, exp4), (exp_sw1, exp_sw2, exp_sw3) |
|
|
387 |
|
|
|
388 |
|
|
|
389 |
@njit(locals={ |
|
|
390 |
"exp_sw1": numba.types.float64[:, ::1], |
|
|
391 |
"exp_sw2": numba.types.float64[:, ::1], |
|
|
392 |
"exp_sw3": numba.types.float64[:, ::1], |
|
|
393 |
"exp1": numba.types.float64[:, ::1], |
|
|
394 |
"exp2": numba.types.float64[:, ::1], |
|
|
395 |
"exp3": numba.types.float64[:, ::1], |
|
|
396 |
"exp4": numba.types.float64[:, ::1], |
|
|
397 |
"tau_sw1": numba.types.float64[::1], |
|
|
398 |
"tau_sw2": numba.types.float64[::1], |
|
|
399 |
"tau_sw3": numba.types.float64[::1], |
|
|
400 |
"tau1": numba.types.float64[::1], |
|
|
401 |
"tau2": numba.types.float64[::1], |
|
|
402 |
"tau3": numba.types.float64[::1], |
|
|
403 |
"tau4": numba.types.float64[::1] |
|
|
404 |
}, |
|
|
405 |
fastmath=True) |
|
|
406 |
def generate_exp_backward(tau_list, t_sw_array, alpha_c, alpha, beta, gamma, |
|
|
407 |
scale_cc=1, model=1): |
|
|
408 |
if beta == alpha_c: |
|
|
409 |
beta += 1e-3 |
|
|
410 |
if gamma == beta or gamma == alpha_c: |
|
|
411 |
gamma += 1e-3 |
|
|
412 |
switch = len(t_sw_array) |
|
|
413 |
if switch >= 1: |
|
|
414 |
tau_sw1 = np.array([t_sw_array[0]]) |
|
|
415 |
if switch >= 2: |
|
|
416 |
tau_sw2 = np.array([t_sw_array[1] - t_sw_array[0]]) |
|
|
417 |
if t is None: |
|
|
418 |
if model == 0: |
|
|
419 |
exp_sw1 = predict_exp(tau_sw1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, |
|
|
420 |
gamma, scale_cc=scale_cc, chrom_open=False, |
|
|
421 |
backward=True) |
|
|
422 |
exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
423 |
exp_sw1[0, 2], alpha_c, alpha, beta, gamma, |
|
|
424 |
scale_cc=scale_cc, chrom_open=False, |
|
|
425 |
backward=True) |
|
|
426 |
elif model == 1: |
|
|
427 |
exp_sw1 = predict_exp(tau_sw1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, |
|
|
428 |
gamma, scale_cc=scale_cc, chrom_open=False, |
|
|
429 |
backward=True) |
|
|
430 |
exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
431 |
exp_sw1[0, 2], alpha_c, alpha, beta, gamma, |
|
|
432 |
scale_cc=scale_cc, chrom_open=False, |
|
|
433 |
backward=True) |
|
|
434 |
elif model == 2: |
|
|
435 |
exp_sw1 = predict_exp(tau_sw1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, |
|
|
436 |
gamma, scale_cc=scale_cc, chrom_open=False, |
|
|
437 |
backward=True) |
|
|
438 |
exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
439 |
exp_sw1[0, 2], alpha_c, 0, beta, gamma, |
|
|
440 |
scale_cc=scale_cc, backward=True) |
|
|
441 |
return (np.empty((0, 0)), |
|
|
442 |
np.empty((0, 0)), |
|
|
443 |
np.empty((0, 0))), (exp_sw1, exp_sw2) |
|
|
444 |
|
|
|
445 |
tau1 = tau_list[0] |
|
|
446 |
if switch >= 1: |
|
|
447 |
tau2 = tau_list[1] |
|
|
448 |
if switch >= 2: |
|
|
449 |
tau3 = tau_list[2] |
|
|
450 |
|
|
|
451 |
exp1, exp2, exp3 = np.empty((0, 3)), np.empty((0, 3)), np.empty((0, 3)) |
|
|
452 |
if model == 0: |
|
|
453 |
exp1 = predict_exp(tau1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, gamma, |
|
|
454 |
scale_cc=scale_cc, chrom_open=False, backward=True) |
|
|
455 |
if switch >= 1: |
|
|
456 |
exp_sw1 = predict_exp(tau_sw1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, |
|
|
457 |
gamma, scale_cc=scale_cc, chrom_open=False, |
|
|
458 |
backward=True) |
|
|
459 |
exp2 = predict_exp(tau2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
460 |
exp_sw1[0, 2], alpha_c, alpha, beta, gamma, |
|
|
461 |
scale_cc=scale_cc, chrom_open=False, |
|
|
462 |
backward=True) |
|
|
463 |
if switch >= 2: |
|
|
464 |
exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
465 |
exp_sw1[0, 2], alpha_c, alpha, beta, |
|
|
466 |
gamma, scale_cc=scale_cc, |
|
|
467 |
chrom_open=False, backward=True) |
|
|
468 |
exp3 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
469 |
exp_sw1[0, 2], alpha_c, alpha, beta, gamma, |
|
|
470 |
scale_cc=scale_cc, chrom_open=False, |
|
|
471 |
backward=True) |
|
|
472 |
elif model == 1: |
|
|
473 |
exp1 = predict_exp(tau1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, gamma, |
|
|
474 |
scale_cc=scale_cc, chrom_open=False, backward=True) |
|
|
475 |
if switch >= 1: |
|
|
476 |
exp_sw1 = predict_exp(tau_sw1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, |
|
|
477 |
gamma, scale_cc=scale_cc, chrom_open=False, |
|
|
478 |
backward=True) |
|
|
479 |
exp2 = predict_exp(tau2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
480 |
exp_sw1[0, 2], alpha_c, alpha, beta, gamma, |
|
|
481 |
scale_cc=scale_cc, chrom_open=False, |
|
|
482 |
backward=True) |
|
|
483 |
if switch >= 2: |
|
|
484 |
exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
485 |
exp_sw1[0, 2], alpha_c, alpha, beta, |
|
|
486 |
gamma, scale_cc=scale_cc, |
|
|
487 |
chrom_open=False, backward=True) |
|
|
488 |
exp3 = predict_exp(tau3, exp_sw2[0, 0], exp_sw2[0, 1], |
|
|
489 |
exp_sw2[0, 2], alpha_c, alpha, beta, gamma, |
|
|
490 |
scale_cc=scale_cc, backward=True) |
|
|
491 |
elif model == 2: |
|
|
492 |
exp1 = predict_exp(tau1, 1e-3, 1e-3, 1e-3, alpha_c, 0, beta, gamma, |
|
|
493 |
scale_cc=scale_cc, chrom_open=False, backward=True) |
|
|
494 |
if switch >= 1: |
|
|
495 |
exp_sw1 = predict_exp(tau_sw1, 1e-3, 1e-3, 1e-3, alpha_c, alpha, |
|
|
496 |
beta, gamma, scale_cc=scale_cc, |
|
|
497 |
chrom_open=False, backward=True) |
|
|
498 |
exp2 = predict_exp(tau2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
499 |
exp_sw1[0, 2], alpha_c, 0, beta, gamma, |
|
|
500 |
scale_cc=scale_cc, backward=True) |
|
|
501 |
if switch >= 2: |
|
|
502 |
exp_sw2 = predict_exp(tau_sw2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
503 |
exp_sw1[0, 2], alpha_c, 0, beta, gamma, |
|
|
504 |
scale_cc=scale_cc, backward=True) |
|
|
505 |
exp3 = predict_exp(tau3, exp_sw2[0, 0], exp_sw2[0, 1], |
|
|
506 |
exp_sw2[0, 2], alpha_c, alpha, beta, gamma, |
|
|
507 |
scale_cc=scale_cc, backward=True) |
|
|
508 |
return (exp1, exp2, exp3), (exp_sw1, exp_sw2) |
|
|
509 |
|
|
|
510 |
|
|
|
511 |
@njit(locals={ |
|
|
512 |
"res": numba.types.float64[:, ::1], |
|
|
513 |
}, |
|
|
514 |
fastmath=True) |
|
|
515 |
def ss_exp(alpha_c, alpha, beta, gamma, pred_r=True, chrom_open=True): |
|
|
516 |
res = np.empty((1, 3)) |
|
|
517 |
if not chrom_open: |
|
|
518 |
res[0, 0] = 0 |
|
|
519 |
res[0, 1] = 0 |
|
|
520 |
res[0, 2] = 0 |
|
|
521 |
else: |
|
|
522 |
res[0, 0] = 1 |
|
|
523 |
if pred_r: |
|
|
524 |
res[0, 1] = alpha / beta |
|
|
525 |
res[0, 2] = alpha / gamma |
|
|
526 |
else: |
|
|
527 |
res[0, 1] = 0 |
|
|
528 |
res[0, 2] = 0 |
|
|
529 |
return res |
|
|
530 |
|
|
|
531 |
|
|
|
532 |
@njit(locals={ |
|
|
533 |
"ss1": numba.types.float64[:, ::1], |
|
|
534 |
"ss2": numba.types.float64[:, ::1], |
|
|
535 |
"ss3": numba.types.float64[:, ::1], |
|
|
536 |
"ss4": numba.types.float64[:, ::1] |
|
|
537 |
}, |
|
|
538 |
fastmath=True) |
|
|
539 |
def compute_ss_exp(alpha_c, alpha, beta, gamma, model=0): |
|
|
540 |
if model == 0: |
|
|
541 |
ss1 = ss_exp(alpha_c, alpha, beta, gamma, pred_r=False) |
|
|
542 |
ss2 = ss_exp(alpha_c, alpha, beta, gamma, pred_r=False, |
|
|
543 |
chrom_open=False) |
|
|
544 |
ss3 = ss_exp(alpha_c, alpha, beta, gamma, chrom_open=False) |
|
|
545 |
ss4 = ss_exp(alpha_c, 0, beta, gamma, chrom_open=False) |
|
|
546 |
elif model == 1: |
|
|
547 |
ss1 = ss_exp(alpha_c, alpha, beta, gamma, pred_r=False) |
|
|
548 |
ss2 = ss_exp(alpha_c, alpha, beta, gamma) |
|
|
549 |
ss3 = ss_exp(alpha_c, alpha, beta, gamma, chrom_open=False) |
|
|
550 |
ss4 = ss_exp(alpha_c, 0, beta, gamma, chrom_open=False) |
|
|
551 |
elif model == 2: |
|
|
552 |
ss1 = ss_exp(alpha_c, alpha, beta, gamma, pred_r=False) |
|
|
553 |
ss2 = ss_exp(alpha_c, alpha, beta, gamma) |
|
|
554 |
ss3 = ss_exp(alpha_c, 0, beta, gamma) |
|
|
555 |
ss4 = ss_exp(alpha_c, 0, beta, gamma, chrom_open=False) |
|
|
556 |
return np.vstack((ss1, ss2, ss3, ss4)) |
|
|
557 |
|
|
|
558 |
|
|
|
559 |
@njit(fastmath=True) |
|
|
560 |
def velocity_equations(c, u, s, alpha_c, alpha, beta, gamma, scale_cc=1, |
|
|
561 |
pred_r=True, chrom_open=True, rna_only=False): |
|
|
562 |
if rna_only: |
|
|
563 |
c = np.full(len(u), 1.0) |
|
|
564 |
if not chrom_open: |
|
|
565 |
alpha_c *= scale_cc |
|
|
566 |
if pred_r: |
|
|
567 |
return -alpha_c * c, alpha * c - beta * u, beta * u - gamma * s |
|
|
568 |
else: |
|
|
569 |
return -alpha_c * c, np.zeros(len(u)), np.zeros(len(u)) |
|
|
570 |
else: |
|
|
571 |
if pred_r: |
|
|
572 |
return (alpha_c - alpha_c * c), (alpha * c - beta * u), (beta * u |
|
|
573 |
- gamma |
|
|
574 |
* s) |
|
|
575 |
else: |
|
|
576 |
return alpha_c - alpha_c * c, np.zeros(len(u)), np.zeros(len(u)) |
|
|
577 |
|
|
|
578 |
|
|
|
579 |
@njit(locals={ |
|
|
580 |
"state0": numba.types.boolean[::1], |
|
|
581 |
"state1": numba.types.boolean[::1], |
|
|
582 |
"state2": numba.types.boolean[::1], |
|
|
583 |
"state3": numba.types.boolean[::1], |
|
|
584 |
"tau1": numba.types.float64[::1], |
|
|
585 |
"tau2": numba.types.float64[::1], |
|
|
586 |
"tau3": numba.types.float64[::1], |
|
|
587 |
"tau4": numba.types.float64[::1], |
|
|
588 |
"exp_list": numba.types.Tuple((numba.types.float64[:, ::1], |
|
|
589 |
numba.types.float64[:, ::1], |
|
|
590 |
numba.types.float64[:, ::1], |
|
|
591 |
numba.types.float64[:, ::1])), |
|
|
592 |
"exp_sw_list": numba.types.Tuple((numba.types.float64[:, ::1], |
|
|
593 |
numba.types.float64[:, ::1], |
|
|
594 |
numba.types.float64[:, ::1])), |
|
|
595 |
"c": numba.types.float64[::1], |
|
|
596 |
"u": numba.types.float64[::1], |
|
|
597 |
"s": numba.types.float64[::1], |
|
|
598 |
"vc_vec": numba.types.float64[::1], |
|
|
599 |
"vu_vec": numba.types.float64[::1], |
|
|
600 |
"vs_vec": numba.types.float64[::1] |
|
|
601 |
}, |
|
|
602 |
fastmath=True) |
|
|
603 |
def compute_velocity(t, |
|
|
604 |
t_sw_array, |
|
|
605 |
state, |
|
|
606 |
alpha_c, |
|
|
607 |
alpha, |
|
|
608 |
beta, |
|
|
609 |
gamma, |
|
|
610 |
rescale_c, |
|
|
611 |
rescale_u, |
|
|
612 |
scale_cc=1, |
|
|
613 |
model=1, |
|
|
614 |
total_h=20, |
|
|
615 |
rna_only=False): |
|
|
616 |
|
|
|
617 |
if state is None: |
|
|
618 |
state0 = t <= t_sw_array[0] |
|
|
619 |
state1 = (t_sw_array[0] < t) & (t <= t_sw_array[1]) |
|
|
620 |
state2 = (t_sw_array[1] < t) & (t <= t_sw_array[2]) |
|
|
621 |
state3 = t_sw_array[2] < t |
|
|
622 |
else: |
|
|
623 |
state0 = np.equal(state, 0) |
|
|
624 |
state1 = np.equal(state, 1) |
|
|
625 |
state2 = np.equal(state, 2) |
|
|
626 |
state3 = np.equal(state, 3) |
|
|
627 |
|
|
|
628 |
tau1 = t[state0] |
|
|
629 |
tau2 = t[state1] - t_sw_array[0] |
|
|
630 |
tau3 = t[state2] - t_sw_array[1] |
|
|
631 |
tau4 = t[state3] - t_sw_array[2] |
|
|
632 |
tau_list = [tau1, tau2, tau3, tau4] |
|
|
633 |
switch = np.sum(t_sw_array < total_h) |
|
|
634 |
typed_tau_list = List() |
|
|
635 |
[typed_tau_list.append(x) for x in tau_list] |
|
|
636 |
exp_list, exp_sw_list = generate_exp(typed_tau_list, |
|
|
637 |
t_sw_array[:switch], |
|
|
638 |
alpha_c, |
|
|
639 |
alpha, |
|
|
640 |
beta, |
|
|
641 |
gamma, |
|
|
642 |
model=model, |
|
|
643 |
scale_cc=scale_cc, |
|
|
644 |
rna_only=rna_only) |
|
|
645 |
|
|
|
646 |
c = np.empty(len(t)) |
|
|
647 |
u = np.empty(len(t)) |
|
|
648 |
s = np.empty(len(t)) |
|
|
649 |
for i, ii in enumerate([state0, state1, state2, state3]): |
|
|
650 |
if np.any(ii): |
|
|
651 |
c[ii] = exp_list[i][:, 0] |
|
|
652 |
u[ii] = exp_list[i][:, 1] |
|
|
653 |
s[ii] = exp_list[i][:, 2] |
|
|
654 |
|
|
|
655 |
vc_vec = np.zeros(len(u)) |
|
|
656 |
vu_vec = np.zeros(len(u)) |
|
|
657 |
vs_vec = np.zeros(len(u)) |
|
|
658 |
|
|
|
659 |
if model == 0: |
|
|
660 |
if np.any(state0): |
|
|
661 |
vc_vec[state0], vu_vec[state0], vs_vec[state0] = \ |
|
|
662 |
velocity_equations(c[state0], u[state0], s[state0], alpha_c, |
|
|
663 |
alpha, beta, gamma, pred_r=False, |
|
|
664 |
scale_cc=scale_cc, rna_only=rna_only) |
|
|
665 |
if np.any(state1): |
|
|
666 |
vc_vec[state1], vu_vec[state1], vs_vec[state1] = \ |
|
|
667 |
velocity_equations(c[state1], u[state1], s[state1], alpha_c, |
|
|
668 |
alpha, beta, gamma, pred_r=False, |
|
|
669 |
chrom_open=False, scale_cc=scale_cc, |
|
|
670 |
rna_only=rna_only) |
|
|
671 |
if np.any(state2): |
|
|
672 |
vc_vec[state2], vu_vec[state2], vs_vec[state2] = \ |
|
|
673 |
velocity_equations(c[state2], u[state2], s[state2], alpha_c, |
|
|
674 |
alpha, beta, gamma, chrom_open=False, |
|
|
675 |
scale_cc=scale_cc, rna_only=rna_only) |
|
|
676 |
if np.any(state3): |
|
|
677 |
vc_vec[state3], vu_vec[state3], vs_vec[state3] = \ |
|
|
678 |
velocity_equations(c[state3], u[state3], s[state3], alpha_c, 0, |
|
|
679 |
beta, gamma, chrom_open=False, |
|
|
680 |
scale_cc=scale_cc, rna_only=rna_only) |
|
|
681 |
elif model == 1: |
|
|
682 |
if np.any(state0): |
|
|
683 |
vc_vec[state0], vu_vec[state0], vs_vec[state0] = \ |
|
|
684 |
velocity_equations(c[state0], u[state0], s[state0], alpha_c, |
|
|
685 |
alpha, beta, gamma, pred_r=False, |
|
|
686 |
scale_cc=scale_cc, rna_only=rna_only) |
|
|
687 |
if np.any(state1): |
|
|
688 |
vc_vec[state1], vu_vec[state1], vs_vec[state1] = \ |
|
|
689 |
velocity_equations(c[state1], u[state1], s[state1], alpha_c, |
|
|
690 |
alpha, beta, gamma, scale_cc=scale_cc, |
|
|
691 |
rna_only=rna_only) |
|
|
692 |
if np.any(state2): |
|
|
693 |
vc_vec[state2], vu_vec[state2], vs_vec[state2] = \ |
|
|
694 |
velocity_equations(c[state2], u[state2], s[state2], alpha_c, |
|
|
695 |
alpha, beta, gamma, chrom_open=False, |
|
|
696 |
scale_cc=scale_cc, rna_only=rna_only) |
|
|
697 |
if np.any(state3): |
|
|
698 |
vc_vec[state3], vu_vec[state3], vs_vec[state3] = \ |
|
|
699 |
velocity_equations(c[state3], u[state3], s[state3], alpha_c, 0, |
|
|
700 |
beta, gamma, chrom_open=False, |
|
|
701 |
scale_cc=scale_cc, rna_only=rna_only) |
|
|
702 |
elif model == 2: |
|
|
703 |
if np.any(state0): |
|
|
704 |
vc_vec[state0], vu_vec[state0], vs_vec[state0] = \ |
|
|
705 |
velocity_equations(c[state0], u[state0], s[state0], alpha_c, |
|
|
706 |
alpha, beta, gamma, pred_r=False, |
|
|
707 |
scale_cc=scale_cc, rna_only=rna_only) |
|
|
708 |
if np.any(state1): |
|
|
709 |
vc_vec[state1], vu_vec[state1], vs_vec[state1] = \ |
|
|
710 |
velocity_equations(c[state1], u[state1], s[state1], alpha_c, |
|
|
711 |
alpha, beta, gamma, scale_cc=scale_cc, |
|
|
712 |
rna_only=rna_only) |
|
|
713 |
if np.any(state2): |
|
|
714 |
vc_vec[state2], vu_vec[state2], vs_vec[state2] = \ |
|
|
715 |
velocity_equations(c[state2], u[state2], s[state2], alpha_c, |
|
|
716 |
0, beta, gamma, scale_cc=scale_cc, |
|
|
717 |
rna_only=rna_only) |
|
|
718 |
if np.any(state3): |
|
|
719 |
vc_vec[state3], vu_vec[state3], vs_vec[state3] = \ |
|
|
720 |
velocity_equations(c[state3], u[state3], s[state3], alpha_c, 0, |
|
|
721 |
beta, gamma, chrom_open=False, |
|
|
722 |
scale_cc=scale_cc, rna_only=rna_only) |
|
|
723 |
return vc_vec * rescale_c, vu_vec * rescale_u, vs_vec |
|
|
724 |
|
|
|
725 |
|
|
|
726 |
def log_valid(x): |
|
|
727 |
return np.log(np.clip(x, 1e-3, 1 - 1e-3)) |
|
|
728 |
|
|
|
729 |
|
|
|
730 |
def approx_tau(u, s, u0, s0, alpha, beta, gamma): |
|
|
731 |
if gamma == beta: |
|
|
732 |
gamma -= 1e-3 |
|
|
733 |
u_inf = alpha / beta |
|
|
734 |
if beta > gamma: |
|
|
735 |
b_new = beta / (gamma - beta) |
|
|
736 |
s_inf = alpha / gamma |
|
|
737 |
s_inf_new = s_inf - b_new * u_inf |
|
|
738 |
s_new = s - b_new * u |
|
|
739 |
s0_new = s0 - b_new * u0 |
|
|
740 |
tau = -1.0 / gamma * log_valid((s_new - s_inf_new) / |
|
|
741 |
(s0_new - s_inf_new)) |
|
|
742 |
else: |
|
|
743 |
tau = -1.0 / beta * log_valid((u - u_inf) / (u0 - u_inf)) |
|
|
744 |
return tau |
|
|
745 |
|
|
|
746 |
|
|
|
747 |
def anchor_points(t_sw_array, total_h=20, t=1000, mode='uniform', |
|
|
748 |
return_time=False): |
|
|
749 |
t_ = np.linspace(0, total_h, t) |
|
|
750 |
tau1 = t_[t_ <= t_sw_array[0]] |
|
|
751 |
tau2 = t_[(t_sw_array[0] < t_) & (t_ <= t_sw_array[1])] - t_sw_array[0] |
|
|
752 |
tau3 = t_[(t_sw_array[1] < t_) & (t_ <= t_sw_array[2])] - t_sw_array[1] |
|
|
753 |
tau4 = t_[t_sw_array[2] < t_] - t_sw_array[2] |
|
|
754 |
|
|
|
755 |
if mode == 'log': |
|
|
756 |
if len(tau1) > 0: |
|
|
757 |
tau1 = np.expm1(tau1) |
|
|
758 |
tau1 = tau1 / np.max(tau1) * (t_sw_array[0]) |
|
|
759 |
if len(tau2) > 0: |
|
|
760 |
tau2 = np.expm1(tau2) |
|
|
761 |
tau2 = tau2 / np.max(tau2) * (t_sw_array[1] - t_sw_array[0]) |
|
|
762 |
if len(tau3) > 0: |
|
|
763 |
tau3 = np.expm1(tau3) |
|
|
764 |
tau3 = tau3 / np.max(tau3) * (t_sw_array[2] - t_sw_array[1]) |
|
|
765 |
if len(tau4) > 0: |
|
|
766 |
tau4 = np.expm1(tau4) |
|
|
767 |
tau4 = tau4 / np.max(tau4) * (total_h - t_sw_array[2]) |
|
|
768 |
|
|
|
769 |
tau_list = [tau1, tau2, tau3, tau4] |
|
|
770 |
if return_time: |
|
|
771 |
return t_, tau_list |
|
|
772 |
else: |
|
|
773 |
return tau_list |
|
|
774 |
|
|
|
775 |
|
|
|
776 |
# @jit(nopython=True, fastmath=True, debug=True) |
|
|
777 |
def pairwise_distance_square(X, Y): |
|
|
778 |
res = np.empty((X.shape[0], Y.shape[0]), dtype=X.dtype) |
|
|
779 |
for a in range(X.shape[0]): |
|
|
780 |
for b in range(Y.shape[0]): |
|
|
781 |
val = 0.0 |
|
|
782 |
for i in range(X.shape[1]): |
|
|
783 |
tmp = X[a, i] - Y[b, i] |
|
|
784 |
val += tmp**2 |
|
|
785 |
res[a, b] = val |
|
|
786 |
return res |
|
|
787 |
|
|
|
788 |
|
|
|
789 |
def calculate_dist_and_time(c, u, s, |
|
|
790 |
t_sw_array, |
|
|
791 |
alpha_c, alpha, beta, gamma, |
|
|
792 |
rescale_c, rescale_u, |
|
|
793 |
scale_cc=1, |
|
|
794 |
scale_factor=None, |
|
|
795 |
model=1, |
|
|
796 |
conn=None, |
|
|
797 |
t=1000, k=1, |
|
|
798 |
direction='complete', |
|
|
799 |
total_h=20, |
|
|
800 |
rna_only=False, |
|
|
801 |
penalize_gap=True, |
|
|
802 |
all_cells=True): |
|
|
803 |
|
|
|
804 |
n = len(u) |
|
|
805 |
if scale_factor is None: |
|
|
806 |
scale_factor = np.array([np.std(c), np.std(u), np.std(s)]) |
|
|
807 |
tau_list = anchor_points(t_sw_array, total_h, t) |
|
|
808 |
switch = np.sum(t_sw_array < total_h) |
|
|
809 |
typed_tau_list = List() |
|
|
810 |
[typed_tau_list.append(x) for x in tau_list] |
|
|
811 |
alpha_c, alpha, beta, gamma = check_params(alpha_c, alpha, beta, gamma) |
|
|
812 |
exp_list, exp_sw_list = generate_exp(typed_tau_list, |
|
|
813 |
t_sw_array[:switch], |
|
|
814 |
alpha_c, |
|
|
815 |
alpha, |
|
|
816 |
beta, |
|
|
817 |
gamma, |
|
|
818 |
model=model, |
|
|
819 |
scale_cc=scale_cc, |
|
|
820 |
rna_only=rna_only) |
|
|
821 |
rescale_factor = np.array([rescale_c, rescale_u, 1.0]) |
|
|
822 |
exp_list = [x*rescale_factor for x in exp_list] |
|
|
823 |
exp_sw_list = [x*rescale_factor for x in exp_sw_list] |
|
|
824 |
max_c = 0 |
|
|
825 |
max_u = 0 |
|
|
826 |
max_s = 0 |
|
|
827 |
if rna_only: |
|
|
828 |
exp_mat = (np.hstack((np.reshape(u, (-1, 1)), np.reshape(s, (-1, 1)))) |
|
|
829 |
/ scale_factor[1:]) |
|
|
830 |
else: |
|
|
831 |
exp_mat = np.hstack((np.reshape(c, (-1, 1)), np.reshape(u, (-1, 1)), |
|
|
832 |
np.reshape(s, (-1, 1)))) / scale_factor |
|
|
833 |
|
|
|
834 |
dists = np.full((n, 4), np.inf) |
|
|
835 |
taus = np.zeros((n, 4), dtype=u.dtype) |
|
|
836 |
ts = np.zeros((n, 4), dtype=u.dtype) |
|
|
837 |
anchor_exp, anchor_t = None, None |
|
|
838 |
|
|
|
839 |
for i in range(switch+1): |
|
|
840 |
if not all_cells: |
|
|
841 |
max_ci = (np.max(exp_list[i][:, 0]) if exp_list[i].shape[0] > 0 |
|
|
842 |
else 0) |
|
|
843 |
max_c = max_ci if max_ci > max_c else max_c |
|
|
844 |
max_ui = np.max(exp_list[i][:, 1]) if exp_list[i].shape[0] > 0 else 0 |
|
|
845 |
max_u = max_ui if max_ui > max_u else max_u |
|
|
846 |
max_si = np.max(exp_list[i][:, 2]) if exp_list[i].shape[0] > 0 else 0 |
|
|
847 |
max_s = max_si if max_si > max_s else max_s |
|
|
848 |
|
|
|
849 |
skip_phase = False |
|
|
850 |
if direction == 'off': |
|
|
851 |
if (model in [1, 2]) and (i < 2): |
|
|
852 |
skip_phase = True |
|
|
853 |
elif direction == 'on': |
|
|
854 |
if (model in [1, 2]) and (i >= 2): |
|
|
855 |
skip_phase = True |
|
|
856 |
if rna_only and i == 0: |
|
|
857 |
skip_phase = True |
|
|
858 |
|
|
|
859 |
if not skip_phase: |
|
|
860 |
if rna_only: |
|
|
861 |
tmp = exp_list[i][:, 1:] / scale_factor[1:] |
|
|
862 |
else: |
|
|
863 |
tmp = exp_list[i] / scale_factor |
|
|
864 |
if anchor_exp is None: |
|
|
865 |
anchor_exp = exp_list[i] |
|
|
866 |
anchor_t = (tau_list[i] + t_sw_array[i-1] if i >= 1 |
|
|
867 |
else tau_list[i]) |
|
|
868 |
else: |
|
|
869 |
anchor_exp = np.vstack((anchor_exp, exp_list[i])) |
|
|
870 |
anchor_t = np.hstack((anchor_t, tau_list[i] + t_sw_array[i-1] |
|
|
871 |
if i >= 1 else tau_list[i])) |
|
|
872 |
|
|
|
873 |
if not all_cells: |
|
|
874 |
anchor_dist = np.diff(tmp, axis=0, prepend=np.zeros((1, 2)) |
|
|
875 |
if rna_only else np.zeros((1, 3))) |
|
|
876 |
anchor_dist = np.sqrt((anchor_dist**2).sum(axis=1)) |
|
|
877 |
remove_cand = anchor_dist < (0.01*np.max(exp_mat[1]) |
|
|
878 |
if rna_only |
|
|
879 |
else 0.01*np.max(exp_mat[2])) |
|
|
880 |
step_idx = np.arange(0, len(anchor_dist), 1) % 3 > 0 |
|
|
881 |
remove_cand &= step_idx |
|
|
882 |
keep_idx = np.where(~remove_cand)[0] |
|
|
883 |
tmp = tmp[keep_idx, :] |
|
|
884 |
|
|
|
885 |
tree = KDTree(tmp) |
|
|
886 |
dd, ii = tree.query(exp_mat, k=k) |
|
|
887 |
dd = dd**2 |
|
|
888 |
if k > 1: |
|
|
889 |
dd = np.mean(dd, axis=1) |
|
|
890 |
if conn is not None: |
|
|
891 |
dd = conn.dot(dd) |
|
|
892 |
dists[:, i] = dd |
|
|
893 |
|
|
|
894 |
if not all_cells: |
|
|
895 |
ii = keep_idx[ii] |
|
|
896 |
if k == 1: |
|
|
897 |
taus[:, i] = tau_list[i][ii] |
|
|
898 |
else: |
|
|
899 |
for j in range(n): |
|
|
900 |
taus[j, i] = tau_list[i][ii[j, :]] |
|
|
901 |
ts[:, i] = taus[:, i] + t_sw_array[i-1] if i >= 1 else taus[:, i] |
|
|
902 |
|
|
|
903 |
min_dist = np.min(dists, axis=1) |
|
|
904 |
state_pred = np.argmin(dists, axis=1) |
|
|
905 |
t_pred = ts[np.arange(n), state_pred] |
|
|
906 |
|
|
|
907 |
anchor_t1_list = [] |
|
|
908 |
anchor_t2_list = [] |
|
|
909 |
t_sw_adjust = np.zeros(3, dtype=u.dtype) |
|
|
910 |
|
|
|
911 |
if direction == 'complete': |
|
|
912 |
t_sorted = np.sort(t_pred) |
|
|
913 |
dt = np.diff(t_sorted, prepend=0) |
|
|
914 |
gap_thresh = 3*np.percentile(dt, 99) |
|
|
915 |
idx = np.where(dt > gap_thresh)[0] |
|
|
916 |
for i in idx: |
|
|
917 |
t1 = t_sorted[i-1] if i > 0 else 0 |
|
|
918 |
t2 = t_sorted[i] |
|
|
919 |
anchor_t1 = anchor_exp[np.argmin(np.abs(anchor_t - t1)), :] |
|
|
920 |
anchor_t2 = anchor_exp[np.argmin(np.abs(anchor_t - t2)), :] |
|
|
921 |
if all_cells: |
|
|
922 |
anchor_t1_list.append(np.ravel(anchor_t1)) |
|
|
923 |
anchor_t2_list.append(np.ravel(anchor_t2)) |
|
|
924 |
if not all_cells: |
|
|
925 |
for j in range(1, switch): |
|
|
926 |
crit1 = ((t1 > t_sw_array[j-1]) and (t2 > t_sw_array[j-1]) |
|
|
927 |
and (t1 <= t_sw_array[j]) |
|
|
928 |
and (t2 <= t_sw_array[j])) |
|
|
929 |
crit2 = ((np.abs(anchor_t1[2] - exp_sw_list[j][0, 2]) |
|
|
930 |
< 0.02 * max_s) and |
|
|
931 |
(np.abs(anchor_t2[2] - exp_sw_list[j][0, 2]) |
|
|
932 |
< 0.01 * max_s)) |
|
|
933 |
crit3 = ((np.abs(anchor_t1[1] - exp_sw_list[j][0, 1]) |
|
|
934 |
< 0.02 * max_u) and |
|
|
935 |
(np.abs(anchor_t2[1] - exp_sw_list[j][0, 1]) |
|
|
936 |
< 0.01 * max_u)) |
|
|
937 |
crit4 = ((np.abs(anchor_t1[0] - exp_sw_list[j][0, 0]) |
|
|
938 |
< 0.02 * max_c) and |
|
|
939 |
(np.abs(anchor_t2[0] - exp_sw_list[j][0, 0]) |
|
|
940 |
< 0.01 * max_c)) |
|
|
941 |
if crit1 and crit2 and crit3 and crit4: |
|
|
942 |
t_sw_adjust[j] += t2 - t1 |
|
|
943 |
if penalize_gap: |
|
|
944 |
dist_gap = np.sum(((anchor_t1[1:] - anchor_t2[1:]) / |
|
|
945 |
scale_factor[1:])**2) |
|
|
946 |
idx_to_adjust = t_pred >= t2 |
|
|
947 |
t_sw_array_ = np.append(t_sw_array, total_h) |
|
|
948 |
state_to_adjust = np.where(t_sw_array_ > t2)[0] |
|
|
949 |
dists[np.ix_(idx_to_adjust, state_to_adjust)] += dist_gap |
|
|
950 |
min_dist = np.min(dists, axis=1) |
|
|
951 |
state_pred = np.argmin(dists, axis=1) |
|
|
952 |
if all_cells: |
|
|
953 |
t_pred = ts[np.arange(n), state_pred] |
|
|
954 |
|
|
|
955 |
if all_cells: |
|
|
956 |
exp_ss_mat = compute_ss_exp(alpha_c, alpha, beta, gamma, model=model) |
|
|
957 |
if rna_only: |
|
|
958 |
exp_ss_mat[:, 0] = 1 |
|
|
959 |
dists_ss = pairwise_distance_square(exp_mat, exp_ss_mat * |
|
|
960 |
rescale_factor / scale_factor) |
|
|
961 |
|
|
|
962 |
reach_ss = np.full((n, 4), False) |
|
|
963 |
for i in range(n): |
|
|
964 |
for j in range(4): |
|
|
965 |
if min_dist[i] > dists_ss[i, j]: |
|
|
966 |
reach_ss[i, j] = True |
|
|
967 |
late_phase = np.full(n, -1) |
|
|
968 |
for i in range(3): |
|
|
969 |
late_phase[np.abs(t_pred - t_sw_array[i]) < 0.1] = i |
|
|
970 |
return min_dist, t_pred, state_pred, reach_ss, late_phase, max_u, \ |
|
|
971 |
max_s, anchor_t1_list, anchor_t2_list |
|
|
972 |
else: |
|
|
973 |
return min_dist, state_pred, max_u, max_s, t_sw_adjust |
|
|
974 |
|
|
|
975 |
|
|
|
976 |
def t_of_c(alpha_c, k_c, c_o, c, rescale_factor, sw_t): |
|
|
977 |
|
|
|
978 |
coef = -float(1)/alpha_c |
|
|
979 |
|
|
|
980 |
c_val = np.clip(c / rescale_factor, a_min=0, a_max=1) |
|
|
981 |
|
|
|
982 |
in_log = (float(k_c) - c_val) / float((k_c) - (c_o)) |
|
|
983 |
|
|
|
984 |
epsilon = 1e-9 |
|
|
985 |
|
|
|
986 |
return_val = coef * np.log(in_log + epsilon) |
|
|
987 |
|
|
|
988 |
if k_c == 0: |
|
|
989 |
return_val += sw_t |
|
|
990 |
|
|
|
991 |
return return_val |
|
|
992 |
|
|
|
993 |
|
|
|
994 |
def make_X(c, u, s, |
|
|
995 |
max_u, |
|
|
996 |
max_s, |
|
|
997 |
alpha_c, alpha, beta, gamma, |
|
|
998 |
gene_sw_t, |
|
|
999 |
c0, c_sw1, c_sw2, c_sw3, |
|
|
1000 |
u0, u_sw1, u_sw2, u_sw3, |
|
|
1001 |
s0, s_sw1, s_sw2, s_sw3, |
|
|
1002 |
model, direction, state): |
|
|
1003 |
|
|
|
1004 |
if direction == "complete": |
|
|
1005 |
dire = 0 |
|
|
1006 |
elif direction == "on": |
|
|
1007 |
dire = 1 |
|
|
1008 |
elif direction == "off": |
|
|
1009 |
dire = 2 |
|
|
1010 |
|
|
|
1011 |
n = c.shape[0] |
|
|
1012 |
|
|
|
1013 |
epsilon = 1e-5 |
|
|
1014 |
|
|
|
1015 |
if dire == 0: |
|
|
1016 |
x = np.concatenate((np.array([c, |
|
|
1017 |
np.log(u + epsilon), |
|
|
1018 |
np.log(s + epsilon)]), |
|
|
1019 |
np.full((n, 17), [np.log(alpha_c + epsilon), |
|
|
1020 |
np.log(alpha + epsilon), |
|
|
1021 |
np.log(beta + epsilon), |
|
|
1022 |
np.log(gamma + epsilon), |
|
|
1023 |
c_sw1, c_sw2, c_sw3, |
|
|
1024 |
np.log(u_sw2 + epsilon), |
|
|
1025 |
np.log(u_sw3 + epsilon), |
|
|
1026 |
np.log(s_sw2 + epsilon), |
|
|
1027 |
np.log(s_sw3 + epsilon), |
|
|
1028 |
np.log(max_u), |
|
|
1029 |
np.log(max_s), |
|
|
1030 |
gene_sw_t[0], |
|
|
1031 |
gene_sw_t[1], |
|
|
1032 |
gene_sw_t[2], |
|
|
1033 |
model]).T, |
|
|
1034 |
np.full((n, 1), state).T |
|
|
1035 |
)).T.astype(np.float32) |
|
|
1036 |
|
|
|
1037 |
elif dire == 1: |
|
|
1038 |
x = np.concatenate((np.array([c, |
|
|
1039 |
np.log(u + epsilon), |
|
|
1040 |
np.log(s + epsilon)]), |
|
|
1041 |
np.full((n, 12), [np.log(alpha_c + epsilon), |
|
|
1042 |
np.log(alpha + epsilon), |
|
|
1043 |
np.log(beta + epsilon), |
|
|
1044 |
np.log(gamma + epsilon), |
|
|
1045 |
c_sw1, c_sw2, |
|
|
1046 |
np.log(u_sw1 + epsilon), |
|
|
1047 |
np.log(u_sw2 + epsilon), |
|
|
1048 |
np.log(s_sw1 + epsilon), |
|
|
1049 |
np.log(s_sw2 + epsilon), |
|
|
1050 |
gene_sw_t[0], |
|
|
1051 |
model]).T, |
|
|
1052 |
np.full((n, 1), state).T |
|
|
1053 |
)).T.astype(np.float32) |
|
|
1054 |
|
|
|
1055 |
elif dire == 2: |
|
|
1056 |
if model == 1: |
|
|
1057 |
|
|
|
1058 |
max_u_t = -(float(1)/alpha_c)*np.log((max_u*beta) |
|
|
1059 |
/ (alpha*c0[2])) |
|
|
1060 |
|
|
|
1061 |
x = np.concatenate((np.array([np.log(c + epsilon), |
|
|
1062 |
np.log(u + epsilon), |
|
|
1063 |
np.log(s + epsilon)]), |
|
|
1064 |
np.full((n, 14), [np.log(alpha_c + epsilon), |
|
|
1065 |
np.log(alpha + epsilon), |
|
|
1066 |
np.log(beta + epsilon), |
|
|
1067 |
np.log(gamma + epsilon), |
|
|
1068 |
c_sw2, c_sw3, |
|
|
1069 |
np.log(u_sw2 + epsilon), |
|
|
1070 |
np.log(u_sw3 + epsilon), |
|
|
1071 |
np.log(s_sw2 + epsilon), |
|
|
1072 |
np.log(s_sw3 + epsilon), |
|
|
1073 |
max_u_t, |
|
|
1074 |
np.log(max_u), |
|
|
1075 |
np.log(max_s), |
|
|
1076 |
gene_sw_t[2]]).T, |
|
|
1077 |
np.full((n, 1), state).T |
|
|
1078 |
)).T.astype(np.float32) |
|
|
1079 |
elif model == 2: |
|
|
1080 |
x = np.concatenate((np.array([c, |
|
|
1081 |
np.log(u + epsilon), |
|
|
1082 |
np.log(s + epsilon)]), |
|
|
1083 |
np.full((n, 12), [np.log(alpha_c + epsilon), |
|
|
1084 |
np.log(alpha + epsilon), |
|
|
1085 |
np.log(beta + epsilon), |
|
|
1086 |
np.log(gamma + epsilon), |
|
|
1087 |
c_sw2, c_sw3, |
|
|
1088 |
np.log(u_sw2 + epsilon), |
|
|
1089 |
np.log(u_sw3 + epsilon), |
|
|
1090 |
np.log(s_sw2 + epsilon), |
|
|
1091 |
np.log(s_sw3 + epsilon), |
|
|
1092 |
np.log(max_u), |
|
|
1093 |
gene_sw_t[2]]).T, |
|
|
1094 |
np.full((n, 1), state).T |
|
|
1095 |
)).T.astype(np.float32) |
|
|
1096 |
|
|
|
1097 |
return x |
|
|
1098 |
|
|
|
1099 |
|
|
|
1100 |
def calculate_dist_and_time_nn(c, u, s, |
|
|
1101 |
max_u, max_s, |
|
|
1102 |
t_sw_array, |
|
|
1103 |
alpha_c, alpha, beta, gamma, |
|
|
1104 |
rescale_c, rescale_u, |
|
|
1105 |
ode_model_0, ode_model_1, |
|
|
1106 |
ode_model_2_m1, ode_model_2_m2, |
|
|
1107 |
device, |
|
|
1108 |
scale_cc=1, |
|
|
1109 |
scale_factor=None, |
|
|
1110 |
model=1, |
|
|
1111 |
conn=None, |
|
|
1112 |
t=1000, k=1, |
|
|
1113 |
direction='complete', |
|
|
1114 |
total_h=20, |
|
|
1115 |
rna_only=False, |
|
|
1116 |
penalize_gap=True, |
|
|
1117 |
all_cells=True): |
|
|
1118 |
|
|
|
1119 |
rescale_factor = np.array([rescale_c, rescale_u, 1.0]) |
|
|
1120 |
|
|
|
1121 |
exp_list_net, exp_sw_list_net = generate_exp(None, |
|
|
1122 |
t_sw_array, |
|
|
1123 |
alpha_c, |
|
|
1124 |
alpha, |
|
|
1125 |
beta, |
|
|
1126 |
gamma, |
|
|
1127 |
model=model, |
|
|
1128 |
scale_cc=scale_cc, |
|
|
1129 |
rna_only=rna_only) |
|
|
1130 |
|
|
|
1131 |
N = len(c) |
|
|
1132 |
N_list = np.arange(N) |
|
|
1133 |
|
|
|
1134 |
if scale_factor is None: |
|
|
1135 |
cur_scale_factor = np.array([np.std(c), |
|
|
1136 |
np.std(u), |
|
|
1137 |
np.std(s)]) |
|
|
1138 |
else: |
|
|
1139 |
cur_scale_factor = scale_factor |
|
|
1140 |
|
|
|
1141 |
t_pred_per_state = [] |
|
|
1142 |
dists_per_state = [] |
|
|
1143 |
|
|
|
1144 |
dire = 0 |
|
|
1145 |
|
|
|
1146 |
if direction == "on": |
|
|
1147 |
states = [0, 1] |
|
|
1148 |
dire = 1 |
|
|
1149 |
|
|
|
1150 |
elif direction == "off": |
|
|
1151 |
states = [2, 3] |
|
|
1152 |
dire = 2 |
|
|
1153 |
|
|
|
1154 |
else: |
|
|
1155 |
states = [0, 1, 2, 3] |
|
|
1156 |
dire = 0 |
|
|
1157 |
|
|
|
1158 |
dists_per_state = np.zeros((N, len(states))) |
|
|
1159 |
t_pred_per_state = np.zeros((N, len(states))) |
|
|
1160 |
u_pred_per_state = np.zeros((N, len(states))) |
|
|
1161 |
s_pred_per_state = np.zeros((N, len(states))) |
|
|
1162 |
|
|
|
1163 |
increment = 0 |
|
|
1164 |
|
|
|
1165 |
# determine when we can consider u and s close to zero |
|
|
1166 |
zero_us = np.logical_and((u < 0.1 * max_u), (s < 0.1 * max_s)) |
|
|
1167 |
|
|
|
1168 |
t_pred = np.zeros(N) |
|
|
1169 |
dists = None |
|
|
1170 |
|
|
|
1171 |
# pass all the data through the neural net as each valid state |
|
|
1172 |
for state in states: |
|
|
1173 |
|
|
|
1174 |
# when u and s = 0, it's better to use the inverse c equation |
|
|
1175 |
# instead of the neural network, which happens for part of |
|
|
1176 |
# state 3 and all of state 0 |
|
|
1177 |
inverse_c = np.logical_or(state == 0, |
|
|
1178 |
np.logical_and(state == 3, zero_us)) |
|
|
1179 |
|
|
|
1180 |
not_inverse_c = np.logical_not(inverse_c) |
|
|
1181 |
|
|
|
1182 |
# if we want to use the inverse c equation... |
|
|
1183 |
if np.any(inverse_c): |
|
|
1184 |
|
|
|
1185 |
# find out at what switch time chromatin closes |
|
|
1186 |
c_sw_t = t_sw_array[int(model)] |
|
|
1187 |
|
|
|
1188 |
# figure out whether chromatin is opening/closing and what |
|
|
1189 |
# the initial c value is |
|
|
1190 |
if state <= model: |
|
|
1191 |
k_c = 1 |
|
|
1192 |
c_0_for_t_guess = 0 |
|
|
1193 |
elif state > model: |
|
|
1194 |
k_c = 0 |
|
|
1195 |
c_0_for_t_guess = exp_sw_list_net[int(model)][0, 0] |
|
|
1196 |
|
|
|
1197 |
# calculate predicted time from the inverse c equation |
|
|
1198 |
t_pred[inverse_c] = t_of_c(alpha_c, |
|
|
1199 |
k_c, c_0_for_t_guess, |
|
|
1200 |
c[inverse_c], |
|
|
1201 |
rescale_factor[0], |
|
|
1202 |
c_sw_t) |
|
|
1203 |
|
|
|
1204 |
# if there are points where we want to use the neural network... |
|
|
1205 |
if np.any(not_inverse_c): |
|
|
1206 |
|
|
|
1207 |
# create an input matrix from the data |
|
|
1208 |
x = make_X(c[not_inverse_c] / rescale_factor[0], |
|
|
1209 |
u[not_inverse_c] / rescale_factor[1], |
|
|
1210 |
s[not_inverse_c] / rescale_factor[2], |
|
|
1211 |
max_u, |
|
|
1212 |
max_s, |
|
|
1213 |
alpha_c*(scale_cc if state > model else 1), |
|
|
1214 |
alpha, beta, gamma, |
|
|
1215 |
t_sw_array, |
|
|
1216 |
0, |
|
|
1217 |
exp_sw_list_net[0][0, 0], |
|
|
1218 |
exp_sw_list_net[1][0, 0], |
|
|
1219 |
exp_sw_list_net[2][0, 0], |
|
|
1220 |
0, |
|
|
1221 |
exp_sw_list_net[0][0, 1], |
|
|
1222 |
exp_sw_list_net[1][0, 1], |
|
|
1223 |
exp_sw_list_net[2][0, 1], |
|
|
1224 |
0, |
|
|
1225 |
exp_sw_list_net[0][0, 2], |
|
|
1226 |
exp_sw_list_net[1][0, 2], |
|
|
1227 |
exp_sw_list_net[2][0, 2], |
|
|
1228 |
model, direction, state) |
|
|
1229 |
|
|
|
1230 |
# do a forward pass |
|
|
1231 |
if dire == 0: |
|
|
1232 |
t_pred_ten = ode_model_0(torch.tensor(x, |
|
|
1233 |
dtype=torch.float, |
|
|
1234 |
device=device) |
|
|
1235 |
.reshape(-1, x.shape[1])) |
|
|
1236 |
|
|
|
1237 |
elif dire == 1: |
|
|
1238 |
t_pred_ten = ode_model_1(torch.tensor(x, |
|
|
1239 |
dtype=torch.float, |
|
|
1240 |
device=device) |
|
|
1241 |
.reshape(-1, x.shape[1])) |
|
|
1242 |
|
|
|
1243 |
elif dire == 2: |
|
|
1244 |
if model == 1: |
|
|
1245 |
t_pred_ten = ode_model_2_m1(torch.tensor(x, |
|
|
1246 |
dtype=torch.float, |
|
|
1247 |
device=device) |
|
|
1248 |
.reshape(-1, x.shape[1])) |
|
|
1249 |
elif model == 2: |
|
|
1250 |
t_pred_ten = ode_model_2_m2(torch.tensor(x, |
|
|
1251 |
dtype=torch.float, |
|
|
1252 |
device=device) |
|
|
1253 |
.reshape(-1, x.shape[1])) |
|
|
1254 |
|
|
|
1255 |
# make a numpy array out of our tensor of predicted time points |
|
|
1256 |
t_pred[not_inverse_c] = (t_pred_ten.cpu().detach().numpy() |
|
|
1257 |
.flatten()*21) - 1 |
|
|
1258 |
|
|
|
1259 |
# calculate tau values from our predicted time points |
|
|
1260 |
if state == 0: |
|
|
1261 |
t_pred = np.clip(t_pred, a_min=0, a_max=t_sw_array[0]) |
|
|
1262 |
tau1 = t_pred |
|
|
1263 |
tau2 = [] |
|
|
1264 |
tau3 = [] |
|
|
1265 |
tau4 = [] |
|
|
1266 |
elif state == 1: |
|
|
1267 |
tau1 = [] |
|
|
1268 |
t_pred = np.clip(t_pred, a_min=t_sw_array[0], a_max=t_sw_array[1]) |
|
|
1269 |
tau2 = t_pred - t_sw_array[0] |
|
|
1270 |
tau3 = [] |
|
|
1271 |
tau4 = [] |
|
|
1272 |
elif state == 2: |
|
|
1273 |
tau1 = [] |
|
|
1274 |
tau2 = [] |
|
|
1275 |
t_pred = np.clip(t_pred, a_min=t_sw_array[1], a_max=t_sw_array[2]) |
|
|
1276 |
tau3 = t_pred - t_sw_array[1] |
|
|
1277 |
tau4 = [] |
|
|
1278 |
elif state == 3: |
|
|
1279 |
tau1 = [] |
|
|
1280 |
tau2 = [] |
|
|
1281 |
tau3 = [] |
|
|
1282 |
t_pred = np.clip(t_pred, a_min=t_sw_array[2], a_max=20) |
|
|
1283 |
tau4 = t_pred - t_sw_array[2] |
|
|
1284 |
|
|
|
1285 |
tau_list = [tau1, tau2, tau3, tau4] |
|
|
1286 |
|
|
|
1287 |
valid_vals = [] |
|
|
1288 |
|
|
|
1289 |
for i in range(len(tau_list)): |
|
|
1290 |
if len(tau_list[i]) == 0: |
|
|
1291 |
tau_list[i] = np.array([0.0]) |
|
|
1292 |
else: |
|
|
1293 |
valid_vals.append(i) |
|
|
1294 |
|
|
|
1295 |
# take the time points and get predicted c/u/s values from them |
|
|
1296 |
exp_list, exp_sw_list_2 = generate_exp(tau_list, |
|
|
1297 |
t_sw_array, |
|
|
1298 |
alpha_c, |
|
|
1299 |
alpha, |
|
|
1300 |
beta, |
|
|
1301 |
gamma, |
|
|
1302 |
model=model, |
|
|
1303 |
scale_cc=scale_cc, |
|
|
1304 |
rna_only=rna_only) |
|
|
1305 |
|
|
|
1306 |
pred_c = np.concatenate([exp_list[x][:, 0] * rescale_factor[0] |
|
|
1307 |
for x in valid_vals]) |
|
|
1308 |
pred_u = np.concatenate([exp_list[x][:, 1] * rescale_factor[1] |
|
|
1309 |
for x in valid_vals]) |
|
|
1310 |
pred_s = np.concatenate([exp_list[x][:, 2] * rescale_factor[2] |
|
|
1311 |
for x in valid_vals]) |
|
|
1312 |
|
|
|
1313 |
# calculate distance between predicted and real values |
|
|
1314 |
c_diff = (c - pred_c) / cur_scale_factor[0] |
|
|
1315 |
u_diff = (u - pred_u) / cur_scale_factor[1] |
|
|
1316 |
s_diff = (s - pred_s) / cur_scale_factor[2] |
|
|
1317 |
|
|
|
1318 |
dists = (c_diff*c_diff) + (u_diff*u_diff) + (s_diff*s_diff) |
|
|
1319 |
|
|
|
1320 |
if conn is not None: |
|
|
1321 |
dists = conn.dot(dists) |
|
|
1322 |
|
|
|
1323 |
# store the distances, times, and predicted u and s values for |
|
|
1324 |
# each state |
|
|
1325 |
dists_per_state[:, increment] = dists |
|
|
1326 |
t_pred_per_state[:, increment] = t_pred |
|
|
1327 |
u_pred_per_state[:, increment] = pred_u |
|
|
1328 |
s_pred_per_state[:, increment] = pred_s |
|
|
1329 |
|
|
|
1330 |
increment += 1 |
|
|
1331 |
|
|
|
1332 |
# whichever state has the smallest distance for a given data point |
|
|
1333 |
# is our predicted state |
|
|
1334 |
state_pred = np.argmin(dists_per_state, axis=1) |
|
|
1335 |
|
|
|
1336 |
# slice dists and predicted time over the correct state |
|
|
1337 |
dists = dists_per_state[N_list, state_pred] |
|
|
1338 |
t_pred = t_pred_per_state[N_list, state_pred] |
|
|
1339 |
|
|
|
1340 |
max_t = t_pred.max() |
|
|
1341 |
min_t = t_pred.min() |
|
|
1342 |
|
|
|
1343 |
penalty = 0 |
|
|
1344 |
|
|
|
1345 |
# for induction and complete genes, add a penalty to ensure that not |
|
|
1346 |
# all points are in state 0 |
|
|
1347 |
if direction == "on" or direction == "complete": |
|
|
1348 |
|
|
|
1349 |
if t_sw_array[0] >= max_t: |
|
|
1350 |
penalty += (t_sw_array[0] - max_t) + 10 |
|
|
1351 |
|
|
|
1352 |
# for induction genes, add a penalty to ensure that predicted time |
|
|
1353 |
# points are not "out of bounds" by being greater than the |
|
|
1354 |
# second switch time |
|
|
1355 |
if direction == "on": |
|
|
1356 |
|
|
|
1357 |
if min_t > t_sw_array[1]: |
|
|
1358 |
penalty += (min_t - t_sw_array[1]) + 10 |
|
|
1359 |
|
|
|
1360 |
# for repression genes, add a penalty to ensure that predicted time |
|
|
1361 |
# points are not "out of bounds" by being smaller than the |
|
|
1362 |
# second switch time |
|
|
1363 |
if direction == "off": |
|
|
1364 |
|
|
|
1365 |
if t_sw_array[1] >= max_t: |
|
|
1366 |
penalty += (t_sw_array[1] - max_t) + 10 |
|
|
1367 |
|
|
|
1368 |
# add penalty to ensure that the time points aren't concentrated to |
|
|
1369 |
# one spot |
|
|
1370 |
if np.abs(max_t - min_t) <= 1e-2: |
|
|
1371 |
penalty += np.abs(max_t - min_t) + 10 |
|
|
1372 |
|
|
|
1373 |
# because the indices chosen by np.argmin are just indices, |
|
|
1374 |
# we need to increment by two to get the true state number for |
|
|
1375 |
# our "off" genes (e.g. so that they're in the domain of [2,3] instead |
|
|
1376 |
# of [0,1]) |
|
|
1377 |
if direction == "off": |
|
|
1378 |
state_pred += 2 |
|
|
1379 |
|
|
|
1380 |
if all_cells: |
|
|
1381 |
return dists, t_pred, state_pred, max_u, max_s, penalty |
|
|
1382 |
else: |
|
|
1383 |
return dists, state_pred, max_u, max_s, penalty |
|
|
1384 |
|
|
|
1385 |
|
|
|
1386 |
# @jit(nopython=True, fastmath=True) |
|
|
1387 |
def compute_likelihood(c, u, s, |
|
|
1388 |
t_sw_array, |
|
|
1389 |
alpha_c, alpha, beta, gamma, |
|
|
1390 |
rescale_c, rescale_u, |
|
|
1391 |
t_pred, |
|
|
1392 |
state_pred, |
|
|
1393 |
scale_cc=1, |
|
|
1394 |
scale_factor=None, |
|
|
1395 |
model=1, |
|
|
1396 |
weight=None, |
|
|
1397 |
total_h=20, |
|
|
1398 |
rna_only=False): |
|
|
1399 |
|
|
|
1400 |
if weight is None: |
|
|
1401 |
weight = np.full(c.shape, True) |
|
|
1402 |
c_ = c[weight] |
|
|
1403 |
u_ = u[weight] |
|
|
1404 |
s_ = s[weight] |
|
|
1405 |
t_pred_ = t_pred[weight] |
|
|
1406 |
state_pred_ = state_pred[weight] |
|
|
1407 |
|
|
|
1408 |
n = len(u_) |
|
|
1409 |
if scale_factor is None: |
|
|
1410 |
scale_factor = np.ones(3) |
|
|
1411 |
tau1 = t_pred_[state_pred_ == 0] |
|
|
1412 |
tau2 = t_pred_[state_pred_ == 1] - t_sw_array[0] |
|
|
1413 |
tau3 = t_pred_[state_pred_ == 2] - t_sw_array[1] |
|
|
1414 |
tau4 = t_pred_[state_pred_ == 3] - t_sw_array[2] |
|
|
1415 |
tau_list = [tau1, tau2, tau3, tau4] |
|
|
1416 |
switch = np.sum(t_sw_array < total_h) |
|
|
1417 |
typed_tau_list = List() |
|
|
1418 |
[typed_tau_list.append(x) for x in tau_list] |
|
|
1419 |
alpha_c, alpha, beta, gamma = check_params(alpha_c, alpha, beta, gamma) |
|
|
1420 |
exp_list, _ = generate_exp(typed_tau_list, |
|
|
1421 |
t_sw_array[:switch], |
|
|
1422 |
alpha_c, |
|
|
1423 |
alpha, |
|
|
1424 |
beta, |
|
|
1425 |
gamma, |
|
|
1426 |
model=model, |
|
|
1427 |
scale_cc=scale_cc, |
|
|
1428 |
rna_only=rna_only) |
|
|
1429 |
rescale_factor = np.array([rescale_c, rescale_u, 1.0]) |
|
|
1430 |
exp_list = [x*rescale_factor*scale_factor for x in exp_list] |
|
|
1431 |
exp_mat = np.hstack((np.reshape(c_, (-1, 1)), np.reshape(u_, (-1, 1)), |
|
|
1432 |
np.reshape(s_, (-1, 1)))) * scale_factor |
|
|
1433 |
diffs = np.empty((n, 3), dtype=u.dtype) |
|
|
1434 |
likelihood_c = 0 |
|
|
1435 |
likelihood_u = 0 |
|
|
1436 |
likelihood_s = 0 |
|
|
1437 |
ssd_c, var_c = 0, 0 |
|
|
1438 |
for i in range(switch+1): |
|
|
1439 |
index = state_pred_ == i |
|
|
1440 |
if np.sum(index) > 0: |
|
|
1441 |
diff = exp_mat[index, :] - exp_list[i] |
|
|
1442 |
diffs[index, :] = diff |
|
|
1443 |
if rna_only: |
|
|
1444 |
diff_u = np.ravel(diffs[:, 0]) |
|
|
1445 |
diff_s = np.ravel(diffs[:, 1]) |
|
|
1446 |
dist_us = diff_u ** 2 + diff_s ** 2 |
|
|
1447 |
var_us = np.var(np.sign(diff_s) * np.sqrt(dist_us)) |
|
|
1448 |
nll = (0.5 * np.log(2 * np.pi * var_us) + 0.5 / n / |
|
|
1449 |
var_us * np.sum(dist_us)) |
|
|
1450 |
else: |
|
|
1451 |
diff_c = np.ravel(diffs[:, 0]) |
|
|
1452 |
diff_u = np.ravel(diffs[:, 1]) |
|
|
1453 |
diff_s = np.ravel(diffs[:, 2]) |
|
|
1454 |
dist_c = diff_c ** 2 |
|
|
1455 |
dist_u = diff_u ** 2 |
|
|
1456 |
dist_s = diff_s ** 2 |
|
|
1457 |
var_c = np.var(diff_c) |
|
|
1458 |
var_u = np.var(diff_u) |
|
|
1459 |
var_s = np.var(diff_s) |
|
|
1460 |
ssd_c = np.sum(dist_c) |
|
|
1461 |
nll_c = (0.5 * np.log(2 * np.pi * var_c) + 0.5 / n / |
|
|
1462 |
var_c * np.sum(dist_c)) |
|
|
1463 |
nll_u = (0.5 * np.log(2 * np.pi * var_u) + 0.5 / n / |
|
|
1464 |
var_u * np.sum(dist_u)) |
|
|
1465 |
nll_s = (0.5 * np.log(2 * np.pi * var_s) + 0.5 / n / |
|
|
1466 |
var_s * np.sum(dist_s)) |
|
|
1467 |
nll = nll_c + nll_u + nll_s |
|
|
1468 |
likelihood_c = np.exp(-nll_c) |
|
|
1469 |
likelihood_u = np.exp(-nll_u) |
|
|
1470 |
likelihood_s = np.exp(-nll_s) |
|
|
1471 |
likelihood = np.exp(-nll) |
|
|
1472 |
return likelihood, likelihood_c, ssd_c, var_c, likelihood_u, likelihood_s |
|
|
1473 |
|
|
|
1474 |
|
|
|
1475 |
class ChromatinDynamical: |
|
|
1476 |
def __init__(self, c, u, s, |
|
|
1477 |
gene=None, |
|
|
1478 |
model=None, |
|
|
1479 |
max_iter=10, |
|
|
1480 |
init_mode="grid", |
|
|
1481 |
device="cpu", |
|
|
1482 |
neural_net=False, |
|
|
1483 |
adam=False, |
|
|
1484 |
adam_lr=None, |
|
|
1485 |
adam_beta1=None, |
|
|
1486 |
adam_beta2=None, |
|
|
1487 |
batch_size=None, |
|
|
1488 |
local_std=None, |
|
|
1489 |
embed_coord=None, |
|
|
1490 |
connectivities=None, |
|
|
1491 |
plot=False, |
|
|
1492 |
save_plot=False, |
|
|
1493 |
plot_dir=None, |
|
|
1494 |
fit_args=None, |
|
|
1495 |
partial=None, |
|
|
1496 |
direction=None, |
|
|
1497 |
rna_only=False, |
|
|
1498 |
fit_decoupling=True, |
|
|
1499 |
extra_color=None, |
|
|
1500 |
rescale_u=None, |
|
|
1501 |
alpha=None, |
|
|
1502 |
beta=None, |
|
|
1503 |
gamma=None, |
|
|
1504 |
t_=None |
|
|
1505 |
): |
|
|
1506 |
|
|
|
1507 |
self.device = device |
|
|
1508 |
self.gene = gene |
|
|
1509 |
self.local_std = local_std |
|
|
1510 |
self.conn = connectivities |
|
|
1511 |
|
|
|
1512 |
self.neural_net = neural_net |
|
|
1513 |
self.adam = adam |
|
|
1514 |
self.adam_lr = adam_lr |
|
|
1515 |
self.adam_beta1 = adam_beta1 |
|
|
1516 |
self.adam_beta2 = adam_beta2 |
|
|
1517 |
self.batch_size = batch_size |
|
|
1518 |
|
|
|
1519 |
self.torch_type = type(u[0].item()) |
|
|
1520 |
|
|
|
1521 |
# fitting arguments |
|
|
1522 |
self.init_mode = init_mode |
|
|
1523 |
self.rna_only = rna_only |
|
|
1524 |
self.fit_decoupling = fit_decoupling |
|
|
1525 |
self.max_iter = max_iter |
|
|
1526 |
self.n_anchors = np.clip(int(fit_args['t']), 201, 2000) |
|
|
1527 |
self.k_dist = np.clip(int(fit_args['k']), 1, 20) |
|
|
1528 |
self.tm = np.clip(fit_args['thresh_multiplier'], 0.4, 2) |
|
|
1529 |
self.weight_c = np.clip(fit_args['weight_c'], 0.1, 5) |
|
|
1530 |
self.outlier = np.clip(fit_args['outlier'], 80, 100) |
|
|
1531 |
self.model = int(model) if isinstance(model, float) else model |
|
|
1532 |
self.model_ = None |
|
|
1533 |
if self.model == 0 and self.init_mode == 'invert': |
|
|
1534 |
self.init_mode = 'grid' |
|
|
1535 |
|
|
|
1536 |
# plot parameters |
|
|
1537 |
self.plot = plot |
|
|
1538 |
self.save_plot = save_plot |
|
|
1539 |
self.extra_color = extra_color |
|
|
1540 |
self.fig_size = fit_args['fig_size'] |
|
|
1541 |
self.point_size = fit_args['point_size'] |
|
|
1542 |
if plot_dir is None: |
|
|
1543 |
self.plot_path = 'rna_plots' if self.rna_only else 'plots' |
|
|
1544 |
else: |
|
|
1545 |
self.plot_path = plot_dir |
|
|
1546 |
self.color = ['tab:red', 'tab:orange', 'tab:green', 'tab:blue'] |
|
|
1547 |
self.fig = None |
|
|
1548 |
self.ax = None |
|
|
1549 |
|
|
|
1550 |
# input |
|
|
1551 |
self.total_n = len(u) |
|
|
1552 |
if sparse.issparse(c): |
|
|
1553 |
c = c.A |
|
|
1554 |
if sparse.issparse(u): |
|
|
1555 |
u = u.A |
|
|
1556 |
if sparse.issparse(s): |
|
|
1557 |
s = s.A |
|
|
1558 |
self.c_all = np.ravel(np.array(c, dtype=np.float64)) |
|
|
1559 |
self.u_all = np.ravel(np.array(u, dtype=np.float64)) |
|
|
1560 |
self.s_all = np.ravel(np.array(s, dtype=np.float64)) |
|
|
1561 |
|
|
|
1562 |
# adjust offset |
|
|
1563 |
self.offset_c, self.offset_u, self.offset_s = np.min(self.c_all), \ |
|
|
1564 |
np.min(self.u_all), np.min(self.s_all) |
|
|
1565 |
self.offset_c = 0 if self.rna_only else self.offset_c |
|
|
1566 |
self.c_all -= self.offset_c |
|
|
1567 |
self.u_all -= self.offset_u |
|
|
1568 |
self.s_all -= self.offset_s |
|
|
1569 |
# remove zero counts |
|
|
1570 |
self.non_zero = (np.ravel(self.c_all > 0) | np.ravel(self.u_all > 0) | |
|
|
1571 |
np.ravel(self.s_all > 0)) |
|
|
1572 |
# remove outliers |
|
|
1573 |
self.non_outlier = np.ravel(self.c_all <= np.percentile(self.c_all, |
|
|
1574 |
self.outlier)) |
|
|
1575 |
self.non_outlier &= np.ravel(self.u_all <= np.percentile(self.u_all, |
|
|
1576 |
self.outlier)) |
|
|
1577 |
self.non_outlier &= np.ravel(self.s_all <= np.percentile(self.s_all, |
|
|
1578 |
self.outlier)) |
|
|
1579 |
self.c = self.c_all[self.non_zero & self.non_outlier] |
|
|
1580 |
self.u = self.u_all[self.non_zero & self.non_outlier] |
|
|
1581 |
self.s = self.s_all[self.non_zero & self.non_outlier] |
|
|
1582 |
self.low_quality = len(self.u) < 10 |
|
|
1583 |
# scale modalities |
|
|
1584 |
self.std_c, self.std_u, self.std_s = (np.std(self.c_all) |
|
|
1585 |
if not self.rna_only |
|
|
1586 |
else 1.0, np.std(self.u_all), |
|
|
1587 |
np.std(self.s_all)) |
|
|
1588 |
if self.std_u == 0 or self.std_s == 0: |
|
|
1589 |
self.low_quality = True |
|
|
1590 |
self.scale_c, self.scale_u, self.scale_s = np.max(self.c_all) \ |
|
|
1591 |
if not self.rna_only else 1.0, self.std_u/self.std_s, 1.0 |
|
|
1592 |
|
|
|
1593 |
# if we're on neural net mode, check to see if c is way bigger than |
|
|
1594 |
# u or s, which would be very hard for the neural net to fit |
|
|
1595 |
if not self.low_quality and neural_net: |
|
|
1596 |
max_c_orig = np.max(self.c) |
|
|
1597 |
if max_c_orig / np.max(self.u) > 500: |
|
|
1598 |
self.low_quality = True |
|
|
1599 |
|
|
|
1600 |
if not self.low_quality: |
|
|
1601 |
if max_c_orig / np.max(self.s) > 500: |
|
|
1602 |
self.low_quality = True |
|
|
1603 |
|
|
|
1604 |
self.c_all /= self.scale_c |
|
|
1605 |
self.u_all /= self.scale_u |
|
|
1606 |
self.s_all /= self.scale_s |
|
|
1607 |
self.c /= self.scale_c |
|
|
1608 |
self.u /= self.scale_u |
|
|
1609 |
self.s /= self.scale_s |
|
|
1610 |
self.scale_factor = np.array([np.std(self.c_all) / self.std_s / |
|
|
1611 |
self.weight_c, 1.0, 1.0]) |
|
|
1612 |
self.scale_factor[0] = 1 if self.rna_only else self.scale_factor[0] |
|
|
1613 |
self.max_u, self.max_s = np.max(self.u), np.max(self.s) |
|
|
1614 |
self.max_u_all, self.max_s_all = np.max(self.u_all), np.max(self.s_all) |
|
|
1615 |
if self.conn is not None: |
|
|
1616 |
self.conn_sub = self.conn[np.ix_(self.non_zero & self.non_outlier, |
|
|
1617 |
self.non_zero & self.non_outlier)] |
|
|
1618 |
else: |
|
|
1619 |
self.conn_sub = None |
|
|
1620 |
|
|
|
1621 |
logg.update(f'{len(self.u)} cells passed filter and will be used to ' |
|
|
1622 |
'compute trajectories.', v=2) |
|
|
1623 |
self.known_pars = (True |
|
|
1624 |
if None not in [rescale_u, alpha, beta, gamma, t_] |
|
|
1625 |
else False) |
|
|
1626 |
if self.known_pars: |
|
|
1627 |
logg.update(f'known parameters for gene {self.gene} are ' |
|
|
1628 |
f'scaling={rescale_u}, alpha={alpha}, beta={beta},' |
|
|
1629 |
f' gamma={gamma}, t_={t_}.', v=1) |
|
|
1630 |
|
|
|
1631 |
# define neural networks |
|
|
1632 |
self.ode_model_0 = nn.Sequential( |
|
|
1633 |
nn.Linear(21, 150), |
|
|
1634 |
nn.ReLU(), |
|
|
1635 |
nn.Linear(150, 112), |
|
|
1636 |
nn.ReLU(), |
|
|
1637 |
nn.Linear(112, 75), |
|
|
1638 |
nn.ReLU(), |
|
|
1639 |
nn.Linear(75, 1), |
|
|
1640 |
nn.Sigmoid() |
|
|
1641 |
) |
|
|
1642 |
|
|
|
1643 |
self.ode_model_1 = nn.Sequential( |
|
|
1644 |
nn.Linear(16, 64), |
|
|
1645 |
nn.ReLU(), |
|
|
1646 |
nn.Linear(64, 48), |
|
|
1647 |
nn.ReLU(), |
|
|
1648 |
nn.Linear(48, 32), |
|
|
1649 |
nn.ReLU(), |
|
|
1650 |
nn.Linear(32, 1), |
|
|
1651 |
nn.Sigmoid() |
|
|
1652 |
) |
|
|
1653 |
|
|
|
1654 |
self.ode_model_2_m1 = nn.Sequential( |
|
|
1655 |
nn.Linear(18, 220), |
|
|
1656 |
nn.ReLU(), |
|
|
1657 |
nn.Linear(220, 165), |
|
|
1658 |
nn.ReLU(), |
|
|
1659 |
nn.Linear(165, 110), |
|
|
1660 |
nn.ReLU(), |
|
|
1661 |
nn.Linear(110, 1), |
|
|
1662 |
nn.Sigmoid() |
|
|
1663 |
) |
|
|
1664 |
|
|
|
1665 |
self.ode_model_2_m2 = nn.Sequential( |
|
|
1666 |
nn.Linear(16, 150), |
|
|
1667 |
nn.ReLU(), |
|
|
1668 |
nn.Linear(150, 112), |
|
|
1669 |
nn.ReLU(), |
|
|
1670 |
nn.Linear(112, 75), |
|
|
1671 |
nn.ReLU(), |
|
|
1672 |
nn.Linear(75, 1), |
|
|
1673 |
nn.Sigmoid() |
|
|
1674 |
) |
|
|
1675 |
|
|
|
1676 |
self.ode_model_0.to(torch.device(self.device)) |
|
|
1677 |
self.ode_model_1.to(torch.device(self.device)) |
|
|
1678 |
self.ode_model_2_m1.to(torch.device(self.device)) |
|
|
1679 |
self.ode_model_2_m2.to(torch.device(self.device)) |
|
|
1680 |
|
|
|
1681 |
# load in neural network |
|
|
1682 |
net_path = os.path.dirname(os.path.abspath(__file__)) + \ |
|
|
1683 |
"/neural_nets/" |
|
|
1684 |
|
|
|
1685 |
self.ode_model_0.load_state_dict(torch.load(net_path+"dir0.pt")) |
|
|
1686 |
self.ode_model_1.load_state_dict(torch.load(net_path+"dir1.pt")) |
|
|
1687 |
self.ode_model_2_m1.load_state_dict(torch.load(net_path+"dir2_m1.pt")) |
|
|
1688 |
self.ode_model_2_m2.load_state_dict(torch.load(net_path+"dir2_m2.pt")) |
|
|
1689 |
|
|
|
1690 |
# 4 rate parameters |
|
|
1691 |
self.alpha_c = 0.1 |
|
|
1692 |
self.alpha = alpha if alpha is not None else 0.0 |
|
|
1693 |
self.beta = beta if beta is not None else 0.0 |
|
|
1694 |
self.gamma = gamma if gamma is not None else 0.0 |
|
|
1695 |
# 3 possible switch time points |
|
|
1696 |
self.t_sw_1 = 0.1 if t_ is not None else 0.0 |
|
|
1697 |
self.t_sw_2 = t_+0.1 if t_ is not None else 0.0 |
|
|
1698 |
self.t_sw_3 = 20.0 if t_ is not None else 0.0 |
|
|
1699 |
# 2 rescale factors |
|
|
1700 |
self.rescale_c = 1.0 |
|
|
1701 |
self.rescale_u = rescale_u if rescale_u is not None else 1.0 |
|
|
1702 |
self.rates = None |
|
|
1703 |
self.t_sw_array = None |
|
|
1704 |
self.fit_rescale = True if rescale_u is None else False |
|
|
1705 |
self.params = None |
|
|
1706 |
|
|
|
1707 |
# other parameters or results |
|
|
1708 |
self.t = None |
|
|
1709 |
self.state = None |
|
|
1710 |
self.loss = [np.inf] |
|
|
1711 |
self.likelihood = -1.0 |
|
|
1712 |
self.l_c = 0 |
|
|
1713 |
self.ssd_c, self.var_c = 0, 0 |
|
|
1714 |
self.scale_cc = 1.0 |
|
|
1715 |
self.fitting_flag_ = 0 |
|
|
1716 |
self.velocity = None |
|
|
1717 |
self.anchor_t1_list, self.anchor_t2_list = None, None |
|
|
1718 |
self.anchor_exp = None |
|
|
1719 |
self.anchor_exp_sw = None |
|
|
1720 |
self.anchor_min_idx, self.anchor_max_idx, self.anchor_velo_min_idx, \ |
|
|
1721 |
self.anchor_velo_max_idx = None, None, None, None |
|
|
1722 |
self.anchor_velo = None |
|
|
1723 |
self.c0 = self.u0 = self.s0 = 0.0 |
|
|
1724 |
self.realign_ratio = 1.0 |
|
|
1725 |
self.partial = False |
|
|
1726 |
self.direction = 'complete' |
|
|
1727 |
self.steady_state_func = None |
|
|
1728 |
|
|
|
1729 |
# for fit and update |
|
|
1730 |
self.cur_iter = 0 |
|
|
1731 |
self.cur_loss = None |
|
|
1732 |
self.cur_state_pred = None |
|
|
1733 |
self.cur_t_sw_adjust = None |
|
|
1734 |
|
|
|
1735 |
# partial checking and model examination |
|
|
1736 |
determine_model = model is None |
|
|
1737 |
if partial is None and direction is None: |
|
|
1738 |
if embed_coord is not None: |
|
|
1739 |
self.embed_coord = embed_coord[self.non_zero & |
|
|
1740 |
self.non_outlier] |
|
|
1741 |
else: |
|
|
1742 |
self.embed_coord = None |
|
|
1743 |
self.check_partial_trajectory(determine_model=determine_model) |
|
|
1744 |
elif direction is not None: |
|
|
1745 |
self.direction = direction |
|
|
1746 |
if direction in ['on', 'off']: |
|
|
1747 |
self.partial = True |
|
|
1748 |
else: |
|
|
1749 |
self.partial = False |
|
|
1750 |
self.check_partial_trajectory(fit_gmm=False, fit_slope=False, |
|
|
1751 |
determine_model=determine_model) |
|
|
1752 |
elif partial is not None: |
|
|
1753 |
self.partial = partial |
|
|
1754 |
self.check_partial_trajectory(fit_gmm=False, |
|
|
1755 |
determine_model=determine_model) |
|
|
1756 |
else: |
|
|
1757 |
self.check_partial_trajectory(fit_gmm=False, fit_slope=False, |
|
|
1758 |
determine_model=determine_model) |
|
|
1759 |
|
|
|
1760 |
# intialize steady state parameters |
|
|
1761 |
if not self.known_pars and not self.low_quality: |
|
|
1762 |
self.initialize_steady_state_params(model_mismatch=self.model |
|
|
1763 |
!= self.model_) |
|
|
1764 |
if self.known_pars: |
|
|
1765 |
self.params = np.array([self.t_sw_1, |
|
|
1766 |
self.t_sw_2-self.t_sw_1, |
|
|
1767 |
self.t_sw_3-self.t_sw_2, |
|
|
1768 |
self.alpha_c, |
|
|
1769 |
self.alpha, |
|
|
1770 |
self.beta, |
|
|
1771 |
self.gamma, |
|
|
1772 |
self.scale_cc, |
|
|
1773 |
self.rescale_c, |
|
|
1774 |
self.rescale_u]) |
|
|
1775 |
|
|
|
1776 |
# the torch tensor version of the anchor points function |
|
|
1777 |
def anchor_points_ten(self, t_sw_array, total_h=20, t=1000, mode='uniform', |
|
|
1778 |
return_time=False): |
|
|
1779 |
|
|
|
1780 |
t_ = torch.linspace(0, total_h, t, device=self.device, |
|
|
1781 |
dtype=self.torch_type) |
|
|
1782 |
tau1 = t_[t_ <= t_sw_array[0]] |
|
|
1783 |
tau2 = t_[(t_sw_array[0] < t_) & (t_ <= t_sw_array[1])] - t_sw_array[0] |
|
|
1784 |
tau3 = t_[(t_sw_array[1] < t_) & (t_ <= t_sw_array[2])] - t_sw_array[1] |
|
|
1785 |
tau4 = t_[t_sw_array[2] < t_] - t_sw_array[2] |
|
|
1786 |
|
|
|
1787 |
if mode == 'log': |
|
|
1788 |
if len(tau1) > 0: |
|
|
1789 |
tau1 = torch.expm1(tau1) |
|
|
1790 |
tau1 = tau1 / torch.max(tau1) * (t_sw_array[0]) |
|
|
1791 |
if len(tau2) > 0: |
|
|
1792 |
tau2 = torch.expm1(tau2) |
|
|
1793 |
tau2 = tau2 / torch.max(tau2) * (t_sw_array[1] - t_sw_array[0]) |
|
|
1794 |
if len(tau3) > 0: |
|
|
1795 |
tau3 = torch.expm1(tau3) |
|
|
1796 |
tau3 = tau3 / torch.max(tau3) * (t_sw_array[2] - t_sw_array[1]) |
|
|
1797 |
if len(tau4) > 0: |
|
|
1798 |
tau4 = torch.expm1(tau4) |
|
|
1799 |
tau4 = tau4 / torch.max(tau4) * (total_h - t_sw_array[2]) |
|
|
1800 |
|
|
|
1801 |
tau_list = [tau1, tau2, tau3, tau4] |
|
|
1802 |
if return_time: |
|
|
1803 |
return t_, tau_list |
|
|
1804 |
else: |
|
|
1805 |
return tau_list |
|
|
1806 |
|
|
|
1807 |
# the torch version of the predict_exp function |
|
|
1808 |
def predict_exp_ten(self, |
|
|
1809 |
tau, |
|
|
1810 |
c0, |
|
|
1811 |
u0, |
|
|
1812 |
s0, |
|
|
1813 |
alpha_c, |
|
|
1814 |
alpha, |
|
|
1815 |
beta, |
|
|
1816 |
gamma, |
|
|
1817 |
scale_cc=None, |
|
|
1818 |
pred_r=True, |
|
|
1819 |
chrom_open=True, |
|
|
1820 |
backward=False, |
|
|
1821 |
rna_only=False): |
|
|
1822 |
|
|
|
1823 |
if scale_cc is None: |
|
|
1824 |
scale_cc = torch.tensor(1.0, requires_grad=True, |
|
|
1825 |
device=self.device, |
|
|
1826 |
dtype=self.torch_type) |
|
|
1827 |
|
|
|
1828 |
if len(tau) == 0: |
|
|
1829 |
return torch.empty((0, 3), |
|
|
1830 |
requires_grad=True, |
|
|
1831 |
device=self.device, |
|
|
1832 |
dtype=self.torch_type) |
|
|
1833 |
if backward: |
|
|
1834 |
tau = -tau |
|
|
1835 |
|
|
|
1836 |
eat = torch.exp(-alpha_c * tau) |
|
|
1837 |
ebt = torch.exp(-beta * tau) |
|
|
1838 |
egt = torch.exp(-gamma * tau) |
|
|
1839 |
if rna_only: |
|
|
1840 |
kc = 1 |
|
|
1841 |
c0 = 1 |
|
|
1842 |
else: |
|
|
1843 |
if chrom_open: |
|
|
1844 |
kc = 1 |
|
|
1845 |
else: |
|
|
1846 |
kc = 0 |
|
|
1847 |
alpha_c = alpha_c * scale_cc |
|
|
1848 |
|
|
|
1849 |
const = (kc - c0) * alpha / (beta - alpha_c) |
|
|
1850 |
|
|
|
1851 |
res0 = kc - (kc - c0) * eat |
|
|
1852 |
|
|
|
1853 |
if pred_r: |
|
|
1854 |
|
|
|
1855 |
res1 = u0 * ebt + (alpha * kc / beta) * (1 - ebt) |
|
|
1856 |
res1 += const * (ebt - eat) |
|
|
1857 |
|
|
|
1858 |
res2 = s0 * egt + (alpha * kc / gamma) * (1 - egt) |
|
|
1859 |
res2 += ((beta / (gamma - beta)) * |
|
|
1860 |
((alpha * kc / beta) - u0 - const) * (egt - ebt)) |
|
|
1861 |
res2 += (beta / (gamma - alpha_c)) * const * (egt - eat) |
|
|
1862 |
|
|
|
1863 |
else: |
|
|
1864 |
res1 = torch.zeros(len(tau), device=self.device, |
|
|
1865 |
requires_grad=True, |
|
|
1866 |
dtype=self.torch_type) |
|
|
1867 |
res2 = torch.zeros(len(tau), device=self.device, |
|
|
1868 |
requires_grad=True, |
|
|
1869 |
dtype=self.torch_type) |
|
|
1870 |
|
|
|
1871 |
res = torch.stack((res0, res1, res2), 1) |
|
|
1872 |
|
|
|
1873 |
return res |
|
|
1874 |
|
|
|
1875 |
# the torch tensor version of the generate_exp function |
|
|
1876 |
def generate_exp_tens(self, |
|
|
1877 |
tau_list, |
|
|
1878 |
t_sw_array, |
|
|
1879 |
alpha_c, |
|
|
1880 |
alpha, |
|
|
1881 |
beta, |
|
|
1882 |
gamma, |
|
|
1883 |
scale_cc=None, |
|
|
1884 |
model=1, |
|
|
1885 |
rna_only=False): |
|
|
1886 |
|
|
|
1887 |
if scale_cc is None: |
|
|
1888 |
scale_cc = torch.tensor(1.0, requires_grad=True, |
|
|
1889 |
device=self.device, |
|
|
1890 |
dtype=self.torch_type) |
|
|
1891 |
|
|
|
1892 |
if beta == alpha_c: |
|
|
1893 |
beta += 1e-3 |
|
|
1894 |
if gamma == beta or gamma == alpha_c: |
|
|
1895 |
gamma += 1e-3 |
|
|
1896 |
switch = int(t_sw_array.size(dim=0)) |
|
|
1897 |
if switch >= 1: |
|
|
1898 |
tau_sw1 = torch.tensor([t_sw_array[0]], requires_grad=True, |
|
|
1899 |
device=self.device, |
|
|
1900 |
dtype=self.torch_type) |
|
|
1901 |
if switch >= 2: |
|
|
1902 |
tau_sw2 = torch.tensor([t_sw_array[1] - t_sw_array[0]], |
|
|
1903 |
requires_grad=True, |
|
|
1904 |
device=self.device, |
|
|
1905 |
dtype=self.torch_type) |
|
|
1906 |
if switch == 3: |
|
|
1907 |
tau_sw3 = torch.tensor([t_sw_array[2] - t_sw_array[1]], |
|
|
1908 |
requires_grad=True, |
|
|
1909 |
device=self.device, |
|
|
1910 |
dtype=self.torch_type) |
|
|
1911 |
exp_sw1, exp_sw2, exp_sw3 = (torch.empty((0, 3), |
|
|
1912 |
requires_grad=True, |
|
|
1913 |
device=self.device, |
|
|
1914 |
dtype=self.torch_type), |
|
|
1915 |
torch.empty((0, 3), |
|
|
1916 |
requires_grad=True, |
|
|
1917 |
device=self.device, |
|
|
1918 |
dtype=self.torch_type), |
|
|
1919 |
torch.empty((0, 3), |
|
|
1920 |
requires_grad=True, |
|
|
1921 |
device=self.device, |
|
|
1922 |
dtype=self.torch_type)) |
|
|
1923 |
if tau_list is None: |
|
|
1924 |
if model == 0: |
|
|
1925 |
if switch >= 1: |
|
|
1926 |
exp_sw1 = self.predict_exp_ten(tau_sw1, 0, 0, 0, alpha_c, |
|
|
1927 |
alpha, beta, gamma, |
|
|
1928 |
pred_r=False, |
|
|
1929 |
scale_cc=scale_cc, |
|
|
1930 |
rna_only=rna_only) |
|
|
1931 |
if switch >= 2: |
|
|
1932 |
exp_sw2 = self.predict_exp_ten(tau_sw2, exp_sw1[0, 0], |
|
|
1933 |
exp_sw1[0, 1], |
|
|
1934 |
exp_sw1[0, 2], |
|
|
1935 |
alpha_c, alpha, beta, |
|
|
1936 |
gamma, pred_r=False, |
|
|
1937 |
chrom_open=False, |
|
|
1938 |
scale_cc=scale_cc, |
|
|
1939 |
rna_only=rna_only) |
|
|
1940 |
if switch >= 3: |
|
|
1941 |
exp_sw3 = self.predict_exp_ten(tau_sw3, |
|
|
1942 |
exp_sw2[0, 0], |
|
|
1943 |
exp_sw2[0, 1], |
|
|
1944 |
exp_sw2[0, 2], |
|
|
1945 |
alpha_c, alpha, |
|
|
1946 |
beta, gamma, |
|
|
1947 |
chrom_open=False, |
|
|
1948 |
scale_cc=scale_cc, |
|
|
1949 |
rna_only=rna_only) |
|
|
1950 |
elif model == 1: |
|
|
1951 |
if switch >= 1: |
|
|
1952 |
exp_sw1 = self.predict_exp_ten(tau_sw1, 0, 0, 0, alpha_c, |
|
|
1953 |
alpha, beta, gamma, |
|
|
1954 |
pred_r=False, |
|
|
1955 |
scale_cc=scale_cc, |
|
|
1956 |
rna_only=rna_only) |
|
|
1957 |
if switch >= 2: |
|
|
1958 |
exp_sw2 = self.predict_exp_ten(tau_sw2, exp_sw1[0, 0], |
|
|
1959 |
exp_sw1[0, 1], |
|
|
1960 |
exp_sw1[0, 2], |
|
|
1961 |
alpha_c, alpha, |
|
|
1962 |
beta, gamma, |
|
|
1963 |
scale_cc=scale_cc, |
|
|
1964 |
rna_only=rna_only) |
|
|
1965 |
if switch >= 3: |
|
|
1966 |
exp_sw3 = self.predict_exp_ten(tau_sw3, |
|
|
1967 |
exp_sw2[0, 0], |
|
|
1968 |
exp_sw2[0, 1], |
|
|
1969 |
exp_sw2[0, 2], |
|
|
1970 |
alpha_c, alpha, |
|
|
1971 |
beta, gamma, |
|
|
1972 |
chrom_open=False, |
|
|
1973 |
scale_cc=scale_cc, |
|
|
1974 |
rna_only=rna_only) |
|
|
1975 |
elif model == 2: |
|
|
1976 |
if switch >= 1: |
|
|
1977 |
exp_sw1 = self.predict_exp_ten(tau_sw1, 0, 0, 0, alpha_c, |
|
|
1978 |
alpha, beta, gamma, |
|
|
1979 |
pred_r=False, |
|
|
1980 |
scale_cc=scale_cc, |
|
|
1981 |
rna_only=rna_only) |
|
|
1982 |
if switch >= 2: |
|
|
1983 |
exp_sw2 = self.predict_exp_ten(tau_sw2, exp_sw1[0, 0], |
|
|
1984 |
exp_sw1[0, 1], |
|
|
1985 |
exp_sw1[0, 2], alpha_c, |
|
|
1986 |
alpha, beta, gamma, |
|
|
1987 |
scale_cc=scale_cc, |
|
|
1988 |
rna_only=rna_only) |
|
|
1989 |
if switch >= 3: |
|
|
1990 |
exp_sw3 = self.predict_exp_ten(tau_sw3, |
|
|
1991 |
exp_sw2[0, 0], |
|
|
1992 |
exp_sw2[0, 1], |
|
|
1993 |
exp_sw2[0, 2], |
|
|
1994 |
alpha_c, 0, beta, |
|
|
1995 |
gamma, |
|
|
1996 |
scale_cc=scale_cc, |
|
|
1997 |
rna_only=rna_only) |
|
|
1998 |
|
|
|
1999 |
return [torch.empty((0, 3), requires_grad=True, |
|
|
2000 |
device=self.device, |
|
|
2001 |
dtype=self.torch_type), |
|
|
2002 |
torch.empty((0, 3), requires_grad=True, |
|
|
2003 |
device=self.device, |
|
|
2004 |
dtype=self.torch_type), |
|
|
2005 |
torch.empty((0, 3), requires_grad=True, |
|
|
2006 |
device=self.device, |
|
|
2007 |
dtype=self.torch_type), |
|
|
2008 |
torch.empty((0, 3), requires_grad=True, |
|
|
2009 |
device=self.device, |
|
|
2010 |
dtype=self.torch_type)], \ |
|
|
2011 |
[exp_sw1, exp_sw2, exp_sw3] |
|
|
2012 |
|
|
|
2013 |
tau1 = tau_list[0] |
|
|
2014 |
if switch >= 1: |
|
|
2015 |
tau2 = tau_list[1] |
|
|
2016 |
if switch >= 2: |
|
|
2017 |
tau3 = tau_list[2] |
|
|
2018 |
if switch == 3: |
|
|
2019 |
tau4 = tau_list[3] |
|
|
2020 |
exp1, exp2, exp3, exp4 = (torch.empty((0, 3), requires_grad=True, |
|
|
2021 |
device=self.device, |
|
|
2022 |
dtype=self.torch_type), |
|
|
2023 |
torch.empty((0, 3), requires_grad=True, |
|
|
2024 |
device=self.device, |
|
|
2025 |
dtype=self.torch_type), |
|
|
2026 |
torch.empty((0, 3), requires_grad=True, |
|
|
2027 |
device=self.device, |
|
|
2028 |
dtype=self.torch_type), |
|
|
2029 |
torch.empty((0, 3), requires_grad=True, |
|
|
2030 |
device=self.device, |
|
|
2031 |
dtype=self.torch_type)) |
|
|
2032 |
if model == 0: |
|
|
2033 |
exp1 = self.predict_exp_ten(tau1, 0, 0, 0, alpha_c, alpha, beta, |
|
|
2034 |
gamma, pred_r=False, scale_cc=scale_cc, |
|
|
2035 |
rna_only=rna_only) |
|
|
2036 |
if switch >= 1: |
|
|
2037 |
exp_sw1 = self.predict_exp_ten(tau_sw1, 0, 0, 0, alpha_c, |
|
|
2038 |
alpha, beta, gamma, |
|
|
2039 |
pred_r=False, scale_cc=scale_cc, |
|
|
2040 |
rna_only=rna_only) |
|
|
2041 |
exp2 = self.predict_exp_ten(tau2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
2042 |
exp_sw1[0, 2], alpha_c, alpha, |
|
|
2043 |
beta, gamma, pred_r=False, |
|
|
2044 |
chrom_open=False, |
|
|
2045 |
scale_cc=scale_cc, |
|
|
2046 |
rna_only=rna_only) |
|
|
2047 |
if switch >= 2: |
|
|
2048 |
exp_sw2 = self.predict_exp_ten(tau_sw2, exp_sw1[0, 0], |
|
|
2049 |
exp_sw1[0, 1], |
|
|
2050 |
exp_sw1[0, 2], |
|
|
2051 |
alpha_c, alpha, beta, gamma, |
|
|
2052 |
pred_r=False, |
|
|
2053 |
chrom_open=False, |
|
|
2054 |
scale_cc=scale_cc, |
|
|
2055 |
rna_only=rna_only) |
|
|
2056 |
exp3 = self.predict_exp_ten(tau3, exp_sw2[0, 0], |
|
|
2057 |
exp_sw2[0, 1], exp_sw2[0, 2], |
|
|
2058 |
alpha_c, alpha, beta, gamma, |
|
|
2059 |
chrom_open=False, |
|
|
2060 |
scale_cc=scale_cc, |
|
|
2061 |
rna_only=rna_only) |
|
|
2062 |
if switch == 3: |
|
|
2063 |
exp_sw3 = self.predict_exp_ten(tau_sw3, exp_sw2[0, 0], |
|
|
2064 |
exp_sw2[0, 1], |
|
|
2065 |
exp_sw2[0, 2], |
|
|
2066 |
alpha_c, alpha, beta, |
|
|
2067 |
gamma, |
|
|
2068 |
chrom_open=False, |
|
|
2069 |
scale_cc=scale_cc, |
|
|
2070 |
rna_only=rna_only) |
|
|
2071 |
exp4 = self.predict_exp_ten(tau4, exp_sw3[0, 0], |
|
|
2072 |
exp_sw3[0, 1], |
|
|
2073 |
exp_sw3[0, 2], |
|
|
2074 |
alpha_c, 0, beta, gamma, |
|
|
2075 |
chrom_open=False, |
|
|
2076 |
scale_cc=scale_cc, |
|
|
2077 |
rna_only=rna_only) |
|
|
2078 |
elif model == 1: |
|
|
2079 |
exp1 = self.predict_exp_ten(tau1, 0, 0, 0, alpha_c, alpha, beta, |
|
|
2080 |
gamma, pred_r=False, scale_cc=scale_cc, |
|
|
2081 |
rna_only=rna_only) |
|
|
2082 |
if switch >= 1: |
|
|
2083 |
exp_sw1 = self.predict_exp_ten(tau_sw1, 0, 0, 0, alpha_c, |
|
|
2084 |
alpha, beta, gamma, |
|
|
2085 |
pred_r=False, scale_cc=scale_cc, |
|
|
2086 |
rna_only=rna_only) |
|
|
2087 |
exp2 = self.predict_exp_ten(tau2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
2088 |
exp_sw1[0, 2], alpha_c, alpha, |
|
|
2089 |
beta, gamma, scale_cc=scale_cc, |
|
|
2090 |
rna_only=rna_only) |
|
|
2091 |
if switch >= 2: |
|
|
2092 |
exp_sw2 = self.predict_exp_ten(tau_sw2, exp_sw1[0, 0], |
|
|
2093 |
exp_sw1[0, 1], |
|
|
2094 |
exp_sw1[0, 2], alpha_c, |
|
|
2095 |
alpha, beta, gamma, |
|
|
2096 |
scale_cc=scale_cc, |
|
|
2097 |
rna_only=rna_only) |
|
|
2098 |
exp3 = self.predict_exp_ten(tau3, exp_sw2[0, 0], |
|
|
2099 |
exp_sw2[0, 1], exp_sw2[0, 2], |
|
|
2100 |
alpha_c, alpha, beta, gamma, |
|
|
2101 |
chrom_open=False, |
|
|
2102 |
scale_cc=scale_cc, |
|
|
2103 |
rna_only=rna_only) |
|
|
2104 |
if switch == 3: |
|
|
2105 |
exp_sw3 = self.predict_exp_ten(tau_sw3, exp_sw2[0, 0], |
|
|
2106 |
exp_sw2[0, 1], |
|
|
2107 |
exp_sw2[0, 2], |
|
|
2108 |
alpha_c, alpha, beta, |
|
|
2109 |
gamma, |
|
|
2110 |
chrom_open=False, |
|
|
2111 |
scale_cc=scale_cc, |
|
|
2112 |
rna_only=rna_only) |
|
|
2113 |
exp4 = self.predict_exp_ten(tau4, exp_sw3[0, 0], |
|
|
2114 |
exp_sw3[0, 1], |
|
|
2115 |
exp_sw3[0, 2], alpha_c, 0, |
|
|
2116 |
beta, gamma, |
|
|
2117 |
chrom_open=False, |
|
|
2118 |
scale_cc=scale_cc, |
|
|
2119 |
rna_only=rna_only) |
|
|
2120 |
elif model == 2: |
|
|
2121 |
exp1 = self.predict_exp_ten(tau1, 0, 0, 0, alpha_c, alpha, beta, |
|
|
2122 |
gamma, pred_r=False, scale_cc=scale_cc, |
|
|
2123 |
rna_only=rna_only) |
|
|
2124 |
if switch >= 1: |
|
|
2125 |
exp_sw1 = self.predict_exp_ten(tau_sw1, 0, 0, 0, alpha_c, |
|
|
2126 |
alpha, beta, gamma, |
|
|
2127 |
pred_r=False, scale_cc=scale_cc, |
|
|
2128 |
rna_only=rna_only) |
|
|
2129 |
exp2 = self.predict_exp_ten(tau2, exp_sw1[0, 0], exp_sw1[0, 1], |
|
|
2130 |
exp_sw1[0, 2], alpha_c, alpha, |
|
|
2131 |
beta, gamma, scale_cc=scale_cc, |
|
|
2132 |
rna_only=rna_only) |
|
|
2133 |
if switch >= 2: |
|
|
2134 |
exp_sw2 = self.predict_exp_ten(tau_sw2, exp_sw1[0, 0], |
|
|
2135 |
exp_sw1[0, 1], |
|
|
2136 |
exp_sw1[0, 2], alpha_c, |
|
|
2137 |
alpha, beta, gamma, |
|
|
2138 |
scale_cc=scale_cc, |
|
|
2139 |
rna_only=rna_only) |
|
|
2140 |
exp3 = self.predict_exp_ten(tau3, exp_sw2[0, 0], |
|
|
2141 |
exp_sw2[0, 1], |
|
|
2142 |
exp_sw2[0, 2], alpha_c, 0, |
|
|
2143 |
beta, gamma, scale_cc=scale_cc, |
|
|
2144 |
rna_only=rna_only) |
|
|
2145 |
if switch == 3: |
|
|
2146 |
exp_sw3 = self.predict_exp_ten(tau_sw3, exp_sw2[0, 0], |
|
|
2147 |
exp_sw2[0, 1], |
|
|
2148 |
exp_sw2[0, 2], |
|
|
2149 |
alpha_c, 0, beta, gamma, |
|
|
2150 |
scale_cc=scale_cc, |
|
|
2151 |
rna_only=rna_only) |
|
|
2152 |
exp4 = self.predict_exp_ten(tau4, exp_sw3[0, 0], |
|
|
2153 |
exp_sw3[0, 1], |
|
|
2154 |
exp_sw3[0, 2], |
|
|
2155 |
alpha_c, 0, beta, gamma, |
|
|
2156 |
chrom_open=False, |
|
|
2157 |
scale_cc=scale_cc, |
|
|
2158 |
rna_only=rna_only) |
|
|
2159 |
return [exp1, exp2, exp3, exp4], [exp_sw1, exp_sw2, exp_sw3] |
|
|
2160 |
|
|
|
2161 |
def check_partial_trajectory(self, fit_gmm=True, fit_slope=True, |
|
|
2162 |
determine_model=True): |
|
|
2163 |
w_non_zero = ((self.c >= 0.1 * np.max(self.c)) & |
|
|
2164 |
(self.u >= 0.1 * np.max(self.u)) & |
|
|
2165 |
(self.s >= 0.1 * np.max(self.s))) |
|
|
2166 |
u_non_zero = self.u[w_non_zero] |
|
|
2167 |
s_non_zero = self.s[w_non_zero] |
|
|
2168 |
if len(u_non_zero) < 10: |
|
|
2169 |
self.low_quality = True |
|
|
2170 |
return |
|
|
2171 |
|
|
|
2172 |
# GMM |
|
|
2173 |
w_low = ((np.percentile(s_non_zero, 30) <= s_non_zero) & |
|
|
2174 |
(s_non_zero <= np.percentile(s_non_zero, 40))) |
|
|
2175 |
if np.sum(w_low) < 10: |
|
|
2176 |
fit_gmm = False |
|
|
2177 |
self.partial = True |
|
|
2178 |
if self.local_std is None: |
|
|
2179 |
logg.warn('local standard deviation not provided. ' |
|
|
2180 |
'Skipping GMM..', v=2) |
|
|
2181 |
if self.embed_coord is None: |
|
|
2182 |
logg.warn('Warning: embedded coordinates not provided. ' |
|
|
2183 |
'Skipping GMM..') |
|
|
2184 |
if (fit_gmm and self.local_std is not None and self.embed_coord |
|
|
2185 |
is not None): |
|
|
2186 |
|
|
|
2187 |
pdist = pairwise_distances( |
|
|
2188 |
self.embed_coord[w_non_zero, :][w_low, :]) |
|
|
2189 |
dists = (np.ravel(pdist[np.triu_indices_from(pdist, k=1)]) |
|
|
2190 |
.reshape(-1, 1)) |
|
|
2191 |
model = GaussianMixture(n_components=2, covariance_type='tied', |
|
|
2192 |
random_state=2021).fit(dists) |
|
|
2193 |
mean_diff = np.abs(model.means_[1][0] - model.means_[0][0]) |
|
|
2194 |
criterion1 = mean_diff > self.local_std / self.tm |
|
|
2195 |
logg.update(f'GMM: difference between means = {mean_diff}, ' |
|
|
2196 |
f'threshold = {self.local_std / self.tm}.', v=2) |
|
|
2197 |
criterion2 = np.all(model.weights_[1] > 0.2 / self.tm) |
|
|
2198 |
logg.update('GMM: weight of the second Gaussian =' |
|
|
2199 |
f' {model.weights_[1]}.', v=2) |
|
|
2200 |
if criterion1 and criterion2: |
|
|
2201 |
self.partial = False |
|
|
2202 |
else: |
|
|
2203 |
self.partial = True |
|
|
2204 |
logg.update(f'GMM decides {"" if self.partial else "not "}' |
|
|
2205 |
'partial.', v=2) |
|
|
2206 |
|
|
|
2207 |
# steady-state slope |
|
|
2208 |
wu = self.u >= np.percentile(u_non_zero, 95) |
|
|
2209 |
ws = self.s >= np.percentile(s_non_zero, 95) |
|
|
2210 |
ss_u = self.u[wu | ws] |
|
|
2211 |
ss_s = self.s[wu | ws] |
|
|
2212 |
if np.all(ss_u == 0) or np.all(ss_s == 0): |
|
|
2213 |
self.low_quality = True |
|
|
2214 |
return |
|
|
2215 |
gamma = np.dot(ss_u, ss_s) / np.dot(ss_s, ss_s) |
|
|
2216 |
self.steady_state_func = lambda x: gamma*x |
|
|
2217 |
|
|
|
2218 |
# thickness of phase portrait |
|
|
2219 |
u_norm = u_non_zero / np.max(self.u) |
|
|
2220 |
s_norm = s_non_zero / np.max(self.s) |
|
|
2221 |
exp = np.hstack((np.reshape(u_norm, (-1, 1)), |
|
|
2222 |
np.reshape(s_norm, (-1, 1)))) |
|
|
2223 |
U, S, Vh = np.linalg.svd(exp) |
|
|
2224 |
self.thickness = S[1] |
|
|
2225 |
|
|
|
2226 |
# slope-based direction decision |
|
|
2227 |
with np.errstate(divide='ignore', invalid='ignore'): |
|
|
2228 |
slope = self.u / self.s |
|
|
2229 |
non_nan = ~np.isnan(slope) |
|
|
2230 |
slope = slope[non_nan] |
|
|
2231 |
on = slope >= gamma |
|
|
2232 |
off = slope < gamma |
|
|
2233 |
if len(ss_u) < 10 or len(u_non_zero) < 10: |
|
|
2234 |
fit_slope = False |
|
|
2235 |
self.direction = 'complete' |
|
|
2236 |
if fit_slope: |
|
|
2237 |
slope_ = u_non_zero / s_non_zero |
|
|
2238 |
on_ = slope_ >= gamma |
|
|
2239 |
off_ = slope_ < gamma |
|
|
2240 |
on_dist = np.sum((u_non_zero[on_] - gamma * s_non_zero[on_])**2) |
|
|
2241 |
off_dist = np.sum((gamma * s_non_zero[off_] - u_non_zero[off_])**2) |
|
|
2242 |
logg.update(f'Slope: SSE on induction phase = {on_dist},' |
|
|
2243 |
f' SSE on repression phase = {off_dist}.', v=2) |
|
|
2244 |
if self.thickness < 1.5 / np.sqrt(self.tm): |
|
|
2245 |
narrow = True |
|
|
2246 |
else: |
|
|
2247 |
narrow = False |
|
|
2248 |
logg.update(f'Thickness of trajectory = {self.thickness}. ' |
|
|
2249 |
f'Trajectory is {"narrow" if narrow else "normal"}.', |
|
|
2250 |
v=2) |
|
|
2251 |
if on_dist > 10 * self.tm**2 * off_dist: |
|
|
2252 |
self.direction = 'on' |
|
|
2253 |
self.partial = True |
|
|
2254 |
elif off_dist > 10 * self.tm**2 * on_dist: |
|
|
2255 |
self.direction = 'off' |
|
|
2256 |
self.partial = True |
|
|
2257 |
else: |
|
|
2258 |
if self.partial is True: |
|
|
2259 |
if on_dist > 3 * self.tm * off_dist: |
|
|
2260 |
self.direction = 'on' |
|
|
2261 |
elif off_dist > 3 * self.tm * on_dist: |
|
|
2262 |
self.direction = 'off' |
|
|
2263 |
else: |
|
|
2264 |
if narrow: |
|
|
2265 |
self.direction = 'on' |
|
|
2266 |
else: |
|
|
2267 |
self.direction = 'complete' |
|
|
2268 |
self.partial = False |
|
|
2269 |
else: |
|
|
2270 |
if narrow: |
|
|
2271 |
self.direction = ('off' |
|
|
2272 |
if off_dist > 2 * self.tm * on_dist |
|
|
2273 |
else 'on') |
|
|
2274 |
self.partial = True |
|
|
2275 |
else: |
|
|
2276 |
self.direction = 'complete' |
|
|
2277 |
|
|
|
2278 |
# model pre-determination |
|
|
2279 |
if self.direction == 'on': |
|
|
2280 |
self.model_ = 1 |
|
|
2281 |
elif self.direction == 'off': |
|
|
2282 |
self.model_ = 2 |
|
|
2283 |
else: |
|
|
2284 |
c_high = self.c >= np.mean(self.c) + 2 * np.std(self.c) |
|
|
2285 |
c_high = c_high[non_nan] |
|
|
2286 |
if np.sum(c_high) < 10: |
|
|
2287 |
c_high = self.c >= np.mean(self.c) + np.std(self.c) |
|
|
2288 |
c_high = c_high[non_nan] |
|
|
2289 |
if np.sum(c_high) < 10: |
|
|
2290 |
c_high = self.c >= np.percentile(self.c, 90) |
|
|
2291 |
c_high = c_high[non_nan] |
|
|
2292 |
if np.sum(self.c[non_nan][c_high] == 0) > 0.5*np.sum(c_high): |
|
|
2293 |
self.low_quality = True |
|
|
2294 |
return |
|
|
2295 |
c_high_on = np.sum(c_high & on) |
|
|
2296 |
c_high_off = np.sum(c_high & off) |
|
|
2297 |
if c_high_on > c_high_off: |
|
|
2298 |
self.model_ = 1 |
|
|
2299 |
else: |
|
|
2300 |
self.model_ = 2 |
|
|
2301 |
if determine_model: |
|
|
2302 |
self.model = self.model_ |
|
|
2303 |
|
|
|
2304 |
if not self.known_pars: |
|
|
2305 |
if fit_gmm or fit_slope: |
|
|
2306 |
logg.update(f'predicted partial trajectory: {self.partial}', |
|
|
2307 |
v=1) |
|
|
2308 |
logg.update('predicted trajectory direction:' |
|
|
2309 |
f'{self.direction}', v=1) |
|
|
2310 |
if determine_model: |
|
|
2311 |
logg.update(f'predicted model: {self.model}', v=1) |
|
|
2312 |
|
|
|
2313 |
def initialize_steady_state_params(self, model_mismatch=False): |
|
|
2314 |
self.scale_cc = 1.0 |
|
|
2315 |
self.rescale_c = 1.0 |
|
|
2316 |
# estimate rescale factor for u |
|
|
2317 |
s_norm = self.s / self.max_s |
|
|
2318 |
u_mid = (self.u >= 0.4 * self.max_u) & (self.u <= 0.6 * self.max_u) |
|
|
2319 |
if np.sum(u_mid) < 10: |
|
|
2320 |
self.rescale_u = self.thickness / 5 |
|
|
2321 |
else: |
|
|
2322 |
s_low, s_high = np.percentile(s_norm[u_mid], [2, 98]) |
|
|
2323 |
s_dist = s_high - s_low |
|
|
2324 |
self.rescale_u = s_dist |
|
|
2325 |
if self.rescale_u == 0: |
|
|
2326 |
self.low_quality = True |
|
|
2327 |
return |
|
|
2328 |
|
|
|
2329 |
c = self.c / self.rescale_c |
|
|
2330 |
u = self.u / self.rescale_u |
|
|
2331 |
s = self.s |
|
|
2332 |
|
|
|
2333 |
# some extreme values |
|
|
2334 |
wu = u >= np.percentile(u, 97) |
|
|
2335 |
ws = s >= np.percentile(s, 97) |
|
|
2336 |
ss_u = u[wu | ws] |
|
|
2337 |
ss_s = s[wu | ws] |
|
|
2338 |
c_upper = np.mean(c[wu | ws]) |
|
|
2339 |
|
|
|
2340 |
c_high = c >= np.mean(c) |
|
|
2341 |
# _r stands for repressed state |
|
|
2342 |
c0_r = np.mean(c[c_high]) |
|
|
2343 |
u0_r = np.mean(ss_u) |
|
|
2344 |
s0_r = np.mean(ss_s) |
|
|
2345 |
if c0_r < c_upper: |
|
|
2346 |
c0_r = c_upper + 0.1 |
|
|
2347 |
|
|
|
2348 |
# adjust chromatin level for reasonable initialization |
|
|
2349 |
if model_mismatch or not self.fit_decoupling: |
|
|
2350 |
c_indu = np.mean(c[self.u > self.steady_state_func(self.s)]) |
|
|
2351 |
c_repr = np.mean(c[self.u < self.steady_state_func(self.s)]) |
|
|
2352 |
if c_indu == np.nan or c_repr == np.nan: |
|
|
2353 |
self.low_quality = True |
|
|
2354 |
return |
|
|
2355 |
c0_r = np.mean(c[c >= np.min([c_indu, c_repr])]) |
|
|
2356 |
|
|
|
2357 |
# initialize rates |
|
|
2358 |
self.alpha_c = 0.1 |
|
|
2359 |
self.beta = 1.0 |
|
|
2360 |
self.gamma = np.dot(ss_u, ss_s) / np.dot(ss_s, ss_s) |
|
|
2361 |
alpha = u0_r |
|
|
2362 |
self.alpha = u0_r |
|
|
2363 |
self.rates = np.array([self.alpha_c, self.alpha, self.beta, |
|
|
2364 |
self.gamma]) |
|
|
2365 |
|
|
|
2366 |
# RNA-only |
|
|
2367 |
if self.rna_only: |
|
|
2368 |
t_sw_1 = 0.1 |
|
|
2369 |
t_sw_3 = 20.0 |
|
|
2370 |
if self.init_mode == 'grid': |
|
|
2371 |
# arange returns sequence [2,6,10,14,18] |
|
|
2372 |
for t_sw_2 in np.arange(2, 20, 4, dtype=np.float64): |
|
|
2373 |
self.update(params, initialize=True, adjust_time=False, |
|
|
2374 |
plot=False) |
|
|
2375 |
|
|
|
2376 |
elif self.init_mode == 'simple': |
|
|
2377 |
t_sw_2 = 10 |
|
|
2378 |
self.params = np.array([t_sw_1, |
|
|
2379 |
t_sw_2-t_sw_1, |
|
|
2380 |
t_sw_3-t_sw_2, |
|
|
2381 |
self.alpha_c, |
|
|
2382 |
self.alpha, |
|
|
2383 |
self.beta, |
|
|
2384 |
self.gamma, |
|
|
2385 |
self.scale_cc, |
|
|
2386 |
self.rescale_c, |
|
|
2387 |
self.rescale_u]) |
|
|
2388 |
|
|
|
2389 |
elif self.init_mode == 'invert': |
|
|
2390 |
t_sw_2 = approx_tau(u0_r, s0_r, 0, 0, alpha, self.beta, |
|
|
2391 |
self.gamma) |
|
|
2392 |
if t_sw_2 <= 0.2: |
|
|
2393 |
t_sw_2 = 1.0 |
|
|
2394 |
elif t_sw_2 >= 19.9: |
|
|
2395 |
t_sw_2 = 19.0 |
|
|
2396 |
self.params = np.array([t_sw_1, |
|
|
2397 |
t_sw_2-t_sw_1, |
|
|
2398 |
t_sw_3-t_sw_2, |
|
|
2399 |
self.alpha_c, |
|
|
2400 |
self.alpha, |
|
|
2401 |
self.beta, |
|
|
2402 |
self.gamma, |
|
|
2403 |
self.scale_cc, |
|
|
2404 |
self.rescale_c, |
|
|
2405 |
self.rescale_u]) |
|
|
2406 |
|
|
|
2407 |
# chromatin-RNA |
|
|
2408 |
else: |
|
|
2409 |
if self.init_mode == 'grid': |
|
|
2410 |
# arange returns sequence [1,5,9,13,17] |
|
|
2411 |
for t_sw_1 in np.arange(1, 18, 4, dtype=np.float64): |
|
|
2412 |
# arange returns sequence 2,6,10,14,18 |
|
|
2413 |
for t_sw_2 in np.arange(t_sw_1+1, 19, 4, dtype=np.float64): |
|
|
2414 |
# arange returns sequence [3,7,11,15,19] |
|
|
2415 |
for t_sw_3 in np.arange(t_sw_2+1, 20, 4, |
|
|
2416 |
dtype=np.float64): |
|
|
2417 |
if not self.fit_decoupling: |
|
|
2418 |
t_sw_3 = t_sw_2 + 30 / self.n_anchors |
|
|
2419 |
params = np.array([t_sw_1, |
|
|
2420 |
t_sw_2-t_sw_1, |
|
|
2421 |
t_sw_3-t_sw_2, |
|
|
2422 |
self.alpha_c, |
|
|
2423 |
self.alpha, |
|
|
2424 |
self.beta, |
|
|
2425 |
self.gamma, |
|
|
2426 |
self.scale_cc, |
|
|
2427 |
self.rescale_c, |
|
|
2428 |
self.rescale_u]) |
|
|
2429 |
self.update(params, initialize=True, |
|
|
2430 |
adjust_time=False, plot=False) |
|
|
2431 |
if not self.fit_decoupling: |
|
|
2432 |
break |
|
|
2433 |
|
|
|
2434 |
elif self.init_mode == 'simple': |
|
|
2435 |
t_sw_1, t_sw_2, t_sw_3 = 5, 10, 15 \ |
|
|
2436 |
if not self.fit_decoupling \ |
|
|
2437 |
else 10.1 |
|
|
2438 |
self.params = np.array([t_sw_1, |
|
|
2439 |
t_sw_2-t_sw_1, |
|
|
2440 |
t_sw_3-t_sw_2, |
|
|
2441 |
self.alpha_c, |
|
|
2442 |
self.alpha, |
|
|
2443 |
self.beta, |
|
|
2444 |
self.gamma, |
|
|
2445 |
self.scale_cc, |
|
|
2446 |
self.rescale_c, |
|
|
2447 |
self.rescale_u]) |
|
|
2448 |
|
|
|
2449 |
elif self.init_mode == 'invert': |
|
|
2450 |
self.alpha = u0_r / c_upper |
|
|
2451 |
if model_mismatch or not self.fit_decoupling: |
|
|
2452 |
self.alpha = u0_r / c0_r |
|
|
2453 |
rna_interval = approx_tau(u0_r, s0_r, 0, 0, alpha, self.beta, |
|
|
2454 |
self.gamma) |
|
|
2455 |
rna_interval = np.clip(rna_interval, 3, 12) |
|
|
2456 |
if self.model == 1: |
|
|
2457 |
for t_sw_1 in np.arange(1, rna_interval-1, 2, |
|
|
2458 |
dtype=np.float64): |
|
|
2459 |
t_sw_3 = rna_interval + t_sw_1 |
|
|
2460 |
for t_sw_2 in np.arange(t_sw_1+1, rna_interval, 2, |
|
|
2461 |
dtype=np.float64): |
|
|
2462 |
if not self.fit_decoupling: |
|
|
2463 |
t_sw_2 = t_sw_3 - 30 / self.n_anchors |
|
|
2464 |
|
|
|
2465 |
alpha_c = -np.log(1 - c0_r) / t_sw_2 |
|
|
2466 |
params = np.array([t_sw_1, |
|
|
2467 |
t_sw_2-t_sw_1, |
|
|
2468 |
t_sw_3-t_sw_2, |
|
|
2469 |
alpha_c, |
|
|
2470 |
self.alpha, |
|
|
2471 |
self.beta, |
|
|
2472 |
self.gamma, |
|
|
2473 |
self.scale_cc, |
|
|
2474 |
self.rescale_c, |
|
|
2475 |
self.rescale_u]) |
|
|
2476 |
self.update(params, initialize=True, |
|
|
2477 |
adjust_time=False, plot=False) |
|
|
2478 |
if not self.fit_decoupling: |
|
|
2479 |
break |
|
|
2480 |
|
|
|
2481 |
elif self.model == 2: |
|
|
2482 |
for t_sw_1 in np.arange(1, rna_interval, 2, |
|
|
2483 |
dtype=np.float64): |
|
|
2484 |
t_sw_2 = rna_interval + t_sw_1 |
|
|
2485 |
for t_sw_3 in np.arange(t_sw_2+1, t_sw_2+6, 2, |
|
|
2486 |
dtype=np.float64): |
|
|
2487 |
if not self.fit_decoupling: |
|
|
2488 |
t_sw_3 = t_sw_2 + 30 / self.n_anchors |
|
|
2489 |
|
|
|
2490 |
alpha_c = -np.log(1 - c0_r) / t_sw_3 |
|
|
2491 |
params = np.array([t_sw_1, |
|
|
2492 |
t_sw_2-t_sw_1, |
|
|
2493 |
t_sw_3-t_sw_2, |
|
|
2494 |
alpha_c, |
|
|
2495 |
self.alpha, |
|
|
2496 |
self.beta, |
|
|
2497 |
self.gamma, |
|
|
2498 |
self.scale_cc, |
|
|
2499 |
self.rescale_c, |
|
|
2500 |
self.rescale_u]) |
|
|
2501 |
self.update(params, initialize=True, |
|
|
2502 |
adjust_time=False, plot=False) |
|
|
2503 |
if not self.fit_decoupling: |
|
|
2504 |
break |
|
|
2505 |
|
|
|
2506 |
self.loss = [self.mse(self.params)] |
|
|
2507 |
self.t_sw_array = np.array([self.params[0], |
|
|
2508 |
self.params[0]+self.params[1], |
|
|
2509 |
self.params[0]+self.params[1] |
|
|
2510 |
+ self.params[2]]) |
|
|
2511 |
self.t_sw_1, self.t_sw_2, self.t_sw_3 = self.t_sw_array |
|
|
2512 |
|
|
|
2513 |
logg.update(f'initial params:\nswitch time array = {self.t_sw_array},' |
|
|
2514 |
'\n' |
|
|
2515 |
f'rates = {self.rates},\ncc scale = {self.scale_cc},\n' |
|
|
2516 |
f'c rescale factor = {self.rescale_c},\n' |
|
|
2517 |
f'u rescale factor = {self.rescale_u}', v=1) |
|
|
2518 |
logg.update(f'initial loss: {self.loss[-1]}', v=1) |
|
|
2519 |
|
|
|
2520 |
def fit(self): |
|
|
2521 |
if self.low_quality: |
|
|
2522 |
return self.loss |
|
|
2523 |
|
|
|
2524 |
if self.plot: |
|
|
2525 |
plt.ion() |
|
|
2526 |
self.fig = plt.figure(figsize=self.fig_size) |
|
|
2527 |
if self.rna_only: |
|
|
2528 |
self.ax = self.fig.add_subplot(111) |
|
|
2529 |
else: |
|
|
2530 |
self.ax = self.fig.add_subplot(111, projection='3d') |
|
|
2531 |
|
|
|
2532 |
if not self.known_pars: |
|
|
2533 |
self.fit_dyn() |
|
|
2534 |
|
|
|
2535 |
self.update(self.params, perform_update=True, fit_outlier=True, |
|
|
2536 |
plot=True) |
|
|
2537 |
|
|
|
2538 |
# remove long gaps in the last observed state |
|
|
2539 |
t_sorted = np.sort(self.t) |
|
|
2540 |
dt = np.diff(t_sorted, prepend=0) |
|
|
2541 |
mean_dt = np.mean(dt) |
|
|
2542 |
std_dt = np.std(dt) |
|
|
2543 |
gap_thresh = np.clip(mean_dt+3*std_dt, 3*20/self.n_anchors, None) |
|
|
2544 |
if gap_thresh > 0: |
|
|
2545 |
idx = np.where(dt > gap_thresh)[0] |
|
|
2546 |
gap_sum = 0 |
|
|
2547 |
last_t_sw = np.max(self.t_sw_array[self.t_sw_array < 20]) |
|
|
2548 |
for i in idx: |
|
|
2549 |
t1 = t_sorted[i-1] if i > 0 else 0 |
|
|
2550 |
t2 = t_sorted[i] |
|
|
2551 |
if t1 > last_t_sw and t2 <= 20: |
|
|
2552 |
gap_sum += np.clip(t2 - t1 - mean_dt, 0, None) |
|
|
2553 |
if last_t_sw > np.max(self.t): |
|
|
2554 |
gap_sum += 20 - last_t_sw |
|
|
2555 |
realign_ratio = np.clip(20/(20 - gap_sum), None, 20/last_t_sw) |
|
|
2556 |
logg.update(f'removing gaps and realigning by {realign_ratio}..', |
|
|
2557 |
v=1) |
|
|
2558 |
self.rates /= realign_ratio |
|
|
2559 |
self.alpha_c, self.alpha, self.beta, self.gamma = self.rates |
|
|
2560 |
self.params[:3] *= realign_ratio |
|
|
2561 |
self.params[3:7] = self.rates |
|
|
2562 |
self.t_sw_array = np.array([self.params[0], |
|
|
2563 |
self.params[0]+self.params[1], |
|
|
2564 |
self.params[0]+self.params[1] |
|
|
2565 |
+ self.params[2]]) |
|
|
2566 |
self.t_sw_1, self.t_sw_2, self.t_sw_3 = self.t_sw_array |
|
|
2567 |
self.update(self.params, perform_update=True, fit_outlier=True, |
|
|
2568 |
plot=True) |
|
|
2569 |
|
|
|
2570 |
if self.plot: |
|
|
2571 |
plt.ioff() |
|
|
2572 |
plt.show(block=True) |
|
|
2573 |
|
|
|
2574 |
# likelihood |
|
|
2575 |
logg.update('computing likelihood..', v=1) |
|
|
2576 |
keep = self.non_zero & self.non_outlier & \ |
|
|
2577 |
(self.u_all > 0.2 * np.percentile(self.u_all, 99.5)) & \ |
|
|
2578 |
(self.s_all > 0.2 * np.percentile(self.s_all, 99.5)) |
|
|
2579 |
scale_factor = np.array([self.scale_c / self.std_c, |
|
|
2580 |
self.scale_u / self.std_u, |
|
|
2581 |
self.scale_s / self.std_s]) |
|
|
2582 |
if np.sum(keep) >= 10: |
|
|
2583 |
self.likelihood, self.l_c, self.ssd_c, self.var_c, l_u, l_s = \ |
|
|
2584 |
compute_likelihood(self.c_all, |
|
|
2585 |
self.u_all, |
|
|
2586 |
self.s_all, |
|
|
2587 |
self.t_sw_array, |
|
|
2588 |
self.alpha_c, |
|
|
2589 |
self.alpha, |
|
|
2590 |
self.beta, |
|
|
2591 |
self.gamma, |
|
|
2592 |
self.rescale_c, |
|
|
2593 |
self.rescale_u, |
|
|
2594 |
self.t, |
|
|
2595 |
self.state, |
|
|
2596 |
scale_cc=self.scale_cc, |
|
|
2597 |
scale_factor=scale_factor, |
|
|
2598 |
model=self.model, |
|
|
2599 |
weight=keep, |
|
|
2600 |
rna_only=self.rna_only) |
|
|
2601 |
else: |
|
|
2602 |
self.likelihood, self.l_c, self.ssd_c, self.var_c, l_u = \ |
|
|
2603 |
0, 0, 0, 0, 0 |
|
|
2604 |
# TODO: Keep? Remove?? |
|
|
2605 |
l_s = 0 |
|
|
2606 |
|
|
|
2607 |
if not self.rna_only: |
|
|
2608 |
logg.update(f'likelihood of c: {self.l_c}, likelihood of u: {l_u},' |
|
|
2609 |
f' likelihood of s: {l_s}', v=1) |
|
|
2610 |
|
|
|
2611 |
# velocity |
|
|
2612 |
logg.update('computing velocities..', v=1) |
|
|
2613 |
self.velocity = np.empty((len(self.u_all), 3)) |
|
|
2614 |
if self.conn is not None: |
|
|
2615 |
new_time = self.conn.dot(self.t) |
|
|
2616 |
new_time[new_time > 20] = 20 |
|
|
2617 |
new_state = self.state.copy() |
|
|
2618 |
new_state[new_time <= self.t_sw_1] = 0 |
|
|
2619 |
new_state[(self.t_sw_1 < new_time) & (new_time <= self.t_sw_2)] = 1 |
|
|
2620 |
new_state[(self.t_sw_2 < new_time) & (new_time <= self.t_sw_3)] = 2 |
|
|
2621 |
new_state[self.t_sw_3 < new_time] = 3 |
|
|
2622 |
|
|
|
2623 |
else: |
|
|
2624 |
new_time = self.t |
|
|
2625 |
new_state = self.state |
|
|
2626 |
|
|
|
2627 |
self.alpha_c, self.alpha, self.beta, self.gamma = \ |
|
|
2628 |
check_params(self.alpha_c, self.alpha, self.beta, self.gamma) |
|
|
2629 |
vc, vu, vs = compute_velocity(new_time, |
|
|
2630 |
self.t_sw_array, |
|
|
2631 |
new_state, |
|
|
2632 |
self.alpha_c, |
|
|
2633 |
self.alpha, |
|
|
2634 |
self.beta, |
|
|
2635 |
self.gamma, |
|
|
2636 |
self.rescale_c, |
|
|
2637 |
self.rescale_u, |
|
|
2638 |
scale_cc=self.scale_cc, |
|
|
2639 |
model=self.model, |
|
|
2640 |
rna_only=self.rna_only) |
|
|
2641 |
|
|
|
2642 |
self.velocity[:, 0] = vc * self.scale_c |
|
|
2643 |
self.velocity[:, 1] = vu * self.scale_u |
|
|
2644 |
self.velocity[:, 2] = vs * self.scale_s |
|
|
2645 |
|
|
|
2646 |
# anchor expression and velocity |
|
|
2647 |
anchor_time, tau_list = anchor_points(self.t_sw_array, 20, |
|
|
2648 |
self.n_anchors, return_time=True) |
|
|
2649 |
switch = np.sum(self.t_sw_array < 20) |
|
|
2650 |
typed_tau_list = List() |
|
|
2651 |
[typed_tau_list.append(x) for x in tau_list] |
|
|
2652 |
self.alpha_c, self.alpha, self.beta, self.gamma, \ |
|
|
2653 |
self.c0, self.u0, self.s0 = \ |
|
|
2654 |
check_params(self.alpha_c, self.alpha, self.beta, self.gamma, |
|
|
2655 |
c0=self.c0, u0=self.u0, s0=self.s0) |
|
|
2656 |
exp_list, exp_sw_list = generate_exp(typed_tau_list, |
|
|
2657 |
self.t_sw_array[:switch], |
|
|
2658 |
self.alpha_c, |
|
|
2659 |
self.alpha, |
|
|
2660 |
self.beta, |
|
|
2661 |
self.gamma, |
|
|
2662 |
scale_cc=self.scale_cc, |
|
|
2663 |
model=self.model, |
|
|
2664 |
rna_only=self.rna_only) |
|
|
2665 |
rescale_factor = np.array([self.rescale_c, self.rescale_u, 1.0]) |
|
|
2666 |
exp_list = [x*rescale_factor for x in exp_list] |
|
|
2667 |
exp_sw_list = [x*rescale_factor for x in exp_sw_list] |
|
|
2668 |
c = np.ravel(np.concatenate([exp_list[x][:, 0] |
|
|
2669 |
for x in range(switch+1)])) |
|
|
2670 |
u = np.ravel(np.concatenate([exp_list[x][:, 1] |
|
|
2671 |
for x in range(switch+1)])) |
|
|
2672 |
s = np.ravel(np.concatenate([exp_list[x][:, 2] |
|
|
2673 |
for x in range(switch+1)])) |
|
|
2674 |
c_sw = np.ravel(np.concatenate([exp_sw_list[x][:, 0] |
|
|
2675 |
for x in range(switch)])) |
|
|
2676 |
u_sw = np.ravel(np.concatenate([exp_sw_list[x][:, 1] |
|
|
2677 |
for x in range(switch)])) |
|
|
2678 |
s_sw = np.ravel(np.concatenate([exp_sw_list[x][:, 2] |
|
|
2679 |
for x in range(switch)])) |
|
|
2680 |
self.alpha_c, self.alpha, self.beta, self.gamma = \ |
|
|
2681 |
check_params(self.alpha_c, self.alpha, self.beta, self.gamma) |
|
|
2682 |
vc, vu, vs = compute_velocity(anchor_time, |
|
|
2683 |
self.t_sw_array, |
|
|
2684 |
None, |
|
|
2685 |
self.alpha_c, |
|
|
2686 |
self.alpha, |
|
|
2687 |
self.beta, |
|
|
2688 |
self.gamma, |
|
|
2689 |
self.rescale_c, |
|
|
2690 |
self.rescale_u, |
|
|
2691 |
scale_cc=self.scale_cc, |
|
|
2692 |
model=self.model, |
|
|
2693 |
rna_only=self.rna_only) |
|
|
2694 |
|
|
|
2695 |
# scale and shift back to original scale |
|
|
2696 |
c_ = c * self.scale_c + self.offset_c |
|
|
2697 |
u_ = u * self.scale_u + self.offset_u |
|
|
2698 |
s_ = s * self.scale_s + self.offset_s |
|
|
2699 |
c_sw_ = c_sw * self.scale_c + self.offset_c |
|
|
2700 |
u_sw_ = u_sw * self.scale_u + self.offset_u |
|
|
2701 |
s_sw_ = s_sw * self.scale_s + self.offset_s |
|
|
2702 |
vc = vc * self.scale_c |
|
|
2703 |
vu = vu * self.scale_u |
|
|
2704 |
vs = vs * self.scale_s |
|
|
2705 |
|
|
|
2706 |
self.anchor_exp = np.empty((len(u_), 3)) |
|
|
2707 |
self.anchor_exp[:, 0], self.anchor_exp[:, 1], self.anchor_exp[:, 2] = \ |
|
|
2708 |
c_, u_, s_ |
|
|
2709 |
self.anchor_exp_sw = np.empty((len(u_sw_), 3)) |
|
|
2710 |
self.anchor_exp_sw[:, 0], self.anchor_exp_sw[:, 1], \ |
|
|
2711 |
self.anchor_exp_sw[:, 2] = c_sw_, u_sw_, s_sw_ |
|
|
2712 |
self.anchor_velo = np.empty((len(u_), 3)) |
|
|
2713 |
self.anchor_velo[:, 0] = vc |
|
|
2714 |
self.anchor_velo[:, 1] = vu |
|
|
2715 |
self.anchor_velo[:, 2] = vs |
|
|
2716 |
self.anchor_velo_min_idx = np.sum(anchor_time < np.min(new_time)) |
|
|
2717 |
self.anchor_velo_max_idx = np.sum(anchor_time < np.max(new_time)) - 1 |
|
|
2718 |
|
|
|
2719 |
if self.save_plot: |
|
|
2720 |
logg.update('saving plots..', v=1) |
|
|
2721 |
self.save_dyn_plot(c_, u_, s_, c_sw_, u_sw_, s_sw_, tau_list) |
|
|
2722 |
|
|
|
2723 |
self.realign_time_and_velocity(c, u, s, anchor_time) |
|
|
2724 |
|
|
|
2725 |
logg.update(f'final params:\nswitch time array = {self.t_sw_array},\n' |
|
|
2726 |
f'rates = {self.rates},\ncc scale = {self.scale_cc},\n' |
|
|
2727 |
f'c rescale factor = {self.rescale_c},\n' |
|
|
2728 |
f'u rescale factor = {self.rescale_u}', |
|
|
2729 |
v=1) |
|
|
2730 |
logg.update(f'final loss: {self.loss[-1]}', v=1) |
|
|
2731 |
logg.update(f'final likelihood: {self.likelihood}', v=1) |
|
|
2732 |
|
|
|
2733 |
return self.loss |
|
|
2734 |
|
|
|
2735 |
# the adam algorithm |
|
|
2736 |
# NOTE: The starting point for this function was an excample on the |
|
|
2737 |
# GeeksForGeeks website. The particular article is linked below: |
|
|
2738 |
# www.geeksforgeeks.org/how-to-implement-adam-gradient-descent-from-scratch-using-python/ |
|
|
2739 |
def AdamMin(self, x, n_iter, tol, eps=1e-8): |
|
|
2740 |
|
|
|
2741 |
n = len(x) |
|
|
2742 |
|
|
|
2743 |
x_ten = torch.tensor(x, requires_grad=True, device=self.device, |
|
|
2744 |
dtype=self.torch_type) |
|
|
2745 |
|
|
|
2746 |
# record lowest loss as a benchmark |
|
|
2747 |
# (right now the lowest loss is the current loss) |
|
|
2748 |
lowest_loss = torch.tensor(np.array(self.loss[-1], dtype=self.u.dtype), |
|
|
2749 |
device=self.device, |
|
|
2750 |
dtype=self.torch_type) |
|
|
2751 |
|
|
|
2752 |
# record the tensor of the parameters that cause the lowest loss |
|
|
2753 |
lowest_x_ten = x_ten |
|
|
2754 |
|
|
|
2755 |
# the m and v variables used in the adam calculations |
|
|
2756 |
m = torch.zeros(n, device=self.device, requires_grad=True, |
|
|
2757 |
dtype=self.torch_type) |
|
|
2758 |
v = torch.zeros(n, device=self.device, requires_grad=True, |
|
|
2759 |
dtype=self.torch_type) |
|
|
2760 |
|
|
|
2761 |
# the update amount to add to the x tensor after the appropriate |
|
|
2762 |
# calculations are made |
|
|
2763 |
u = torch.ones(n, device=self.device, requires_grad=True, |
|
|
2764 |
dtype=self.torch_type) * float("inf") |
|
|
2765 |
|
|
|
2766 |
# how many times the new loss is lower than the lowest loss |
|
|
2767 |
update_count = 0 |
|
|
2768 |
|
|
|
2769 |
iterations = 0 |
|
|
2770 |
|
|
|
2771 |
# run the gradient descent updates |
|
|
2772 |
for t in range(n_iter): |
|
|
2773 |
|
|
|
2774 |
iterations += 1 |
|
|
2775 |
|
|
|
2776 |
# calculate the loss |
|
|
2777 |
loss = self.mse_ten(x_ten) |
|
|
2778 |
|
|
|
2779 |
# if the loss is lower than the lowest loss... |
|
|
2780 |
if loss < lowest_loss: |
|
|
2781 |
|
|
|
2782 |
# record the new best tensor |
|
|
2783 |
lowest_x_ten = x_ten |
|
|
2784 |
update_count += 1 |
|
|
2785 |
|
|
|
2786 |
# if the percentage difference in x tensors and loss values |
|
|
2787 |
# is less than the tolerance parameter and we've update the |
|
|
2788 |
# loss 3 times by now... |
|
|
2789 |
if torch.all((torch.abs(u) / lowest_x_ten) < tol) and \ |
|
|
2790 |
(torch.abs(loss - lowest_loss) / lowest_loss) < tol and \ |
|
|
2791 |
update_count >= 3: |
|
|
2792 |
|
|
|
2793 |
# ...we've updated enough. Break! |
|
|
2794 |
break |
|
|
2795 |
|
|
|
2796 |
# record the new lowest loss |
|
|
2797 |
lowest_loss = loss |
|
|
2798 |
|
|
|
2799 |
# take the gradient of mse w/r/t our current parameter values |
|
|
2800 |
loss.backward(inputs=x_ten) |
|
|
2801 |
g = x_ten.grad |
|
|
2802 |
|
|
|
2803 |
# calculate the new update value using the Adam formula |
|
|
2804 |
m = (self.adam_beta1 * m) + ((1.0 - self.adam_beta1) * g) |
|
|
2805 |
v = (self.adam_beta2 * v) + ((1.0 - self.adam_beta2) * g * g) |
|
|
2806 |
|
|
|
2807 |
mhat = m / (1.0 - (self.adam_beta1**(t+1))) |
|
|
2808 |
vhat = v / (1.0 - (self.adam_beta2**(t+1))) |
|
|
2809 |
|
|
|
2810 |
u = -(self.adam_lr * mhat) / (torch.sqrt(vhat) + eps) |
|
|
2811 |
|
|
|
2812 |
# update the x tensor |
|
|
2813 |
x_ten = x_ten + u |
|
|
2814 |
|
|
|
2815 |
# as long as we've found at least one better x tensor... |
|
|
2816 |
if update_count > 1: |
|
|
2817 |
|
|
|
2818 |
# record the final lowest loss |
|
|
2819 |
if loss < lowest_loss: |
|
|
2820 |
lowest_loss = loss |
|
|
2821 |
|
|
|
2822 |
# set the new loss for the gene to the new lowest loss |
|
|
2823 |
self.cur_loss = lowest_loss.item() |
|
|
2824 |
|
|
|
2825 |
# use the update() function so the gene's parameters |
|
|
2826 |
# are the new best one we found |
|
|
2827 |
updated = self.update(lowest_x_ten.cpu().detach().numpy()) |
|
|
2828 |
|
|
|
2829 |
# if we never found a better x tensor, then the return value should |
|
|
2830 |
# state that we did not update it |
|
|
2831 |
else: |
|
|
2832 |
updated = False |
|
|
2833 |
|
|
|
2834 |
# return whether we updated the x tensor or not |
|
|
2835 |
return updated |
|
|
2836 |
|
|
|
2837 |
def fit_dyn(self): |
|
|
2838 |
|
|
|
2839 |
while self.cur_iter < self.max_iter: |
|
|
2840 |
self.cur_iter += 1 |
|
|
2841 |
|
|
|
2842 |
# RNA-only |
|
|
2843 |
if self.rna_only: |
|
|
2844 |
logg.update('Nelder Mead on t_sw_2 and alpha..', v=2) |
|
|
2845 |
self.fitting_flag_ = 0 |
|
|
2846 |
if self.cur_iter == 1: |
|
|
2847 |
var_test = (self.alpha + |
|
|
2848 |
np.array([-2, -1, -0.5, 0.5, 1, 2]) * 0.1 |
|
|
2849 |
* self.alpha) |
|
|
2850 |
new_params = self.params.copy() |
|
|
2851 |
for var in var_test: |
|
|
2852 |
new_params[4] = var |
|
|
2853 |
self.update(new_params, adjust_time=False, |
|
|
2854 |
penalize_gap=False) |
|
|
2855 |
res = minimize(self.mse, x0=[self.params[1], self.params[4]], |
|
|
2856 |
method='Nelder-Mead', tol=1e-2, |
|
|
2857 |
callback=self.update, options={'maxiter': 3}) |
|
|
2858 |
|
|
|
2859 |
if self.fit_rescale: |
|
|
2860 |
logg.update('Nelder Mead on t_sw_2, beta, and rescale u..', |
|
|
2861 |
v=2) |
|
|
2862 |
res = minimize(self.mse, x0=[self.params[1], |
|
|
2863 |
self.params[5], |
|
|
2864 |
self.params[9]], |
|
|
2865 |
method='Nelder-Mead', tol=1e-2, |
|
|
2866 |
callback=self.update, |
|
|
2867 |
options={'maxiter': 5}) |
|
|
2868 |
|
|
|
2869 |
logg.update('Nelder Mead on alpha and gamma..', v=2) |
|
|
2870 |
self.fitting_flag_ = 1 |
|
|
2871 |
res = minimize(self.mse, x0=[self.params[4], self.params[6]], |
|
|
2872 |
method='Nelder-Mead', tol=1e-2, |
|
|
2873 |
callback=self.update, options={'maxiter': 3}) |
|
|
2874 |
|
|
|
2875 |
logg.update('Nelder Mead on t_sw_2..', v=2) |
|
|
2876 |
res = minimize(self.mse, x0=[self.params[1]], |
|
|
2877 |
method='Nelder-Mead', tol=1e-2, |
|
|
2878 |
callback=self.update, options={'maxiter': 2}) |
|
|
2879 |
|
|
|
2880 |
logg.update('Full Nelder Mead..', v=2) |
|
|
2881 |
res = minimize(self.mse, x0=[self.params[1], self.params[4], |
|
|
2882 |
self.params[5], self.params[6]], |
|
|
2883 |
method='Nelder-Mead', tol=1e-2, |
|
|
2884 |
callback=self.update, options={'maxiter': 5}) |
|
|
2885 |
|
|
|
2886 |
# chromatin-RNA |
|
|
2887 |
else: |
|
|
2888 |
|
|
|
2889 |
if not self.adam: |
|
|
2890 |
logg.update('Nelder Mead on t_sw_1, chromatin switch time,' |
|
|
2891 |
'and alpha_c..', v=2) |
|
|
2892 |
self.fitting_flag_ = 1 |
|
|
2893 |
if self.cur_iter == 1: |
|
|
2894 |
var_test = (self.gamma + np.array([-1, -0.5, 0.5, 1]) |
|
|
2895 |
* 0.1 * self.gamma) |
|
|
2896 |
new_params = self.params.copy() |
|
|
2897 |
for var in var_test: |
|
|
2898 |
new_params[6] = var |
|
|
2899 |
self.update(new_params, adjust_time=False) |
|
|
2900 |
if self.model == 0 or self.model == 1: |
|
|
2901 |
res = minimize(self.mse, x0=[self.params[0], |
|
|
2902 |
self.params[1], |
|
|
2903 |
self.params[3]], |
|
|
2904 |
method='Nelder-Mead', tol=1e-2, |
|
|
2905 |
callback=self.update, |
|
|
2906 |
options={'maxiter': 20}) |
|
|
2907 |
elif self.model == 2: |
|
|
2908 |
res = minimize(self.mse, x0=[self.params[0], |
|
|
2909 |
self.params[2], |
|
|
2910 |
self.params[3]], |
|
|
2911 |
method='Nelder-Mead', tol=1e-2, |
|
|
2912 |
callback=self.update, |
|
|
2913 |
options={'maxiter': 20}) |
|
|
2914 |
|
|
|
2915 |
logg.update('Nelder Mead on chromatin switch time,' |
|
|
2916 |
'chromatin closing rate scaling, and rescale' |
|
|
2917 |
'c..', v=2) |
|
|
2918 |
self.fitting_flag_ = 2 |
|
|
2919 |
if self.model == 0 or self.model == 1: |
|
|
2920 |
res = minimize(self.mse, x0=[self.params[1], |
|
|
2921 |
self.params[7], |
|
|
2922 |
self.params[8]], |
|
|
2923 |
method='Nelder-Mead', tol=1e-2, |
|
|
2924 |
callback=self.update, |
|
|
2925 |
options={'maxiter': 20}) |
|
|
2926 |
elif self.model == 2: |
|
|
2927 |
res = minimize(self.mse, x0=[self.params[2], |
|
|
2928 |
self.params[7], |
|
|
2929 |
self.params[8]], |
|
|
2930 |
method='Nelder-Mead', tol=1e-2, |
|
|
2931 |
callback=self.update, |
|
|
2932 |
options={'maxiter': 20}) |
|
|
2933 |
|
|
|
2934 |
logg.update('Nelder Mead on rna switch time and alpha..', |
|
|
2935 |
v=2) |
|
|
2936 |
self.fitting_flag_ = 1 |
|
|
2937 |
if self.model == 0 or self.model == 1: |
|
|
2938 |
res = minimize(self.mse, x0=[self.params[2], |
|
|
2939 |
self.params[4]], |
|
|
2940 |
method='Nelder-Mead', tol=1e-2, |
|
|
2941 |
callback=self.update, |
|
|
2942 |
options={'maxiter': 10}) |
|
|
2943 |
elif self.model == 2: |
|
|
2944 |
res = minimize(self.mse, x0=[self.params[1], |
|
|
2945 |
self.params[4]], |
|
|
2946 |
method='Nelder-Mead', tol=1e-2, |
|
|
2947 |
callback=self.update, |
|
|
2948 |
options={'maxiter': 10}) |
|
|
2949 |
|
|
|
2950 |
logg.update('Nelder Mead on rna switch time, beta, and ' |
|
|
2951 |
'rescale u..', v=2) |
|
|
2952 |
self.fitting_flag_ = 3 |
|
|
2953 |
if self.model == 0 or self.model == 1: |
|
|
2954 |
res = minimize(self.mse, x0=[self.params[2], |
|
|
2955 |
self.params[5], |
|
|
2956 |
self.params[9]], |
|
|
2957 |
method='Nelder-Mead', tol=1e-2, |
|
|
2958 |
callback=self.update, |
|
|
2959 |
options={'maxiter': 20}) |
|
|
2960 |
elif self.model == 2: |
|
|
2961 |
res = minimize(self.mse, x0=[self.params[1], |
|
|
2962 |
self.params[5], |
|
|
2963 |
self.params[9]], |
|
|
2964 |
method='Nelder-Mead', tol=1e-2, |
|
|
2965 |
callback=self.update, |
|
|
2966 |
options={'maxiter': 20}) |
|
|
2967 |
|
|
|
2968 |
logg.update('Nelder Mead on alpha and gamma..', v=2) |
|
|
2969 |
self.fitting_flag_ = 2 |
|
|
2970 |
res = minimize(self.mse, x0=[self.params[4], |
|
|
2971 |
self.params[6]], |
|
|
2972 |
method='Nelder-Mead', tol=1e-2, |
|
|
2973 |
callback=self.update, |
|
|
2974 |
options={'maxiter': 10}) |
|
|
2975 |
|
|
|
2976 |
logg.update('Nelder Mead on t_sw..', v=2) |
|
|
2977 |
self.fitting_flag_ = 4 |
|
|
2978 |
res = minimize(self.mse, x0=self.params[:3], |
|
|
2979 |
method='Nelder-Mead', tol=1e-2, |
|
|
2980 |
callback=self.update, |
|
|
2981 |
options={'maxiter': 20}) |
|
|
2982 |
|
|
|
2983 |
else: |
|
|
2984 |
|
|
|
2985 |
logg.update('Adam on all parameters', v=2) |
|
|
2986 |
self.AdamMin(np.array(self.params, dtype=self.u.dtype), 20, |
|
|
2987 |
tol=1e-2) |
|
|
2988 |
|
|
|
2989 |
logg.update('Nelder Mead on t_sw..', v=2) |
|
|
2990 |
self.fitting_flag_ = 4 |
|
|
2991 |
res = minimize(self.mse, x0=self.params[:3], |
|
|
2992 |
method='Nelder-Mead', tol=1e-2, |
|
|
2993 |
callback=self.update, |
|
|
2994 |
options={'maxiter': 15}) |
|
|
2995 |
|
|
|
2996 |
logg.update(f'iteration {self.cur_iter} finished', v=2) |
|
|
2997 |
|
|
|
2998 |
def _variables(self, x): |
|
|
2999 |
scale_cc = self.scale_cc |
|
|
3000 |
rescale_c = self.rescale_c |
|
|
3001 |
rescale_u = self.rescale_u |
|
|
3002 |
|
|
|
3003 |
# RNA-only |
|
|
3004 |
if self.rna_only: |
|
|
3005 |
if len(x) == 1: # fit t_sw_2 |
|
|
3006 |
t3 = np.array([self.t_sw_1, x[0], |
|
|
3007 |
self.t_sw_3 - self.t_sw_1 - x[0]]) |
|
|
3008 |
r4 = self.rates |
|
|
3009 |
|
|
|
3010 |
elif len(x) == 2: |
|
|
3011 |
if self.fitting_flag_: # fit alpha and gamma |
|
|
3012 |
t3 = self.params[:3] |
|
|
3013 |
r4 = np.array([self.alpha_c, x[0], self.beta, x[1]]) |
|
|
3014 |
else: # fit t_sw_2 and alpha |
|
|
3015 |
t3 = np.array([self.t_sw_1, x[0], |
|
|
3016 |
self.t_sw_3 - self.t_sw_1 - x[0]]) |
|
|
3017 |
r4 = np.array([self.alpha_c, x[1], self.beta, self.gamma]) |
|
|
3018 |
|
|
|
3019 |
elif len(x) == 3: # fit t_sw_2, beta, and rescale u |
|
|
3020 |
t3 = np.array([self.t_sw_1, |
|
|
3021 |
x[0], self.t_sw_3 - self.t_sw_1 - x[0]]) |
|
|
3022 |
r4 = np.array([self.alpha_c, self.alpha, x[1], self.gamma]) |
|
|
3023 |
rescale_u = x[2] |
|
|
3024 |
|
|
|
3025 |
elif len(x) == 4: # fit all |
|
|
3026 |
t3 = np.array([self.t_sw_1, x[0], self.t_sw_3 - self.t_sw_1 |
|
|
3027 |
- x[0]]) |
|
|
3028 |
r4 = np.array([self.alpha_c, x[1], x[2], x[3]]) |
|
|
3029 |
|
|
|
3030 |
elif len(x) == 10: # all available |
|
|
3031 |
t3 = x[:3] |
|
|
3032 |
r4 = x[3:7] |
|
|
3033 |
scale_cc = x[7] |
|
|
3034 |
rescale_c = x[8] |
|
|
3035 |
rescale_u = x[9] |
|
|
3036 |
|
|
|
3037 |
else: |
|
|
3038 |
return |
|
|
3039 |
|
|
|
3040 |
# chromatin-RNA |
|
|
3041 |
else: |
|
|
3042 |
|
|
|
3043 |
if len(x) == 2: |
|
|
3044 |
if self.fitting_flag_ == 1: # fit rna switch time and alpha |
|
|
3045 |
if self.model == 0 or self.model == 1: |
|
|
3046 |
t3 = np.array([self.t_sw_1, self.params[1], x[0]]) |
|
|
3047 |
elif self.model == 2: |
|
|
3048 |
t3 = np.array([self.t_sw_1, x[0], |
|
|
3049 |
self.t_sw_3 - self.t_sw_1 - x[0]]) |
|
|
3050 |
r4 = np.array([self.alpha_c, x[1], self.beta, self.gamma]) |
|
|
3051 |
elif self.fitting_flag_ == 2: # fit alpha and gamma |
|
|
3052 |
t3 = self.params[:3] |
|
|
3053 |
r4 = np.array([self.alpha_c, x[0], self.beta, x[1]]) |
|
|
3054 |
|
|
|
3055 |
elif len(x) == 3: |
|
|
3056 |
# fit t_sw_1, chromatin switch time, and alpha_c |
|
|
3057 |
if self.fitting_flag_ == 1: |
|
|
3058 |
if self.model == 0 or self.model == 1: |
|
|
3059 |
t3 = np.array([x[0], x[1], self.t_sw_3 - x[0] - x[1]]) |
|
|
3060 |
elif self.model == 2: |
|
|
3061 |
t3 = np.array([x[0], self.t_sw_2 - x[0], x[1]]) |
|
|
3062 |
r4 = np.array([x[2], self.alpha, self.beta, self.gamma]) |
|
|
3063 |
# fit chromatin switch time, chromatin closing rate scaling, |
|
|
3064 |
# and rescale c |
|
|
3065 |
elif self.fitting_flag_ == 2: |
|
|
3066 |
if self.model == 0 or self.model == 1: |
|
|
3067 |
t3 = np.array([self.t_sw_1, x[0], |
|
|
3068 |
self.t_sw_3 - self.t_sw_1 - x[0]]) |
|
|
3069 |
elif self.model == 2: |
|
|
3070 |
t3 = np.array([self.t_sw_1, self.params[1], x[0]]) |
|
|
3071 |
r4 = self.rates |
|
|
3072 |
scale_cc = x[1] |
|
|
3073 |
rescale_c = x[2] |
|
|
3074 |
# fit rna switch time, beta, and rescale u |
|
|
3075 |
elif self.fitting_flag_ == 3: |
|
|
3076 |
if self.model == 0 or self.model == 1: |
|
|
3077 |
t3 = np.array([self.t_sw_1, self.params[1], x[0]]) |
|
|
3078 |
elif self.model == 2: |
|
|
3079 |
t3 = np.array([self.t_sw_1, x[0], |
|
|
3080 |
self.t_sw_3 - self.t_sw_1 - x[0]]) |
|
|
3081 |
r4 = np.array([self.alpha_c, self.alpha, x[1], self.gamma]) |
|
|
3082 |
rescale_u = x[2] |
|
|
3083 |
# fit three switch times |
|
|
3084 |
elif self.fitting_flag_ == 4: |
|
|
3085 |
t3 = x |
|
|
3086 |
r4 = self.rates |
|
|
3087 |
|
|
|
3088 |
elif len(x) == 7: |
|
|
3089 |
t3 = x[:3] |
|
|
3090 |
r4 = x[3:] |
|
|
3091 |
|
|
|
3092 |
elif len(x) == 10: |
|
|
3093 |
t3 = x[:3] |
|
|
3094 |
r4 = x[3:7] |
|
|
3095 |
scale_cc = x[7] |
|
|
3096 |
rescale_c = x[8] |
|
|
3097 |
rescale_u = x[9] |
|
|
3098 |
|
|
|
3099 |
else: |
|
|
3100 |
return |
|
|
3101 |
|
|
|
3102 |
# clip to meaningful values |
|
|
3103 |
if self.fitting_flag_ and not self.adam: |
|
|
3104 |
scale_cc = np.clip(scale_cc, |
|
|
3105 |
np.max([0.5*self.scale_cc, 0.25]), |
|
|
3106 |
np.min([2*self.scale_cc, 4])) |
|
|
3107 |
|
|
|
3108 |
if not self.known_pars: |
|
|
3109 |
if self.fit_decoupling: |
|
|
3110 |
t3 = np.clip(t3, 0.1, None) |
|
|
3111 |
else: |
|
|
3112 |
t3[2] = 30 / self.n_anchors |
|
|
3113 |
t3[:2] = np.clip(t3[:2], 0.1, None) |
|
|
3114 |
r4 = np.clip(r4, 0.001, 1000) |
|
|
3115 |
rescale_c = np.clip(rescale_c, 0.75, 1.5) |
|
|
3116 |
rescale_u = np.clip(rescale_u, 0.2, 3) |
|
|
3117 |
|
|
|
3118 |
return t3, r4, scale_cc, rescale_c, rescale_u |
|
|
3119 |
|
|
|
3120 |
# the tensor version of the calculate_dist_and_time function |
|
|
3121 |
def calculate_dist_and_time_ten(self, |
|
|
3122 |
c, u, s, |
|
|
3123 |
t_sw_array, |
|
|
3124 |
alpha_c, alpha, beta, gamma, |
|
|
3125 |
rescale_c, rescale_u, |
|
|
3126 |
scale_cc=1, |
|
|
3127 |
scale_factor=None, |
|
|
3128 |
model=1, |
|
|
3129 |
conn=None, |
|
|
3130 |
t=1000, k=1, |
|
|
3131 |
direction='complete', |
|
|
3132 |
total_h=20, |
|
|
3133 |
rna_only=False, |
|
|
3134 |
penalize_gap=True, |
|
|
3135 |
all_cells=True): |
|
|
3136 |
|
|
|
3137 |
conn = torch.tensor(conn.todense(), |
|
|
3138 |
device=self.device, |
|
|
3139 |
dtype=self.torch_type) |
|
|
3140 |
|
|
|
3141 |
c_ten = torch.tensor(c, device=self.device, dtype=self.torch_type) |
|
|
3142 |
u_ten = torch.tensor(u, device=self.device, dtype=self.torch_type) |
|
|
3143 |
s_ten = torch.tensor(s, device=self.device, dtype=self.torch_type) |
|
|
3144 |
|
|
|
3145 |
n = len(u) |
|
|
3146 |
if scale_factor is None: |
|
|
3147 |
scale_factor_ten = torch.stack((torch.std(c_ten), torch.std(u_ten), |
|
|
3148 |
torch.std(s_ten))) |
|
|
3149 |
else: |
|
|
3150 |
scale_factor_ten = torch.tensor(scale_factor, device=self.device, |
|
|
3151 |
dtype=self.torch_type) |
|
|
3152 |
|
|
|
3153 |
tau_list = self.anchor_points_ten(t_sw_array, total_h, t) |
|
|
3154 |
|
|
|
3155 |
switch = torch.sum(t_sw_array < total_h) |
|
|
3156 |
|
|
|
3157 |
exp_list, exp_sw_list = self.generate_exp_tens(tau_list, |
|
|
3158 |
t_sw_array[:switch], |
|
|
3159 |
alpha_c, |
|
|
3160 |
alpha, |
|
|
3161 |
beta, |
|
|
3162 |
gamma, |
|
|
3163 |
model=model, |
|
|
3164 |
scale_cc=scale_cc, |
|
|
3165 |
rna_only=rna_only) |
|
|
3166 |
|
|
|
3167 |
rescale_factor = torch.stack((rescale_c, rescale_u, |
|
|
3168 |
torch.tensor(1.0, device=self.device, |
|
|
3169 |
requires_grad=True, |
|
|
3170 |
dtype=self.torch_type))) |
|
|
3171 |
|
|
|
3172 |
for i in range(len(exp_list)): |
|
|
3173 |
exp_list[i] = exp_list[i]*rescale_factor |
|
|
3174 |
|
|
|
3175 |
if i < len(exp_list)-1: |
|
|
3176 |
exp_sw_list[i] = exp_sw_list[i]*rescale_factor |
|
|
3177 |
|
|
|
3178 |
max_c = 0 |
|
|
3179 |
max_u = 0 |
|
|
3180 |
max_s = 0 |
|
|
3181 |
|
|
|
3182 |
if rna_only: |
|
|
3183 |
exp_mat = (torch.hstack((torch.reshape(u_ten, (-1, 1)), |
|
|
3184 |
torch.reshape(s_ten, (-1, 1)))) |
|
|
3185 |
/ scale_factor_ten[1:]) |
|
|
3186 |
else: |
|
|
3187 |
exp_mat = torch.hstack((torch.reshape(c_ten, (-1, 1)), |
|
|
3188 |
torch.reshape(u_ten, (-1, 1)), |
|
|
3189 |
torch.reshape(s_ten, (-1, 1))))\ |
|
|
3190 |
/ scale_factor_ten |
|
|
3191 |
|
|
|
3192 |
taus = torch.zeros((1, n), device=self.device, |
|
|
3193 |
requires_grad=True, |
|
|
3194 |
dtype=self.torch_type) |
|
|
3195 |
anchor_exp, anchor_t = None, None |
|
|
3196 |
|
|
|
3197 |
dists0 = torch.full((1, n), 0.0 if direction == "on" |
|
|
3198 |
or direction == "complete" else np.inf, |
|
|
3199 |
device=self.device, |
|
|
3200 |
requires_grad=True, |
|
|
3201 |
dtype=self.torch_type) |
|
|
3202 |
dists1 = torch.full((1, n), 0.0 if direction == "on" |
|
|
3203 |
or direction == "complete" else np.inf, |
|
|
3204 |
device=self.device, |
|
|
3205 |
requires_grad=True, |
|
|
3206 |
dtype=self.torch_type) |
|
|
3207 |
dists2 = torch.full((1, n), 0.0 if direction == "off" |
|
|
3208 |
or direction == "complete" else np.inf, |
|
|
3209 |
device=self.device, |
|
|
3210 |
requires_grad=True, |
|
|
3211 |
dtype=self.torch_type) |
|
|
3212 |
dists3 = torch.full((1, n), 0.0 if direction == "off" |
|
|
3213 |
or direction == "complete" else np.inf, |
|
|
3214 |
device=self.device, |
|
|
3215 |
requires_grad=True, |
|
|
3216 |
dtype=self.torch_type) |
|
|
3217 |
|
|
|
3218 |
ts0 = torch.zeros((1, n), device=self.device, |
|
|
3219 |
requires_grad=True, |
|
|
3220 |
dtype=self.torch_type) |
|
|
3221 |
ts1 = torch.zeros((1, n), device=self.device, |
|
|
3222 |
requires_grad=True, |
|
|
3223 |
dtype=self.torch_type) |
|
|
3224 |
ts2 = torch.zeros((1, n), device=self.device, |
|
|
3225 |
requires_grad=True, |
|
|
3226 |
dtype=self.torch_type) |
|
|
3227 |
ts3 = torch.zeros((1, n), device=self.device, |
|
|
3228 |
requires_grad=True, |
|
|
3229 |
dtype=self.torch_type) |
|
|
3230 |
|
|
|
3231 |
for i in range(switch+1): |
|
|
3232 |
|
|
|
3233 |
if not all_cells: |
|
|
3234 |
max_ci = (torch.max(exp_list[i][:, 0]) |
|
|
3235 |
if exp_list[i].shape[0] > 0 |
|
|
3236 |
else 0) |
|
|
3237 |
max_c = max_ci if max_ci > max_c else max_c |
|
|
3238 |
max_ui = torch.max(exp_list[i][:, 1]) if exp_list[i].shape[0] > 0 \ |
|
|
3239 |
else 0 |
|
|
3240 |
max_u = max_ui if max_ui > max_u else max_u |
|
|
3241 |
max_si = torch.max(exp_list[i][:, 2]) if exp_list[i].shape[0] > 0 \ |
|
|
3242 |
else 0 |
|
|
3243 |
max_s = max_si if max_si > max_s else max_s |
|
|
3244 |
|
|
|
3245 |
skip_phase = False |
|
|
3246 |
if direction == 'off': |
|
|
3247 |
if (model in [1, 2]) and (i < 2): |
|
|
3248 |
skip_phase = True |
|
|
3249 |
elif direction == 'on': |
|
|
3250 |
if (model in [1, 2]) and (i >= 2): |
|
|
3251 |
skip_phase = True |
|
|
3252 |
if rna_only and i == 0: |
|
|
3253 |
skip_phase = True |
|
|
3254 |
|
|
|
3255 |
if not skip_phase: |
|
|
3256 |
if rna_only: |
|
|
3257 |
tmp = exp_list[i][:, 1:] / scale_factor_ten[1:] |
|
|
3258 |
else: |
|
|
3259 |
tmp = exp_list[i] / scale_factor_ten |
|
|
3260 |
if anchor_exp is None: |
|
|
3261 |
anchor_exp = exp_list[i] |
|
|
3262 |
anchor_t = (tau_list[i] + t_sw_array[i-1] if i >= 1 |
|
|
3263 |
else tau_list[i]) |
|
|
3264 |
else: |
|
|
3265 |
anchor_exp = torch.vstack((anchor_exp, exp_list[i])) |
|
|
3266 |
anchor_t = torch.hstack((anchor_t, |
|
|
3267 |
tau_list[i] + t_sw_array[i-1] |
|
|
3268 |
if i >= 1 else tau_list[i])) |
|
|
3269 |
|
|
|
3270 |
if not all_cells: |
|
|
3271 |
anchor_prepend_rna = torch.zeros((1, 2), |
|
|
3272 |
device=self.device, |
|
|
3273 |
dtype=self.torch_type) |
|
|
3274 |
anchor_prepend_chrom = torch.zeros((1, 3), |
|
|
3275 |
device=self.device, |
|
|
3276 |
dtype=self.torch_type) |
|
|
3277 |
anchor_dist = torch.diff(tmp, dim=0, |
|
|
3278 |
prepend=anchor_prepend_rna |
|
|
3279 |
if rna_only |
|
|
3280 |
else anchor_prepend_chrom) |
|
|
3281 |
|
|
|
3282 |
anchor_dist = torch.sqrt((anchor_dist*anchor_dist) |
|
|
3283 |
.sum(axis=1)) |
|
|
3284 |
remove_cand = anchor_dist < (0.01*torch.max(exp_mat[1]) |
|
|
3285 |
if rna_only |
|
|
3286 |
else |
|
|
3287 |
0.01*torch.max(exp_mat[2])) |
|
|
3288 |
step_idx = torch.arange(0, anchor_dist.size()[0], 1, |
|
|
3289 |
device=self.device, |
|
|
3290 |
dtype=self.torch_type) % 3 > 0 |
|
|
3291 |
remove_cand &= step_idx |
|
|
3292 |
keep_idx = torch.where(~remove_cand)[0] |
|
|
3293 |
|
|
|
3294 |
tmp = tmp[keep_idx, :] |
|
|
3295 |
|
|
|
3296 |
model = NearestNeighbors(n_neighbors=k, output_type="numpy") |
|
|
3297 |
model.fit(tmp.detach()) |
|
|
3298 |
dd, ii = model.kneighbors(exp_mat.detach()) |
|
|
3299 |
ii = ii.T[0] |
|
|
3300 |
|
|
|
3301 |
new_dd = ((exp_mat[:, 0] - tmp[ii, 0]) |
|
|
3302 |
* (exp_mat[:, 0] - tmp[ii, 0]) |
|
|
3303 |
+ (exp_mat[:, 1] - tmp[ii, 1]) |
|
|
3304 |
* (exp_mat[:, 1] - tmp[ii, 1]) |
|
|
3305 |
+ (exp_mat[:, 2] - tmp[ii, 2]) |
|
|
3306 |
* (exp_mat[:, 2] - tmp[ii, 2])) |
|
|
3307 |
|
|
|
3308 |
if k > 1: |
|
|
3309 |
new_dd = torch.mean(new_dd, dim=1) |
|
|
3310 |
if conn is not None: |
|
|
3311 |
new_dd = torch.matmul(conn, new_dd) |
|
|
3312 |
|
|
|
3313 |
if i == 0: |
|
|
3314 |
dists0 = dists0 + new_dd |
|
|
3315 |
elif i == 1: |
|
|
3316 |
dists1 = dists1 + new_dd |
|
|
3317 |
elif i == 2: |
|
|
3318 |
dists2 = dists2 + new_dd |
|
|
3319 |
elif i == 3: |
|
|
3320 |
dists3 = dists3 + new_dd |
|
|
3321 |
|
|
|
3322 |
if not all_cells: |
|
|
3323 |
ii = keep_idx[ii] |
|
|
3324 |
if k == 1: |
|
|
3325 |
taus = tau_list[i][ii] |
|
|
3326 |
else: |
|
|
3327 |
for j in range(n): |
|
|
3328 |
taus[j] = tau_list[i][ii[j, :]] |
|
|
3329 |
|
|
|
3330 |
if i == 0: |
|
|
3331 |
ts0 = ts0 + taus |
|
|
3332 |
elif i == 1: |
|
|
3333 |
ts1 = ts1 + taus + t_sw_array[0] |
|
|
3334 |
elif i == 2: |
|
|
3335 |
ts2 = ts2 + taus + t_sw_array[1] |
|
|
3336 |
elif i == 3: |
|
|
3337 |
ts3 = ts3 + taus + t_sw_array[2] |
|
|
3338 |
|
|
|
3339 |
dists = torch.cat((dists0, dists1, dists2, dists3), 0) |
|
|
3340 |
|
|
|
3341 |
ts = torch.cat((ts0, ts1, ts2, ts3), 0) |
|
|
3342 |
|
|
|
3343 |
state_pred = torch.argmin(dists, axis=0) |
|
|
3344 |
|
|
|
3345 |
t_pred = ts[state_pred, torch.arange(n, device=self.device)] |
|
|
3346 |
|
|
|
3347 |
anchor_t1_list = [] |
|
|
3348 |
anchor_t2_list = [] |
|
|
3349 |
|
|
|
3350 |
t_sw_adjust = torch.zeros(3, device=self.device, dtype=self.torch_type) |
|
|
3351 |
|
|
|
3352 |
if direction == 'complete': |
|
|
3353 |
|
|
|
3354 |
dist_gap_add = torch.zeros((1, n), device=self.device, |
|
|
3355 |
dtype=self.torch_type) |
|
|
3356 |
|
|
|
3357 |
t_sorted = torch.clone(t_pred) |
|
|
3358 |
t_sorted, t_sorted_indices = torch.sort(t_sorted) |
|
|
3359 |
|
|
|
3360 |
dt = torch.diff(t_sorted, dim=0, |
|
|
3361 |
prepend=torch.zeros(1, device=self.device, |
|
|
3362 |
dtype=self.torch_type)) |
|
|
3363 |
|
|
|
3364 |
gap_thresh = 3*torch.quantile(dt, 0.99) |
|
|
3365 |
|
|
|
3366 |
idx = torch.where(dt > gap_thresh)[0] |
|
|
3367 |
|
|
|
3368 |
if len(idx) > 0 and penalize_gap: |
|
|
3369 |
h_tens = torch.tensor([total_h], device=self.device, |
|
|
3370 |
dtype=self.torch_type) |
|
|
3371 |
|
|
|
3372 |
for i in idx: |
|
|
3373 |
|
|
|
3374 |
t1 = t_sorted[i-1] if i > 0 else 0 |
|
|
3375 |
t2 = t_sorted[i] |
|
|
3376 |
anchor_t1 = anchor_exp[torch.argmin(torch.abs(anchor_t - t1)), |
|
|
3377 |
:] |
|
|
3378 |
anchor_t2 = anchor_exp[torch.argmin(torch.abs(anchor_t - t2)), |
|
|
3379 |
:] |
|
|
3380 |
if all_cells: |
|
|
3381 |
anchor_t1_list.append(torch.ravel(anchor_t1)) |
|
|
3382 |
anchor_t2_list.append(torch.ravel(anchor_t2)) |
|
|
3383 |
if not all_cells: |
|
|
3384 |
for j in range(1, switch): |
|
|
3385 |
crit1 = ((t1 > t_sw_array[j-1]) |
|
|
3386 |
and (t2 > t_sw_array[j-1]) |
|
|
3387 |
and (t1 <= t_sw_array[j]) |
|
|
3388 |
and (t2 <= t_sw_array[j])) |
|
|
3389 |
crit2 = ((torch.abs(anchor_t1[2] |
|
|
3390 |
- exp_sw_list[j][0, 2]) |
|
|
3391 |
< 0.02 * max_s) and |
|
|
3392 |
(torch.abs(anchor_t2[2] |
|
|
3393 |
- exp_sw_list[j][0, 2]) |
|
|
3394 |
< 0.01 * max_s)) |
|
|
3395 |
crit3 = ((torch.abs(anchor_t1[1] |
|
|
3396 |
- exp_sw_list[j][0, 1]) |
|
|
3397 |
< 0.02 * max_u) and |
|
|
3398 |
(torch.abs(anchor_t2[1] |
|
|
3399 |
- exp_sw_list[j][0, 1]) |
|
|
3400 |
< 0.01 * max_u)) |
|
|
3401 |
crit4 = ((torch.abs(anchor_t1[0] |
|
|
3402 |
- exp_sw_list[j][0, 0]) |
|
|
3403 |
< 0.02 * max_c) and |
|
|
3404 |
(torch.abs(anchor_t2[0] |
|
|
3405 |
- exp_sw_list[j][0, 0]) |
|
|
3406 |
< 0.01 * max_c)) |
|
|
3407 |
if crit1 and crit2 and crit3 and crit4: |
|
|
3408 |
t_sw_adjust[j] += t2 - t1 |
|
|
3409 |
if penalize_gap: |
|
|
3410 |
dist_gap = torch.sum(((anchor_t1[1:] - anchor_t2[1:]) / |
|
|
3411 |
scale_factor_ten[1:])**2) |
|
|
3412 |
|
|
|
3413 |
idx_to_adjust = torch.tensor(t_pred >= t2, |
|
|
3414 |
device=self.device) |
|
|
3415 |
|
|
|
3416 |
idx_to_adjust = torch.reshape(idx_to_adjust, |
|
|
3417 |
(1, idx_to_adjust.size()[0])) |
|
|
3418 |
|
|
|
3419 |
true_tensor = torch.tensor([True], device=self.device) |
|
|
3420 |
false_tensor = torch.tensor([False], device=self.device) |
|
|
3421 |
|
|
|
3422 |
t_sw_array_ = torch.cat((t_sw_array, h_tens), dim=0) |
|
|
3423 |
state_to_adjust = torch.where(t_sw_array_ > t2, |
|
|
3424 |
true_tensor, false_tensor) |
|
|
3425 |
|
|
|
3426 |
dist_gap_add[idx_to_adjust] += dist_gap |
|
|
3427 |
|
|
|
3428 |
if state_to_adjust[0].item(): |
|
|
3429 |
dists0 += dist_gap_add |
|
|
3430 |
if state_to_adjust[1].item(): |
|
|
3431 |
dists1 += dist_gap_add |
|
|
3432 |
if state_to_adjust[2].item(): |
|
|
3433 |
dists2 += dist_gap_add |
|
|
3434 |
if state_to_adjust[3].item(): |
|
|
3435 |
dists3 += dist_gap_add |
|
|
3436 |
|
|
|
3437 |
dist_gap_add[idx_to_adjust] -= dist_gap |
|
|
3438 |
|
|
|
3439 |
dists = torch.cat((dists0, dists1, dists2, dists3), 0) |
|
|
3440 |
|
|
|
3441 |
state_pred = torch.argmin(dists, dim=0) |
|
|
3442 |
|
|
|
3443 |
if all_cells: |
|
|
3444 |
t_pred = ts[torch.arange(n, device=self.device), state_pred] |
|
|
3445 |
|
|
|
3446 |
min_dist = torch.min(dists, dim=0).values |
|
|
3447 |
|
|
|
3448 |
if all_cells: |
|
|
3449 |
exp_ss_mat = compute_ss_exp(alpha_c, alpha, beta, gamma, |
|
|
3450 |
model=model) |
|
|
3451 |
if rna_only: |
|
|
3452 |
exp_ss_mat[:, 0] = 1 |
|
|
3453 |
dists_ss = pairwise_distance_square(exp_mat, exp_ss_mat * |
|
|
3454 |
rescale_factor / scale_factor) |
|
|
3455 |
|
|
|
3456 |
reach_ss = np.full((n, 4), False) |
|
|
3457 |
for i in range(n): |
|
|
3458 |
for j in range(4): |
|
|
3459 |
if min_dist[i] > dists_ss[i, j]: |
|
|
3460 |
reach_ss[i, j] = True |
|
|
3461 |
late_phase = np.full(n, -1) |
|
|
3462 |
for i in range(3): |
|
|
3463 |
late_phase[torch.abs(t_pred - t_sw_array[i]) < 0.1] = i |
|
|
3464 |
|
|
|
3465 |
return min_dist, t_pred, state_pred.cpu().detach().numpy(), \ |
|
|
3466 |
reach_ss, late_phase, max_u, max_s, anchor_t1_list, \ |
|
|
3467 |
anchor_t2_list |
|
|
3468 |
|
|
|
3469 |
else: |
|
|
3470 |
return min_dist, state_pred.cpu().detach().numpy(), max_u, max_s, \ |
|
|
3471 |
t_sw_adjust.cpu().detach().numpy() |
|
|
3472 |
|
|
|
3473 |
# the torch tensor version of the mse function |
|
|
3474 |
def mse_ten(self, x, fit_outlier=False, |
|
|
3475 |
penalize_gap=True): |
|
|
3476 |
|
|
|
3477 |
t3 = x[:3] |
|
|
3478 |
r4 = x[3:7] |
|
|
3479 |
scale_cc = x[7] |
|
|
3480 |
rescale_c = x[8] |
|
|
3481 |
rescale_u = x[9] |
|
|
3482 |
|
|
|
3483 |
if not self.known_pars: |
|
|
3484 |
if self.fit_decoupling: |
|
|
3485 |
t3 = torch.clip(t3, 0.1, None) |
|
|
3486 |
else: |
|
|
3487 |
t3[2] = 30 / self.n_anchors |
|
|
3488 |
t3[:2] = torch.clip(t3[:2], 0.1, None) |
|
|
3489 |
r4 = torch.clip(r4, 0.001, 1000) |
|
|
3490 |
rescale_c = torch.clip(rescale_c, 0.75, 1.5) |
|
|
3491 |
rescale_u = torch.clip(rescale_u, 0.2, 3) |
|
|
3492 |
|
|
|
3493 |
t_sw_array = torch.cumsum(t3, dim=0) |
|
|
3494 |
|
|
|
3495 |
if self.rna_only: |
|
|
3496 |
t_sw_array[2] = 20 |
|
|
3497 |
|
|
|
3498 |
# conditions for minimum switch time and rate params |
|
|
3499 |
penalty = 0 |
|
|
3500 |
if any(t3 < 0.2) or any(r4 < 0.005): |
|
|
3501 |
penalty = (torch.sum(0.2 - t3[t3 < 0.2]) if self.fit_decoupling |
|
|
3502 |
else torch.sum(0.2 - t3[:2][t3[:2] < 0.2])) |
|
|
3503 |
penalty += torch.sum(0.005 - r4[r4 < 0.005]) * 1e2 |
|
|
3504 |
|
|
|
3505 |
# condition for all params |
|
|
3506 |
if any(x > 500): |
|
|
3507 |
penalty = torch.sum(x[x > 500] - 500) * 1e-2 |
|
|
3508 |
|
|
|
3509 |
c_array = self.c_all if fit_outlier else self.c |
|
|
3510 |
u_array = self.u_all if fit_outlier else self.u |
|
|
3511 |
s_array = self.s_all if fit_outlier else self.s |
|
|
3512 |
|
|
|
3513 |
if self.batch_size is not None and self.batch_size < len(c_array): |
|
|
3514 |
|
|
|
3515 |
subset_choice = np.random.choice(len(c_array), self.batch_size, |
|
|
3516 |
replace=False) |
|
|
3517 |
|
|
|
3518 |
c_array = c_array[subset_choice] |
|
|
3519 |
u_array = u_array[subset_choice] |
|
|
3520 |
s_array = s_array[subset_choice] |
|
|
3521 |
|
|
|
3522 |
if fit_outlier: |
|
|
3523 |
conn_for_calc = self.conn[subset_choice] |
|
|
3524 |
if not fit_outlier: |
|
|
3525 |
conn_for_calc = self.conn_sub[subset_choice] |
|
|
3526 |
|
|
|
3527 |
conn_for_calc = ((conn_for_calc.T)[subset_choice]).T |
|
|
3528 |
|
|
|
3529 |
else: |
|
|
3530 |
|
|
|
3531 |
if fit_outlier: |
|
|
3532 |
conn_for_calc = self.conn |
|
|
3533 |
if not fit_outlier: |
|
|
3534 |
conn_for_calc = self.conn_sub |
|
|
3535 |
|
|
|
3536 |
scale_factor_func = np.array(self.scale_factor, dtype=self.u.dtype) |
|
|
3537 |
|
|
|
3538 |
# distances and time assignments |
|
|
3539 |
res = self.calculate_dist_and_time_ten(c_array, |
|
|
3540 |
u_array, |
|
|
3541 |
s_array, |
|
|
3542 |
t_sw_array, |
|
|
3543 |
r4[0], |
|
|
3544 |
r4[1], |
|
|
3545 |
r4[2], |
|
|
3546 |
r4[3], |
|
|
3547 |
rescale_c, |
|
|
3548 |
rescale_u, |
|
|
3549 |
scale_cc=scale_cc, |
|
|
3550 |
scale_factor=scale_factor_func, |
|
|
3551 |
model=self.model, |
|
|
3552 |
direction=self.direction, |
|
|
3553 |
conn=conn_for_calc, |
|
|
3554 |
k=self.k_dist, |
|
|
3555 |
t=self.n_anchors, |
|
|
3556 |
rna_only=self.rna_only, |
|
|
3557 |
penalize_gap=penalize_gap, |
|
|
3558 |
all_cells=fit_outlier) |
|
|
3559 |
|
|
|
3560 |
if fit_outlier: |
|
|
3561 |
min_dist, t_pred, state_pred, reach_ss, late_phase, max_u, max_s, \ |
|
|
3562 |
self.anchor_t1_list, self.anchor_t2_list = res |
|
|
3563 |
else: |
|
|
3564 |
min_dist, state_pred, max_u, max_s, t_sw_adjust = res |
|
|
3565 |
|
|
|
3566 |
loss = torch.mean(min_dist) |
|
|
3567 |
|
|
|
3568 |
# avoid exceeding maximum expressions |
|
|
3569 |
reg = torch.max(torch.tensor([0, max_s - torch.tensor(self.max_s)], |
|
|
3570 |
requires_grad=True, |
|
|
3571 |
dtype=self.torch_type))\ |
|
|
3572 |
+ torch.max(torch.tensor([0, max_u - torch.tensor(self.max_u)], |
|
|
3573 |
requires_grad=True, |
|
|
3574 |
dtype=self.torch_type)) |
|
|
3575 |
|
|
|
3576 |
loss += reg |
|
|
3577 |
|
|
|
3578 |
loss += 1e-1 * penalty |
|
|
3579 |
|
|
|
3580 |
self.cur_loss = loss.item() |
|
|
3581 |
self.cur_state_pred = state_pred |
|
|
3582 |
|
|
|
3583 |
if fit_outlier: |
|
|
3584 |
return loss, t_pred |
|
|
3585 |
else: |
|
|
3586 |
self.cur_t_sw_adjust = t_sw_adjust |
|
|
3587 |
|
|
|
3588 |
return loss |
|
|
3589 |
|
|
|
3590 |
def mse(self, x, fit_outlier=False, penalize_gap=True): |
|
|
3591 |
x = np.array(x) |
|
|
3592 |
|
|
|
3593 |
t3, r4, scale_cc, rescale_c, rescale_u = self._variables(x) |
|
|
3594 |
|
|
|
3595 |
t_sw_array = np.array([t3[0], t3[0]+t3[1], t3[0]+t3[1]+t3[2]]) |
|
|
3596 |
if self.rna_only: |
|
|
3597 |
t_sw_array[2] = 20 |
|
|
3598 |
|
|
|
3599 |
# conditions for minimum switch time and rate params |
|
|
3600 |
penalty = 0 |
|
|
3601 |
if any(t3 < 0.2) or any(r4 < 0.005): |
|
|
3602 |
penalty = (np.sum(0.2 - t3[t3 < 0.2]) if self.fit_decoupling |
|
|
3603 |
else np.sum(0.2 - t3[:2][t3[:2] < 0.2])) |
|
|
3604 |
penalty += np.sum(0.005 - r4[r4 < 0.005]) * 1e2 |
|
|
3605 |
|
|
|
3606 |
# condition for all params |
|
|
3607 |
if any(x > 500): |
|
|
3608 |
penalty = np.sum(x[x > 500] - 500) * 1e-2 |
|
|
3609 |
|
|
|
3610 |
c_array = self.c_all if fit_outlier else self.c |
|
|
3611 |
u_array = self.u_all if fit_outlier else self.u |
|
|
3612 |
s_array = self.s_all if fit_outlier else self.s |
|
|
3613 |
|
|
|
3614 |
if self.neural_net: |
|
|
3615 |
|
|
|
3616 |
res = calculate_dist_and_time_nn(c_array, |
|
|
3617 |
u_array, |
|
|
3618 |
s_array, |
|
|
3619 |
self.max_u_all if fit_outlier |
|
|
3620 |
else self.max_u, |
|
|
3621 |
self.max_s_all if fit_outlier |
|
|
3622 |
else self.max_s, |
|
|
3623 |
t_sw_array, |
|
|
3624 |
r4[0], |
|
|
3625 |
r4[1], |
|
|
3626 |
r4[2], |
|
|
3627 |
r4[3], |
|
|
3628 |
rescale_c, |
|
|
3629 |
rescale_u, |
|
|
3630 |
self.ode_model_0, |
|
|
3631 |
self.ode_model_1, |
|
|
3632 |
self.ode_model_2_m1, |
|
|
3633 |
self.ode_model_2_m2, |
|
|
3634 |
self.device, |
|
|
3635 |
scale_cc=scale_cc, |
|
|
3636 |
scale_factor=self.scale_factor, |
|
|
3637 |
model=self.model, |
|
|
3638 |
direction=self.direction, |
|
|
3639 |
conn=self.conn if fit_outlier |
|
|
3640 |
else self.conn_sub, |
|
|
3641 |
k=self.k_dist, |
|
|
3642 |
t=self.n_anchors, |
|
|
3643 |
rna_only=self.rna_only, |
|
|
3644 |
penalize_gap=penalize_gap, |
|
|
3645 |
all_cells=fit_outlier) |
|
|
3646 |
|
|
|
3647 |
if fit_outlier: |
|
|
3648 |
min_dist, t_pred, state_pred, max_u, max_s, nn_penalty = res |
|
|
3649 |
else: |
|
|
3650 |
min_dist, state_pred, max_u, max_s, nn_penalty = res |
|
|
3651 |
|
|
|
3652 |
penalty += nn_penalty |
|
|
3653 |
|
|
|
3654 |
t_sw_adjust = [0, 0, 0] |
|
|
3655 |
|
|
|
3656 |
else: |
|
|
3657 |
|
|
|
3658 |
# distances and time assignments |
|
|
3659 |
res = calculate_dist_and_time(c_array, |
|
|
3660 |
u_array, |
|
|
3661 |
s_array, |
|
|
3662 |
t_sw_array, |
|
|
3663 |
r4[0], |
|
|
3664 |
r4[1], |
|
|
3665 |
r4[2], |
|
|
3666 |
r4[3], |
|
|
3667 |
rescale_c, |
|
|
3668 |
rescale_u, |
|
|
3669 |
scale_cc=scale_cc, |
|
|
3670 |
scale_factor=self.scale_factor, |
|
|
3671 |
model=self.model, |
|
|
3672 |
direction=self.direction, |
|
|
3673 |
conn=self.conn if fit_outlier |
|
|
3674 |
else self.conn_sub, |
|
|
3675 |
k=self.k_dist, |
|
|
3676 |
t=self.n_anchors, |
|
|
3677 |
rna_only=self.rna_only, |
|
|
3678 |
penalize_gap=penalize_gap, |
|
|
3679 |
all_cells=fit_outlier) |
|
|
3680 |
|
|
|
3681 |
if fit_outlier: |
|
|
3682 |
min_dist, t_pred, state_pred, reach_ss, late_phase, max_u, \ |
|
|
3683 |
max_s, self.anchor_t1_list, self.anchor_t2_list = res |
|
|
3684 |
else: |
|
|
3685 |
min_dist, state_pred, max_u, max_s, t_sw_adjust = res |
|
|
3686 |
|
|
|
3687 |
loss = np.mean(min_dist) |
|
|
3688 |
|
|
|
3689 |
# avoid exceeding maximum expressions |
|
|
3690 |
reg = np.max([0, max_s - self.max_s]) + np.max([0, max_u - self.max_u]) |
|
|
3691 |
loss += reg |
|
|
3692 |
|
|
|
3693 |
loss += 1e-1 * penalty |
|
|
3694 |
self.cur_loss = loss |
|
|
3695 |
self.cur_state_pred = state_pred |
|
|
3696 |
|
|
|
3697 |
if fit_outlier: |
|
|
3698 |
return loss, t_pred |
|
|
3699 |
else: |
|
|
3700 |
self.cur_t_sw_adjust = t_sw_adjust |
|
|
3701 |
|
|
|
3702 |
return loss |
|
|
3703 |
|
|
|
3704 |
def update(self, x, perform_update=False, initialize=False, |
|
|
3705 |
fit_outlier=False, adjust_time=True, penalize_gap=True, |
|
|
3706 |
plot=True): |
|
|
3707 |
t3, r4, scale_cc, rescale_c, rescale_u = self._variables(x) |
|
|
3708 |
t_sw_array = np.array([t3[0], t3[0]+t3[1], t3[0]+t3[1]+t3[2]]) |
|
|
3709 |
|
|
|
3710 |
# read results |
|
|
3711 |
if initialize: |
|
|
3712 |
new_loss = self.mse(x, penalize_gap=penalize_gap) |
|
|
3713 |
elif fit_outlier: |
|
|
3714 |
new_loss, t_pred = self.mse(x, fit_outlier=True, |
|
|
3715 |
penalize_gap=penalize_gap) |
|
|
3716 |
else: |
|
|
3717 |
new_loss = self.cur_loss |
|
|
3718 |
t_sw_adjust = self.cur_t_sw_adjust |
|
|
3719 |
state_pred = self.cur_state_pred |
|
|
3720 |
|
|
|
3721 |
if new_loss < self.loss[-1] or perform_update: |
|
|
3722 |
perform_update = True |
|
|
3723 |
|
|
|
3724 |
self.loss.append(new_loss) |
|
|
3725 |
self.alpha_c, self.alpha, self.beta, self.gamma = r4 |
|
|
3726 |
self.rates = r4 |
|
|
3727 |
self.scale_cc = scale_cc |
|
|
3728 |
self.rescale_c = rescale_c |
|
|
3729 |
self.rescale_u = rescale_u |
|
|
3730 |
|
|
|
3731 |
# adjust overcrowded anchors |
|
|
3732 |
if not fit_outlier and adjust_time: |
|
|
3733 |
t_sw_array -= np.cumsum(t_sw_adjust) |
|
|
3734 |
if self.rna_only: |
|
|
3735 |
t_sw_array[2] = 20 |
|
|
3736 |
|
|
|
3737 |
self.t_sw_1, self.t_sw_2, self.t_sw_3 = t_sw_array |
|
|
3738 |
self.t_sw_array = t_sw_array |
|
|
3739 |
self.params = np.array([self.t_sw_1, |
|
|
3740 |
self.t_sw_2-self.t_sw_1, |
|
|
3741 |
self.t_sw_3-self.t_sw_2, |
|
|
3742 |
self.alpha_c, |
|
|
3743 |
self.alpha, |
|
|
3744 |
self.beta, |
|
|
3745 |
self.gamma, |
|
|
3746 |
self.scale_cc, |
|
|
3747 |
self.rescale_c, |
|
|
3748 |
self.rescale_u]) |
|
|
3749 |
if not initialize: |
|
|
3750 |
self.state = state_pred |
|
|
3751 |
if fit_outlier: |
|
|
3752 |
self.t = t_pred |
|
|
3753 |
|
|
|
3754 |
logg.update(f'params updated as: {self.t_sw_array} {self.rates} ' |
|
|
3755 |
f'{self.scale_cc} {self.rescale_c} {self.rescale_u}', |
|
|
3756 |
v=2) |
|
|
3757 |
|
|
|
3758 |
# interactive plot |
|
|
3759 |
if self.plot and plot: |
|
|
3760 |
tau_list = anchor_points(self.t_sw_array, 20, self.n_anchors) |
|
|
3761 |
switch = np.sum(self.t_sw_array < 20) |
|
|
3762 |
typed_tau_list = List() |
|
|
3763 |
[typed_tau_list.append(x) for x in tau_list] |
|
|
3764 |
self.alpha_c, self.alpha, self.beta, self.gamma, \ |
|
|
3765 |
self.c0, self.u0, self.s0 = \ |
|
|
3766 |
check_params(self.alpha_c, self.alpha, self.beta, |
|
|
3767 |
self.gamma, c0=self.c0, u0=self.u0, |
|
|
3768 |
s0=self.s0) |
|
|
3769 |
exp_list, exp_sw_list = generate_exp(typed_tau_list, |
|
|
3770 |
self.t_sw_array[:switch], |
|
|
3771 |
self.alpha_c, |
|
|
3772 |
self.alpha, |
|
|
3773 |
self.beta, |
|
|
3774 |
self.gamma, |
|
|
3775 |
scale_cc=self.scale_cc, |
|
|
3776 |
model=self.model, |
|
|
3777 |
rna_only=self.rna_only) |
|
|
3778 |
rescale_factor = np.array([self.rescale_c, |
|
|
3779 |
self.rescale_u, |
|
|
3780 |
1.0]) |
|
|
3781 |
exp_list = [x*rescale_factor for x in exp_list] |
|
|
3782 |
exp_sw_list = [x*rescale_factor for x in exp_sw_list] |
|
|
3783 |
c = np.ravel(np.concatenate([exp_list[x][:, 0] for x in |
|
|
3784 |
range(switch+1)])) |
|
|
3785 |
u = np.ravel(np.concatenate([exp_list[x][:, 1] for x in |
|
|
3786 |
range(switch+1)])) |
|
|
3787 |
s = np.ravel(np.concatenate([exp_list[x][:, 2] for x in |
|
|
3788 |
range(switch+1)])) |
|
|
3789 |
c_ = self.c_all if fit_outlier else self.c |
|
|
3790 |
u_ = self.u_all if fit_outlier else self.u |
|
|
3791 |
s_ = self.s_all if fit_outlier else self.s |
|
|
3792 |
self.ax.clear() |
|
|
3793 |
plt.pause(0.1) |
|
|
3794 |
if self.rna_only: |
|
|
3795 |
self.ax.scatter(s, u, s=self.point_size*1.5, c='black', |
|
|
3796 |
alpha=0.6, zorder=2) |
|
|
3797 |
if switch >= 1: |
|
|
3798 |
c_sw1, u_sw1, s_sw1 = exp_sw_list[0][0] |
|
|
3799 |
self.ax.plot([s_sw1], [u_sw1], "om", |
|
|
3800 |
markersize=self.point_size, zorder=5) |
|
|
3801 |
if switch >= 2: |
|
|
3802 |
c_sw2, u_sw2, s_sw2 = exp_sw_list[1][0] |
|
|
3803 |
self.ax.plot([s_sw2], [u_sw2], "Xm", |
|
|
3804 |
markersize=self.point_size, zorder=5) |
|
|
3805 |
if switch == 3: |
|
|
3806 |
c_sw3, u_sw3, s_sw3 = exp_sw_list[2][0] |
|
|
3807 |
self.ax.plot([s_sw3], [u_sw3], "Dm", |
|
|
3808 |
markersize=self.point_size, zorder=5) |
|
|
3809 |
if np.max(self.t) == 20: |
|
|
3810 |
self.ax.plot([s[-1]], [u[-1]], "*m", |
|
|
3811 |
markersize=self.point_size, zorder=5) |
|
|
3812 |
for i in range(4): |
|
|
3813 |
if any(self.state == i): |
|
|
3814 |
self.ax.scatter(s_[(self.state == i)], |
|
|
3815 |
u_[(self.state == i)], |
|
|
3816 |
s=self.point_size, c=self.color[i]) |
|
|
3817 |
self.ax.set_xlabel('s') |
|
|
3818 |
self.ax.set_ylabel('u') |
|
|
3819 |
|
|
|
3820 |
else: |
|
|
3821 |
self.ax.scatter(s, u, c, s=self.point_size*1.5, |
|
|
3822 |
c='black', alpha=0.6, zorder=2) |
|
|
3823 |
if switch >= 1: |
|
|
3824 |
c_sw1, u_sw1, s_sw1 = exp_sw_list[0][0] |
|
|
3825 |
self.ax.plot([s_sw1], [u_sw1], [c_sw1], "om", |
|
|
3826 |
markersize=self.point_size, zorder=5) |
|
|
3827 |
if switch >= 2: |
|
|
3828 |
c_sw2, u_sw2, s_sw2 = exp_sw_list[1][0] |
|
|
3829 |
self.ax.plot([s_sw2], [u_sw2], [c_sw2], "Xm", |
|
|
3830 |
markersize=self.point_size, zorder=5) |
|
|
3831 |
if switch == 3: |
|
|
3832 |
c_sw3, u_sw3, s_sw3 = exp_sw_list[2][0] |
|
|
3833 |
self.ax.plot([s_sw3], [u_sw3], [c_sw3], "Dm", |
|
|
3834 |
markersize=self.point_size, zorder=5) |
|
|
3835 |
if np.max(self.t) == 20: |
|
|
3836 |
self.ax.plot([s[-1]], [u[-1]], [c[-1]], "*m", |
|
|
3837 |
markersize=self.point_size, zorder=5) |
|
|
3838 |
for i in range(4): |
|
|
3839 |
if any(self.state == i): |
|
|
3840 |
self.ax.scatter(s_[(self.state == i)], |
|
|
3841 |
u_[(self.state == i)], |
|
|
3842 |
c_[(self.state == i)], |
|
|
3843 |
s=self.point_size, c=self.color[i]) |
|
|
3844 |
self.ax.set_xlabel('s') |
|
|
3845 |
self.ax.set_ylabel('u') |
|
|
3846 |
self.ax.set_zlabel('c') |
|
|
3847 |
self.fig.canvas.draw() |
|
|
3848 |
plt.pause(0.1) |
|
|
3849 |
return perform_update |
|
|
3850 |
|
|
|
3851 |
def save_dyn_plot(self, c, u, s, c_sw, u_sw, s_sw, tau_list, |
|
|
3852 |
show_all=False): |
|
|
3853 |
if not os.path.exists(self.plot_path): |
|
|
3854 |
os.makedirs(self.plot_path) |
|
|
3855 |
logg.update(f'{self.plot_path} directory created.', v=2) |
|
|
3856 |
|
|
|
3857 |
switch = np.sum(self.t_sw_array < 20) |
|
|
3858 |
scale_back = np.array([self.scale_c, self.scale_u, self.scale_s]) |
|
|
3859 |
shift_back = np.array([self.offset_c, self.offset_u, self.offset_s]) |
|
|
3860 |
if switch >= 1: |
|
|
3861 |
c_sw1, u_sw1, s_sw1 = c_sw[0], u_sw[0], s_sw[0] |
|
|
3862 |
if switch >= 2: |
|
|
3863 |
c_sw2, u_sw2, s_sw2 = c_sw[1], u_sw[1], s_sw[1] |
|
|
3864 |
if switch == 3: |
|
|
3865 |
c_sw3, u_sw3, s_sw3 = c_sw[2], u_sw[2], s_sw[2] |
|
|
3866 |
|
|
|
3867 |
if not show_all: |
|
|
3868 |
n_anchors = len(u) |
|
|
3869 |
t_lower = np.min(self.t) |
|
|
3870 |
t_upper = np.max(self.t) |
|
|
3871 |
t_ = np.concatenate((tau_list[0], tau_list[1] + self.t_sw_array[0], |
|
|
3872 |
tau_list[2] + self.t_sw_array[1], |
|
|
3873 |
tau_list[3] + self.t_sw_array[2])) |
|
|
3874 |
c_pre = c[t_[:n_anchors] <= t_lower] |
|
|
3875 |
u_pre = u[t_[:n_anchors] <= t_lower] |
|
|
3876 |
s_pre = s[t_[:n_anchors] <= t_lower] |
|
|
3877 |
c = c[(t_lower < t_[:n_anchors]) & (t_[:n_anchors] < t_upper)] |
|
|
3878 |
u = u[(t_lower < t_[:n_anchors]) & (t_[:n_anchors] < t_upper)] |
|
|
3879 |
s = s[(t_lower < t_[:n_anchors]) & (t_[:n_anchors] < t_upper)] |
|
|
3880 |
|
|
|
3881 |
c_all = self.c_all * self.scale_c + self.offset_c |
|
|
3882 |
u_all = self.u_all * self.scale_u + self.offset_u |
|
|
3883 |
s_all = self.s_all * self.scale_s + self.offset_s |
|
|
3884 |
|
|
|
3885 |
fig = plt.figure(figsize=self.fig_size) |
|
|
3886 |
fig.patch.set_facecolor('white') |
|
|
3887 |
ax = fig.add_subplot(111, facecolor='white') |
|
|
3888 |
if not show_all and len(u_pre) > 0: |
|
|
3889 |
ax.scatter(s_pre, u_pre, s=self.point_size/2, c='black', |
|
|
3890 |
alpha=0.4, zorder=2) |
|
|
3891 |
ax.scatter(s, u, s=self.point_size*1.5, c='black', alpha=0.6, zorder=2) |
|
|
3892 |
for i in range(4): |
|
|
3893 |
if any(self.state == i): |
|
|
3894 |
ax.scatter(s_all[(self.state == i) & (self.non_outlier)], |
|
|
3895 |
u_all[(self.state == i) & (self.non_outlier)], |
|
|
3896 |
s=self.point_size, c=self.color[i]) |
|
|
3897 |
ax.scatter(s_all[~self.non_outlier], u_all[~self.non_outlier], |
|
|
3898 |
s=self.point_size/2, c='grey') |
|
|
3899 |
if show_all or t_lower <= self.t_sw_array[0]: |
|
|
3900 |
ax.plot([s_sw1], [u_sw1], "om", markersize=self.point_size, |
|
|
3901 |
zorder=5) |
|
|
3902 |
if switch >= 2 and (show_all or (t_lower <= self.t_sw_array[1] and |
|
|
3903 |
t_upper >= self.t_sw_array[1])): |
|
|
3904 |
ax.plot([s_sw2], [u_sw2], "Xm", markersize=self.point_size, |
|
|
3905 |
zorder=5) |
|
|
3906 |
if switch >= 3 and (show_all or (t_lower <= self.t_sw_array[2] and |
|
|
3907 |
t_upper >= self.t_sw_array[2])): |
|
|
3908 |
ax.plot([s_sw3], [u_sw3], "Dm", markersize=self.point_size, |
|
|
3909 |
zorder=5) |
|
|
3910 |
if np.max(self.t) == 20: |
|
|
3911 |
ax.plot([s[-1]], [u[-1]], "*m", markersize=self.point_size, |
|
|
3912 |
zorder=5) |
|
|
3913 |
if (self.anchor_t1_list is not None and len(self.anchor_t1_list) > 0 |
|
|
3914 |
and show_all): |
|
|
3915 |
for i in range(len(self.anchor_t1_list)): |
|
|
3916 |
exp_t1 = self.anchor_t1_list[i] * scale_back + shift_back |
|
|
3917 |
exp_t2 = self.anchor_t2_list[i] * scale_back + shift_back |
|
|
3918 |
ax.plot([exp_t1[2]], [exp_t1[1]], "|y", |
|
|
3919 |
markersize=self.point_size*1.5) |
|
|
3920 |
ax.plot([exp_t2[2]], [exp_t2[1]], "|c", |
|
|
3921 |
markersize=self.point_size*1.5) |
|
|
3922 |
ax.plot(s_all, |
|
|
3923 |
self.steady_state_func(self.s_all) * self.scale_u |
|
|
3924 |
+ self.offset_u, c='grey', ls=':', lw=self.point_size/4, |
|
|
3925 |
alpha=0.7) |
|
|
3926 |
ax.set_xlabel('s') |
|
|
3927 |
ax.set_ylabel('u') |
|
|
3928 |
ax.set_title(f'{self.gene}-{self.model}') |
|
|
3929 |
plt.tight_layout() |
|
|
3930 |
fig.savefig(f'{self.plot_path}/{self.gene}-{self.model}-us.png', |
|
|
3931 |
dpi=fig.dpi, facecolor=fig.get_facecolor(), |
|
|
3932 |
transparent=False, edgecolor='none') |
|
|
3933 |
plt.close(fig) |
|
|
3934 |
plt.pause(0.2) |
|
|
3935 |
|
|
|
3936 |
if self.extra_color is not None: |
|
|
3937 |
fig = plt.figure(figsize=self.fig_size) |
|
|
3938 |
fig.patch.set_facecolor('white') |
|
|
3939 |
ax = fig.add_subplot(111, facecolor='white') |
|
|
3940 |
if not show_all and len(u_pre) > 0: |
|
|
3941 |
ax.scatter(s_pre, u_pre, s=self.point_size/2, c='black', |
|
|
3942 |
alpha=0.4, zorder=2) |
|
|
3943 |
ax.scatter(s, u, s=self.point_size*1.5, c='black', alpha=0.6, |
|
|
3944 |
zorder=2) |
|
|
3945 |
ax.scatter(s_all, u_all, s=self.point_size, c=self.extra_color) |
|
|
3946 |
if show_all or t_lower <= self.t_sw_array[0]: |
|
|
3947 |
ax.plot([s_sw1], [u_sw1], "om", markersize=self.point_size, |
|
|
3948 |
zorder=5) |
|
|
3949 |
if switch >= 2 and (show_all or (t_lower <= self.t_sw_array[1] and |
|
|
3950 |
t_upper >= self.t_sw_array[1])): |
|
|
3951 |
ax.plot([s_sw2], [u_sw2], "Xm", markersize=self.point_size, |
|
|
3952 |
zorder=5) |
|
|
3953 |
if switch >= 3 and (show_all or (t_lower <= self.t_sw_array[2] and |
|
|
3954 |
t_upper >= self.t_sw_array[2])): |
|
|
3955 |
ax.plot([s_sw3], [u_sw3], "Dm", markersize=self.point_size, |
|
|
3956 |
zorder=5) |
|
|
3957 |
if np.max(self.t) == 20: |
|
|
3958 |
ax.plot([s[-1]], [u[-1]], "*m", markersize=self.point_size, |
|
|
3959 |
zorder=5) |
|
|
3960 |
if (self.anchor_t1_list is not None and |
|
|
3961 |
len(self.anchor_t1_list) > 0 and show_all): |
|
|
3962 |
for i in range(len(self.anchor_t1_list)): |
|
|
3963 |
exp_t1 = self.anchor_t1_list[i] * scale_back + shift_back |
|
|
3964 |
exp_t2 = self.anchor_t2_list[i] * scale_back + shift_back |
|
|
3965 |
ax.plot([exp_t1[2]], [exp_t1[1]], "|y", |
|
|
3966 |
markersize=self.point_size*1.5) |
|
|
3967 |
ax.plot([exp_t2[2]], [exp_t2[1]], "|c", |
|
|
3968 |
markersize=self.point_size*1.5) |
|
|
3969 |
ax.plot(s_all, self.steady_state_func(self.s_all) * self.scale_u |
|
|
3970 |
+ self.offset_u, c='grey', ls=':', lw=self.point_size/4, |
|
|
3971 |
alpha=0.7) |
|
|
3972 |
ax.set_xlabel('s') |
|
|
3973 |
ax.set_ylabel('u') |
|
|
3974 |
ax.set_title(f'{self.gene}-{self.model}') |
|
|
3975 |
plt.tight_layout() |
|
|
3976 |
fig.savefig(f'{self.plot_path}/{self.gene}-{self.model}-' |
|
|
3977 |
'us_colorby_extra.png', dpi=fig.dpi, |
|
|
3978 |
facecolor=fig.get_facecolor(), transparent=False, |
|
|
3979 |
edgecolor='none') |
|
|
3980 |
plt.close(fig) |
|
|
3981 |
plt.pause(0.2) |
|
|
3982 |
|
|
|
3983 |
if not self.rna_only: |
|
|
3984 |
fig = plt.figure(figsize=self.fig_size) |
|
|
3985 |
fig.patch.set_facecolor('white') |
|
|
3986 |
ax = fig.add_subplot(111, facecolor='white') |
|
|
3987 |
if not show_all and len(u_pre) > 0: |
|
|
3988 |
ax.scatter(u_pre, c_pre, s=self.point_size/2, c='black', |
|
|
3989 |
alpha=0.4, zorder=2) |
|
|
3990 |
ax.scatter(u, c, s=self.point_size*1.5, c='black', alpha=0.6, |
|
|
3991 |
zorder=2) |
|
|
3992 |
ax.scatter(u_all, c_all, s=self.point_size, c=self.extra_color) |
|
|
3993 |
if show_all or t_lower <= self.t_sw_array[0]: |
|
|
3994 |
ax.plot([u_sw1], [c_sw1], "om", markersize=self.point_size, |
|
|
3995 |
zorder=5) |
|
|
3996 |
if switch >= 2 and (show_all or (t_lower <= self.t_sw_array[1] |
|
|
3997 |
and t_upper >= |
|
|
3998 |
self.t_sw_array[1])): |
|
|
3999 |
ax.plot([u_sw2], [c_sw2], "Xm", markersize=self.point_size, |
|
|
4000 |
zorder=5) |
|
|
4001 |
if switch >= 3 and (show_all or (t_lower <= self.t_sw_array[2] |
|
|
4002 |
and t_upper >= |
|
|
4003 |
self.t_sw_array[2])): |
|
|
4004 |
ax.plot([u_sw3], [c_sw3], "Dm", markersize=self.point_size, |
|
|
4005 |
zorder=5) |
|
|
4006 |
if np.max(self.t) == 20: |
|
|
4007 |
ax.plot([u[-1]], [c[-1]], "*m", markersize=self.point_size, |
|
|
4008 |
zorder=5) |
|
|
4009 |
ax.set_xlabel('u') |
|
|
4010 |
ax.set_ylabel('c') |
|
|
4011 |
ax.set_title(f'{self.gene}-{self.model}') |
|
|
4012 |
plt.tight_layout() |
|
|
4013 |
fig.savefig(f'{self.plot_path}/{self.gene}-{self.model}-' |
|
|
4014 |
'cu_colorby_extra.png', dpi=fig.dpi, |
|
|
4015 |
facecolor=fig.get_facecolor(), transparent=False, |
|
|
4016 |
edgecolor='none') |
|
|
4017 |
plt.close(fig) |
|
|
4018 |
plt.pause(0.2) |
|
|
4019 |
|
|
|
4020 |
if not self.rna_only: |
|
|
4021 |
fig = plt.figure(figsize=self.fig_size) |
|
|
4022 |
fig.patch.set_facecolor('white') |
|
|
4023 |
ax = fig.add_subplot(111, projection='3d', facecolor='white') |
|
|
4024 |
if not show_all and len(u_pre) > 0: |
|
|
4025 |
ax.scatter(s_pre, u_pre, c_pre, s=self.point_size/2, c='black', |
|
|
4026 |
alpha=0.4, zorder=2) |
|
|
4027 |
ax.scatter(s, u, c, s=self.point_size*1.5, c='black', alpha=0.6, |
|
|
4028 |
zorder=2) |
|
|
4029 |
for i in range(4): |
|
|
4030 |
if any(self.state == i): |
|
|
4031 |
ax.scatter(s_all[(self.state == i) & (self.non_outlier)], |
|
|
4032 |
u_all[(self.state == i) & (self.non_outlier)], |
|
|
4033 |
c_all[(self.state == i) & (self.non_outlier)], |
|
|
4034 |
s=self.point_size, c=self.color[i]) |
|
|
4035 |
ax.scatter(s_all[~self.non_outlier], u_all[~self.non_outlier], |
|
|
4036 |
c_all[~self.non_outlier], s=self.point_size/2, c='grey') |
|
|
4037 |
if show_all or t_lower <= self.t_sw_array[0]: |
|
|
4038 |
ax.plot([s_sw1], [u_sw1], [c_sw1], "om", |
|
|
4039 |
markersize=self.point_size, zorder=5) |
|
|
4040 |
if switch >= 2 and (show_all or (t_lower <= self.t_sw_array[1] and |
|
|
4041 |
t_upper >= self.t_sw_array[1])): |
|
|
4042 |
ax.plot([s_sw2], [u_sw2], [c_sw2], "Xm", |
|
|
4043 |
markersize=self.point_size, zorder=5) |
|
|
4044 |
if switch >= 3 and (show_all or (t_lower <= self.t_sw_array[2] and |
|
|
4045 |
t_upper >= self.t_sw_array[2])): |
|
|
4046 |
ax.plot([s_sw3], [u_sw3], [c_sw3], "Dm", |
|
|
4047 |
markersize=self.point_size, zorder=5) |
|
|
4048 |
if np.max(self.t) == 20: |
|
|
4049 |
ax.plot([s[-1]], [u[-1]], [c[-1]], "*m", |
|
|
4050 |
markersize=self.point_size, zorder=5) |
|
|
4051 |
ax.set_xlabel('s') |
|
|
4052 |
ax.set_ylabel('u') |
|
|
4053 |
ax.set_zlabel('c') |
|
|
4054 |
ax.set_title(f'{self.gene}-{self.model}') |
|
|
4055 |
plt.tight_layout() |
|
|
4056 |
fig.savefig(f'{self.plot_path}/{self.gene}-{self.model}-cus.png', |
|
|
4057 |
dpi=fig.dpi, facecolor=fig.get_facecolor(), |
|
|
4058 |
transparent=False, edgecolor='none') |
|
|
4059 |
plt.close(fig) |
|
|
4060 |
plt.pause(0.2) |
|
|
4061 |
|
|
|
4062 |
fig = plt.figure(figsize=self.fig_size) |
|
|
4063 |
fig.patch.set_facecolor('white') |
|
|
4064 |
ax = fig.add_subplot(111, facecolor='white') |
|
|
4065 |
if not show_all and len(u_pre) > 0: |
|
|
4066 |
ax.scatter(s_pre, u_pre, s=self.point_size/2, c='black', |
|
|
4067 |
alpha=0.4, zorder=2) |
|
|
4068 |
ax.scatter(s, u, s=self.point_size*1.5, c='black', alpha=0.6, |
|
|
4069 |
zorder=2) |
|
|
4070 |
ax.scatter(s_all, u_all, s=self.point_size, c=np.log1p(self.c_all), |
|
|
4071 |
cmap='coolwarm') |
|
|
4072 |
if show_all or t_lower <= self.t_sw_array[0]: |
|
|
4073 |
ax.plot([s_sw1], [u_sw1], "om", markersize=self.point_size, |
|
|
4074 |
zorder=5) |
|
|
4075 |
if switch >= 2 and (show_all or (t_lower <= self.t_sw_array[1] and |
|
|
4076 |
t_upper >= self.t_sw_array[1])): |
|
|
4077 |
ax.plot([s_sw2], [u_sw2], "Xm", markersize=self.point_size, |
|
|
4078 |
zorder=5) |
|
|
4079 |
if switch >= 3 and (show_all or (t_lower <= self.t_sw_array[2] and |
|
|
4080 |
t_upper >= self.t_sw_array[2])): |
|
|
4081 |
ax.plot([s_sw3], [u_sw3], "Dm", markersize=self.point_size, |
|
|
4082 |
zorder=5) |
|
|
4083 |
if np.max(self.t) == 20: |
|
|
4084 |
ax.plot([s[-1]], [u[-1]], "*m", markersize=self.point_size, |
|
|
4085 |
zorder=5) |
|
|
4086 |
ax.plot(s_all, self.steady_state_func(self.s_all) * self.scale_u + |
|
|
4087 |
self.offset_u, c='grey', ls=':', lw=self.point_size/4, |
|
|
4088 |
alpha=0.7) |
|
|
4089 |
ax.set_xlabel('s') |
|
|
4090 |
ax.set_ylabel('u') |
|
|
4091 |
ax.set_title(f'{self.gene}-{self.model}') |
|
|
4092 |
plt.tight_layout() |
|
|
4093 |
fig.savefig(f'{self.plot_path}/{self.gene}-{self.model}-' |
|
|
4094 |
'us_colorby_c.png', dpi=fig.dpi, |
|
|
4095 |
facecolor=fig.get_facecolor(), transparent=False, |
|
|
4096 |
edgecolor='none') |
|
|
4097 |
plt.close(fig) |
|
|
4098 |
plt.pause(0.2) |
|
|
4099 |
|
|
|
4100 |
fig = plt.figure(figsize=self.fig_size) |
|
|
4101 |
fig.patch.set_facecolor('white') |
|
|
4102 |
ax = fig.add_subplot(111, facecolor='white') |
|
|
4103 |
if not show_all and len(u_pre) > 0: |
|
|
4104 |
ax.scatter(u_pre, c_pre, s=self.point_size/2, c='black', |
|
|
4105 |
alpha=0.4, zorder=2) |
|
|
4106 |
ax.scatter(u, c, s=self.point_size*1.5, c='black', alpha=0.6, |
|
|
4107 |
zorder=2) |
|
|
4108 |
for i in range(4): |
|
|
4109 |
if any(self.state == i): |
|
|
4110 |
ax.scatter(u_all[(self.state == i) & (self.non_outlier)], |
|
|
4111 |
c_all[(self.state == i) & (self.non_outlier)], |
|
|
4112 |
s=self.point_size, c=self.color[i]) |
|
|
4113 |
ax.scatter(u_all[~self.non_outlier], c_all[~self.non_outlier], |
|
|
4114 |
s=self.point_size/2, c='grey') |
|
|
4115 |
if show_all or t_lower <= self.t_sw_array[0]: |
|
|
4116 |
ax.plot([u_sw1], [c_sw1], "om", markersize=self.point_size, |
|
|
4117 |
zorder=5) |
|
|
4118 |
if switch >= 2 and (show_all or (t_lower <= self.t_sw_array[1] and |
|
|
4119 |
t_upper >= self.t_sw_array[1])): |
|
|
4120 |
ax.plot([u_sw2], [c_sw2], "Xm", markersize=self.point_size, |
|
|
4121 |
zorder=5) |
|
|
4122 |
if switch >= 3 and (show_all or (t_lower <= self.t_sw_array[2] and |
|
|
4123 |
t_upper >= self.t_sw_array[2])): |
|
|
4124 |
ax.plot([u_sw3], [c_sw3], "Dm", markersize=self.point_size, |
|
|
4125 |
zorder=5) |
|
|
4126 |
if np.max(self.t) == 20: |
|
|
4127 |
ax.plot([u[-1]], [c[-1]], "*m", markersize=self.point_size, |
|
|
4128 |
zorder=5) |
|
|
4129 |
ax.set_xlabel('u') |
|
|
4130 |
ax.set_ylabel('c') |
|
|
4131 |
ax.set_title(f'{self.gene}-{self.model}') |
|
|
4132 |
plt.tight_layout() |
|
|
4133 |
fig.savefig(f'{self.plot_path}/{self.gene}-{self.model}-cu.png', |
|
|
4134 |
dpi=fig.dpi, facecolor=fig.get_facecolor(), |
|
|
4135 |
transparent=False, edgecolor='none') |
|
|
4136 |
plt.close(fig) |
|
|
4137 |
plt.pause(0.2) |
|
|
4138 |
|
|
|
4139 |
def get_loss(self): |
|
|
4140 |
return self.loss |
|
|
4141 |
|
|
|
4142 |
def get_model(self): |
|
|
4143 |
return self.model |
|
|
4144 |
|
|
|
4145 |
def get_params(self): |
|
|
4146 |
return self.t_sw_array, self.rates, self.scale_cc, self.rescale_c, \ |
|
|
4147 |
self.rescale_u, self.realign_ratio |
|
|
4148 |
|
|
|
4149 |
def is_partial(self): |
|
|
4150 |
return self.partial |
|
|
4151 |
|
|
|
4152 |
def get_direction(self): |
|
|
4153 |
return self.direction |
|
|
4154 |
|
|
|
4155 |
def realign_time_and_velocity(self, c, u, s, anchor_time): |
|
|
4156 |
# realign time to range (0,20) |
|
|
4157 |
self.anchor_min_idx = np.sum(anchor_time < (np.min(self.t)-1e-5)) |
|
|
4158 |
self.anchor_max_idx = np.sum(anchor_time < (np.max(self.t)-1e-5)) |
|
|
4159 |
self.c0 = c[self.anchor_min_idx] |
|
|
4160 |
self.u0 = u[self.anchor_min_idx] |
|
|
4161 |
self.s0 = s[self.anchor_min_idx] |
|
|
4162 |
self.realign_ratio = 20 / (np.max(self.t) - np.min(self.t)) |
|
|
4163 |
logg.update(f'fitted params:\nswitch time array = {self.t_sw_array},\n' |
|
|
4164 |
f'rates = {self.rates},\ncc scale = {self.scale_cc},\n' |
|
|
4165 |
f'c rescale factor = {self.rescale_c},\n' |
|
|
4166 |
f'u rescale factor = {self.rescale_u}', |
|
|
4167 |
v=1) |
|
|
4168 |
logg.update(f'aligning to range (0,20) by {self.realign_ratio}..', |
|
|
4169 |
v=1) |
|
|
4170 |
self.rates /= self.realign_ratio |
|
|
4171 |
self.alpha_c, self.alpha, self.beta, self.gamma = self.rates |
|
|
4172 |
self.params[3:7] = self.rates |
|
|
4173 |
self.t_sw_array = ((self.t_sw_array - np.min(self.t)) |
|
|
4174 |
* self.realign_ratio) |
|
|
4175 |
self.t_sw_1, self.t_sw_2, self.t_sw_3 = self.t_sw_array |
|
|
4176 |
self.params[:3] = np.array([self.t_sw_1, self.t_sw_2 - self.t_sw_1, |
|
|
4177 |
self.t_sw_3 - self.t_sw_2]) |
|
|
4178 |
self.t -= np.min(self.t) |
|
|
4179 |
self.t = self.t * 20 / np.max(self.t) |
|
|
4180 |
self.velocity /= self.realign_ratio |
|
|
4181 |
self.velocity[:, 0] = np.clip(self.velocity[:, 0], -self.c_all |
|
|
4182 |
* self.scale_c, None) |
|
|
4183 |
self.velocity[:, 1] = np.clip(self.velocity[:, 1], -self.u_all |
|
|
4184 |
* self.scale_u, None) |
|
|
4185 |
self.velocity[:, 2] = np.clip(self.velocity[:, 2], -self.s_all |
|
|
4186 |
* self.scale_s, None) |
|
|
4187 |
self.anchor_velo /= self.realign_ratio |
|
|
4188 |
self.anchor_velo[:, 0] = np.clip(self.anchor_velo[:, 0], |
|
|
4189 |
-np.max(self.c_all * self.scale_c), |
|
|
4190 |
None) |
|
|
4191 |
self.anchor_velo[:, 1] = np.clip(self.anchor_velo[:, 1], |
|
|
4192 |
-np.max(self.u_all * self.scale_u), |
|
|
4193 |
None) |
|
|
4194 |
self.anchor_velo[:, 2] = np.clip(self.anchor_velo[:, 2], |
|
|
4195 |
-np.max(self.s_all * self.scale_s), |
|
|
4196 |
None) |
|
|
4197 |
|
|
|
4198 |
def get_initial_exp(self): |
|
|
4199 |
return np.array([self.c0, self.u0, self.s0]) |
|
|
4200 |
|
|
|
4201 |
def get_time_assignment(self): |
|
|
4202 |
if self.low_quality: |
|
|
4203 |
return np.zeros(len(self.u_all)) |
|
|
4204 |
return self.t |
|
|
4205 |
|
|
|
4206 |
def get_state_assignment(self): |
|
|
4207 |
if self.low_quality: |
|
|
4208 |
return np.zeros(len(self.u_all)) |
|
|
4209 |
return self.state |
|
|
4210 |
|
|
|
4211 |
def get_velocity(self): |
|
|
4212 |
if self.low_quality: |
|
|
4213 |
return np.zeros((len(self.u_all), 3)) |
|
|
4214 |
return self.velocity |
|
|
4215 |
|
|
|
4216 |
def get_likelihood(self): |
|
|
4217 |
return self.likelihood, self.l_c, self.ssd_c, self.var_c |
|
|
4218 |
|
|
|
4219 |
def get_anchors(self): |
|
|
4220 |
if self.low_quality: |
|
|
4221 |
return (np.zeros((1, 3)), np.zeros((1, 3)), np.zeros((1, 3)), |
|
|
4222 |
0, 0, 0, 0) |
|
|
4223 |
return self.anchor_exp, self.anchor_exp_sw, self.anchor_velo, \ |
|
|
4224 |
self.anchor_min_idx, self.anchor_max_idx, \ |
|
|
4225 |
self.anchor_velo_min_idx, self.anchor_velo_max_idx |
|
|
4226 |
|
|
|
4227 |
|
|
|
4228 |
def regress_func(c, u, s, m, mi, im, dev, nn, ad, lr, b1, b2, bs, gpdist, |
|
|
4229 |
embed, conn, pl, sp, pdir, fa, gene, pa, di, ro, fit, fd, |
|
|
4230 |
extra, ru, alpha, beta, gamma, t_, verbosity, log_folder, |
|
|
4231 |
log_filename): |
|
|
4232 |
|
|
|
4233 |
settings.VERBOSITY = verbosity |
|
|
4234 |
settings.LOG_FOLDER = log_folder |
|
|
4235 |
settings.LOG_FILENAME = log_filename |
|
|
4236 |
settings.GENE = gene |
|
|
4237 |
|
|
|
4238 |
if m is not None: |
|
|
4239 |
logg.update('#########################################################' |
|
|
4240 |
'######################################', v=1) |
|
|
4241 |
logg.update(f'testing model {m}', v=1) |
|
|
4242 |
|
|
|
4243 |
c_90 = np.percentile(c, 90) |
|
|
4244 |
u_90 = np.percentile(u, 90) |
|
|
4245 |
s_90 = np.percentile(s, 90) |
|
|
4246 |
low_quality = (u_90 == 0 or s_90 == 0) if ro else (c_90 == 0 or u_90 == 0 |
|
|
4247 |
or s_90 == 0) |
|
|
4248 |
if low_quality: |
|
|
4249 |
logg.update(f'low quality gene {gene}, skipping', v=1) |
|
|
4250 |
return (np.inf, np.nan, '', (np.zeros(3), np.zeros(4), 0, 0, 0, 0), |
|
|
4251 |
np.zeros(3), np.zeros(len(u)), np.zeros(len(u)), |
|
|
4252 |
np.zeros((len(u), 3)), (-1.0, 0, 0, 0), |
|
|
4253 |
(np.zeros((1, 3)), np.zeros((1, 3)), np.zeros((1, 3)), 0, 0, |
|
|
4254 |
0, 0)) |
|
|
4255 |
|
|
|
4256 |
if gpdist is not None: |
|
|
4257 |
subset_cells = s > 0.1 * np.percentile(s, 99) |
|
|
4258 |
subset_cells = np.where(subset_cells)[0] |
|
|
4259 |
if len(subset_cells) > 3000: |
|
|
4260 |
rng = np.random.default_rng(2021) |
|
|
4261 |
subset_cells = rng.choice(subset_cells, 3000, replace=False) |
|
|
4262 |
local_pdist = gpdist[np.ix_(subset_cells, subset_cells)] |
|
|
4263 |
dists = (np.ravel(local_pdist[np.triu_indices_from(local_pdist, k=1)]) |
|
|
4264 |
.reshape(-1, 1)) |
|
|
4265 |
local_std = np.std(dists) |
|
|
4266 |
else: |
|
|
4267 |
local_std = None |
|
|
4268 |
|
|
|
4269 |
cdc = ChromatinDynamical(c, |
|
|
4270 |
u, |
|
|
4271 |
s, |
|
|
4272 |
model=m, |
|
|
4273 |
max_iter=mi, |
|
|
4274 |
init_mode=im, |
|
|
4275 |
device=dev, |
|
|
4276 |
neural_net=nn, |
|
|
4277 |
adam=ad, |
|
|
4278 |
adam_lr=lr, |
|
|
4279 |
adam_beta1=b1, |
|
|
4280 |
adam_beta2=b2, |
|
|
4281 |
batch_size=bs, |
|
|
4282 |
local_std=local_std, |
|
|
4283 |
embed_coord=embed, |
|
|
4284 |
connectivities=conn, |
|
|
4285 |
plot=pl, |
|
|
4286 |
save_plot=sp, |
|
|
4287 |
plot_dir=pdir, |
|
|
4288 |
fit_args=fa, |
|
|
4289 |
gene=gene, |
|
|
4290 |
partial=pa, |
|
|
4291 |
direction=di, |
|
|
4292 |
rna_only=ro, |
|
|
4293 |
fit_decoupling=fd, |
|
|
4294 |
extra_color=extra, |
|
|
4295 |
rescale_u=ru, |
|
|
4296 |
alpha=alpha, |
|
|
4297 |
beta=beta, |
|
|
4298 |
gamma=gamma, |
|
|
4299 |
t_=t_) |
|
|
4300 |
if fit: |
|
|
4301 |
loss = cdc.fit() |
|
|
4302 |
if loss[-1] == np.inf: |
|
|
4303 |
logg.update(f'low quality gene {gene}, skipping..', v=1) |
|
|
4304 |
loss = cdc.get_loss() |
|
|
4305 |
model = cdc.get_model() |
|
|
4306 |
direction = cdc.get_direction() |
|
|
4307 |
parameters = cdc.get_params() |
|
|
4308 |
initial_exp = cdc.get_initial_exp() |
|
|
4309 |
velocity = cdc.get_velocity() |
|
|
4310 |
likelihood = cdc.get_likelihood() |
|
|
4311 |
time = cdc.get_time_assignment() |
|
|
4312 |
state = cdc.get_state_assignment() |
|
|
4313 |
anchors = cdc.get_anchors() |
|
|
4314 |
return loss[-1], model, direction, parameters, initial_exp, time, state, \ |
|
|
4315 |
velocity, likelihood, anchors |
|
|
4316 |
|
|
|
4317 |
|
|
|
4318 |
def multimodel_helper(c, u, s, |
|
|
4319 |
model_to_run, |
|
|
4320 |
max_iter, |
|
|
4321 |
init_mode, |
|
|
4322 |
device, |
|
|
4323 |
neural_net, |
|
|
4324 |
adam, |
|
|
4325 |
adam_lr, |
|
|
4326 |
adam_beta1, |
|
|
4327 |
adam_beta2, |
|
|
4328 |
batch_size, |
|
|
4329 |
global_pdist, |
|
|
4330 |
embed_coord, |
|
|
4331 |
conn, |
|
|
4332 |
plot, |
|
|
4333 |
save_plot, |
|
|
4334 |
plot_dir, |
|
|
4335 |
fit_args, |
|
|
4336 |
gene, |
|
|
4337 |
partial, |
|
|
4338 |
direction, |
|
|
4339 |
rna_only, |
|
|
4340 |
fit, |
|
|
4341 |
fit_decoupling, |
|
|
4342 |
extra_color, |
|
|
4343 |
rescale_u, |
|
|
4344 |
alpha, |
|
|
4345 |
beta, |
|
|
4346 |
gamma, |
|
|
4347 |
t_, |
|
|
4348 |
verbosity, log_folder, log_filename): |
|
|
4349 |
|
|
|
4350 |
loss, param_cand, initial_cand, time_cand = [], [], [], [] |
|
|
4351 |
state_cand, velo_cand, likelihood_cand, anch_cand = [], [], [], [] |
|
|
4352 |
|
|
|
4353 |
for model in model_to_run: |
|
|
4354 |
(loss_m, _, direction_, parameters, initial_exp, |
|
|
4355 |
time, state, velocity, likelihood, anchors) = \ |
|
|
4356 |
regress_func(c, u, s, model, max_iter, init_mode, device, neural_net, |
|
|
4357 |
adam, adam_lr, adam_beta1, adam_beta2, batch_size, |
|
|
4358 |
global_pdist, embed_coord, conn, plot, save_plot, |
|
|
4359 |
plot_dir, fit_args, gene, partial, direction, rna_only, |
|
|
4360 |
fit, fit_decoupling, extra_color, rescale_u, alpha, beta, |
|
|
4361 |
gamma, t_) |
|
|
4362 |
loss.append(loss_m) |
|
|
4363 |
param_cand.append(parameters) |
|
|
4364 |
initial_cand.append(initial_exp) |
|
|
4365 |
time_cand.append(time) |
|
|
4366 |
state_cand.append(state) |
|
|
4367 |
velo_cand.append(velocity) |
|
|
4368 |
likelihood_cand.append(likelihood) |
|
|
4369 |
anch_cand.append(anchors) |
|
|
4370 |
|
|
|
4371 |
best_model = np.argmin(loss) |
|
|
4372 |
model = np.nan if rna_only else model_to_run[best_model] |
|
|
4373 |
parameters = param_cand[best_model] |
|
|
4374 |
initial_exp = initial_cand[best_model] |
|
|
4375 |
time = time_cand[best_model] |
|
|
4376 |
state = state_cand[best_model] |
|
|
4377 |
velocity = velo_cand[best_model] |
|
|
4378 |
likelihood = likelihood_cand[best_model] |
|
|
4379 |
anchors = anch_cand[best_model] |
|
|
4380 |
return loss, model, direction_, parameters, initial_exp, time, state, \ |
|
|
4381 |
velocity, likelihood, anchors |
|
|
4382 |
|
|
|
4383 |
|
|
|
4384 |
def recover_dynamics_chrom(adata_rna, |
|
|
4385 |
adata_atac=None, |
|
|
4386 |
gene_list=None, |
|
|
4387 |
max_iter=5, |
|
|
4388 |
init_mode='invert', |
|
|
4389 |
device="cpu", |
|
|
4390 |
neural_net=False, |
|
|
4391 |
adam=False, |
|
|
4392 |
adam_lr=None, |
|
|
4393 |
adam_beta1=None, |
|
|
4394 |
adam_beta2=None, |
|
|
4395 |
batch_size=None, |
|
|
4396 |
model_to_run=None, |
|
|
4397 |
plot=False, |
|
|
4398 |
parallel=True, |
|
|
4399 |
n_jobs=None, |
|
|
4400 |
save_plot=False, |
|
|
4401 |
plot_dir=None, |
|
|
4402 |
rna_only=False, |
|
|
4403 |
fit=True, |
|
|
4404 |
fit_decoupling=True, |
|
|
4405 |
extra_color_key=None, |
|
|
4406 |
embedding='X_umap', |
|
|
4407 |
n_anchors=500, |
|
|
4408 |
k_dist=1, |
|
|
4409 |
thresh_multiplier=1.0, |
|
|
4410 |
weight_c=0.6, |
|
|
4411 |
outlier=99.8, |
|
|
4412 |
n_pcs=30, |
|
|
4413 |
n_neighbors=30, |
|
|
4414 |
fig_size=(8, 6), |
|
|
4415 |
point_size=7, |
|
|
4416 |
partial=None, |
|
|
4417 |
direction=None, |
|
|
4418 |
rescale_u=None, |
|
|
4419 |
alpha=None, |
|
|
4420 |
beta=None, |
|
|
4421 |
gamma=None, |
|
|
4422 |
t_sw=None |
|
|
4423 |
): |
|
|
4424 |
|
|
|
4425 |
"""Multi-omic dynamics recovery. |
|
|
4426 |
|
|
|
4427 |
This function optimizes the joint chromatin and RNA model parameters in |
|
|
4428 |
ODE solutions. |
|
|
4429 |
|
|
|
4430 |
Parameters |
|
|
4431 |
---------- |
|
|
4432 |
adata_rna: :class:`~anndata.AnnData` |
|
|
4433 |
RNA anndata object. Required fields: `Mu`, `Ms`, and `connectivities`. |
|
|
4434 |
adata_atac: :class:`~anndata.AnnData` (default: `None`) |
|
|
4435 |
ATAC anndata object. Required fields: `Mc`. |
|
|
4436 |
gene_list: `str`, list of `str` (default: highly variable genes) |
|
|
4437 |
Genes to use for model fitting. |
|
|
4438 |
max_iter: `int` (default: `5`) |
|
|
4439 |
Iterations to run for parameter optimization. |
|
|
4440 |
init_mode: `str` (default: `'invert'`) |
|
|
4441 |
Initialization method for switch times. |
|
|
4442 |
`'invert'`: initial RNA switch time will be computed with scVelo time |
|
|
4443 |
inversion method. |
|
|
4444 |
`'grid'`: grid search the best set of switch times. |
|
|
4445 |
`'simple'`: simply initialize switch times to be 5, 10, and 15. |
|
|
4446 |
device: `str` (default: `'cpu'`) |
|
|
4447 |
The CUDA device that pytorch tensor calculations will be run on. Only |
|
|
4448 |
to be used with Adam or Neural Network mode. |
|
|
4449 |
neural_net: `bool` (default: `False`) |
|
|
4450 |
Whether to run time predictions with a neural network or not. Shortens |
|
|
4451 |
runtime at the expense of accuracy. If False, uses the usual method of |
|
|
4452 |
assigning each data point to an anchor time point as outlined in the |
|
|
4453 |
Multivelo paper. |
|
|
4454 |
adam: `bool` (default: `False`) |
|
|
4455 |
Whether MSE minimization is handled by the Adam algorithm or not. When |
|
|
4456 |
set to the default of False, function uses Nelder-Mead instead. |
|
|
4457 |
adam_lr: `float` (default: `None`) |
|
|
4458 |
The learning rate to use the Adam algorithm. If adam is False, this |
|
|
4459 |
value is ignored. |
|
|
4460 |
adam_beta1: `float` (default: `None`) |
|
|
4461 |
The beta1 parameter for the Adam algorithm. If adam is False, this |
|
|
4462 |
value is ignored. |
|
|
4463 |
adam_beta2: `float` (default: `None`) |
|
|
4464 |
The beta2 parameter for the Adam algorithm. If adam is False, this |
|
|
4465 |
value is ignored. |
|
|
4466 |
batch_size: `int` (default: `None`) |
|
|
4467 |
Speeds up performance using minibatch training. Specifies number of |
|
|
4468 |
cells to use per run of MSE when running the Adam algorithm. Ignored |
|
|
4469 |
if Adam is set to False. |
|
|
4470 |
model_to_run: `int` or list of `int` (default: `None`) |
|
|
4471 |
User specified models for each genes. Possible values are 1 are 2. If |
|
|
4472 |
`None`, the model |
|
|
4473 |
for each gene will be inferred based on expression patterns. If more |
|
|
4474 |
than one value is given, |
|
|
4475 |
the best model will be decided based on loss of fit. |
|
|
4476 |
plot: `bool` or `None` (default: `False`) |
|
|
4477 |
Whether to interactively plot the 3D gene portraits. Ignored if |
|
|
4478 |
parallel is True. |
|
|
4479 |
parallel: `bool` (default: `True`) |
|
|
4480 |
Whether to fit genes in a parallel fashion (recommended). |
|
|
4481 |
n_jobs: `int` (default: available threads) |
|
|
4482 |
Number of parallel jobs. |
|
|
4483 |
save_plot: `bool` (default: `False`) |
|
|
4484 |
Whether to save the fitted gene portrait figures as files. This will |
|
|
4485 |
take some disk space. |
|
|
4486 |
plot_dir: `str` (default: `plots` for multiome and `rna_plots` for |
|
|
4487 |
RNA-only) |
|
|
4488 |
Directory to save the plots. |
|
|
4489 |
rna_only: `bool` (default: `False`) |
|
|
4490 |
Whether to only use RNA for fitting (RNA velocity). |
|
|
4491 |
fit: `bool` (default: `True`) |
|
|
4492 |
Whether to fit the models. If False, only pre-determination and |
|
|
4493 |
initialization will be run. |
|
|
4494 |
fit_decoupling: `bool` (default: `True`) |
|
|
4495 |
Whether to fit decoupling phase (Model 1 vs Model 2 distinction). |
|
|
4496 |
n_anchors: `int` (default: 500) |
|
|
4497 |
Number of anchor time-points to generate as a representation of the |
|
|
4498 |
trajectory. |
|
|
4499 |
k_dist: `int` (default: 1) |
|
|
4500 |
Number of anchors to use to determine a cell's gene time. If more than |
|
|
4501 |
1, time will be averaged. |
|
|
4502 |
thresh_multiplier: `float` (default: 1.0) |
|
|
4503 |
Multiplier for the heuristic threshold of partial versus complete |
|
|
4504 |
trajectory pre-determination. |
|
|
4505 |
weight_c: `float` (default: 0.6) |
|
|
4506 |
Weighting of scaled chromatin distances when performing 3D residual |
|
|
4507 |
calculation. |
|
|
4508 |
outlier: `float` (default: 99.8) |
|
|
4509 |
The percentile to mark as outlier that will be excluded when fitting |
|
|
4510 |
the model. |
|
|
4511 |
n_pcs: `int` (default: 30) |
|
|
4512 |
Number of principal components to compute distance smoothing neighbors. |
|
|
4513 |
This can be different from the one used for expression smoothing. |
|
|
4514 |
n_neighbors: `int` (default: 30) |
|
|
4515 |
Number of nearest neighbors for distance smoothing. |
|
|
4516 |
This can be different from the one used for expression smoothing. |
|
|
4517 |
fig_size: `tuple` (default: (8,6)) |
|
|
4518 |
Size of each figure when saved. |
|
|
4519 |
point_size: `float` (default: 7) |
|
|
4520 |
Marker point size for plotting. |
|
|
4521 |
extra_color_key: `str` (default: `None`) |
|
|
4522 |
Extra color key used for plotting. Common choices are `leiden`, |
|
|
4523 |
`celltype`, etc. |
|
|
4524 |
The colors for each category must be present in one of anndatas, which |
|
|
4525 |
can be pre-computed |
|
|
4526 |
with `scanpy.pl.scatter` function. |
|
|
4527 |
embedding: `str` (default: `X_umap`) |
|
|
4528 |
2D coordinates of the low-dimensional embedding of cells. |
|
|
4529 |
partial: `bool` or list of `bool` (default: `None`) |
|
|
4530 |
User specified trajectory completeness for each gene. |
|
|
4531 |
direction: `str` or list of `str` (default: `None`) |
|
|
4532 |
User specified trajectory directionality for each gene. |
|
|
4533 |
rescale_u: `float` or list of `float` (default: `None`) |
|
|
4534 |
Known scaling factors for unspliced. Can be computed from scVelo |
|
|
4535 |
`fit_scaling` values |
|
|
4536 |
as `rescale_u = fit_scaling / std(u) * std(s)`. |
|
|
4537 |
alpha: `float` or list of `float` (default: `None`) |
|
|
4538 |
Known trascription rates. Can be computed from scVelo `fit_alpha` |
|
|
4539 |
values |
|
|
4540 |
as `alpha = fit_alpha * fit_alignment_scaling`. |
|
|
4541 |
beta: `float` or list of `float` (default: `None`) |
|
|
4542 |
Known splicing rates. Can be computed from scVelo `fit_alpha` values |
|
|
4543 |
as `beta = fit_beta * fit_alignment_scaling`. |
|
|
4544 |
gamma: `float` or list of `float` (default: `None`) |
|
|
4545 |
Known degradation rates. Can be computed from scVelo `fit_gamma` values |
|
|
4546 |
as `gamma = fit_gamma * fit_alignment_scaling`. |
|
|
4547 |
t_sw: `float` or list of `float` (default: `None`) |
|
|
4548 |
Known RNA switch time. Can be computed from scVelo `fit_t_` values |
|
|
4549 |
as `t_sw = fit_t_ / fit_alignment_scaling`. |
|
|
4550 |
|
|
|
4551 |
Returns |
|
|
4552 |
------- |
|
|
4553 |
fit_alpha_c, fit_alpha, fit_beta, fit_gamma: `.var` |
|
|
4554 |
inferred chromatin opening, transcription, splicing, and degradation |
|
|
4555 |
(nuclear export) rates |
|
|
4556 |
fit_t_sw1, fit_t_sw2, fit_t_sw3: `.var` |
|
|
4557 |
inferred switching time points |
|
|
4558 |
fit_rescale_c, fit_rescale_u: `.var` |
|
|
4559 |
inferred scaling factor for chromatin and unspliced counts |
|
|
4560 |
fit_scale_cc: `.var` |
|
|
4561 |
inferred scaling value for chromatin closing rate compared to opening |
|
|
4562 |
rate |
|
|
4563 |
fit_alignment_scaling: `.var` |
|
|
4564 |
ratio used to realign observed time range to 0-20 |
|
|
4565 |
fit_c0, fit_u0, fit_s0: `.var` |
|
|
4566 |
initial expression values at earliest observed time |
|
|
4567 |
fit_model: `.var` |
|
|
4568 |
inferred gene model |
|
|
4569 |
fit_direction: `.var` |
|
|
4570 |
inferred gene direction |
|
|
4571 |
fit_loss: `.var` |
|
|
4572 |
loss of model fit |
|
|
4573 |
fit_likelihood: `.var` |
|
|
4574 |
likelihood of model fit |
|
|
4575 |
fit_likelihood_c: `.var` |
|
|
4576 |
likelihood of chromatin fit |
|
|
4577 |
fit_anchor_c, fit_anchor_u, fit_anchor_s: `.varm` |
|
|
4578 |
anchor expressions |
|
|
4579 |
fit_anchor_c_sw, fit_anchor_u_sw, fit_anchor_s_sw: `.varm` |
|
|
4580 |
switch time-point expressions |
|
|
4581 |
fit_anchor_c_velo, fit_anchor_u_velo, fit_anchor_s_velo: `.varm` |
|
|
4582 |
velocities of anchors |
|
|
4583 |
fit_anchor_min_idx: `.var` |
|
|
4584 |
first anchor mapped to observations |
|
|
4585 |
fit_anchor_max_idx: `.var` |
|
|
4586 |
last anchor mapped to observations |
|
|
4587 |
fit_anchor_velo_min_idx: `.var` |
|
|
4588 |
first velocity anchor mapped to observations |
|
|
4589 |
fit_anchor_velo_max_idx: `.var` |
|
|
4590 |
last velocity anchor mapped to observations |
|
|
4591 |
fit_t: `.layers` |
|
|
4592 |
inferred gene time |
|
|
4593 |
fit_state: `.layers` |
|
|
4594 |
inferred state assignments |
|
|
4595 |
velo_s, velo_u, velo_chrom: `.layers` |
|
|
4596 |
velocities in spliced, unspliced, and chromatin space |
|
|
4597 |
velo_s_genes, velo_u_genes, velo_chrom_genes: `.var` |
|
|
4598 |
velocity genes |
|
|
4599 |
velo_s_params, velo_u_params, velo_chrom_params: `.var` |
|
|
4600 |
fitting arguments used |
|
|
4601 |
ATAC: `.layers` |
|
|
4602 |
KNN smoothed chromatin accessibilities copied from adata_atac |
|
|
4603 |
""" |
|
|
4604 |
|
|
|
4605 |
fit_args = {} |
|
|
4606 |
fit_args['max_iter'] = max_iter |
|
|
4607 |
fit_args['init_mode'] = init_mode |
|
|
4608 |
fit_args['fit_decoupling'] = fit_decoupling |
|
|
4609 |
n_anchors = np.clip(int(n_anchors), 201, 2000) |
|
|
4610 |
fit_args['t'] = n_anchors |
|
|
4611 |
fit_args['k'] = k_dist |
|
|
4612 |
fit_args['thresh_multiplier'] = thresh_multiplier |
|
|
4613 |
fit_args['weight_c'] = weight_c |
|
|
4614 |
fit_args['outlier'] = outlier |
|
|
4615 |
fit_args['n_pcs'] = n_pcs |
|
|
4616 |
fit_args['n_neighbors'] = n_neighbors |
|
|
4617 |
fit_args['fig_size'] = list(fig_size) |
|
|
4618 |
fit_args['point_size'] = point_size |
|
|
4619 |
|
|
|
4620 |
if adam and neural_net: |
|
|
4621 |
raise Exception("ADAM and Neural Net mode can not be run concurently." |
|
|
4622 |
" Please choose one to run on.") |
|
|
4623 |
|
|
|
4624 |
if not adam and not neural_net and not device == "cpu": |
|
|
4625 |
raise Exception("Multivelo only uses non-CPU devices for Adam or" |
|
|
4626 |
" Neural Network mode. Please use one of those or" |
|
|
4627 |
"set the device to \"cpu\"") |
|
|
4628 |
|
|
|
4629 |
if adam and not device[0:5] == "cuda:": |
|
|
4630 |
raise Exception("ADAM and Neural Net mode are only possible on a cuda " |
|
|
4631 |
"device. Please try again.") |
|
|
4632 |
if not adam and batch_size is not None: |
|
|
4633 |
raise Exception("Batch training is for ADAM only, please set " |
|
|
4634 |
"batch_size to None") |
|
|
4635 |
|
|
|
4636 |
if adam: |
|
|
4637 |
from cuml.neighbors import NearestNeighbors |
|
|
4638 |
|
|
|
4639 |
all_genes = adata_rna.var_names |
|
|
4640 |
if adata_atac is None: |
|
|
4641 |
import anndata as ad |
|
|
4642 |
rna_only = True |
|
|
4643 |
adata_atac = ad.AnnData(X=np.ones(adata_rna.shape), obs=adata_rna.obs, |
|
|
4644 |
var=adata_rna.var) |
|
|
4645 |
adata_atac.layers['Mc'] = np.ones(adata_rna.shape) |
|
|
4646 |
if adata_rna.shape != adata_atac.shape: |
|
|
4647 |
raise ValueError('Shape of RNA and ATAC adata objects do not match: ' |
|
|
4648 |
f'{adata_rna.shape} {adata_atac.shape}') |
|
|
4649 |
if not np.all(adata_rna.obs_names == adata_atac.obs_names): |
|
|
4650 |
raise ValueError('obs_names of RNA and ATAC adata objects do not ' |
|
|
4651 |
'match, please check if they are consistent') |
|
|
4652 |
if not np.all(all_genes == adata_atac.var_names): |
|
|
4653 |
raise ValueError('var_names of RNA and ATAC adata objects do not ' |
|
|
4654 |
'match, please check if they are consistent') |
|
|
4655 |
if 'connectivities' not in adata_rna.obsp.keys(): |
|
|
4656 |
raise ValueError('Missing connectivities entry in RNA adata object') |
|
|
4657 |
if extra_color_key is None: |
|
|
4658 |
extra_color = None |
|
|
4659 |
elif (isinstance(extra_color_key, str) and extra_color_key in adata_rna.obs |
|
|
4660 |
and adata_rna.obs[extra_color_key].dtype.name == 'category'): |
|
|
4661 |
ngroups = len(adata_rna.obs[extra_color_key].cat.categories) |
|
|
4662 |
extra_color = adata_rna.obs[extra_color_key].cat.rename_categories( |
|
|
4663 |
adata_rna.uns[extra_color_key+'_colors'][:ngroups]).to_numpy() |
|
|
4664 |
elif (isinstance(extra_color_key, str) and extra_color_key in |
|
|
4665 |
adata_atac.obs and |
|
|
4666 |
adata_rna.obs[extra_color_key].dtype.name == 'category'): |
|
|
4667 |
ngroups = len(adata_atac.obs[extra_color_key].cat.categories) |
|
|
4668 |
extra_color = adata_atac.obs[extra_color_key].cat.rename_categories( |
|
|
4669 |
adata_atac.uns[extra_color_key+'_colors'][:ngroups]).to_numpy() |
|
|
4670 |
else: |
|
|
4671 |
raise ValueError('Currently, extra_color_key must be a single string ' |
|
|
4672 |
'of categories and available in adata obs, and its ' |
|
|
4673 |
'colors can be found in adata uns') |
|
|
4674 |
if ('connectivities' not in adata_rna.obsp.keys() or |
|
|
4675 |
(adata_rna.obsp['connectivities'] > 0).sum(1).min() |
|
|
4676 |
> (n_neighbors-1)): |
|
|
4677 |
neighbors = Neighbors(adata_rna) |
|
|
4678 |
neighbors.compute_neighbors(n_neighbors=n_neighbors, knn=True, |
|
|
4679 |
n_pcs=n_pcs) |
|
|
4680 |
rna_conn = neighbors.connectivities |
|
|
4681 |
else: |
|
|
4682 |
rna_conn = adata_rna.obsp['connectivities'].copy() |
|
|
4683 |
rna_conn.setdiag(1) |
|
|
4684 |
rna_conn = rna_conn.multiply(1.0 / rna_conn.sum(1)).tocsr() |
|
|
4685 |
if not rna_only: |
|
|
4686 |
if 'connectivities' not in adata_atac.obsp.keys(): |
|
|
4687 |
logg.update('Missing connectivities in ATAC adata object, using ' |
|
|
4688 |
'RNA connectivities instead', v=1) |
|
|
4689 |
atac_conn = rna_conn |
|
|
4690 |
else: |
|
|
4691 |
atac_conn = adata_atac.obsp['connectivities'].copy() |
|
|
4692 |
atac_conn.setdiag(1) |
|
|
4693 |
atac_conn = atac_conn.multiply(1.0 / atac_conn.sum(1)).tocsr() |
|
|
4694 |
if gene_list is None: |
|
|
4695 |
if 'highly_variable' in adata_rna.var: |
|
|
4696 |
gene_list = adata_rna.var_names[adata_rna.var['highly_variable']]\ |
|
|
4697 |
.values |
|
|
4698 |
else: |
|
|
4699 |
gene_list = adata_rna.var_names.values[ |
|
|
4700 |
(~np.isnan(np.asarray(adata_rna.layers['Mu'].sum(0)) |
|
|
4701 |
.reshape(-1) |
|
|
4702 |
if sparse.issparse(adata_rna.layers['Mu']) |
|
|
4703 |
else np.sum(adata_rna.layers['Mu'], axis=0))) |
|
|
4704 |
& (~np.isnan(np.asarray(adata_rna.layers['Ms'].sum(0)) |
|
|
4705 |
.reshape(-1) |
|
|
4706 |
if sparse.issparse(adata_rna.layers['Ms']) |
|
|
4707 |
else np.sum(adata_rna.layers['Ms'], axis=0))) |
|
|
4708 |
& (~np.isnan(np.asarray(adata_atac.layers['Mc'].sum(0)) |
|
|
4709 |
.reshape(-1) |
|
|
4710 |
if sparse.issparse(adata_atac.layers['Mc']) |
|
|
4711 |
else np.sum(adata_atac.layers['Mc'], axis=0)))] |
|
|
4712 |
elif isinstance(gene_list, (list, np.ndarray, pd.Index, pd.Series)): |
|
|
4713 |
gene_list = np.array([x for x in gene_list if x in all_genes]) |
|
|
4714 |
elif isinstance(gene_list, str): |
|
|
4715 |
gene_list = np.array([gene_list]) if gene_list in all_genes else [] |
|
|
4716 |
else: |
|
|
4717 |
raise ValueError('Invalid gene list, must be one of (str, np.ndarray,' |
|
|
4718 |
'pd.Index, pd.Series)') |
|
|
4719 |
gn = len(gene_list) |
|
|
4720 |
if gn == 0: |
|
|
4721 |
raise ValueError('None of the genes specified are in the adata object') |
|
|
4722 |
logg.update(f'{gn} genes will be fitted', v=1) |
|
|
4723 |
|
|
|
4724 |
models = np.zeros(gn) |
|
|
4725 |
t_sws = np.zeros((gn, 3)) |
|
|
4726 |
rates = np.zeros((gn, 4)) |
|
|
4727 |
scale_ccs = np.zeros(gn) |
|
|
4728 |
rescale_cs = np.zeros(gn) |
|
|
4729 |
rescale_us = np.zeros(gn) |
|
|
4730 |
realign_ratios = np.zeros(gn) |
|
|
4731 |
initial_exps = np.zeros((gn, 3)) |
|
|
4732 |
times = np.zeros((adata_rna.n_obs, gn)) |
|
|
4733 |
states = np.zeros((adata_rna.n_obs, gn)) |
|
|
4734 |
if not rna_only: |
|
|
4735 |
velo_c = np.zeros((adata_rna.n_obs, gn)) |
|
|
4736 |
velo_u = np.zeros((adata_rna.n_obs, gn)) |
|
|
4737 |
velo_s = np.zeros((adata_rna.n_obs, gn)) |
|
|
4738 |
likelihoods = np.zeros(gn) |
|
|
4739 |
l_cs = np.zeros(gn) |
|
|
4740 |
ssd_cs = np.zeros(gn) |
|
|
4741 |
var_cs = np.zeros(gn) |
|
|
4742 |
directions = [] |
|
|
4743 |
anchor_c = np.zeros((n_anchors, gn)) |
|
|
4744 |
anchor_u = np.zeros((n_anchors, gn)) |
|
|
4745 |
anchor_s = np.zeros((n_anchors, gn)) |
|
|
4746 |
anchor_c_sw = np.zeros((3, gn)) |
|
|
4747 |
anchor_u_sw = np.zeros((3, gn)) |
|
|
4748 |
anchor_s_sw = np.zeros((3, gn)) |
|
|
4749 |
anchor_vc = np.zeros((n_anchors, gn)) |
|
|
4750 |
anchor_vu = np.zeros((n_anchors, gn)) |
|
|
4751 |
anchor_vs = np.zeros((n_anchors, gn)) |
|
|
4752 |
anchor_min_idx = np.zeros(gn) |
|
|
4753 |
anchor_max_idx = np.zeros(gn) |
|
|
4754 |
anchor_velo_min_idx = np.zeros(gn) |
|
|
4755 |
anchor_velo_max_idx = np.zeros(gn) |
|
|
4756 |
|
|
|
4757 |
if rna_only: |
|
|
4758 |
model_to_run = [2] |
|
|
4759 |
logg.update('Skipping model checking for RNA-only, running model 2', |
|
|
4760 |
v=1) |
|
|
4761 |
|
|
|
4762 |
m_per_g = False |
|
|
4763 |
if model_to_run is not None: |
|
|
4764 |
if isinstance(model_to_run, (list, np.ndarray, pd.Index, pd.Series)): |
|
|
4765 |
model_to_run = [int(x) for x in model_to_run] |
|
|
4766 |
if np.any(~np.isin(model_to_run, [0, 1, 2])): |
|
|
4767 |
raise ValueError('Invalid model number (must be values in' |
|
|
4768 |
' [0,1,2])') |
|
|
4769 |
if len(model_to_run) == gn: |
|
|
4770 |
losses = np.zeros((gn, 1)) |
|
|
4771 |
m_per_g = True |
|
|
4772 |
func_to_call = regress_func |
|
|
4773 |
else: |
|
|
4774 |
losses = np.zeros((gn, len(model_to_run))) |
|
|
4775 |
func_to_call = multimodel_helper |
|
|
4776 |
elif isinstance(model_to_run, (int, float)): |
|
|
4777 |
model_to_run = int(model_to_run) |
|
|
4778 |
if not np.isin(model_to_run, [0, 1, 2]): |
|
|
4779 |
raise ValueError('Invalid model number (must be values in ' |
|
|
4780 |
'[0,1,2])') |
|
|
4781 |
model_to_run = [model_to_run] |
|
|
4782 |
losses = np.zeros((gn, 1)) |
|
|
4783 |
func_to_call = multimodel_helper |
|
|
4784 |
else: |
|
|
4785 |
raise ValueError('Invalid model number (must be values in ' |
|
|
4786 |
'[0,1,2])') |
|
|
4787 |
else: |
|
|
4788 |
losses = np.zeros((gn, 1)) |
|
|
4789 |
func_to_call = regress_func |
|
|
4790 |
|
|
|
4791 |
p_per_g = False |
|
|
4792 |
if partial is not None: |
|
|
4793 |
if isinstance(partial, (list, np.ndarray, pd.Index, pd.Series)): |
|
|
4794 |
if np.any(~np.isin(partial, [True, False])): |
|
|
4795 |
raise ValueError('Invalid partial argument (must be values in' |
|
|
4796 |
' [True,False])') |
|
|
4797 |
if len(partial) == gn: |
|
|
4798 |
p_per_g = True |
|
|
4799 |
else: |
|
|
4800 |
raise ValueError('Incorrect partial argument length') |
|
|
4801 |
elif isinstance(partial, bool): |
|
|
4802 |
if not np.isin(partial, [True, False]): |
|
|
4803 |
raise ValueError('Invalid partial argument (must be values in' |
|
|
4804 |
' [True,False])') |
|
|
4805 |
else: |
|
|
4806 |
raise ValueError('Invalid partial argument (must be values in' |
|
|
4807 |
' [True,False])') |
|
|
4808 |
|
|
|
4809 |
d_per_g = False |
|
|
4810 |
if direction is not None: |
|
|
4811 |
if isinstance(direction, (list, np.ndarray, pd.Index, pd.Series)): |
|
|
4812 |
if np.any(~np.isin(direction, ['on', 'off', 'complete'])): |
|
|
4813 |
raise ValueError('Invalid direction argument (must be values' |
|
|
4814 |
' in ["on","off","complete"])') |
|
|
4815 |
if len(direction) == gn: |
|
|
4816 |
d_per_g = True |
|
|
4817 |
else: |
|
|
4818 |
raise ValueError('Incorrect direction argument length') |
|
|
4819 |
elif isinstance(direction, str): |
|
|
4820 |
if not np.isin(direction, ['on', 'off', 'complete']): |
|
|
4821 |
raise ValueError('Invalid direction argument (must be values' |
|
|
4822 |
' in ["on","off","complete"])') |
|
|
4823 |
else: |
|
|
4824 |
raise ValueError('Invalid direction argument (must be values in' |
|
|
4825 |
' ["on","off","complete"])') |
|
|
4826 |
|
|
|
4827 |
known_pars = [rescale_u, alpha, beta, gamma, t_sw] |
|
|
4828 |
for x in known_pars: |
|
|
4829 |
if x is not None: |
|
|
4830 |
if isinstance(x, (list, np.ndarray)): |
|
|
4831 |
if np.sum(np.isnan(x)) + np.sum(np.isinf(x)) > 0: |
|
|
4832 |
raise ValueError('Known parameters cannot contain NaN or' |
|
|
4833 |
' Inf') |
|
|
4834 |
elif isinstance(x, (int, float)): |
|
|
4835 |
if x == np.nan or x == np.inf: |
|
|
4836 |
raise ValueError('Known parameters cannot contain NaN or' |
|
|
4837 |
' Inf') |
|
|
4838 |
else: |
|
|
4839 |
raise ValueError('Invalid known parameters type') |
|
|
4840 |
|
|
|
4841 |
if ((embedding not in adata_rna.obsm) and |
|
|
4842 |
(embedding not in adata_atac.obsm)): |
|
|
4843 |
raise ValueError(f'{embedding} is not found in obsm') |
|
|
4844 |
embed_coord = adata_rna.obsm[embedding] if embedding in adata_rna.obsm \ |
|
|
4845 |
else adata_atac.obsm[embedding] |
|
|
4846 |
global_pdist = pairwise_distances(embed_coord) |
|
|
4847 |
|
|
|
4848 |
u_mat = adata_rna[:, gene_list].layers['Mu'].A \ |
|
|
4849 |
if sparse.issparse(adata_rna.layers['Mu']) \ |
|
|
4850 |
else adata_rna[:, gene_list].layers['Mu'] |
|
|
4851 |
s_mat = adata_rna[:, gene_list].layers['Ms'].A \ |
|
|
4852 |
if sparse.issparse(adata_rna.layers['Ms']) \ |
|
|
4853 |
else adata_rna[:, gene_list].layers['Ms'] |
|
|
4854 |
c_mat = adata_atac[:, gene_list].layers['Mc'].A \ |
|
|
4855 |
if sparse.issparse(adata_atac.layers['Mc']) \ |
|
|
4856 |
else adata_atac[:, gene_list].layers['Mc'] |
|
|
4857 |
|
|
|
4858 |
ru = rescale_u if rescale_u is not None else None |
|
|
4859 |
|
|
|
4860 |
if parallel: |
|
|
4861 |
if (n_jobs is None or not isinstance(n_jobs, int) or n_jobs < 0 or |
|
|
4862 |
n_jobs > os.cpu_count()): |
|
|
4863 |
n_jobs = os.cpu_count() |
|
|
4864 |
if n_jobs > gn: |
|
|
4865 |
n_jobs = gn |
|
|
4866 |
batches = -(-gn // n_jobs) |
|
|
4867 |
if n_jobs > 1: |
|
|
4868 |
logg.update(f'running {n_jobs} jobs in parallel', v=1) |
|
|
4869 |
else: |
|
|
4870 |
n_jobs = 1 |
|
|
4871 |
batches = gn |
|
|
4872 |
if n_jobs == 1: |
|
|
4873 |
parallel = False |
|
|
4874 |
|
|
|
4875 |
pbar = tqdm(total=gn) |
|
|
4876 |
for group in range(batches): |
|
|
4877 |
gene_indices = range(group * n_jobs, np.min([gn, (group+1) * n_jobs])) |
|
|
4878 |
if parallel: |
|
|
4879 |
verb = 51 if settings.VERBOSITY >= 2 else 0 |
|
|
4880 |
plot = False |
|
|
4881 |
|
|
|
4882 |
# clear the settings file if it exists |
|
|
4883 |
open("settings.txt", "w").close() |
|
|
4884 |
|
|
|
4885 |
# write our current settings to the file |
|
|
4886 |
with open("settings.txt", "a") as sfile: |
|
|
4887 |
sfile.write(str(settings.VERBOSITY) + "\n") |
|
|
4888 |
sfile.write(str(settings.CWD) + "\n") |
|
|
4889 |
sfile.write(str(settings.LOG_FOLDER) + "\n") |
|
|
4890 |
sfile.write(str(settings.LOG_FILENAME) + "\n") |
|
|
4891 |
|
|
|
4892 |
res = Parallel(n_jobs=n_jobs, backend='loky', verbose=verb)( |
|
|
4893 |
delayed(func_to_call)( |
|
|
4894 |
c_mat[:, i], |
|
|
4895 |
u_mat[:, i], |
|
|
4896 |
s_mat[:, i], |
|
|
4897 |
model_to_run[i] if m_per_g else model_to_run, |
|
|
4898 |
max_iter, |
|
|
4899 |
init_mode, |
|
|
4900 |
device, |
|
|
4901 |
neural_net, |
|
|
4902 |
adam, |
|
|
4903 |
adam_lr, |
|
|
4904 |
adam_beta1, |
|
|
4905 |
adam_beta2, |
|
|
4906 |
batch_size, |
|
|
4907 |
global_pdist, |
|
|
4908 |
embed_coord, |
|
|
4909 |
rna_conn, |
|
|
4910 |
plot, |
|
|
4911 |
save_plot, |
|
|
4912 |
plot_dir, |
|
|
4913 |
fit_args, |
|
|
4914 |
gene_list[i], |
|
|
4915 |
partial[i] if p_per_g else partial, |
|
|
4916 |
direction[i] if d_per_g else direction, |
|
|
4917 |
rna_only, |
|
|
4918 |
fit, |
|
|
4919 |
fit_decoupling, |
|
|
4920 |
extra_color, |
|
|
4921 |
ru[i] if isinstance(ru, (list, np.ndarray)) else ru, |
|
|
4922 |
alpha[i] if isinstance(alpha, (list, np.ndarray)) |
|
|
4923 |
else alpha, |
|
|
4924 |
beta[i] if isinstance(beta, (list, np.ndarray)) |
|
|
4925 |
else beta, |
|
|
4926 |
gamma[i] if isinstance(gamma, (list, np.ndarray)) |
|
|
4927 |
else gamma, |
|
|
4928 |
t_sw[i] if isinstance(t_sw, (list, np.ndarray)) else t_sw, |
|
|
4929 |
settings.VERBOSITY, |
|
|
4930 |
settings.LOG_FOLDER, |
|
|
4931 |
settings.LOG_FILENAME) |
|
|
4932 |
for i in gene_indices) |
|
|
4933 |
|
|
|
4934 |
for i, r in zip(gene_indices, res): |
|
|
4935 |
(loss, model, direct_out, parameters, initial_exp, |
|
|
4936 |
time, state, velocity, likelihood, anchors) = r |
|
|
4937 |
switch, rate, scale_cc, rescale_c, rescale_u, realign_ratio = \ |
|
|
4938 |
parameters |
|
|
4939 |
likelihood, l_c, ssd_c, var_c = likelihood |
|
|
4940 |
losses[i, :] = loss |
|
|
4941 |
models[i] = model |
|
|
4942 |
directions.append(direct_out) |
|
|
4943 |
t_sws[i, :] = switch |
|
|
4944 |
rates[i, :] = rate |
|
|
4945 |
scale_ccs[i] = scale_cc |
|
|
4946 |
rescale_cs[i] = rescale_c |
|
|
4947 |
rescale_us[i] = rescale_u |
|
|
4948 |
realign_ratios[i] = realign_ratio |
|
|
4949 |
likelihoods[i] = likelihood |
|
|
4950 |
l_cs[i] = l_c |
|
|
4951 |
ssd_cs[i] = ssd_c |
|
|
4952 |
var_cs[i] = var_c |
|
|
4953 |
if fit: |
|
|
4954 |
initial_exps[i, :] = initial_exp |
|
|
4955 |
times[:, i] = time |
|
|
4956 |
states[:, i] = state |
|
|
4957 |
n_anchors_ = anchors[0].shape[0] |
|
|
4958 |
n_switch = anchors[1].shape[0] |
|
|
4959 |
if not rna_only: |
|
|
4960 |
velo_c[:, i] = smooth_scale(atac_conn, velocity[:, 0]) |
|
|
4961 |
anchor_c[:n_anchors_, i] = anchors[0][:, 0] |
|
|
4962 |
anchor_c_sw[:n_switch, i] = anchors[1][:, 0] |
|
|
4963 |
anchor_vc[:n_anchors_, i] = anchors[2][:, 0] |
|
|
4964 |
velo_u[:, i] = smooth_scale(rna_conn, velocity[:, 1]) |
|
|
4965 |
velo_s[:, i] = smooth_scale(rna_conn, velocity[:, 2]) |
|
|
4966 |
anchor_u[:n_anchors_, i] = anchors[0][:, 1] |
|
|
4967 |
anchor_s[:n_anchors_, i] = anchors[0][:, 2] |
|
|
4968 |
anchor_u_sw[:n_switch, i] = anchors[1][:, 1] |
|
|
4969 |
anchor_s_sw[:n_switch, i] = anchors[1][:, 2] |
|
|
4970 |
anchor_vu[:n_anchors_, i] = anchors[2][:, 1] |
|
|
4971 |
anchor_vs[:n_anchors_, i] = anchors[2][:, 2] |
|
|
4972 |
anchor_min_idx[i] = anchors[3] |
|
|
4973 |
anchor_max_idx[i] = anchors[4] |
|
|
4974 |
anchor_velo_min_idx[i] = anchors[5] |
|
|
4975 |
anchor_velo_max_idx[i] = anchors[6] |
|
|
4976 |
else: |
|
|
4977 |
i = group |
|
|
4978 |
gene = gene_list[i] |
|
|
4979 |
logg.update(f'@@@@@fitting {gene}', v=1) |
|
|
4980 |
(loss, model, direct_out, |
|
|
4981 |
parameters, initial_exp, |
|
|
4982 |
time, state, velocity, |
|
|
4983 |
likelihood, anchors) = \ |
|
|
4984 |
func_to_call(c_mat[:, i], u_mat[:, i], s_mat[:, i], |
|
|
4985 |
model_to_run[i] if m_per_g else model_to_run, |
|
|
4986 |
max_iter, init_mode, |
|
|
4987 |
device, |
|
|
4988 |
neural_net, |
|
|
4989 |
adam, |
|
|
4990 |
adam_lr, |
|
|
4991 |
adam_beta1, |
|
|
4992 |
adam_beta2, |
|
|
4993 |
batch_size, |
|
|
4994 |
global_pdist, embed_coord, |
|
|
4995 |
rna_conn, plot, save_plot, plot_dir, |
|
|
4996 |
fit_args, gene, |
|
|
4997 |
partial[i] if p_per_g else partial, |
|
|
4998 |
direction[i] if d_per_g else direction, |
|
|
4999 |
rna_only, fit, fit_decoupling, extra_color, |
|
|
5000 |
ru[i] if isinstance(ru, (list, np.ndarray)) |
|
|
5001 |
else ru, |
|
|
5002 |
alpha[i] if isinstance(alpha, (list, np.ndarray)) |
|
|
5003 |
else alpha, |
|
|
5004 |
beta[i] if isinstance(beta, (list, np.ndarray)) |
|
|
5005 |
else beta, |
|
|
5006 |
gamma[i] if isinstance(gamma, (list, np.ndarray)) |
|
|
5007 |
else gamma, |
|
|
5008 |
t_sw[i] if isinstance(t_sw, (list, np.ndarray)) |
|
|
5009 |
else t_sw, |
|
|
5010 |
settings.VERBOSITY, |
|
|
5011 |
settings.LOG_FOLDER, |
|
|
5012 |
settings.LOG_FILENAME) |
|
|
5013 |
switch, rate, scale_cc, rescale_c, rescale_u, realign_ratio = \ |
|
|
5014 |
parameters |
|
|
5015 |
likelihood, l_c, ssd_c, var_c = likelihood |
|
|
5016 |
losses[i, :] = loss |
|
|
5017 |
models[i] = model |
|
|
5018 |
directions.append(direct_out) |
|
|
5019 |
t_sws[i, :] = switch |
|
|
5020 |
rates[i, :] = rate |
|
|
5021 |
scale_ccs[i] = scale_cc |
|
|
5022 |
rescale_cs[i] = rescale_c |
|
|
5023 |
rescale_us[i] = rescale_u |
|
|
5024 |
realign_ratios[i] = realign_ratio |
|
|
5025 |
likelihoods[i] = likelihood |
|
|
5026 |
l_cs[i] = l_c |
|
|
5027 |
ssd_cs[i] = ssd_c |
|
|
5028 |
var_cs[i] = var_c |
|
|
5029 |
if fit: |
|
|
5030 |
initial_exps[i, :] = initial_exp |
|
|
5031 |
times[:, i] = time |
|
|
5032 |
states[:, i] = state |
|
|
5033 |
n_anchors_ = anchors[0].shape[0] |
|
|
5034 |
n_switch = anchors[1].shape[0] |
|
|
5035 |
if not rna_only: |
|
|
5036 |
velo_c[:, i] = smooth_scale(atac_conn, velocity[:, 0]) |
|
|
5037 |
anchor_c[:n_anchors_, i] = anchors[0][:, 0] |
|
|
5038 |
anchor_c_sw[:n_switch, i] = anchors[1][:, 0] |
|
|
5039 |
anchor_vc[:n_anchors_, i] = anchors[2][:, 0] |
|
|
5040 |
velo_u[:, i] = smooth_scale(rna_conn, velocity[:, 1]) |
|
|
5041 |
velo_s[:, i] = smooth_scale(rna_conn, velocity[:, 2]) |
|
|
5042 |
anchor_u[:n_anchors_, i] = anchors[0][:, 1] |
|
|
5043 |
anchor_s[:n_anchors_, i] = anchors[0][:, 2] |
|
|
5044 |
anchor_u_sw[:n_switch, i] = anchors[1][:, 1] |
|
|
5045 |
anchor_s_sw[:n_switch, i] = anchors[1][:, 2] |
|
|
5046 |
anchor_vu[:n_anchors_, i] = anchors[2][:, 1] |
|
|
5047 |
anchor_vs[:n_anchors_, i] = anchors[2][:, 2] |
|
|
5048 |
anchor_min_idx[i] = anchors[3] |
|
|
5049 |
anchor_max_idx[i] = anchors[4] |
|
|
5050 |
anchor_velo_min_idx[i] = anchors[5] |
|
|
5051 |
anchor_velo_max_idx[i] = anchors[6] |
|
|
5052 |
pbar.update(len(gene_indices)) |
|
|
5053 |
pbar.close() |
|
|
5054 |
directions = np.array(directions) |
|
|
5055 |
|
|
|
5056 |
filt = np.sum(losses != np.inf, 1) >= 1 |
|
|
5057 |
if np.sum(filt) == 0: |
|
|
5058 |
raise ValueError('None of the genes were fitted due to low quality,' |
|
|
5059 |
' not returning') |
|
|
5060 |
adata_copy = adata_rna[:, gene_list[filt]].copy() |
|
|
5061 |
adata_copy.layers['ATAC'] = c_mat[:, filt] |
|
|
5062 |
adata_copy.var['fit_alpha_c'] = rates[filt, 0] |
|
|
5063 |
adata_copy.var['fit_alpha'] = rates[filt, 1] |
|
|
5064 |
adata_copy.var['fit_beta'] = rates[filt, 2] |
|
|
5065 |
adata_copy.var['fit_gamma'] = rates[filt, 3] |
|
|
5066 |
adata_copy.var['fit_t_sw1'] = t_sws[filt, 0] |
|
|
5067 |
adata_copy.var['fit_t_sw2'] = t_sws[filt, 1] |
|
|
5068 |
adata_copy.var['fit_t_sw3'] = t_sws[filt, 2] |
|
|
5069 |
adata_copy.var['fit_scale_cc'] = scale_ccs[filt] |
|
|
5070 |
adata_copy.var['fit_rescale_c'] = rescale_cs[filt] |
|
|
5071 |
adata_copy.var['fit_rescale_u'] = rescale_us[filt] |
|
|
5072 |
adata_copy.var['fit_alignment_scaling'] = realign_ratios[filt] |
|
|
5073 |
adata_copy.var['fit_model'] = models[filt] |
|
|
5074 |
adata_copy.var['fit_direction'] = directions[filt] |
|
|
5075 |
if model_to_run is not None and not m_per_g and not rna_only: |
|
|
5076 |
for i, m in enumerate(model_to_run): |
|
|
5077 |
adata_copy.var[f'fit_loss_M{m}'] = losses[filt, i] |
|
|
5078 |
else: |
|
|
5079 |
adata_copy.var['fit_loss'] = losses[filt, 0] |
|
|
5080 |
adata_copy.var['fit_likelihood'] = likelihoods[filt] |
|
|
5081 |
adata_copy.var['fit_likelihood_c'] = l_cs[filt] |
|
|
5082 |
adata_copy.var['fit_ssd_c'] = ssd_cs[filt] |
|
|
5083 |
adata_copy.var['fit_var_c'] = var_cs[filt] |
|
|
5084 |
if fit: |
|
|
5085 |
adata_copy.layers['fit_t'] = times[:, filt] |
|
|
5086 |
adata_copy.layers['fit_state'] = states[:, filt] |
|
|
5087 |
adata_copy.layers['velo_s'] = velo_s[:, filt] |
|
|
5088 |
adata_copy.layers['velo_u'] = velo_u[:, filt] |
|
|
5089 |
if not rna_only: |
|
|
5090 |
adata_copy.layers['velo_chrom'] = velo_c[:, filt] |
|
|
5091 |
adata_copy.var['fit_c0'] = initial_exps[filt, 0] |
|
|
5092 |
adata_copy.var['fit_u0'] = initial_exps[filt, 1] |
|
|
5093 |
adata_copy.var['fit_s0'] = initial_exps[filt, 2] |
|
|
5094 |
adata_copy.var['fit_anchor_min_idx'] = anchor_min_idx[filt] |
|
|
5095 |
adata_copy.var['fit_anchor_max_idx'] = anchor_max_idx[filt] |
|
|
5096 |
adata_copy.var['fit_anchor_velo_min_idx'] = anchor_velo_min_idx[filt] |
|
|
5097 |
adata_copy.var['fit_anchor_velo_max_idx'] = anchor_velo_max_idx[filt] |
|
|
5098 |
adata_copy.varm['fit_anchor_c'] = np.transpose(anchor_c[:, filt]) |
|
|
5099 |
adata_copy.varm['fit_anchor_u'] = np.transpose(anchor_u[:, filt]) |
|
|
5100 |
adata_copy.varm['fit_anchor_s'] = np.transpose(anchor_s[:, filt]) |
|
|
5101 |
adata_copy.varm['fit_anchor_c_sw'] = np.transpose(anchor_c_sw[:, filt]) |
|
|
5102 |
adata_copy.varm['fit_anchor_u_sw'] = np.transpose(anchor_u_sw[:, filt]) |
|
|
5103 |
adata_copy.varm['fit_anchor_s_sw'] = np.transpose(anchor_s_sw[:, filt]) |
|
|
5104 |
adata_copy.varm['fit_anchor_c_velo'] = np.transpose(anchor_vc[:, filt]) |
|
|
5105 |
adata_copy.varm['fit_anchor_u_velo'] = np.transpose(anchor_vu[:, filt]) |
|
|
5106 |
adata_copy.varm['fit_anchor_s_velo'] = np.transpose(anchor_vs[:, filt]) |
|
|
5107 |
v_genes = adata_copy.var['fit_likelihood'] >= 0.05 |
|
|
5108 |
adata_copy.var['velo_s_genes'] = adata_copy.var['velo_u_genes'] = \ |
|
|
5109 |
adata_copy.var['velo_chrom_genes'] = v_genes |
|
|
5110 |
adata_copy.uns['velo_s_params'] = adata_copy.uns['velo_u_params'] = \ |
|
|
5111 |
adata_copy.uns['velo_chrom_params'] = {'mode': 'dynamical'} |
|
|
5112 |
adata_copy.uns['velo_s_params'].update(fit_args) |
|
|
5113 |
adata_copy.uns['velo_u_params'].update(fit_args) |
|
|
5114 |
adata_copy.uns['velo_chrom_params'].update(fit_args) |
|
|
5115 |
adata_copy.obsp['_RNA_conn'] = rna_conn |
|
|
5116 |
if not rna_only: |
|
|
5117 |
adata_copy.obsp['_ATAC_conn'] = atac_conn |
|
|
5118 |
return adata_copy |
|
|
5119 |
|
|
|
5120 |
|
|
|
5121 |
def smooth_scale(conn, vector): |
|
|
5122 |
max_to = np.max(vector) |
|
|
5123 |
min_to = np.min(vector) |
|
|
5124 |
v = conn.dot(vector.T).T |
|
|
5125 |
max_from = np.max(v) |
|
|
5126 |
min_from = np.min(v) |
|
|
5127 |
res = ((v - min_from) * (max_to - min_to) / (max_from - min_from)) + min_to |
|
|
5128 |
return res |
|
|
5129 |
|
|
|
5130 |
|
|
|
5131 |
def top_n_sparse(conn, n): |
|
|
5132 |
conn_ll = conn.tolil() |
|
|
5133 |
for i in range(conn_ll.shape[0]): |
|
|
5134 |
row_data = np.array(conn_ll.data[i]) |
|
|
5135 |
row_idx = np.array(conn_ll.rows[i]) |
|
|
5136 |
new_idx = row_data.argsort()[-n:] |
|
|
5137 |
top_val = row_data[new_idx] |
|
|
5138 |
top_idx = row_idx[new_idx] |
|
|
5139 |
conn_ll.data[i] = top_val.tolist() |
|
|
5140 |
conn_ll.rows[i] = top_idx.tolist() |
|
|
5141 |
conn = conn_ll.tocsr() |
|
|
5142 |
idx1 = conn > 0 |
|
|
5143 |
idx2 = conn > 0.25 |
|
|
5144 |
idx3 = conn > 0.5 |
|
|
5145 |
conn[idx1] = 0.25 |
|
|
5146 |
conn[idx2] = 0.5 |
|
|
5147 |
conn[idx3] = 1 |
|
|
5148 |
conn.eliminate_zeros() |
|
|
5149 |
return conn |
|
|
5150 |
|
|
|
5151 |
|
|
|
5152 |
def set_velocity_genes(adata, |
|
|
5153 |
likelihood_lower=0.05, |
|
|
5154 |
rescale_u_upper=None, |
|
|
5155 |
rescale_u_lower=None, |
|
|
5156 |
rescale_c_upper=None, |
|
|
5157 |
rescale_c_lower=None, |
|
|
5158 |
primed_upper=None, |
|
|
5159 |
primed_lower=None, |
|
|
5160 |
decoupled_upper=None, |
|
|
5161 |
decoupled_lower=None, |
|
|
5162 |
alpha_c_upper=None, |
|
|
5163 |
alpha_c_lower=None, |
|
|
5164 |
alpha_upper=None, |
|
|
5165 |
alpha_lower=None, |
|
|
5166 |
beta_upper=None, |
|
|
5167 |
beta_lower=None, |
|
|
5168 |
gamma_upper=None, |
|
|
5169 |
gamma_lower=None, |
|
|
5170 |
scale_cc_upper=None, |
|
|
5171 |
scale_cc_lower=None |
|
|
5172 |
): |
|
|
5173 |
"""Reset velocity genes. |
|
|
5174 |
|
|
|
5175 |
This function resets velocity genes based on criteria of variables. |
|
|
5176 |
|
|
|
5177 |
Parameters |
|
|
5178 |
---------- |
|
|
5179 |
adata: :class:`~anndata.AnnData` |
|
|
5180 |
Anndata result from dynamics recovery. |
|
|
5181 |
likelihood_lower: `float` (default: 0.05) |
|
|
5182 |
Minimum ikelihood. |
|
|
5183 |
rescale_u_upper: `float` (default: `None`) |
|
|
5184 |
Maximum rescale_u. |
|
|
5185 |
rescale_u_lower: `float` (default: `None`) |
|
|
5186 |
Minimum rescale_u. |
|
|
5187 |
rescale_c_upper: `float` (default: `None`) |
|
|
5188 |
Maximum rescale_c. |
|
|
5189 |
rescale_c_lower: `float` (default: `None`) |
|
|
5190 |
Minimum rescale_c. |
|
|
5191 |
primed_upper: `float` (default: `None`) |
|
|
5192 |
Maximum primed interval. |
|
|
5193 |
primed_lower: `float` (default: `None`) |
|
|
5194 |
Minimum primed interval. |
|
|
5195 |
decoupled_upper: `float` (default: `None`) |
|
|
5196 |
Maximum decoupled interval. |
|
|
5197 |
decoupled_lower: `float` (default: `None`) |
|
|
5198 |
Minimum decoupled interval. |
|
|
5199 |
alpha_c_upper: `float` (default: `None`) |
|
|
5200 |
Maximum alpha_c. |
|
|
5201 |
alpha_c_lower: `float` (default: `None`) |
|
|
5202 |
Minimum alpha_c. |
|
|
5203 |
alpha_upper: `float` (default: `None`) |
|
|
5204 |
Maximum alpha. |
|
|
5205 |
alpha_lower: `float` (default: `None`) |
|
|
5206 |
Minimum alpha. |
|
|
5207 |
beta_upper: `float` (default: `None`) |
|
|
5208 |
Maximum beta. |
|
|
5209 |
beta_lower: `float` (default: `None`) |
|
|
5210 |
Minimum beta. |
|
|
5211 |
gamma_upper: `float` (default: `None`) |
|
|
5212 |
Maximum gamma. |
|
|
5213 |
gamma_lower: `float` (default: `None`) |
|
|
5214 |
Minimum gamma. |
|
|
5215 |
scale_cc_upper: `float` (default: `None`) |
|
|
5216 |
Maximum scale_cc. |
|
|
5217 |
scale_cc_lower: `float` (default: `None`) |
|
|
5218 |
Minimum scale_cc. |
|
|
5219 |
|
|
|
5220 |
Returns |
|
|
5221 |
------- |
|
|
5222 |
velo_s_genes, velo_u_genes, velo_chrom_genes: `.var` |
|
|
5223 |
new velocity genes for each modalities. |
|
|
5224 |
""" |
|
|
5225 |
|
|
|
5226 |
v_genes = (adata.var['fit_likelihood'] >= likelihood_lower) |
|
|
5227 |
if rescale_u_upper is not None: |
|
|
5228 |
v_genes &= adata.var['fit_rescale_u'] <= rescale_u_upper |
|
|
5229 |
if rescale_u_lower is not None: |
|
|
5230 |
v_genes &= adata.var['fit_rescale_u'] >= rescale_u_lower |
|
|
5231 |
if rescale_c_upper is not None: |
|
|
5232 |
v_genes &= adata.var['fit_rescale_c'] <= rescale_c_upper |
|
|
5233 |
if rescale_c_lower is not None: |
|
|
5234 |
v_genes &= adata.var['fit_rescale_c'] >= rescale_c_lower |
|
|
5235 |
t_sw1 = adata.var['fit_t_sw1'] + 20 / adata.uns['velo_s_params']['t'] * \ |
|
|
5236 |
adata.var['fit_anchor_min_idx'] * adata.var['fit_alignment_scaling'] |
|
|
5237 |
if primed_upper is not None: |
|
|
5238 |
v_genes &= t_sw1 <= primed_upper |
|
|
5239 |
if primed_lower is not None: |
|
|
5240 |
v_genes &= t_sw1 >= primed_lower |
|
|
5241 |
t_sw2 = np.clip(adata.var['fit_t_sw2'], None, 20) |
|
|
5242 |
t_sw3 = np.clip(adata.var['fit_t_sw3'], None, 20) |
|
|
5243 |
t_interval3 = t_sw3 - t_sw2 |
|
|
5244 |
if decoupled_upper is not None: |
|
|
5245 |
v_genes &= t_interval3 <= decoupled_upper |
|
|
5246 |
if decoupled_lower is not None: |
|
|
5247 |
v_genes &= t_interval3 >= decoupled_lower |
|
|
5248 |
if alpha_c_upper is not None: |
|
|
5249 |
v_genes &= adata.var['fit_alpha_c'] <= alpha_c_upper |
|
|
5250 |
if alpha_c_lower is not None: |
|
|
5251 |
v_genes &= adata.var['fit_alpha_c'] >= alpha_c_lower |
|
|
5252 |
if alpha_upper is not None: |
|
|
5253 |
v_genes &= adata.var['fit_alpha'] <= alpha_upper |
|
|
5254 |
if alpha_lower is not None: |
|
|
5255 |
v_genes &= adata.var['fit_alpha'] >= alpha_lower |
|
|
5256 |
if beta_upper is not None: |
|
|
5257 |
v_genes &= adata.var['fit_beta'] <= beta_upper |
|
|
5258 |
if beta_lower is not None: |
|
|
5259 |
v_genes &= adata.var['fit_beta'] >= beta_lower |
|
|
5260 |
if gamma_upper is not None: |
|
|
5261 |
v_genes &= adata.var['fit_gamma'] <= gamma_upper |
|
|
5262 |
if gamma_lower is not None: |
|
|
5263 |
v_genes &= adata.var['fit_gamma'] >= gamma_lower |
|
|
5264 |
if scale_cc_upper is not None: |
|
|
5265 |
v_genes &= adata.var['fit_scale_cc'] <= scale_cc_upper |
|
|
5266 |
if scale_cc_lower is not None: |
|
|
5267 |
v_genes &= adata.var['fit_scale_cc'] >= scale_cc_lower |
|
|
5268 |
logg.update(f'{np.sum(v_genes)} velocity genes were selected', v=1) |
|
|
5269 |
adata.var['velo_s_genes'] = adata.var['velo_u_genes'] = \ |
|
|
5270 |
adata.var['velo_chrom_genes'] = v_genes |
|
|
5271 |
|
|
|
5272 |
|
|
|
5273 |
def velocity_graph(adata, vkey='velo_s', xkey='Ms', **kwargs): |
|
|
5274 |
"""Computes velocity graph. |
|
|
5275 |
|
|
|
5276 |
This function normalizes the velocity matrix and computes velocity graph |
|
|
5277 |
with `scvelo.tl.velocity_graph`. |
|
|
5278 |
|
|
|
5279 |
Parameters |
|
|
5280 |
---------- |
|
|
5281 |
adata: :class:`~anndata.AnnData` |
|
|
5282 |
Anndata result from dynamics recovery. |
|
|
5283 |
vkey: `str` (default: `velo_s`) |
|
|
5284 |
Default to use spliced velocities. |
|
|
5285 |
xkey: `str` (default: `Ms`) |
|
|
5286 |
Default to use smoothed spliced counts. |
|
|
5287 |
Additional parameters passed to `scvelo.tl.velocity_graph`. |
|
|
5288 |
|
|
|
5289 |
Returns |
|
|
5290 |
------- |
|
|
5291 |
Normalized velocity matrix and associated velocity genes and params. |
|
|
5292 |
Outputs of `scvelo.tl.velocity_graph`. |
|
|
5293 |
""" |
|
|
5294 |
if vkey not in adata.layers.keys(): |
|
|
5295 |
raise ValueError('Velocity matrix is not found. Please run multivelo' |
|
|
5296 |
'.recover_dynamics_chrom function first.') |
|
|
5297 |
if vkey+'_norm' not in adata.layers.keys(): |
|
|
5298 |
adata.layers[vkey+'_norm'] = adata.layers[vkey] / np.sum( |
|
|
5299 |
np.abs(adata.layers[vkey]), 0) |
|
|
5300 |
adata.layers[vkey+'_norm'] /= np.mean(adata.layers[vkey+'_norm']) |
|
|
5301 |
adata.uns[vkey+'_norm_params'] = adata.uns[vkey+'_params'] |
|
|
5302 |
if vkey+'_norm_genes' not in adata.var.columns: |
|
|
5303 |
adata.var[vkey+'_norm_genes'] = adata.var[vkey+'_genes'] |
|
|
5304 |
scv.tl.velocity_graph(adata, vkey=vkey+'_norm', xkey=xkey, **kwargs) |
|
|
5305 |
|
|
|
5306 |
|
|
|
5307 |
def velocity_embedding_stream(adata, vkey='velo_s', show=True, **kwargs): |
|
|
5308 |
"""Plots velocity stream. |
|
|
5309 |
|
|
|
5310 |
This function plots velocity streamplot with |
|
|
5311 |
`scvelo.pl.velocity_embedding_stream`. |
|
|
5312 |
|
|
|
5313 |
Parameters |
|
|
5314 |
---------- |
|
|
5315 |
adata: :class:`~anndata.AnnData` |
|
|
5316 |
Anndata result from dynamics recovery. |
|
|
5317 |
vkey: `str` (default: `velo_s`) |
|
|
5318 |
Default to use spliced velocities. The normalized matrix will be used. |
|
|
5319 |
show: `bool` (default: `True`) |
|
|
5320 |
Whether to show the plot. |
|
|
5321 |
Additional parameters passed to `scvelo.tl.velocity_graph`. |
|
|
5322 |
|
|
|
5323 |
Returns |
|
|
5324 |
------- |
|
|
5325 |
If `show==False`, a matplotlib axis object. |
|
|
5326 |
""" |
|
|
5327 |
if vkey not in adata.layers: |
|
|
5328 |
raise ValueError('Velocity matrix is not found. Please run multivelo.' |
|
|
5329 |
'recover_dynamics_chrom function first.') |
|
|
5330 |
if vkey+'_norm' not in adata.layers.keys(): |
|
|
5331 |
adata.layers[vkey+'_norm'] = adata.layers[vkey] / np.sum( |
|
|
5332 |
np.abs(adata.layers[vkey]), 0) |
|
|
5333 |
adata.uns[vkey+'_norm_params'] = adata.uns[vkey+'_params'] |
|
|
5334 |
if vkey+'_norm_genes' not in adata.var.columns: |
|
|
5335 |
adata.var[vkey+'_norm_genes'] = adata.var[vkey+'_genes'] |
|
|
5336 |
if vkey+'_norm_graph' not in adata.uns.keys(): |
|
|
5337 |
velocity_graph(adata, vkey=vkey, **kwargs) |
|
|
5338 |
out = scv.pl.velocity_embedding_stream(adata, vkey=vkey+'_norm', show=show, |
|
|
5339 |
**kwargs) |
|
|
5340 |
if not show: |
|
|
5341 |
return out |
|
|
5342 |
|
|
|
5343 |
|
|
|
5344 |
def latent_time(adata, vkey='velo_s', **kwargs): |
|
|
5345 |
"""Computes latent time. |
|
|
5346 |
|
|
|
5347 |
This function computes latent time with `scvelo.tl.latent_time`. |
|
|
5348 |
|
|
|
5349 |
Parameters |
|
|
5350 |
---------- |
|
|
5351 |
adata: :class:`~anndata.AnnData` |
|
|
5352 |
Anndata result from dynamics recovery. |
|
|
5353 |
vkey: `str` (default: `velo_s`) |
|
|
5354 |
Default to use spliced velocities. The normalized matrix will be used. |
|
|
5355 |
Additional parameters passed to `scvelo.tl.velocity_graph`. |
|
|
5356 |
|
|
|
5357 |
Returns |
|
|
5358 |
------- |
|
|
5359 |
Outputs of `scvelo.tl.latent_time`. |
|
|
5360 |
""" |
|
|
5361 |
if vkey not in adata.layers.keys() or 'fit_t' not in adata.layers.keys(): |
|
|
5362 |
raise ValueError('Velocity or time matrix is not found. Please run ' |
|
|
5363 |
'multivelo.recover_dynamics_chrom function first.') |
|
|
5364 |
if vkey+'_norm' not in adata.layers.keys(): |
|
|
5365 |
raise ValueError('Normalized velocity matrix is not found. Please ' |
|
|
5366 |
'run multivelo.velocity_graph function first.') |
|
|
5367 |
if vkey+'_norm_graph' not in adata.uns.keys(): |
|
|
5368 |
velocity_graph(adata, vkey=vkey, **kwargs) |
|
|
5369 |
scv.tl.latent_time(adata, vkey=vkey+'_norm', **kwargs) |
|
|
5370 |
|
|
|
5371 |
|
|
|
5372 |
def LRT_decoupling(adata_rna, adata_atac, **kwargs): |
|
|
5373 |
"""Computes likelihood ratio test for decoupling state. |
|
|
5374 |
|
|
|
5375 |
This function computes whether keeping decoupling state improves fit |
|
|
5376 |
Likelihood. |
|
|
5377 |
|
|
|
5378 |
Parameters |
|
|
5379 |
---------- |
|
|
5380 |
adata_rna: :class:`~anndata.AnnData` |
|
|
5381 |
RNA anndata object |
|
|
5382 |
adata_atac: :class:`~anndata.AnnData` |
|
|
5383 |
ATAC anndata object. |
|
|
5384 |
Additional parameters passed to `recover_dynamics_chrom`. |
|
|
5385 |
|
|
|
5386 |
Returns |
|
|
5387 |
------- |
|
|
5388 |
adata_result_w_decoupled: class:`~anndata.AnnData` |
|
|
5389 |
fit result with decoupling state |
|
|
5390 |
adata_result_w_decoupled: class:`~anndata.AnnData` |
|
|
5391 |
fit result without decoupling state |
|
|
5392 |
res: `pandas.DataFrame` |
|
|
5393 |
LRT statistics |
|
|
5394 |
""" |
|
|
5395 |
from scipy.stats.distributions import chi2 |
|
|
5396 |
logg.update('fitting models with decoupling intervals', v=0) |
|
|
5397 |
adata_result_w_decoupled = recover_dynamics_chrom(adata_rna, adata_atac, |
|
|
5398 |
fit_decoupling=True, |
|
|
5399 |
**kwargs) |
|
|
5400 |
logg.update('fitting models without decoupling intervals', v=0) |
|
|
5401 |
adata_result_wo_decoupled = recover_dynamics_chrom(adata_rna, adata_atac, |
|
|
5402 |
fit_decoupling=False, |
|
|
5403 |
**kwargs) |
|
|
5404 |
logg.update('testing likelihood ratio', v=0) |
|
|
5405 |
shared_genes = pd.Index(np.intersect1d(adata_result_w_decoupled.var_names, |
|
|
5406 |
adata_result_wo_decoupled.var_names) |
|
|
5407 |
) |
|
|
5408 |
l_c_w_decoupled = adata_result_w_decoupled[:, shared_genes].\ |
|
|
5409 |
var['fit_likelihood_c'].values |
|
|
5410 |
l_c_wo_decoupled = adata_result_wo_decoupled[:, shared_genes].\ |
|
|
5411 |
var['fit_likelihood_c'].values |
|
|
5412 |
n_obs = adata_rna.n_obs |
|
|
5413 |
LRT_c = -2 * n_obs * (np.log(l_c_wo_decoupled) - np.log(l_c_w_decoupled)) |
|
|
5414 |
p_c = chi2.sf(LRT_c, 1) |
|
|
5415 |
l_w_decoupled = adata_result_w_decoupled[:, shared_genes].\ |
|
|
5416 |
var['fit_likelihood'].values |
|
|
5417 |
l_wo_decoupled = adata_result_wo_decoupled[:, shared_genes].\ |
|
|
5418 |
var['fit_likelihood'].values |
|
|
5419 |
LRT = -2 * n_obs * (np.log(l_wo_decoupled) - np.log(l_w_decoupled)) |
|
|
5420 |
p = chi2.sf(LRT, 1) |
|
|
5421 |
res = pd.DataFrame({'likelihood_c_w_decoupled': l_c_w_decoupled, |
|
|
5422 |
'likelihood_c_wo_decoupled': l_c_wo_decoupled, |
|
|
5423 |
'LRT_c': LRT_c, |
|
|
5424 |
'pval_c': p_c, |
|
|
5425 |
'likelihood_w_decoupled': l_w_decoupled, |
|
|
5426 |
'likelihood_wo_decoupled': l_wo_decoupled, |
|
|
5427 |
'LRT': LRT, |
|
|
5428 |
'pval': p, |
|
|
5429 |
}, index=shared_genes) |
|
|
5430 |
return adata_result_w_decoupled, adata_result_wo_decoupled, res |
|
|
5431 |
|
|
|
5432 |
|
|
|
5433 |
def transition_matrix_s(s_mat, velo_s, knn): |
|
|
5434 |
knn = knn.astype(int) |
|
|
5435 |
tm_val, tm_col, tm_row = [], [], [] |
|
|
5436 |
for i in range(knn.shape[0]): |
|
|
5437 |
two_step_knn = knn[i, :] |
|
|
5438 |
for j in knn[i, :]: |
|
|
5439 |
two_step_knn = np.append(two_step_knn, knn[j, :]) |
|
|
5440 |
two_step_knn = np.unique(two_step_knn) |
|
|
5441 |
for j in two_step_knn: |
|
|
5442 |
s = s_mat[i, :] |
|
|
5443 |
sn = s_mat[j, :] |
|
|
5444 |
ds = s - sn |
|
|
5445 |
dx = np.ravel(ds.A) |
|
|
5446 |
velo = velo_s[i, :] |
|
|
5447 |
cos_sim = np.dot(dx, velo)/(norm(dx)*norm(velo)) |
|
|
5448 |
tm_val.append(cos_sim) |
|
|
5449 |
tm_col.append(j) |
|
|
5450 |
tm_row.append(i) |
|
|
5451 |
tm = coo_matrix((tm_val, (tm_row, tm_col)), shape=(s_mat.shape[0], |
|
|
5452 |
s_mat.shape[0])).tocsr() |
|
|
5453 |
tm.setdiag(0) |
|
|
5454 |
tm_neg = tm.copy() |
|
|
5455 |
tm.data = np.clip(tm.data, 0, 1) |
|
|
5456 |
tm_neg.data = np.clip(tm_neg.data, -1, 0) |
|
|
5457 |
tm.eliminate_zeros() |
|
|
5458 |
tm_neg.eliminate_zeros() |
|
|
5459 |
return tm, tm_neg |
|
|
5460 |
|
|
|
5461 |
|
|
|
5462 |
def transition_matrix_chrom(c_mat, u_mat, s_mat, velo_c, velo_u, velo_s, knn): |
|
|
5463 |
knn = knn.astype(int) |
|
|
5464 |
tm_val, tm_col, tm_row = [], [], [] |
|
|
5465 |
for i in range(knn.shape[0]): |
|
|
5466 |
two_step_knn = knn[i, :] |
|
|
5467 |
for j in knn[i, :]: |
|
|
5468 |
two_step_knn = np.append(two_step_knn, knn[j, :]) |
|
|
5469 |
two_step_knn = np.unique(two_step_knn) |
|
|
5470 |
for j in two_step_knn: |
|
|
5471 |
u = u_mat[i, :].A |
|
|
5472 |
s = s_mat[i, :].A |
|
|
5473 |
c = c_mat[i, :].A |
|
|
5474 |
un = u_mat[j, :] |
|
|
5475 |
sn = s_mat[j, :] |
|
|
5476 |
cn = c_mat[j, :] |
|
|
5477 |
dc = (c - cn) / np.std(c) |
|
|
5478 |
du = (u - un) / np.std(u) |
|
|
5479 |
ds = (s - sn) / np.std(s) |
|
|
5480 |
dx = np.ravel(np.hstack((dc.A, du.A, ds.A))) |
|
|
5481 |
velo = np.hstack((velo_c[i, :], velo_u[i, :], velo_s[i, :])) |
|
|
5482 |
cos_sim = np.dot(dx, velo)/(norm(dx)*norm(velo)) |
|
|
5483 |
tm_val.append(cos_sim) |
|
|
5484 |
tm_col.append(j) |
|
|
5485 |
tm_row.append(i) |
|
|
5486 |
tm = coo_matrix((tm_val, (tm_row, tm_col)), shape=(c_mat.shape[0], |
|
|
5487 |
c_mat.shape[0])).tocsr() |
|
|
5488 |
tm.setdiag(0) |
|
|
5489 |
tm_neg = tm.copy() |
|
|
5490 |
tm.data = np.clip(tm.data, 0, 1) |
|
|
5491 |
tm_neg.data = np.clip(tm_neg.data, -1, 0) |
|
|
5492 |
tm.eliminate_zeros() |
|
|
5493 |
tm_neg.eliminate_zeros() |
|
|
5494 |
return tm, tm_neg |
|
|
5495 |
|
|
|
5496 |
|
|
|
5497 |
def likelihood_plot(adata, |
|
|
5498 |
genes=None, |
|
|
5499 |
figsize=(14, 10), |
|
|
5500 |
bins=50, |
|
|
5501 |
pointsize=4 |
|
|
5502 |
): |
|
|
5503 |
"""Likelihood plots. |
|
|
5504 |
|
|
|
5505 |
This function plots likelihood and variable distributions. |
|
|
5506 |
|
|
|
5507 |
Parameters |
|
|
5508 |
---------- |
|
|
5509 |
adata: :class:`~anndata.AnnData` |
|
|
5510 |
Anndata result from dynamics recovery. |
|
|
5511 |
genes: `str`, list of `str` (default: `None`) |
|
|
5512 |
If `None`, will use all fitted genes. |
|
|
5513 |
figsize: `tuple` (default: (14,10)) |
|
|
5514 |
Figure size. |
|
|
5515 |
bins: `int` (default: 50) |
|
|
5516 |
Number of bins for histograms. |
|
|
5517 |
pointsize: `float` (default: 4) |
|
|
5518 |
Point size for scatter plots. |
|
|
5519 |
""" |
|
|
5520 |
if genes is None: |
|
|
5521 |
var = adata.var |
|
|
5522 |
else: |
|
|
5523 |
genes = np.array(genes) |
|
|
5524 |
var = adata[:, genes].var |
|
|
5525 |
likelihood = var[['fit_likelihood']].values |
|
|
5526 |
rescale_u = var[['fit_rescale_u']].values |
|
|
5527 |
rescale_c = var[['fit_rescale_c']].values |
|
|
5528 |
t_interval1 = var['fit_t_sw1'] + 20 / adata.uns['velo_s_params']['t'] \ |
|
|
5529 |
* var['fit_anchor_min_idx'] * var['fit_alignment_scaling'] |
|
|
5530 |
t_sw2 = np.clip(var['fit_t_sw2'], None, 20) |
|
|
5531 |
t_sw3 = np.clip(var['fit_t_sw3'], None, 20) |
|
|
5532 |
t_interval3 = t_sw3 - t_sw2 |
|
|
5533 |
log_s = np.log1p(np.sum(adata.layers['Ms'], axis=0)) |
|
|
5534 |
alpha_c = var[['fit_alpha_c']].values |
|
|
5535 |
alpha = var[['fit_alpha']].values |
|
|
5536 |
beta = var[['fit_beta']].values |
|
|
5537 |
gamma = var[['fit_gamma']].values |
|
|
5538 |
scale_cc = var[['fit_scale_cc']].values |
|
|
5539 |
|
|
|
5540 |
fig, axes = plt.subplots(4, 5, figsize=figsize) |
|
|
5541 |
axes[0, 0].hist(likelihood, bins=bins) |
|
|
5542 |
axes[0, 0].set_title('likelihood') |
|
|
5543 |
axes[0, 1].hist(rescale_u, bins=bins) |
|
|
5544 |
axes[0, 1].set_title('rescale u') |
|
|
5545 |
axes[0, 2].hist(rescale_c, bins=bins) |
|
|
5546 |
axes[0, 2].set_title('rescale c') |
|
|
5547 |
axes[0, 3].hist(t_interval1.values, bins=bins) |
|
|
5548 |
axes[0, 3].set_title('primed interval') |
|
|
5549 |
axes[0, 4].hist(t_interval3, bins=bins) |
|
|
5550 |
axes[0, 4].set_title('decoupled interval') |
|
|
5551 |
|
|
|
5552 |
axes[1, 0].scatter(log_s, likelihood, s=pointsize) |
|
|
5553 |
axes[1, 0].set_xlabel('log spliced') |
|
|
5554 |
axes[1, 0].set_ylabel('likelihood') |
|
|
5555 |
axes[1, 1].scatter(rescale_u, likelihood, s=pointsize) |
|
|
5556 |
axes[1, 1].set_xlabel('rescale u') |
|
|
5557 |
axes[1, 2].scatter(rescale_c, likelihood, s=pointsize) |
|
|
5558 |
axes[1, 2].set_xlabel('rescale c') |
|
|
5559 |
axes[1, 3].scatter(t_interval1.values, likelihood, s=pointsize) |
|
|
5560 |
axes[1, 3].set_xlabel('primed interval') |
|
|
5561 |
axes[1, 4].scatter(t_interval3, likelihood, s=pointsize) |
|
|
5562 |
axes[1, 4].set_xlabel('decoupled interval') |
|
|
5563 |
|
|
|
5564 |
axes[2, 0].hist(alpha_c, bins=bins) |
|
|
5565 |
axes[2, 0].set_title('alpha c') |
|
|
5566 |
axes[2, 1].hist(alpha, bins=bins) |
|
|
5567 |
axes[2, 1].set_title('alpha') |
|
|
5568 |
axes[2, 2].hist(beta, bins=bins) |
|
|
5569 |
axes[2, 2].set_title('beta') |
|
|
5570 |
axes[2, 3].hist(gamma, bins=bins) |
|
|
5571 |
axes[2, 3].set_title('gamma') |
|
|
5572 |
axes[2, 4].hist(scale_cc, bins=bins) |
|
|
5573 |
axes[2, 4].set_title('scale cc') |
|
|
5574 |
|
|
|
5575 |
axes[3, 0].scatter(alpha_c, likelihood, s=pointsize) |
|
|
5576 |
axes[3, 0].set_xlabel('alpha c') |
|
|
5577 |
axes[3, 0].set_ylabel('likelihood') |
|
|
5578 |
axes[3, 1].scatter(alpha, likelihood, s=pointsize) |
|
|
5579 |
axes[3, 1].set_xlabel('alpha') |
|
|
5580 |
axes[3, 2].scatter(beta, likelihood, s=pointsize) |
|
|
5581 |
axes[3, 2].set_xlabel('beta') |
|
|
5582 |
axes[3, 3].scatter(gamma, likelihood, s=pointsize) |
|
|
5583 |
axes[3, 3].set_xlabel('gamma') |
|
|
5584 |
axes[3, 4].scatter(scale_cc, likelihood, s=pointsize) |
|
|
5585 |
axes[3, 4].set_xlabel('scale cc') |
|
|
5586 |
fig.tight_layout() |
|
|
5587 |
|
|
|
5588 |
|
|
|
5589 |
def pie_summary(adata, genes=None): |
|
|
5590 |
"""Summary of directions and models. |
|
|
5591 |
|
|
|
5592 |
This function plots a pie chart for (pre-determined or specified) |
|
|
5593 |
directions and models. |
|
|
5594 |
`induction`: induction-only genes. |
|
|
5595 |
`repression`: repression-only genes. |
|
|
5596 |
`Model 1`: model 1 complete genes. |
|
|
5597 |
`Model 2`: model 2 complete genes. |
|
|
5598 |
|
|
|
5599 |
Parameters |
|
|
5600 |
---------- |
|
|
5601 |
adata: :class:`~anndata.AnnData` |
|
|
5602 |
Anndata result from dynamics recovery. |
|
|
5603 |
genes: `str`, list of `str` (default: `None`) |
|
|
5604 |
If `None`, will use all fitted genes. |
|
|
5605 |
""" |
|
|
5606 |
if genes is None: |
|
|
5607 |
genes = adata.var_names |
|
|
5608 |
fit_model = adata[:, (adata.var['fit_direction'] == 'complete') & |
|
|
5609 |
np.isin(adata.var_names, genes)].var['fit_model'].values |
|
|
5610 |
fit_direction = adata[:, genes].var['fit_direction'].values |
|
|
5611 |
data = [np.sum(fit_direction == 'on'), np.sum(fit_direction == 'off'), |
|
|
5612 |
np.sum(fit_model == 1), np.sum(fit_model == 2)] |
|
|
5613 |
index = ['induction', 'repression', 'Model 1', 'Model 2'] |
|
|
5614 |
index = [x for i, x in enumerate(index) if data[i] > 0] |
|
|
5615 |
data = [x for x in data if x > 0] |
|
|
5616 |
df = pd.DataFrame({'data': data}, index=index) |
|
|
5617 |
df.plot.pie(y='data', autopct='%1.1f%%', legend=False, startangle=30, |
|
|
5618 |
ylabel='') |
|
|
5619 |
circle = plt.Circle((0, 0), 0.8, fc='white') |
|
|
5620 |
fig = plt.gcf() |
|
|
5621 |
fig.gca().add_artist(circle) |
|
|
5622 |
|
|
|
5623 |
|
|
|
5624 |
def switch_time_summary(adata, genes=None): |
|
|
5625 |
"""Summary of switch times. |
|
|
5626 |
|
|
|
5627 |
This function plots a box plot for observed switch times. |
|
|
5628 |
`primed`: primed intervals. |
|
|
5629 |
`coupled-on`: coupled induction intervals. |
|
|
5630 |
`decoupled`: decoupled intervals. |
|
|
5631 |
`coupled-off`: coupled repression intervals. |
|
|
5632 |
|
|
|
5633 |
Parameters |
|
|
5634 |
---------- |
|
|
5635 |
adata: :class:`~anndata.AnnData` |
|
|
5636 |
Anndata result from dynamics recovery. |
|
|
5637 |
genes: `str`, list of `str` (default: `None`) |
|
|
5638 |
If `None`, will use velocity genes. |
|
|
5639 |
""" |
|
|
5640 |
t_sw = adata[:, adata.var['velo_s_genes'] |
|
|
5641 |
if genes is None |
|
|
5642 |
else genes] \ |
|
|
5643 |
.var[['fit_t_sw1', 'fit_t_sw2', 'fit_t_sw3']].copy() |
|
|
5644 |
t_sw = t_sw.mask(t_sw > 20, 20) |
|
|
5645 |
t_sw = t_sw.mask(t_sw < 0) |
|
|
5646 |
t_sw['interval 1'] = t_sw['fit_t_sw1'] |
|
|
5647 |
t_sw['t_sw2 - t_sw1'] = t_sw['fit_t_sw2'] - t_sw['fit_t_sw1'] |
|
|
5648 |
t_sw['t_sw3 - t_sw2'] = t_sw['fit_t_sw3'] - t_sw['fit_t_sw2'] |
|
|
5649 |
t_sw['20 - t_sw3'] = 20 - t_sw['fit_t_sw3'] |
|
|
5650 |
t_sw = t_sw.mask(t_sw <= 0) |
|
|
5651 |
t_sw = t_sw.mask(t_sw > 20) |
|
|
5652 |
t_sw.columns = pd.Index(['time 1', 'time 2', 'time 3', 'primed', |
|
|
5653 |
'coupled-on', 'decoupled', 'coupled-off']) |
|
|
5654 |
t_sw = t_sw[['primed', 'coupled-on', 'decoupled', 'coupled-off']] |
|
|
5655 |
t_sw = t_sw / 20 |
|
|
5656 |
fig, ax = plt.subplots(figsize=(4, 5)) |
|
|
5657 |
ax = sns.boxplot(data=t_sw, width=0.5, palette='Set2', ax=ax) |
|
|
5658 |
ax.set_yticks(np.linspace(0, 1, 5)) |
|
|
5659 |
ax.set_title('Switch Intervals') |
|
|
5660 |
|
|
|
5661 |
|
|
|
5662 |
def dynamic_plot(adata, |
|
|
5663 |
genes, |
|
|
5664 |
by='expression', |
|
|
5665 |
color_by='state', |
|
|
5666 |
gene_time=True, |
|
|
5667 |
axis_on=True, |
|
|
5668 |
frame_on=True, |
|
|
5669 |
show_anchors=True, |
|
|
5670 |
show_switches=True, |
|
|
5671 |
downsample=1, |
|
|
5672 |
full_range=False, |
|
|
5673 |
figsize=None, |
|
|
5674 |
pointsize=2, |
|
|
5675 |
linewidth=1.5, |
|
|
5676 |
cmap='coolwarm' |
|
|
5677 |
): |
|
|
5678 |
"""Gene dynamics plot. |
|
|
5679 |
|
|
|
5680 |
This function plots accessibility, expression, or velocity by time. |
|
|
5681 |
|
|
|
5682 |
Parameters |
|
|
5683 |
---------- |
|
|
5684 |
adata: :class:`~anndata.AnnData` |
|
|
5685 |
Anndata result from dynamics recovery. |
|
|
5686 |
genes: `str`, list of `str` |
|
|
5687 |
List of genes to plot. |
|
|
5688 |
by: `str` (default: `expression`) |
|
|
5689 |
Plot accessibilities and expressions if `expression`. Plot velocities |
|
|
5690 |
if `velocity`. |
|
|
5691 |
color_by: `str` (default: `state`) |
|
|
5692 |
Color by the four potential states if `state`. Other common values are |
|
|
5693 |
leiden, louvain, celltype, etc. |
|
|
5694 |
If not `state`, the color field must be present in `.uns`, which can |
|
|
5695 |
be pre-computed with `scanpy.pl.scatter`. |
|
|
5696 |
For `state`, red, orange, green, and blue represent state 1, 2, 3, and |
|
|
5697 |
4, respectively. |
|
|
5698 |
gene_time: `bool` (default: `True`) |
|
|
5699 |
Whether to use individual gene fitted time, or shared global latent |
|
|
5700 |
time. |
|
|
5701 |
Mean values of 20 equal sized windows will be connected and shown if |
|
|
5702 |
`gene_time==False`. |
|
|
5703 |
axis_on: `bool` (default: `True`) |
|
|
5704 |
Whether to show axis labels. |
|
|
5705 |
frame_on: `bool` (default: `True`) |
|
|
5706 |
Whether to show plot frames. |
|
|
5707 |
show_anchors: `bool` (default: `True`) |
|
|
5708 |
Whether to display anchors. |
|
|
5709 |
show_switches: `bool` (default: `True`) |
|
|
5710 |
Whether to show switch times. The switch times are indicated by |
|
|
5711 |
vertical dotted line. |
|
|
5712 |
downsample: `int` (default: 1) |
|
|
5713 |
How much to downsample the cells. The remaining number will be |
|
|
5714 |
`1/downsample` of original. |
|
|
5715 |
full_range: `bool` (default: `False`) |
|
|
5716 |
Whether to show the full time range of velocities before smoothing or |
|
|
5717 |
subset to only smoothed range. |
|
|
5718 |
figsize: `tuple` (default: `None`) |
|
|
5719 |
Total figure size. |
|
|
5720 |
pointsize: `float` (default: 2) |
|
|
5721 |
Point size for scatter plots. |
|
|
5722 |
linewidth: `float` (default: 1.5) |
|
|
5723 |
Line width for anchor line or mean line. |
|
|
5724 |
cmap: `str` (default: `coolwarm`) |
|
|
5725 |
Color map for continuous color key. |
|
|
5726 |
""" |
|
|
5727 |
from pandas.api.types import is_numeric_dtype, is_categorical_dtype |
|
|
5728 |
if by not in ['expression', 'velocity']: |
|
|
5729 |
raise ValueError('"by" must be either "expression" or "velocity".') |
|
|
5730 |
if by == 'velocity': |
|
|
5731 |
show_switches = False |
|
|
5732 |
if color_by == 'state': |
|
|
5733 |
types = [0, 1, 2, 3] |
|
|
5734 |
colors = ['tab:red', 'tab:orange', 'tab:green', 'tab:blue'] |
|
|
5735 |
elif color_by in adata.obs and is_numeric_dtype(adata.obs[color_by]): |
|
|
5736 |
types = None |
|
|
5737 |
colors = adata.obs[color_by].values |
|
|
5738 |
elif color_by in adata.obs and is_categorical_dtype(adata.obs[color_by]) \ |
|
|
5739 |
and color_by+'_colors' in adata.uns.keys(): |
|
|
5740 |
types = adata.obs[color_by].cat.categories |
|
|
5741 |
colors = adata.uns[f'{color_by}_colors'] |
|
|
5742 |
else: |
|
|
5743 |
raise ValueError('Currently, color key must be a single string of ' |
|
|
5744 |
'either numerical or categorical available in adata ' |
|
|
5745 |
'obs, and the colors of categories can be found in ' |
|
|
5746 |
'adata uns.') |
|
|
5747 |
|
|
|
5748 |
downsample = np.clip(int(downsample), 1, 10) |
|
|
5749 |
genes = np.array(genes) |
|
|
5750 |
missing_genes = genes[~np.isin(genes, adata.var_names)] |
|
|
5751 |
if len(missing_genes) > 0: |
|
|
5752 |
logg.update(f'{missing_genes} not found', v=0) |
|
|
5753 |
genes = genes[np.isin(genes, adata.var_names)] |
|
|
5754 |
gn = len(genes) |
|
|
5755 |
if gn == 0: |
|
|
5756 |
return |
|
|
5757 |
if not gene_time: |
|
|
5758 |
show_anchors = False |
|
|
5759 |
latent_time = np.array(adata.obs['latent_time']) |
|
|
5760 |
time_window = latent_time // 0.05 |
|
|
5761 |
time_window = time_window.astype(int) |
|
|
5762 |
time_window[time_window == 20] = 19 |
|
|
5763 |
if 'velo_s_params' in adata.uns.keys() and 'outlier' \ |
|
|
5764 |
in adata.uns['velo_s_params']: |
|
|
5765 |
outlier = adata.uns['velo_s_params']['outlier'] |
|
|
5766 |
else: |
|
|
5767 |
outlier = 99 |
|
|
5768 |
|
|
|
5769 |
fig, axs = plt.subplots(gn, 3, squeeze=False, figsize=(10, 2.3*gn) |
|
|
5770 |
if figsize is None else figsize) |
|
|
5771 |
fig.patch.set_facecolor('white') |
|
|
5772 |
for row, gene in enumerate(genes): |
|
|
5773 |
u = adata[:, gene].layers['Mu' if by == 'expression' else 'velo_u'] |
|
|
5774 |
s = adata[:, gene].layers['Ms' if by == 'expression' else 'velo_s'] |
|
|
5775 |
c = adata[:, gene].layers['ATAC' if by == 'expression' |
|
|
5776 |
else 'velo_chrom'] |
|
|
5777 |
c = c.A if sparse.issparse(c) else c |
|
|
5778 |
u = u.A if sparse.issparse(u) else u |
|
|
5779 |
s = s.A if sparse.issparse(s) else s |
|
|
5780 |
c, u, s = np.ravel(c), np.ravel(u), np.ravel(s) |
|
|
5781 |
non_outlier = c <= np.percentile(c, outlier) |
|
|
5782 |
non_outlier &= u <= np.percentile(u, outlier) |
|
|
5783 |
non_outlier &= s <= np.percentile(s, outlier) |
|
|
5784 |
c, u, s = c[non_outlier], u[non_outlier], s[non_outlier] |
|
|
5785 |
time = np.array(adata[:, gene].layers['fit_t'] if gene_time |
|
|
5786 |
else latent_time) |
|
|
5787 |
if by == 'velocity': |
|
|
5788 |
time = np.reshape(time, (-1, 1)) |
|
|
5789 |
time = np.ravel(adata.obsp['_RNA_conn'].dot(time)) |
|
|
5790 |
time = time[non_outlier] |
|
|
5791 |
if types is not None: |
|
|
5792 |
for i in range(len(types)): |
|
|
5793 |
if color_by == 'state': |
|
|
5794 |
filt = adata[non_outlier, gene].layers['fit_state'] \ |
|
|
5795 |
== types[i] |
|
|
5796 |
else: |
|
|
5797 |
filt = adata[non_outlier, :].obs[color_by] == types[i] |
|
|
5798 |
filt = np.ravel(filt) |
|
|
5799 |
if np.sum(filt) > 0: |
|
|
5800 |
axs[row, 0].scatter(time[filt][::downsample], |
|
|
5801 |
c[filt][::downsample], s=pointsize, |
|
|
5802 |
c=colors[i], alpha=0.6) |
|
|
5803 |
axs[row, 1].scatter(time[filt][::downsample], |
|
|
5804 |
u[filt][::downsample], |
|
|
5805 |
s=pointsize, c=colors[i], alpha=0.6) |
|
|
5806 |
axs[row, 2].scatter(time[filt][::downsample], |
|
|
5807 |
s[filt][::downsample], s=pointsize, |
|
|
5808 |
c=colors[i], alpha=0.6) |
|
|
5809 |
else: |
|
|
5810 |
axs[row, 0].scatter(time[::downsample], c[::downsample], |
|
|
5811 |
s=pointsize, |
|
|
5812 |
c=colors[non_outlier][::downsample], |
|
|
5813 |
alpha=0.6, cmap=cmap) |
|
|
5814 |
axs[row, 1].scatter(time[::downsample], u[::downsample], |
|
|
5815 |
s=pointsize, |
|
|
5816 |
c=colors[non_outlier][::downsample], |
|
|
5817 |
alpha=0.6, cmap=cmap) |
|
|
5818 |
axs[row, 2].scatter(time[::downsample], s[::downsample], |
|
|
5819 |
s=pointsize, |
|
|
5820 |
c=colors[non_outlier][::downsample], |
|
|
5821 |
alpha=0.6, cmap=cmap) |
|
|
5822 |
|
|
|
5823 |
if not gene_time: |
|
|
5824 |
window_count = np.zeros(20) |
|
|
5825 |
window_mean_c = np.zeros(20) |
|
|
5826 |
window_mean_u = np.zeros(20) |
|
|
5827 |
window_mean_s = np.zeros(20) |
|
|
5828 |
for i in np.unique(time_window[non_outlier]): |
|
|
5829 |
idx = time_window[non_outlier] == i |
|
|
5830 |
window_count[i] = np.sum(idx) |
|
|
5831 |
window_mean_c[i] = np.mean(c[idx]) |
|
|
5832 |
window_mean_u[i] = np.mean(u[idx]) |
|
|
5833 |
window_mean_s[i] = np.mean(s[idx]) |
|
|
5834 |
window_idx = np.where(window_count > 20)[0] |
|
|
5835 |
axs[row, 0].plot(window_idx*0.05+0.025, window_mean_c[window_idx], |
|
|
5836 |
linewidth=linewidth, color='black', alpha=0.5) |
|
|
5837 |
axs[row, 1].plot(window_idx*0.05+0.025, window_mean_u[window_idx], |
|
|
5838 |
linewidth=linewidth, color='black', alpha=0.5) |
|
|
5839 |
axs[row, 2].plot(window_idx*0.05+0.025, window_mean_s[window_idx], |
|
|
5840 |
linewidth=linewidth, color='black', alpha=0.5) |
|
|
5841 |
|
|
|
5842 |
if show_anchors: |
|
|
5843 |
n_anchors = adata.uns['velo_s_params']['t'] |
|
|
5844 |
t_sw_array = np.array([adata[:, gene].var['fit_t_sw1'], |
|
|
5845 |
adata[:, gene].var['fit_t_sw2'], |
|
|
5846 |
adata[:, gene].var['fit_t_sw3']]) |
|
|
5847 |
t_sw_array = t_sw_array[t_sw_array < 20] |
|
|
5848 |
min_idx = int(adata[:, gene].var['fit_anchor_min_idx']) |
|
|
5849 |
max_idx = int(adata[:, gene].var['fit_anchor_max_idx']) |
|
|
5850 |
old_t = np.linspace(0, 20, n_anchors)[min_idx:max_idx+1] |
|
|
5851 |
new_t = old_t - np.min(old_t) |
|
|
5852 |
new_t = new_t * 20 / np.max(new_t) |
|
|
5853 |
if by == 'velocity' and not full_range: |
|
|
5854 |
anchor_interval = 20 / (max_idx + 1 - min_idx) |
|
|
5855 |
min_idx = int(adata[:, gene].var['fit_anchor_velo_min_idx']) |
|
|
5856 |
max_idx = int(adata[:, gene].var['fit_anchor_velo_max_idx']) |
|
|
5857 |
start = 0 + (min_idx - |
|
|
5858 |
adata[:, gene].var['fit_anchor_min_idx']) \ |
|
|
5859 |
* anchor_interval |
|
|
5860 |
end = 20 + (max_idx - |
|
|
5861 |
adata[:, gene].var['fit_anchor_max_idx']) \ |
|
|
5862 |
* anchor_interval |
|
|
5863 |
new_t = np.linspace(start, end, max_idx + 1 - min_idx) |
|
|
5864 |
ax = axs[row, 0] |
|
|
5865 |
a_c = adata[:, gene].varm['fit_anchor_c' if by == 'expression' |
|
|
5866 |
else 'fit_anchor_c_velo']\ |
|
|
5867 |
.ravel()[min_idx:max_idx+1] |
|
|
5868 |
if show_switches: |
|
|
5869 |
for t_sw in t_sw_array: |
|
|
5870 |
if t_sw > 0: |
|
|
5871 |
ax.vlines(t_sw, np.min(c), np.max(c), colors='black', |
|
|
5872 |
linestyles='dashed', alpha=0.5) |
|
|
5873 |
ax.plot(new_t[0:new_t.shape[0]], a_c, linewidth=linewidth, |
|
|
5874 |
color='black', alpha=0.5) |
|
|
5875 |
ax = axs[row, 1] |
|
|
5876 |
a_u = adata[:, gene].varm['fit_anchor_u' if by == 'expression' |
|
|
5877 |
else 'fit_anchor_u_velo']\ |
|
|
5878 |
.ravel()[min_idx:max_idx+1] |
|
|
5879 |
if show_switches: |
|
|
5880 |
for t_sw in t_sw_array: |
|
|
5881 |
if t_sw > 0: |
|
|
5882 |
ax.vlines(t_sw, np.min(u), np.max(u), colors='black', |
|
|
5883 |
linestyles='dashed', alpha=0.5) |
|
|
5884 |
ax.plot(new_t[0:new_t.shape[0]], a_u, linewidth=linewidth, |
|
|
5885 |
color='black', alpha=0.5) |
|
|
5886 |
ax = axs[row, 2] |
|
|
5887 |
a_s = adata[:, gene].varm['fit_anchor_s' if by == 'expression' |
|
|
5888 |
else 'fit_anchor_s_velo']\ |
|
|
5889 |
.ravel()[min_idx:max_idx+1] |
|
|
5890 |
if show_switches: |
|
|
5891 |
for t_sw in t_sw_array: |
|
|
5892 |
if t_sw > 0: |
|
|
5893 |
ax.vlines(t_sw, np.min(s), np.max(s), colors='black', |
|
|
5894 |
linestyles='dashed', alpha=0.5) |
|
|
5895 |
ax.plot(new_t[0:new_t.shape[0]], a_s, linewidth=linewidth, |
|
|
5896 |
color='black', alpha=0.5) |
|
|
5897 |
|
|
|
5898 |
axs[row, 0].set_title(f'{gene} ATAC' if by == 'expression' |
|
|
5899 |
else f'{gene} chromatin velocity') |
|
|
5900 |
axs[row, 0].set_xlabel('t' if by == 'expression' else '~t') |
|
|
5901 |
axs[row, 0].set_ylabel('c' if by == 'expression' else 'dc/dt') |
|
|
5902 |
axs[row, 1].set_title(f'{gene} unspliced' + ('' if by == 'expression' |
|
|
5903 |
else ' velocity')) |
|
|
5904 |
axs[row, 1].set_xlabel('t' if by == 'expression' else '~t') |
|
|
5905 |
axs[row, 1].set_ylabel('u' if by == 'expression' else 'du/dt') |
|
|
5906 |
axs[row, 2].set_title(f'{gene} spliced' + ('' if by == 'expression' |
|
|
5907 |
else ' velocity')) |
|
|
5908 |
axs[row, 2].set_xlabel('t' if by == 'expression' else '~t') |
|
|
5909 |
axs[row, 2].set_ylabel('s' if by == 'expression' else 'ds/dt') |
|
|
5910 |
|
|
|
5911 |
for j in range(3): |
|
|
5912 |
ax = axs[row, j] |
|
|
5913 |
if not axis_on: |
|
|
5914 |
ax.xaxis.set_ticks_position('none') |
|
|
5915 |
ax.yaxis.set_ticks_position('none') |
|
|
5916 |
ax.get_xaxis().set_visible(False) |
|
|
5917 |
ax.get_yaxis().set_visible(False) |
|
|
5918 |
if not frame_on: |
|
|
5919 |
ax.xaxis.set_ticks_position('none') |
|
|
5920 |
ax.yaxis.set_ticks_position('none') |
|
|
5921 |
ax.set_frame_on(False) |
|
|
5922 |
fig.tight_layout() |
|
|
5923 |
|
|
|
5924 |
|
|
|
5925 |
def scatter_plot(adata, |
|
|
5926 |
genes, |
|
|
5927 |
by='us', |
|
|
5928 |
color_by='state', |
|
|
5929 |
n_cols=5, |
|
|
5930 |
axis_on=True, |
|
|
5931 |
frame_on=True, |
|
|
5932 |
show_anchors=True, |
|
|
5933 |
show_switches=True, |
|
|
5934 |
show_all_anchors=False, |
|
|
5935 |
title_more_info=False, |
|
|
5936 |
velocity_arrows=False, |
|
|
5937 |
downsample=1, |
|
|
5938 |
figsize=None, |
|
|
5939 |
pointsize=2, |
|
|
5940 |
markersize=5, |
|
|
5941 |
linewidth=2, |
|
|
5942 |
cmap='coolwarm', |
|
|
5943 |
view_3d_elev=None, |
|
|
5944 |
view_3d_azim=None, |
|
|
5945 |
full_name=False |
|
|
5946 |
): |
|
|
5947 |
"""Gene scatter plot. |
|
|
5948 |
|
|
|
5949 |
This function plots phase portraits of the specified plane. |
|
|
5950 |
|
|
|
5951 |
Parameters |
|
|
5952 |
---------- |
|
|
5953 |
adata: :class:`~anndata.AnnData` |
|
|
5954 |
Anndata result from dynamics recovery. |
|
|
5955 |
genes: `str`, list of `str` |
|
|
5956 |
List of genes to plot. |
|
|
5957 |
by: `str` (default: `us`) |
|
|
5958 |
Plot unspliced-spliced plane if `us`. Plot chromatin-unspliced plane |
|
|
5959 |
if `cu`. |
|
|
5960 |
Plot 3D phase portraits if `cus`. |
|
|
5961 |
color_by: `str` (default: `state`) |
|
|
5962 |
Color by the four potential states if `state`. Other common values are |
|
|
5963 |
leiden, louvain, celltype, etc. |
|
|
5964 |
If not `state`, the color field must be present in `.uns`, which can be |
|
|
5965 |
pre-computed with `scanpy.pl.scatter`. |
|
|
5966 |
For `state`, red, orange, green, and blue represent state 1, 2, 3, and |
|
|
5967 |
4, respectively. |
|
|
5968 |
When `by=='us'`, `color_by` can also be `c`, which displays the log |
|
|
5969 |
accessibility on U-S phase portraits. |
|
|
5970 |
n_cols: `int` (default: 5) |
|
|
5971 |
Number of columns to plot on each row. |
|
|
5972 |
axis_on: `bool` (default: `True`) |
|
|
5973 |
Whether to show axis labels. |
|
|
5974 |
frame_on: `bool` (default: `True`) |
|
|
5975 |
Whether to show plot frames. |
|
|
5976 |
show_anchors: `bool` (default: `True`) |
|
|
5977 |
Whether to display anchors. |
|
|
5978 |
show_switches: `bool` (default: `True`) |
|
|
5979 |
Whether to show switch times. The three switch times and the end of |
|
|
5980 |
trajectory are indicated by |
|
|
5981 |
circle, cross, dismond, and star, respectively. |
|
|
5982 |
show_all_anchors: `bool` (default: `False`) |
|
|
5983 |
Whether to display full range of (predicted) anchors even for |
|
|
5984 |
repression-only genes. |
|
|
5985 |
title_more_info: `bool` (default: `False`) |
|
|
5986 |
Whether to display model, direction, and likelihood information for |
|
|
5987 |
the gene in title. |
|
|
5988 |
velocity_arrows: `bool` (default: `False`) |
|
|
5989 |
Whether to show velocity arrows of cells on the phase portraits. |
|
|
5990 |
downsample: `int` (default: 1) |
|
|
5991 |
How much to downsample the cells. The remaining number will be |
|
|
5992 |
`1/downsample` of original. |
|
|
5993 |
figsize: `tuple` (default: `None`) |
|
|
5994 |
Total figure size. |
|
|
5995 |
pointsize: `float` (default: 2) |
|
|
5996 |
Point size for scatter plots. |
|
|
5997 |
markersize: `float` (default: 5) |
|
|
5998 |
Point size for switch time points. |
|
|
5999 |
linewidth: `float` (default: 2) |
|
|
6000 |
Line width for connected anchors. |
|
|
6001 |
cmap: `str` (default: `coolwarm`) |
|
|
6002 |
Color map for log accessibilities or other continuous color keys when |
|
|
6003 |
plotting on U-S plane. |
|
|
6004 |
view_3d_elev: `float` (default: `None`) |
|
|
6005 |
Matplotlib 3D plot `elev` argument. `elev=90` is the same as U-S plane, |
|
|
6006 |
and `elev=0` is the same as C-U plane. |
|
|
6007 |
view_3d_azim: `float` (default: `None`) |
|
|
6008 |
Matplotlib 3D plot `azim` argument. `azim=270` is the same as U-S |
|
|
6009 |
plane, and `azim=0` is the same as C-U plane. |
|
|
6010 |
full_name: `bool` (default: `False`) |
|
|
6011 |
Show full names for chromatin, unspliced, and spliced rather than |
|
|
6012 |
using abbreviated terms c, u, and s. |
|
|
6013 |
""" |
|
|
6014 |
from pandas.api.types import is_numeric_dtype, is_categorical_dtype |
|
|
6015 |
if by not in ['us', 'cu', 'cus']: |
|
|
6016 |
raise ValueError("'by' argument must be one of ['us', 'cu', 'cus']") |
|
|
6017 |
if color_by == 'state': |
|
|
6018 |
types = [0, 1, 2, 3] |
|
|
6019 |
colors = ['tab:red', 'tab:orange', 'tab:green', 'tab:blue'] |
|
|
6020 |
elif by == 'us' and color_by == 'c': |
|
|
6021 |
types = None |
|
|
6022 |
elif color_by in adata.obs and is_numeric_dtype(adata.obs[color_by]): |
|
|
6023 |
types = None |
|
|
6024 |
colors = adata.obs[color_by].values |
|
|
6025 |
elif color_by in adata.obs and is_categorical_dtype(adata.obs[color_by]) \ |
|
|
6026 |
and color_by+'_colors' in adata.uns.keys(): |
|
|
6027 |
types = adata.obs[color_by].cat.categories |
|
|
6028 |
colors = adata.uns[f'{color_by}_colors'] |
|
|
6029 |
else: |
|
|
6030 |
raise ValueError('Currently, color key must be a single string of ' |
|
|
6031 |
'either numerical or categorical available in adata' |
|
|
6032 |
' obs, and the colors of categories can be found in' |
|
|
6033 |
' adata uns.') |
|
|
6034 |
|
|
|
6035 |
if 'velo_s_params' not in adata.uns.keys() \ |
|
|
6036 |
or 'fit_anchor_s' not in adata.varm.keys(): |
|
|
6037 |
show_anchors = False |
|
|
6038 |
if color_by == 'state' and 'fit_state' not in adata.layers.keys(): |
|
|
6039 |
raise ValueError('fit_state is not found. Please run ' |
|
|
6040 |
'recover_dynamics_chrom function first or provide a ' |
|
|
6041 |
'valid color key.') |
|
|
6042 |
|
|
|
6043 |
downsample = np.clip(int(downsample), 1, 10) |
|
|
6044 |
genes = np.array(genes) |
|
|
6045 |
missing_genes = genes[~np.isin(genes, adata.var_names)] |
|
|
6046 |
if len(missing_genes) > 0: |
|
|
6047 |
logg.update(f'{missing_genes} not found', v=0) |
|
|
6048 |
genes = genes[np.isin(genes, adata.var_names)] |
|
|
6049 |
gn = len(genes) |
|
|
6050 |
if gn == 0: |
|
|
6051 |
return |
|
|
6052 |
if gn < n_cols: |
|
|
6053 |
n_cols = gn |
|
|
6054 |
if by == 'cus': |
|
|
6055 |
fig, axs = plt.subplots(-(-gn // n_cols), n_cols, squeeze=False, |
|
|
6056 |
figsize=(3.2*n_cols, 2.7*(-(-gn // n_cols))) |
|
|
6057 |
if figsize is None else figsize, |
|
|
6058 |
subplot_kw={'projection': '3d'}) |
|
|
6059 |
else: |
|
|
6060 |
fig, axs = plt.subplots(-(-gn // n_cols), n_cols, squeeze=False, |
|
|
6061 |
figsize=(2.7*n_cols, 2.4*(-(-gn // n_cols))) |
|
|
6062 |
if figsize is None else figsize) |
|
|
6063 |
fig.patch.set_facecolor('white') |
|
|
6064 |
count = 0 |
|
|
6065 |
for gene in genes: |
|
|
6066 |
u = adata[:, gene].layers['Mu'].copy() if 'Mu' in adata.layers \ |
|
|
6067 |
else adata[:, gene].layers['unspliced'].copy() |
|
|
6068 |
s = adata[:, gene].layers['Ms'].copy() if 'Ms' in adata.layers \ |
|
|
6069 |
else adata[:, gene].layers['spliced'].copy() |
|
|
6070 |
u = u.A if sparse.issparse(u) else u |
|
|
6071 |
s = s.A if sparse.issparse(s) else s |
|
|
6072 |
u, s = np.ravel(u), np.ravel(s) |
|
|
6073 |
if 'ATAC' not in adata.layers.keys() and \ |
|
|
6074 |
'Mc' not in adata.layers.keys(): |
|
|
6075 |
show_anchors = False |
|
|
6076 |
elif 'ATAC' in adata.layers.keys(): |
|
|
6077 |
c = adata[:, gene].layers['ATAC'].copy() |
|
|
6078 |
c = c.A if sparse.issparse(c) else c |
|
|
6079 |
c = np.ravel(c) |
|
|
6080 |
elif 'Mc' in adata.layers.keys(): |
|
|
6081 |
c = adata[:, gene].layers['Mc'].copy() |
|
|
6082 |
c = c.A if sparse.issparse(c) else c |
|
|
6083 |
c = np.ravel(c) |
|
|
6084 |
|
|
|
6085 |
if velocity_arrows: |
|
|
6086 |
if 'velo_u' in adata.layers.keys(): |
|
|
6087 |
vu = adata[:, gene].layers['velo_u'].copy() |
|
|
6088 |
elif 'velocity_u' in adata.layers.keys(): |
|
|
6089 |
vu = adata[:, gene].layers['velocity_u'].copy() |
|
|
6090 |
else: |
|
|
6091 |
vu = np.zeros(adata.n_obs) |
|
|
6092 |
max_u = np.max([np.max(u), 1e-6]) |
|
|
6093 |
u /= max_u |
|
|
6094 |
vu = np.ravel(vu) |
|
|
6095 |
vu /= np.max([np.max(np.abs(vu)), 1e-6]) |
|
|
6096 |
if 'velo_s' in adata.layers.keys(): |
|
|
6097 |
vs = adata[:, gene].layers['velo_s'].copy() |
|
|
6098 |
elif 'velocity' in adata.layers.keys(): |
|
|
6099 |
vs = adata[:, gene].layers['velocity'].copy() |
|
|
6100 |
max_s = np.max([np.max(s), 1e-6]) |
|
|
6101 |
s /= max_s |
|
|
6102 |
vs = np.ravel(vs) |
|
|
6103 |
vs /= np.max([np.max(np.abs(vs)), 1e-6]) |
|
|
6104 |
if 'velo_chrom' in adata.layers.keys(): |
|
|
6105 |
vc = adata[:, gene].layers['velo_chrom'].copy() |
|
|
6106 |
max_c = np.max([np.max(c), 1e-6]) |
|
|
6107 |
c /= max_c |
|
|
6108 |
vc = np.ravel(vc) |
|
|
6109 |
vc /= np.max([np.max(np.abs(vc)), 1e-6]) |
|
|
6110 |
|
|
|
6111 |
row = count // n_cols |
|
|
6112 |
col = count % n_cols |
|
|
6113 |
ax = axs[row, col] |
|
|
6114 |
if types is not None: |
|
|
6115 |
for i in range(len(types)): |
|
|
6116 |
if color_by == 'state': |
|
|
6117 |
filt = adata[:, gene].layers['fit_state'] == types[i] |
|
|
6118 |
else: |
|
|
6119 |
filt = adata.obs[color_by] == types[i] |
|
|
6120 |
filt = np.ravel(filt) |
|
|
6121 |
if by == 'us': |
|
|
6122 |
if velocity_arrows: |
|
|
6123 |
ax.quiver(s[filt][::downsample], u[filt][::downsample], |
|
|
6124 |
vs[filt][::downsample], |
|
|
6125 |
vu[filt][::downsample], color=colors[i], |
|
|
6126 |
alpha=0.5, scale_units='xy', scale=10, |
|
|
6127 |
width=0.005, headwidth=4, headaxislength=5.5) |
|
|
6128 |
else: |
|
|
6129 |
ax.scatter(s[filt][::downsample], |
|
|
6130 |
u[filt][::downsample], s=pointsize, |
|
|
6131 |
c=colors[i], alpha=0.7) |
|
|
6132 |
elif by == 'cu': |
|
|
6133 |
if velocity_arrows: |
|
|
6134 |
ax.quiver(u[filt][::downsample], |
|
|
6135 |
c[filt][::downsample], |
|
|
6136 |
vu[filt][::downsample], |
|
|
6137 |
vc[filt][::downsample], color=colors[i], |
|
|
6138 |
alpha=0.5, scale_units='xy', scale=10, |
|
|
6139 |
width=0.005, headwidth=4, headaxislength=5.5) |
|
|
6140 |
else: |
|
|
6141 |
ax.scatter(u[filt][::downsample], |
|
|
6142 |
c[filt][::downsample], s=pointsize, |
|
|
6143 |
c=colors[i], alpha=0.7) |
|
|
6144 |
else: |
|
|
6145 |
if velocity_arrows: |
|
|
6146 |
ax.quiver(s[filt][::downsample], |
|
|
6147 |
u[filt][::downsample], c[filt][::downsample], |
|
|
6148 |
vs[filt][::downsample], |
|
|
6149 |
vu[filt][::downsample], |
|
|
6150 |
vc[filt][::downsample], |
|
|
6151 |
color=colors[i], alpha=0.4, length=0.1, |
|
|
6152 |
arrow_length_ratio=0.5, normalize=True) |
|
|
6153 |
else: |
|
|
6154 |
ax.scatter(s[filt][::downsample], |
|
|
6155 |
u[filt][::downsample], |
|
|
6156 |
c[filt][::downsample], s=pointsize, |
|
|
6157 |
c=colors[i], alpha=0.7) |
|
|
6158 |
elif color_by == 'c': |
|
|
6159 |
if 'velo_s_params' in adata.uns.keys() and \ |
|
|
6160 |
'outlier' in adata.uns['velo_s_params']: |
|
|
6161 |
outlier = adata.uns['velo_s_params']['outlier'] |
|
|
6162 |
else: |
|
|
6163 |
outlier = 99.8 |
|
|
6164 |
non_zero = (u > 0) & (s > 0) & (c > 0) |
|
|
6165 |
non_outlier = u < np.percentile(u, outlier) |
|
|
6166 |
non_outlier &= s < np.percentile(s, outlier) |
|
|
6167 |
non_outlier &= c < np.percentile(c, outlier) |
|
|
6168 |
c -= np.min(c) |
|
|
6169 |
c /= np.max(c) |
|
|
6170 |
if velocity_arrows: |
|
|
6171 |
ax.quiver(s[non_zero & non_outlier][::downsample], |
|
|
6172 |
u[non_zero & non_outlier][::downsample], |
|
|
6173 |
vs[non_zero & non_outlier][::downsample], |
|
|
6174 |
vu[non_zero & non_outlier][::downsample], |
|
|
6175 |
np.log1p(c[non_zero & non_outlier][::downsample]), |
|
|
6176 |
alpha=0.5, |
|
|
6177 |
scale_units='xy', scale=10, width=0.005, |
|
|
6178 |
headwidth=4, headaxislength=5.5, cmap=cmap) |
|
|
6179 |
else: |
|
|
6180 |
ax.scatter(s[non_zero & non_outlier][::downsample], |
|
|
6181 |
u[non_zero & non_outlier][::downsample], |
|
|
6182 |
s=pointsize, |
|
|
6183 |
c=np.log1p(c[non_zero & non_outlier][::downsample]), |
|
|
6184 |
alpha=0.8, cmap=cmap) |
|
|
6185 |
else: |
|
|
6186 |
if by == 'us': |
|
|
6187 |
if velocity_arrows: |
|
|
6188 |
ax.quiver(s[::downsample], u[::downsample], |
|
|
6189 |
vs[::downsample], vu[::downsample], |
|
|
6190 |
colors[::downsample], alpha=0.5, |
|
|
6191 |
scale_units='xy', scale=10, width=0.005, |
|
|
6192 |
headwidth=4, headaxislength=5.5, cmap=cmap) |
|
|
6193 |
else: |
|
|
6194 |
ax.scatter(s[::downsample], u[::downsample], s=pointsize, |
|
|
6195 |
c=colors[::downsample], alpha=0.7, cmap=cmap) |
|
|
6196 |
elif by == 'cu': |
|
|
6197 |
if velocity_arrows: |
|
|
6198 |
ax.quiver(u[::downsample], c[::downsample], |
|
|
6199 |
vu[::downsample], vc[::downsample], |
|
|
6200 |
colors[::downsample], alpha=0.5, |
|
|
6201 |
scale_units='xy', scale=10, width=0.005, |
|
|
6202 |
headwidth=4, headaxislength=5.5, cmap=cmap) |
|
|
6203 |
else: |
|
|
6204 |
ax.scatter(u[::downsample], c[::downsample], s=pointsize, |
|
|
6205 |
c=colors[::downsample], alpha=0.7, cmap=cmap) |
|
|
6206 |
else: |
|
|
6207 |
if velocity_arrows: |
|
|
6208 |
ax.quiver(s[::downsample], u[::downsample], |
|
|
6209 |
c[::downsample], vs[::downsample], |
|
|
6210 |
vu[::downsample], vc[::downsample], |
|
|
6211 |
colors[::downsample], alpha=0.4, length=0.1, |
|
|
6212 |
arrow_length_ratio=0.5, normalize=True, |
|
|
6213 |
cmap=cmap) |
|
|
6214 |
else: |
|
|
6215 |
ax.scatter(s[::downsample], u[::downsample], |
|
|
6216 |
c[::downsample], s=pointsize, |
|
|
6217 |
c=colors[::downsample], alpha=0.7, cmap=cmap) |
|
|
6218 |
|
|
|
6219 |
if show_anchors: |
|
|
6220 |
min_idx = int(adata[:, gene].var['fit_anchor_min_idx']) |
|
|
6221 |
max_idx = int(adata[:, gene].var['fit_anchor_max_idx']) |
|
|
6222 |
a_c = adata[:, gene].varm['fit_anchor_c']\ |
|
|
6223 |
.ravel()[min_idx:max_idx+1].copy() |
|
|
6224 |
a_u = adata[:, gene].varm['fit_anchor_u']\ |
|
|
6225 |
.ravel()[min_idx:max_idx+1].copy() |
|
|
6226 |
a_s = adata[:, gene].varm['fit_anchor_s']\ |
|
|
6227 |
.ravel()[min_idx:max_idx+1].copy() |
|
|
6228 |
if velocity_arrows: |
|
|
6229 |
a_c /= max_c |
|
|
6230 |
a_u /= max_u |
|
|
6231 |
a_s /= max_s |
|
|
6232 |
if by == 'us': |
|
|
6233 |
ax.plot(a_s, a_u, linewidth=linewidth, color='black', |
|
|
6234 |
alpha=0.7, zorder=1000) |
|
|
6235 |
elif by == 'cu': |
|
|
6236 |
ax.plot(a_u, a_c, linewidth=linewidth, color='black', |
|
|
6237 |
alpha=0.7, zorder=1000) |
|
|
6238 |
else: |
|
|
6239 |
ax.plot(a_s, a_u, a_c, linewidth=linewidth, color='black', |
|
|
6240 |
alpha=0.7, zorder=1000) |
|
|
6241 |
if show_all_anchors: |
|
|
6242 |
a_c_pre = adata[:, gene].varm['fit_anchor_c']\ |
|
|
6243 |
.ravel()[:min_idx].copy() |
|
|
6244 |
a_u_pre = adata[:, gene].varm['fit_anchor_u']\ |
|
|
6245 |
.ravel()[:min_idx].copy() |
|
|
6246 |
a_s_pre = adata[:, gene].varm['fit_anchor_s']\ |
|
|
6247 |
.ravel()[:min_idx].copy() |
|
|
6248 |
if velocity_arrows: |
|
|
6249 |
a_c_pre /= max_c |
|
|
6250 |
a_u_pre /= max_u |
|
|
6251 |
a_s_pre /= max_s |
|
|
6252 |
if len(a_c_pre) > 0: |
|
|
6253 |
if by == 'us': |
|
|
6254 |
ax.plot(a_s_pre, a_u_pre, linewidth=linewidth/1.3, |
|
|
6255 |
color='black', alpha=0.6, zorder=1000) |
|
|
6256 |
elif by == 'cu': |
|
|
6257 |
ax.plot(a_u_pre, a_c_pre, linewidth=linewidth/1.3, |
|
|
6258 |
color='black', alpha=0.6, zorder=1000) |
|
|
6259 |
else: |
|
|
6260 |
ax.plot(a_s_pre, a_u_pre, a_c_pre, |
|
|
6261 |
linewidth=linewidth/1.3, color='black', |
|
|
6262 |
alpha=0.6, zorder=1000) |
|
|
6263 |
if show_switches: |
|
|
6264 |
t_sw_array = np.array([adata[:, gene].var['fit_t_sw1'] |
|
|
6265 |
.values[0], |
|
|
6266 |
adata[:, gene].var['fit_t_sw2'] |
|
|
6267 |
.values[0], |
|
|
6268 |
adata[:, gene].var['fit_t_sw3'] |
|
|
6269 |
.values[0]]) |
|
|
6270 |
in_range = (t_sw_array > 0) & (t_sw_array < 20) |
|
|
6271 |
a_c_sw = adata[:, gene].varm['fit_anchor_c_sw'].ravel().copy() |
|
|
6272 |
a_u_sw = adata[:, gene].varm['fit_anchor_u_sw'].ravel().copy() |
|
|
6273 |
a_s_sw = adata[:, gene].varm['fit_anchor_s_sw'].ravel().copy() |
|
|
6274 |
if velocity_arrows: |
|
|
6275 |
a_c_sw /= max_c |
|
|
6276 |
a_u_sw /= max_u |
|
|
6277 |
a_s_sw /= max_s |
|
|
6278 |
if in_range[0]: |
|
|
6279 |
c_sw1, u_sw1, s_sw1 = a_c_sw[0], a_u_sw[0], a_s_sw[0] |
|
|
6280 |
if by == 'us': |
|
|
6281 |
ax.plot([s_sw1], [u_sw1], "om", markersize=markersize, |
|
|
6282 |
zorder=2000) |
|
|
6283 |
elif by == 'cu': |
|
|
6284 |
ax.plot([u_sw1], [c_sw1], "om", markersize=markersize, |
|
|
6285 |
zorder=2000) |
|
|
6286 |
else: |
|
|
6287 |
ax.plot([s_sw1], [u_sw1], [c_sw1], "om", |
|
|
6288 |
markersize=markersize, zorder=2000) |
|
|
6289 |
if in_range[1]: |
|
|
6290 |
c_sw2, u_sw2, s_sw2 = a_c_sw[1], a_u_sw[1], a_s_sw[1] |
|
|
6291 |
if by == 'us': |
|
|
6292 |
ax.plot([s_sw2], [u_sw2], "Xm", markersize=markersize, |
|
|
6293 |
zorder=2000) |
|
|
6294 |
elif by == 'cu': |
|
|
6295 |
ax.plot([u_sw2], [c_sw2], "Xm", markersize=markersize, |
|
|
6296 |
zorder=2000) |
|
|
6297 |
else: |
|
|
6298 |
ax.plot([s_sw2], [u_sw2], [c_sw2], "Xm", |
|
|
6299 |
markersize=markersize, zorder=2000) |
|
|
6300 |
if in_range[2]: |
|
|
6301 |
c_sw3, u_sw3, s_sw3 = a_c_sw[2], a_u_sw[2], a_s_sw[2] |
|
|
6302 |
if by == 'us': |
|
|
6303 |
ax.plot([s_sw3], [u_sw3], "Dm", markersize=markersize, |
|
|
6304 |
zorder=2000) |
|
|
6305 |
elif by == 'cu': |
|
|
6306 |
ax.plot([u_sw3], [c_sw3], "Dm", markersize=markersize, |
|
|
6307 |
zorder=2000) |
|
|
6308 |
else: |
|
|
6309 |
ax.plot([s_sw3], [u_sw3], [c_sw3], "Dm", |
|
|
6310 |
markersize=markersize, zorder=2000) |
|
|
6311 |
if max_idx > adata.uns['velo_s_params']['t'] - 4: |
|
|
6312 |
if by == 'us': |
|
|
6313 |
ax.plot([a_s[-1]], [a_u[-1]], "*m", |
|
|
6314 |
markersize=markersize, zorder=2000) |
|
|
6315 |
elif by == 'cu': |
|
|
6316 |
ax.plot([a_u[-1]], [a_c[-1]], "*m", |
|
|
6317 |
markersize=markersize, zorder=2000) |
|
|
6318 |
else: |
|
|
6319 |
ax.plot([a_s[-1]], [a_u[-1]], [a_c[-1]], "*m", |
|
|
6320 |
markersize=markersize, zorder=2000) |
|
|
6321 |
|
|
|
6322 |
if by == 'cus' and \ |
|
|
6323 |
(view_3d_elev is not None or view_3d_azim is not None): |
|
|
6324 |
# US: elev=90, azim=270. CU: elev=0, azim=0. |
|
|
6325 |
ax.view_init(elev=view_3d_elev, azim=view_3d_azim) |
|
|
6326 |
title = gene |
|
|
6327 |
if title_more_info: |
|
|
6328 |
if 'fit_model' in adata.var: |
|
|
6329 |
title += f" M{int(adata[:,gene].var['fit_model'].values[0])}" |
|
|
6330 |
if 'fit_direction' in adata.var: |
|
|
6331 |
title += f" {adata[:,gene].var['fit_direction'].values[0]}" |
|
|
6332 |
if 'fit_likelihood' in adata.var \ |
|
|
6333 |
and not np.all(adata.var['fit_likelihood'].values == -1): |
|
|
6334 |
title += " " |
|
|
6335 |
f"{adata[:,gene].var['fit_likelihood'].values[0]:.3g}" |
|
|
6336 |
ax.set_title(f'{title}', fontsize=11) |
|
|
6337 |
if by == 'us': |
|
|
6338 |
ax.set_xlabel('spliced' if full_name else 's') |
|
|
6339 |
ax.set_ylabel('unspliced' if full_name else 'u') |
|
|
6340 |
elif by == 'cu': |
|
|
6341 |
ax.set_xlabel('unspliced' if full_name else 'u') |
|
|
6342 |
ax.set_ylabel('chromatin' if full_name else 'c') |
|
|
6343 |
elif by == 'cus': |
|
|
6344 |
ax.set_xlabel('spliced' if full_name else 's') |
|
|
6345 |
ax.set_ylabel('unspliced' if full_name else 'u') |
|
|
6346 |
ax.set_zlabel('chromatin' if full_name else 'c') |
|
|
6347 |
if by in ['us', 'cu']: |
|
|
6348 |
if not axis_on: |
|
|
6349 |
ax.xaxis.set_ticks_position('none') |
|
|
6350 |
ax.yaxis.set_ticks_position('none') |
|
|
6351 |
ax.get_xaxis().set_visible(False) |
|
|
6352 |
ax.get_yaxis().set_visible(False) |
|
|
6353 |
if not frame_on: |
|
|
6354 |
ax.xaxis.set_ticks_position('none') |
|
|
6355 |
ax.yaxis.set_ticks_position('none') |
|
|
6356 |
ax.set_frame_on(False) |
|
|
6357 |
elif by == 'cus': |
|
|
6358 |
if not axis_on: |
|
|
6359 |
ax.set_xlabel('') |
|
|
6360 |
ax.set_ylabel('') |
|
|
6361 |
ax.set_zlabel('') |
|
|
6362 |
ax.xaxis.set_ticklabels([]) |
|
|
6363 |
ax.yaxis.set_ticklabels([]) |
|
|
6364 |
ax.zaxis.set_ticklabels([]) |
|
|
6365 |
if not frame_on: |
|
|
6366 |
ax.xaxis._axinfo['grid']['color'] = (1, 1, 1, 0) |
|
|
6367 |
ax.yaxis._axinfo['grid']['color'] = (1, 1, 1, 0) |
|
|
6368 |
ax.zaxis._axinfo['grid']['color'] = (1, 1, 1, 0) |
|
|
6369 |
ax.xaxis._axinfo['tick']['inward_factor'] = 0 |
|
|
6370 |
ax.xaxis._axinfo['tick']['outward_factor'] = 0 |
|
|
6371 |
ax.yaxis._axinfo['tick']['inward_factor'] = 0 |
|
|
6372 |
ax.yaxis._axinfo['tick']['outward_factor'] = 0 |
|
|
6373 |
ax.zaxis._axinfo['tick']['inward_factor'] = 0 |
|
|
6374 |
ax.zaxis._axinfo['tick']['outward_factor'] = 0 |
|
|
6375 |
count += 1 |
|
|
6376 |
for i in range(col+1, n_cols): |
|
|
6377 |
fig.delaxes(axs[row, i]) |
|
|
6378 |
fig.tight_layout() |