--- a +++ b/src/multivelo/pyWNN.py @@ -0,0 +1,237 @@ +# pyWNN is a package developed by Dylan Kotliar (GitHub username: dylkot), published under the MIT license. + +# The original release, including tutorials, can be found here: https://github.com/dylkot/pyWNN + +import scanpy as sc +import numpy as np +from sklearn import preprocessing +from scipy.sparse import csr_matrix, lil_matrix, diags +import time + + + +def get_nearestneighbor(knn, neighbor=1): + '''For each row of knn, returns the column with the lowest value + I.e. the nearest neighbor''' + indices = knn.indices + indptr = knn.indptr + data = knn.data + nn_idx = [] + for i in range(knn.shape[0]): + cols = indices[indptr[i]:indptr[i+1]] + rowvals = data[indptr[i]:indptr[i+1]] + idx = np.argsort(rowvals) + nn_idx.append(cols[idx[neighbor-1]]) + return(np.array(nn_idx)) + + +def compute_bw(knn_adj, embedding, n_neighbors=20): + intersect = knn_adj.dot(knn_adj.T) + indices = intersect.indices + indptr = intersect.indptr + data = intersect.data + data = data / ((n_neighbors*2) - data) + bandwidth = [] + for i in range(intersect.shape[0]): + cols = indices[indptr[i]:indptr[i+1]] + rowvals = data[indptr[i]:indptr[i+1]] + idx = np.argsort(rowvals) + valssort = rowvals[idx] + numinset = len(cols) + if numinset<n_neighbors: + sys.exit('Fewer than 20 cells with Jacard sim > 0') + else: + curval = valssort[n_neighbors] + for num in range(n_neighbors, numinset): + if valssort[num]!=curval: + break + else: + num+=1 + minjacinset = cols[idx][:num] + if num <n_neighbors: + print('shouldnt end up here') + sys.exit(-1) + else: + euc_dist = ((embedding[minjacinset,:]-embedding[i,:])**2).sum(axis=1)**.5 + euc_dist_sorted = np.sort(euc_dist)[::-1] + bandwidth.append( np.mean(euc_dist_sorted[:n_neighbors]) ) + return(np.array(bandwidth)) + + +def compute_affinity(dist_to_predict, dist_to_nn, bw): + affinity = dist_to_predict-dist_to_nn + affinity[affinity<0]=0 + affinity = affinity * -1 + affinity = np.exp(affinity / (bw-dist_to_nn)) + return(affinity) + + +def dist_from_adj(adjacency, embed1, embed2, nndist1, nndist2): + dist1 = lil_matrix(adjacency.shape) + dist2 = lil_matrix(adjacency.shape) + + count = 0 + indices = adjacency.indices + indptr = adjacency.indptr + ncells = adjacency.shape[0] + + tic = time.perf_counter() + for i in range(ncells): + for j in range(indptr[i], indptr[i+1]): + col = indices[j] + a = (((embed1[i,:] - embed1[col,:])**2).sum()**.5) - nndist1[i] + if a == 0: dist1[i,col] = np.nan + else: dist1[i,col] = a + b = (((embed2[i,:] - embed2[col,:])**2).sum()**.5) - nndist2[i] + if b == 0: dist2[i,col] = np.nan + else: dist2[i,col] = b + + if (i % 2000) == 0: + toc = time.perf_counter() + print('%d out of %d %.2f seconds elapsed' % (i, ncells, toc-tic)) + + return(csr_matrix(dist1), csr_matrix(dist2)) + + +def select_topK(dist, n_neighbors=20): + indices = dist.indices + indptr = dist.indptr + data = dist.data + nrows = dist.shape[0] + + final_data = [] + final_col_ind = [] + + tic = time.perf_counter() + for i in range(nrows): + cols = indices[indptr[i]:indptr[i+1]] + rowvals = data[indptr[i]:indptr[i+1]] + idx = np.argsort(rowvals) + final_data.append(rowvals[idx[(-1*n_neighbors):]]) + final_col_ind.append(cols[idx[(-1*n_neighbors):]]) + + final_data = np.concatenate(final_data) + final_col_ind = np.concatenate(final_col_ind) + final_row_ind = np.tile(np.arange(nrows), (n_neighbors, 1)).reshape(-1, order='F') + + result = csr_matrix((final_data, (final_row_ind, final_col_ind)), shape=(nrows, dist.shape[1])) + + return(result) + + +class pyWNN(): + + def __init__(self, adata, reps=['X_pca', 'X_apca'], n_neighbors=20, npcs=[20, 20], seed=14, distances=None): + """\ + Class for running weighted nearest neighbors analysis as described in Hao + et al 2021. + """ + + self.seed = seed + np.random.seed(seed) + + if len(reps)>2: + sys.exit('WNN currently only implemented for 2 modalities') + + self.adata = adata.copy() + self.reps = [r+'_norm' for r in reps] + self.npcs = npcs + for (i,r) in enumerate(reps): + self.adata.obsm[self.reps[i]] = preprocessing.normalize(adata.obsm[r][:,0:npcs[i]]) + + self.n_neighbors = n_neighbors + if distances is None: + print('Computing KNN distance matrices using default Scanpy implementation') + sc.pp.neighbors(self.adata, n_neighbors=n_neighbors, n_pcs=npcs[0], use_rep=self.reps[0], metric='euclidean', key_added='1') + sc.pp.neighbors(self.adata, n_neighbors=n_neighbors, n_pcs=npcs[1], use_rep=self.reps[1], metric='euclidean', key_added='2') + sc.pp.neighbors(self.adata, n_neighbors=200, n_pcs=npcs[0], use_rep=self.reps[0], metric='euclidean', key_added='1_200') + sc.pp.neighbors(self.adata, n_neighbors=200, n_pcs=npcs[1], use_rep=self.reps[1], metric='euclidean', key_added='2_200') + self.distances = ['1_distances', '2_distances', '1_200_distances', '2_200_distances'] + else: + self.distances = distances + + for d in self.distances: + if type(self.adata.obsp[d]) is not csr_matrix: + self.adata.obsp[d] = csr_matrix(self.adata.obsp[d]) + + self.NNdist = [] + self.NNidx = [] + self.NNadjacency = [] + self.BWs = [] + + for (i,r) in enumerate(self.reps): + nn = get_nearestneighbor(self.adata.obsp[self.distances[i]]) + dist_to_nn = ((self.adata.obsm[r]-self.adata.obsm[r][nn, :])**2).sum(axis=1)**.5 + nn_adj = (self.adata.obsp[self.distances[i]]>0).astype(int) + nn_adj_wdiag = nn_adj.copy() + nn_adj_wdiag.setdiag(1) + bw = compute_bw(nn_adj_wdiag, self.adata.obsm[r], n_neighbors=self.n_neighbors) + self.NNidx.append(nn) + self.NNdist.append(dist_to_nn) + self.NNadjacency.append(nn_adj) + self.BWs.append(bw) + + self.weights = [] + self.WNN = None + + def compute_weights(self): + cmap = {0:1, 1:0} + affinity_ratios = [] + self.within = [] + self.cross = [] + for (i,r) in enumerate(self.reps): + within_predict = self.NNadjacency[i].dot(self.adata.obsm[r]) / (self.n_neighbors-1) + cross_predict = self.NNadjacency[cmap[i]].dot(self.adata.obsm[r]) / (self.n_neighbors-1) + + within_predict_dist = ((self.adata.obsm[r] - within_predict)**2).sum(axis=1)**.5 + cross_predict_dist = ((self.adata.obsm[r] - cross_predict)**2).sum(axis=1)**.5 + within_affinity = compute_affinity(within_predict_dist, self.NNdist[i], self.BWs[i]) + cross_affinity = compute_affinity(cross_predict_dist, self.NNdist[i], self.BWs[i]) + affinity_ratios.append(within_affinity / (cross_affinity + 0.0001)) + self.within.append(within_predict_dist) + self.cross.append(cross_predict_dist) + + self.weights.append( 1 / (1+ np.exp(affinity_ratios[1]-affinity_ratios[0])) ) + self.weights.append( 1 - self.weights[0] ) + + + def compute_wnn(self, adata): + print('Computing modality weights') + self.compute_weights() + union_adj_mat = ((self.adata.obsp[self.distances[2]]+self.adata.obsp[self.distances[3]]) > 0).astype(int) + + print('Computing weighted distances for union of 200 nearest neighbors between modalities') + full_dists = dist_from_adj(union_adj_mat, self.adata.obsm[self.reps[0]], self.adata.obsm[self.reps[1]], + self.NNdist[0], self.NNdist[1]) + weighted_dist = csr_matrix(union_adj_mat.shape) + for (i,dist) in enumerate(full_dists): + dist = diags(-1 / (self.BWs[i] - self.NNdist[i]), format='csr').dot(dist) + dist.data = np.exp(dist.data) + ind = np.isnan(dist.data) + dist.data[ind] = 1 + dist = diags(self.weights[i]).dot(dist) + weighted_dist += dist + + print('Selecting top K neighbors') + self.WNN = select_topK(weighted_dist, n_neighbors=self.n_neighbors) + WNNdist = self.WNN.copy() + x = (1-WNNdist.data) / 2 + x[x<0]=0 + x[x>1]=1 + WNNdist.data = np.sqrt(x) + self.WNNdist = WNNdist + + + adata.obsp['WNN'] = self.WNN + adata.obsp['WNN_distance'] = self.WNNdist + adata.obsm[self.reps[0]] = self.adata.obsm[self.reps[0]] + adata.obsm[self.reps[1]] = self.adata.obsm[self.reps[1]] + adata.uns['WNN'] = {'connectivities_key': 'WNN', + 'distances_key': 'WNN_distance', + 'params': {'n_neighbors': self.n_neighbors, + 'method': 'WNN', + 'random_state': self.seed, + 'metric': 'euclidean', + 'use_rep': self.reps[0], + 'n_pcs': self.npcs[0]}} + return(adata)