|
a |
|
b/VITAE/inference.py |
|
|
1 |
import warnings |
|
|
2 |
#from typing import Optional |
|
|
3 |
|
|
|
4 |
import numpy as np |
|
|
5 |
import networkx as nx |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
class Inferer(object): |
|
|
9 |
''' |
|
|
10 |
The class for doing inference based on posterior estimations. |
|
|
11 |
''' |
|
|
12 |
|
|
|
13 |
def __init__(self, n_states: int): |
|
|
14 |
''' |
|
|
15 |
Parameters |
|
|
16 |
---------- |
|
|
17 |
n_states : int |
|
|
18 |
The number of vertices in the latent space. |
|
|
19 |
''' |
|
|
20 |
self.n_states = n_states |
|
|
21 |
self.n_categories = int(n_states*(n_states+1)/2) |
|
|
22 |
# self.A, self.B = np.nonzero(np.triu(np.ones(n_states))) |
|
|
23 |
## indicator of the catagories |
|
|
24 |
self.C = np.triu(np.ones(n_states)) |
|
|
25 |
self.C[self.C>0] = np.arange(self.n_categories) |
|
|
26 |
self.C = self.C.astype(int) |
|
|
27 |
|
|
|
28 |
def build_graphs(self, w_tilde, pc_x, method: str = 'mean', thres: float = 0.5, no_loop: bool = False, |
|
|
29 |
cutoff = 0): |
|
|
30 |
'''Build the backbone. |
|
|
31 |
|
|
|
32 |
Parameters |
|
|
33 |
---------- |
|
|
34 |
pc_x : np.array |
|
|
35 |
\([N, K]\) The estimated \(p(c_i|Y_i,X_i)\). |
|
|
36 |
method : string, optional |
|
|
37 |
'mean', 'modified_mean', 'map', or 'modified_map'. |
|
|
38 |
thres : float, optional |
|
|
39 |
The threshold used for filtering edges \(e_{ij}\) that \((n_{i}+n_{j}+e_{ij})/N<thres\), only applied to mean method. |
|
|
40 |
|
|
|
41 |
Retruns |
|
|
42 |
---------- |
|
|
43 |
G : nx.Graph |
|
|
44 |
The graph of edge scores. |
|
|
45 |
''' |
|
|
46 |
self.no_loop = no_loop |
|
|
47 |
# self.w_tilde = w_tilde |
|
|
48 |
|
|
|
49 |
graph = np.zeros((self.n_states,self.n_states)) |
|
|
50 |
if method=='mean': |
|
|
51 |
for i in range(self.n_states-1): |
|
|
52 |
for j in range(i+1,self.n_states): |
|
|
53 |
idx = np.sum(pc_x[:,self.C[[i,i,j],[i,j,j]]], axis=1)>=thres |
|
|
54 |
if np.sum(idx)>0: |
|
|
55 |
graph[i,j] = np.mean(pc_x[idx,self.C[i,j]]/np.sum(pc_x[idx][:,self.C[[i,i,j],[i,j,j]]], axis=-1)) |
|
|
56 |
elif method=='modified_mean': |
|
|
57 |
for i in range(self.n_states-1): |
|
|
58 |
for j in range(i+1,self.n_states): |
|
|
59 |
idx = np.sum(pc_x[:,self.C[[i,i,j],[i,j,j]]], axis=1)>=thres |
|
|
60 |
if np.sum(idx)>0: |
|
|
61 |
graph[i,j] = np.sum(pc_x[idx,self.C[i,j]])/np.sum(pc_x[idx][:,self.C[[i,i,j],[i,j,j]]]) |
|
|
62 |
elif method=='map': |
|
|
63 |
c = np.argmax(pc_x, axis=-1) |
|
|
64 |
for i in range(self.n_states-1): |
|
|
65 |
for j in range(i+1,self.n_states): |
|
|
66 |
if np.sum(c==self.C[i,j])>0: |
|
|
67 |
graph[i,j] = np.sum(c==self.C[i,j])/np.sum((c==self.C[i,j])|(c==self.C[i,i])|(c==self.C[j,j])) |
|
|
68 |
elif method=='modified_map': |
|
|
69 |
c = np.argmax(pc_x, axis=-1) |
|
|
70 |
for i in range(self.n_states-1): |
|
|
71 |
for j in range(i+1,self.n_states): |
|
|
72 |
graph[i,j] = np.sum(c==self.C[i,j])/(np.sum((w_tilde[:,i]>0.5)|(w_tilde[:,j]>0.5))+1e-16) |
|
|
73 |
elif method=='raw_map': |
|
|
74 |
c = np.argmax(pc_x, axis=-1) |
|
|
75 |
for i in range(self.n_states-1): |
|
|
76 |
for j in range(i+1,self.n_states): |
|
|
77 |
if np.sum(c==self.C[i,j])>0: |
|
|
78 |
graph[i,j] = np.sum(c==self.C[i,j])/np.sum(np.isin(c, np.diagonal(self.C)) == False) |
|
|
79 |
elif method == "w_base": |
|
|
80 |
for i in range(self.n_states): |
|
|
81 |
for j in range(i+1,self.n_states): |
|
|
82 |
two_vertice_max_w = w_tilde[(np.argmax(w_tilde, axis=1) == i) | (np.argmax(w_tilde, axis=1) == j),:] |
|
|
83 |
num_two_vertice = two_vertice_max_w.shape[0] |
|
|
84 |
if num_two_vertice > 0: |
|
|
85 |
graph[i, j] = np.sum( |
|
|
86 |
np.abs(two_vertice_max_w[:, i] - two_vertice_max_w[:, j]) < 0.1) / num_two_vertice |
|
|
87 |
elif method == "modified_w_base": |
|
|
88 |
top2_idx = np.argpartition(w_tilde, -2, axis=1)[:, -2:] |
|
|
89 |
for i in range(self.n_states): |
|
|
90 |
for j in range(i + 1, self.n_states): |
|
|
91 |
two_vertice_max_w = np.all(top2_idx == [i, j], axis=1) | np.all(top2_idx == [j, i], axis=1) |
|
|
92 |
two_vertice_max_w = w_tilde[two_vertice_max_w, :] |
|
|
93 |
vertice_count = w_tilde[(np.argmax(w_tilde, axis=1) == i) | (np.argmax(w_tilde, axis=1) == j), :] |
|
|
94 |
vertice_count = vertice_count.shape[0] |
|
|
95 |
if vertice_count > 0: |
|
|
96 |
edge_count = \ |
|
|
97 |
np.max((two_vertice_max_w[:, i], two_vertice_max_w[:, j]), axis=0) \ |
|
|
98 |
/ (two_vertice_max_w[:, i] + two_vertice_max_w[:, j]) |
|
|
99 |
edge_count = np.sum(edge_count < 0.55) |
|
|
100 |
graph[i, j] = edge_count / vertice_count |
|
|
101 |
else: |
|
|
102 |
raise ValueError("Invalid method, must be one of 'mean', 'modified_mean', 'map', 'modified_map','raw_map','w_base', and 'modified_w_base'.") |
|
|
103 |
|
|
|
104 |
graph[graph<=cutoff] = 0 |
|
|
105 |
G = nx.from_numpy_array(graph) |
|
|
106 |
|
|
|
107 |
if self.no_loop and not nx.is_tree(G): |
|
|
108 |
# prune if there are no loops |
|
|
109 |
G = nx.maximum_spanning_tree(G) |
|
|
110 |
|
|
|
111 |
return G |
|
|
112 |
|
|
|
113 |
def modify_wtilde(self, w_tilde, edges): |
|
|
114 |
'''Project \(\\tilde{w}\) to the estimated backbone. |
|
|
115 |
|
|
|
116 |
Parameters |
|
|
117 |
---------- |
|
|
118 |
w_tilde : np.array |
|
|
119 |
\([N, k]\) The estimated \(\\tilde{w}\). |
|
|
120 |
edges : np.array |
|
|
121 |
\([|\\mathcal{E}(\\widehat{\\mathcal{B}})|, 2]\). |
|
|
122 |
|
|
|
123 |
Retruns |
|
|
124 |
---------- |
|
|
125 |
w : np.array |
|
|
126 |
The projected \(\\tilde{w}\). |
|
|
127 |
''' |
|
|
128 |
w = np.zeros_like(w_tilde) |
|
|
129 |
|
|
|
130 |
# projection on nodes |
|
|
131 |
best_proj_err_node = np.sum(w_tilde**2, axis=-1) - 2*np.max(w_tilde, axis=-1) +1 |
|
|
132 |
best_proj_err_node_ind = np.argmax(w_tilde, axis=-1) |
|
|
133 |
|
|
|
134 |
if len(edges)>0: |
|
|
135 |
# projection on edges |
|
|
136 |
idc = np.tile(np.arange(w.shape[0]), (2,1)).T |
|
|
137 |
ide = edges[np.argmax(np.sum(w_tilde[:,edges], axis=-1)**2 - |
|
|
138 |
4 * np.prod(w_tilde[:,edges], axis=-1) + |
|
|
139 |
2 * np.sum(w_tilde[:,edges], axis=-1), axis=-1)] |
|
|
140 |
w[idc, ide] = w_tilde[idc, ide] + (1-np.sum(w_tilde[idc, ide], axis=-1, keepdims=True))/2 |
|
|
141 |
best_proj_err_edge = np.sum(w_tilde**2, axis=-1) - np.sum(w_tilde[idc, ide]**2, axis=-1) + (1-np.sum(w_tilde[idc, ide], axis=-1))**2/2 |
|
|
142 |
|
|
|
143 |
idc = (best_proj_err_node<best_proj_err_edge) |
|
|
144 |
w[idc,:] = np.eye(w_tilde.shape[-1])[best_proj_err_node_ind[idc]] |
|
|
145 |
else: |
|
|
146 |
idc = np.arange(w.shape[0]) |
|
|
147 |
w[idc, best_proj_err_node_ind] = 1 |
|
|
148 |
return w |
|
|
149 |
|
|
|
150 |
|
|
|
151 |
def build_milestone_net(self, subgraph, init_node: int): |
|
|
152 |
'''Build the milestone network. |
|
|
153 |
|
|
|
154 |
Parameters |
|
|
155 |
---------- |
|
|
156 |
subgraph : nx.Graph |
|
|
157 |
The connected component of the backbone given the root vertex. |
|
|
158 |
init_node : int |
|
|
159 |
The root vertex. |
|
|
160 |
|
|
|
161 |
Returns |
|
|
162 |
---------- |
|
|
163 |
df_subgraph : pd.DataFrame |
|
|
164 |
The milestone network. |
|
|
165 |
''' |
|
|
166 |
if len(subgraph)==1: |
|
|
167 |
warnings.warn('Singular node.') |
|
|
168 |
return [] |
|
|
169 |
elif nx.is_directed_acyclic_graph(subgraph): |
|
|
170 |
milestone_net = [] |
|
|
171 |
for edge in list(subgraph.edges): |
|
|
172 |
if edge[0]==init_node: |
|
|
173 |
dist = 1 |
|
|
174 |
elif edge[1]==init_node: |
|
|
175 |
paths_0 = nx.all_simple_paths(subgraph, source=init_node, target=edge[0]) |
|
|
176 |
dist = - (np.max([len(p) for p in paths_1]) - 1) |
|
|
177 |
else: |
|
|
178 |
paths_0 = nx.all_simple_paths(subgraph, source=init_node, target=edge[0]) |
|
|
179 |
paths_1 = nx.all_simple_paths(subgraph, source=init_node, target=edge[1]) |
|
|
180 |
dist = np.max([len(p) for p in paths_1]) - np.max([len(p) for p in paths_0]) |
|
|
181 |
milestone_net.append([edge[0], edge[1], dist]) |
|
|
182 |
else: |
|
|
183 |
# Dijkstra's Algorithm to find the shortest path |
|
|
184 |
unvisited = {node: {'parent':None, |
|
|
185 |
'score':np.inf, |
|
|
186 |
'distance':np.inf} for node in subgraph.nodes} |
|
|
187 |
current = init_node |
|
|
188 |
currentScore = 0 |
|
|
189 |
currentDistance = 0 |
|
|
190 |
unvisited[current]['score'] = currentScore |
|
|
191 |
|
|
|
192 |
milestone_net = [] |
|
|
193 |
while True: |
|
|
194 |
for neighbour in subgraph.neighbors(current): |
|
|
195 |
if neighbour not in unvisited: continue |
|
|
196 |
newScore = currentScore + subgraph[current][neighbour]['weight'] |
|
|
197 |
if unvisited[neighbour]['score'] > newScore: |
|
|
198 |
unvisited[neighbour]['score'] = newScore |
|
|
199 |
unvisited[neighbour]['parent'] = current |
|
|
200 |
unvisited[neighbour]['distance'] = currentDistance+1 |
|
|
201 |
|
|
|
202 |
if len(unvisited)<len(subgraph): |
|
|
203 |
milestone_net.append([unvisited[current]['parent'], |
|
|
204 |
current, |
|
|
205 |
unvisited[current]['distance']]) |
|
|
206 |
del unvisited[current] |
|
|
207 |
if not unvisited: break |
|
|
208 |
current, currentScore, currentDistance = \ |
|
|
209 |
sorted([(i[0],i[1]['score'],i[1]['distance']) for i in unvisited.items()], |
|
|
210 |
key = lambda x: x[1])[0] |
|
|
211 |
return np.array(milestone_net) |
|
|
212 |
|
|
|
213 |
|
|
|
214 |
def comp_pseudotime(self, milestone_net, init_node: int, w): |
|
|
215 |
'''Compute pseudotime. |
|
|
216 |
|
|
|
217 |
Parameters |
|
|
218 |
---------- |
|
|
219 |
milestone_net : pd.DataFrame |
|
|
220 |
The milestone network. |
|
|
221 |
init_node : int |
|
|
222 |
The root vertex. |
|
|
223 |
w : np.array |
|
|
224 |
\([N, k]\) The projected \(\\tilde{w}\). |
|
|
225 |
|
|
|
226 |
Returns |
|
|
227 |
---------- |
|
|
228 |
pseudotime : np.array |
|
|
229 |
\([N, k]\) The estimated pseudtotime. |
|
|
230 |
''' |
|
|
231 |
pseudotime = np.empty(w.shape[0]) |
|
|
232 |
pseudotime.fill(np.nan) |
|
|
233 |
pseudotime[w[:,init_node]==1] = 0 |
|
|
234 |
|
|
|
235 |
if len(milestone_net)>0: |
|
|
236 |
for i in range(len(milestone_net)): |
|
|
237 |
_from, _to = milestone_net[i,:2] |
|
|
238 |
_from, _to = int(_from), int(_to) |
|
|
239 |
|
|
|
240 |
idc = ((w[:,_from]>0)&(w[:,_to]>0)) | (w[:,_to]==1) |
|
|
241 |
pseudotime[idc] = w[idc,_to] + milestone_net[i,-1] - 1 |
|
|
242 |
|
|
|
243 |
return pseudotime |
|
|
244 |
|
|
|
245 |
|