|
a |
|
b/src/multivelo/pyWNN.py |
|
|
1 |
# pyWNN is a package developed by Dylan Kotliar (GitHub username: dylkot), published under the MIT license. |
|
|
2 |
|
|
|
3 |
# The original release, including tutorials, can be found here: https://github.com/dylkot/pyWNN |
|
|
4 |
|
|
|
5 |
import scanpy as sc |
|
|
6 |
import numpy as np |
|
|
7 |
from sklearn import preprocessing |
|
|
8 |
from scipy.sparse import csr_matrix, lil_matrix, diags |
|
|
9 |
import time |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
|
|
|
13 |
def get_nearestneighbor(knn, neighbor=1): |
|
|
14 |
'''For each row of knn, returns the column with the lowest value |
|
|
15 |
I.e. the nearest neighbor''' |
|
|
16 |
indices = knn.indices |
|
|
17 |
indptr = knn.indptr |
|
|
18 |
data = knn.data |
|
|
19 |
nn_idx = [] |
|
|
20 |
for i in range(knn.shape[0]): |
|
|
21 |
cols = indices[indptr[i]:indptr[i+1]] |
|
|
22 |
rowvals = data[indptr[i]:indptr[i+1]] |
|
|
23 |
idx = np.argsort(rowvals) |
|
|
24 |
nn_idx.append(cols[idx[neighbor-1]]) |
|
|
25 |
return(np.array(nn_idx)) |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
def compute_bw(knn_adj, embedding, n_neighbors=20): |
|
|
29 |
intersect = knn_adj.dot(knn_adj.T) |
|
|
30 |
indices = intersect.indices |
|
|
31 |
indptr = intersect.indptr |
|
|
32 |
data = intersect.data |
|
|
33 |
data = data / ((n_neighbors*2) - data) |
|
|
34 |
bandwidth = [] |
|
|
35 |
for i in range(intersect.shape[0]): |
|
|
36 |
cols = indices[indptr[i]:indptr[i+1]] |
|
|
37 |
rowvals = data[indptr[i]:indptr[i+1]] |
|
|
38 |
idx = np.argsort(rowvals) |
|
|
39 |
valssort = rowvals[idx] |
|
|
40 |
numinset = len(cols) |
|
|
41 |
if numinset<n_neighbors: |
|
|
42 |
sys.exit('Fewer than 20 cells with Jacard sim > 0') |
|
|
43 |
else: |
|
|
44 |
curval = valssort[n_neighbors] |
|
|
45 |
for num in range(n_neighbors, numinset): |
|
|
46 |
if valssort[num]!=curval: |
|
|
47 |
break |
|
|
48 |
else: |
|
|
49 |
num+=1 |
|
|
50 |
minjacinset = cols[idx][:num] |
|
|
51 |
if num <n_neighbors: |
|
|
52 |
print('shouldnt end up here') |
|
|
53 |
sys.exit(-1) |
|
|
54 |
else: |
|
|
55 |
euc_dist = ((embedding[minjacinset,:]-embedding[i,:])**2).sum(axis=1)**.5 |
|
|
56 |
euc_dist_sorted = np.sort(euc_dist)[::-1] |
|
|
57 |
bandwidth.append( np.mean(euc_dist_sorted[:n_neighbors]) ) |
|
|
58 |
return(np.array(bandwidth)) |
|
|
59 |
|
|
|
60 |
|
|
|
61 |
def compute_affinity(dist_to_predict, dist_to_nn, bw): |
|
|
62 |
affinity = dist_to_predict-dist_to_nn |
|
|
63 |
affinity[affinity<0]=0 |
|
|
64 |
affinity = affinity * -1 |
|
|
65 |
affinity = np.exp(affinity / (bw-dist_to_nn)) |
|
|
66 |
return(affinity) |
|
|
67 |
|
|
|
68 |
|
|
|
69 |
def dist_from_adj(adjacency, embed1, embed2, nndist1, nndist2): |
|
|
70 |
dist1 = lil_matrix(adjacency.shape) |
|
|
71 |
dist2 = lil_matrix(adjacency.shape) |
|
|
72 |
|
|
|
73 |
count = 0 |
|
|
74 |
indices = adjacency.indices |
|
|
75 |
indptr = adjacency.indptr |
|
|
76 |
ncells = adjacency.shape[0] |
|
|
77 |
|
|
|
78 |
tic = time.perf_counter() |
|
|
79 |
for i in range(ncells): |
|
|
80 |
for j in range(indptr[i], indptr[i+1]): |
|
|
81 |
col = indices[j] |
|
|
82 |
a = (((embed1[i,:] - embed1[col,:])**2).sum()**.5) - nndist1[i] |
|
|
83 |
if a == 0: dist1[i,col] = np.nan |
|
|
84 |
else: dist1[i,col] = a |
|
|
85 |
b = (((embed2[i,:] - embed2[col,:])**2).sum()**.5) - nndist2[i] |
|
|
86 |
if b == 0: dist2[i,col] = np.nan |
|
|
87 |
else: dist2[i,col] = b |
|
|
88 |
|
|
|
89 |
if (i % 2000) == 0: |
|
|
90 |
toc = time.perf_counter() |
|
|
91 |
print('%d out of %d %.2f seconds elapsed' % (i, ncells, toc-tic)) |
|
|
92 |
|
|
|
93 |
return(csr_matrix(dist1), csr_matrix(dist2)) |
|
|
94 |
|
|
|
95 |
|
|
|
96 |
def select_topK(dist, n_neighbors=20): |
|
|
97 |
indices = dist.indices |
|
|
98 |
indptr = dist.indptr |
|
|
99 |
data = dist.data |
|
|
100 |
nrows = dist.shape[0] |
|
|
101 |
|
|
|
102 |
final_data = [] |
|
|
103 |
final_col_ind = [] |
|
|
104 |
|
|
|
105 |
tic = time.perf_counter() |
|
|
106 |
for i in range(nrows): |
|
|
107 |
cols = indices[indptr[i]:indptr[i+1]] |
|
|
108 |
rowvals = data[indptr[i]:indptr[i+1]] |
|
|
109 |
idx = np.argsort(rowvals) |
|
|
110 |
final_data.append(rowvals[idx[(-1*n_neighbors):]]) |
|
|
111 |
final_col_ind.append(cols[idx[(-1*n_neighbors):]]) |
|
|
112 |
|
|
|
113 |
final_data = np.concatenate(final_data) |
|
|
114 |
final_col_ind = np.concatenate(final_col_ind) |
|
|
115 |
final_row_ind = np.tile(np.arange(nrows), (n_neighbors, 1)).reshape(-1, order='F') |
|
|
116 |
|
|
|
117 |
result = csr_matrix((final_data, (final_row_ind, final_col_ind)), shape=(nrows, dist.shape[1])) |
|
|
118 |
|
|
|
119 |
return(result) |
|
|
120 |
|
|
|
121 |
|
|
|
122 |
class pyWNN(): |
|
|
123 |
|
|
|
124 |
def __init__(self, adata, reps=['X_pca', 'X_apca'], n_neighbors=20, npcs=[20, 20], seed=14, distances=None): |
|
|
125 |
"""\ |
|
|
126 |
Class for running weighted nearest neighbors analysis as described in Hao |
|
|
127 |
et al 2021. |
|
|
128 |
""" |
|
|
129 |
|
|
|
130 |
self.seed = seed |
|
|
131 |
np.random.seed(seed) |
|
|
132 |
|
|
|
133 |
if len(reps)>2: |
|
|
134 |
sys.exit('WNN currently only implemented for 2 modalities') |
|
|
135 |
|
|
|
136 |
self.adata = adata.copy() |
|
|
137 |
self.reps = [r+'_norm' for r in reps] |
|
|
138 |
self.npcs = npcs |
|
|
139 |
for (i,r) in enumerate(reps): |
|
|
140 |
self.adata.obsm[self.reps[i]] = preprocessing.normalize(adata.obsm[r][:,0:npcs[i]]) |
|
|
141 |
|
|
|
142 |
self.n_neighbors = n_neighbors |
|
|
143 |
if distances is None: |
|
|
144 |
print('Computing KNN distance matrices using default Scanpy implementation') |
|
|
145 |
sc.pp.neighbors(self.adata, n_neighbors=n_neighbors, n_pcs=npcs[0], use_rep=self.reps[0], metric='euclidean', key_added='1') |
|
|
146 |
sc.pp.neighbors(self.adata, n_neighbors=n_neighbors, n_pcs=npcs[1], use_rep=self.reps[1], metric='euclidean', key_added='2') |
|
|
147 |
sc.pp.neighbors(self.adata, n_neighbors=200, n_pcs=npcs[0], use_rep=self.reps[0], metric='euclidean', key_added='1_200') |
|
|
148 |
sc.pp.neighbors(self.adata, n_neighbors=200, n_pcs=npcs[1], use_rep=self.reps[1], metric='euclidean', key_added='2_200') |
|
|
149 |
self.distances = ['1_distances', '2_distances', '1_200_distances', '2_200_distances'] |
|
|
150 |
else: |
|
|
151 |
self.distances = distances |
|
|
152 |
|
|
|
153 |
for d in self.distances: |
|
|
154 |
if type(self.adata.obsp[d]) is not csr_matrix: |
|
|
155 |
self.adata.obsp[d] = csr_matrix(self.adata.obsp[d]) |
|
|
156 |
|
|
|
157 |
self.NNdist = [] |
|
|
158 |
self.NNidx = [] |
|
|
159 |
self.NNadjacency = [] |
|
|
160 |
self.BWs = [] |
|
|
161 |
|
|
|
162 |
for (i,r) in enumerate(self.reps): |
|
|
163 |
nn = get_nearestneighbor(self.adata.obsp[self.distances[i]]) |
|
|
164 |
dist_to_nn = ((self.adata.obsm[r]-self.adata.obsm[r][nn, :])**2).sum(axis=1)**.5 |
|
|
165 |
nn_adj = (self.adata.obsp[self.distances[i]]>0).astype(int) |
|
|
166 |
nn_adj_wdiag = nn_adj.copy() |
|
|
167 |
nn_adj_wdiag.setdiag(1) |
|
|
168 |
bw = compute_bw(nn_adj_wdiag, self.adata.obsm[r], n_neighbors=self.n_neighbors) |
|
|
169 |
self.NNidx.append(nn) |
|
|
170 |
self.NNdist.append(dist_to_nn) |
|
|
171 |
self.NNadjacency.append(nn_adj) |
|
|
172 |
self.BWs.append(bw) |
|
|
173 |
|
|
|
174 |
self.weights = [] |
|
|
175 |
self.WNN = None |
|
|
176 |
|
|
|
177 |
def compute_weights(self): |
|
|
178 |
cmap = {0:1, 1:0} |
|
|
179 |
affinity_ratios = [] |
|
|
180 |
self.within = [] |
|
|
181 |
self.cross = [] |
|
|
182 |
for (i,r) in enumerate(self.reps): |
|
|
183 |
within_predict = self.NNadjacency[i].dot(self.adata.obsm[r]) / (self.n_neighbors-1) |
|
|
184 |
cross_predict = self.NNadjacency[cmap[i]].dot(self.adata.obsm[r]) / (self.n_neighbors-1) |
|
|
185 |
|
|
|
186 |
within_predict_dist = ((self.adata.obsm[r] - within_predict)**2).sum(axis=1)**.5 |
|
|
187 |
cross_predict_dist = ((self.adata.obsm[r] - cross_predict)**2).sum(axis=1)**.5 |
|
|
188 |
within_affinity = compute_affinity(within_predict_dist, self.NNdist[i], self.BWs[i]) |
|
|
189 |
cross_affinity = compute_affinity(cross_predict_dist, self.NNdist[i], self.BWs[i]) |
|
|
190 |
affinity_ratios.append(within_affinity / (cross_affinity + 0.0001)) |
|
|
191 |
self.within.append(within_predict_dist) |
|
|
192 |
self.cross.append(cross_predict_dist) |
|
|
193 |
|
|
|
194 |
self.weights.append( 1 / (1+ np.exp(affinity_ratios[1]-affinity_ratios[0])) ) |
|
|
195 |
self.weights.append( 1 - self.weights[0] ) |
|
|
196 |
|
|
|
197 |
|
|
|
198 |
def compute_wnn(self, adata): |
|
|
199 |
print('Computing modality weights') |
|
|
200 |
self.compute_weights() |
|
|
201 |
union_adj_mat = ((self.adata.obsp[self.distances[2]]+self.adata.obsp[self.distances[3]]) > 0).astype(int) |
|
|
202 |
|
|
|
203 |
print('Computing weighted distances for union of 200 nearest neighbors between modalities') |
|
|
204 |
full_dists = dist_from_adj(union_adj_mat, self.adata.obsm[self.reps[0]], self.adata.obsm[self.reps[1]], |
|
|
205 |
self.NNdist[0], self.NNdist[1]) |
|
|
206 |
weighted_dist = csr_matrix(union_adj_mat.shape) |
|
|
207 |
for (i,dist) in enumerate(full_dists): |
|
|
208 |
dist = diags(-1 / (self.BWs[i] - self.NNdist[i]), format='csr').dot(dist) |
|
|
209 |
dist.data = np.exp(dist.data) |
|
|
210 |
ind = np.isnan(dist.data) |
|
|
211 |
dist.data[ind] = 1 |
|
|
212 |
dist = diags(self.weights[i]).dot(dist) |
|
|
213 |
weighted_dist += dist |
|
|
214 |
|
|
|
215 |
print('Selecting top K neighbors') |
|
|
216 |
self.WNN = select_topK(weighted_dist, n_neighbors=self.n_neighbors) |
|
|
217 |
WNNdist = self.WNN.copy() |
|
|
218 |
x = (1-WNNdist.data) / 2 |
|
|
219 |
x[x<0]=0 |
|
|
220 |
x[x>1]=1 |
|
|
221 |
WNNdist.data = np.sqrt(x) |
|
|
222 |
self.WNNdist = WNNdist |
|
|
223 |
|
|
|
224 |
|
|
|
225 |
adata.obsp['WNN'] = self.WNN |
|
|
226 |
adata.obsp['WNN_distance'] = self.WNNdist |
|
|
227 |
adata.obsm[self.reps[0]] = self.adata.obsm[self.reps[0]] |
|
|
228 |
adata.obsm[self.reps[1]] = self.adata.obsm[self.reps[1]] |
|
|
229 |
adata.uns['WNN'] = {'connectivities_key': 'WNN', |
|
|
230 |
'distances_key': 'WNN_distance', |
|
|
231 |
'params': {'n_neighbors': self.n_neighbors, |
|
|
232 |
'method': 'WNN', |
|
|
233 |
'random_state': self.seed, |
|
|
234 |
'metric': 'euclidean', |
|
|
235 |
'use_rep': self.reps[0], |
|
|
236 |
'n_pcs': self.npcs[0]}} |
|
|
237 |
return(adata) |