Diff of /src/multivelo/pyWNN.py [000000] .. [5d6472]

Switch to unified view

a b/src/multivelo/pyWNN.py
1
# pyWNN is a package developed by Dylan Kotliar (GitHub username: dylkot), published under the MIT license.
2
3
# The original release, including tutorials, can be found here: https://github.com/dylkot/pyWNN
4
5
import scanpy as sc
6
import numpy as np
7
from sklearn import preprocessing
8
from scipy.sparse import csr_matrix, lil_matrix, diags
9
import time
10
11
12
13
def get_nearestneighbor(knn, neighbor=1):
14
    '''For each row of knn, returns the column with the lowest value
15
    I.e. the nearest neighbor'''
16
    indices = knn.indices
17
    indptr = knn.indptr
18
    data = knn.data
19
    nn_idx = []
20
    for i in range(knn.shape[0]):
21
        cols = indices[indptr[i]:indptr[i+1]]
22
        rowvals = data[indptr[i]:indptr[i+1]]
23
        idx = np.argsort(rowvals)
24
        nn_idx.append(cols[idx[neighbor-1]])
25
    return(np.array(nn_idx))
26
27
28
def compute_bw(knn_adj, embedding, n_neighbors=20):
29
    intersect = knn_adj.dot(knn_adj.T)
30
    indices = intersect.indices
31
    indptr = intersect.indptr
32
    data = intersect.data
33
    data = data / ((n_neighbors*2) - data)
34
    bandwidth = []
35
    for i in range(intersect.shape[0]):
36
        cols = indices[indptr[i]:indptr[i+1]]
37
        rowvals = data[indptr[i]:indptr[i+1]]
38
        idx = np.argsort(rowvals)
39
        valssort = rowvals[idx]
40
        numinset = len(cols)
41
        if numinset<n_neighbors:
42
            sys.exit('Fewer than 20 cells with Jacard sim > 0')
43
        else:
44
            curval = valssort[n_neighbors]
45
            for num in range(n_neighbors, numinset):
46
                if valssort[num]!=curval:
47
                    break
48
                else:
49
                    num+=1
50
            minjacinset = cols[idx][:num]
51
            if num <n_neighbors:
52
                print('shouldnt end up here')
53
                sys.exit(-1)
54
            else:
55
                euc_dist = ((embedding[minjacinset,:]-embedding[i,:])**2).sum(axis=1)**.5
56
                euc_dist_sorted = np.sort(euc_dist)[::-1]
57
                bandwidth.append( np.mean(euc_dist_sorted[:n_neighbors]) )
58
    return(np.array(bandwidth))
59
60
61
def compute_affinity(dist_to_predict, dist_to_nn, bw):
62
    affinity = dist_to_predict-dist_to_nn
63
    affinity[affinity<0]=0
64
    affinity = affinity * -1
65
    affinity = np.exp(affinity / (bw-dist_to_nn))
66
    return(affinity)
67
68
69
def dist_from_adj(adjacency, embed1, embed2, nndist1, nndist2):
70
    dist1 = lil_matrix(adjacency.shape)
71
    dist2 = lil_matrix(adjacency.shape)
72
73
    count = 0
74
    indices = adjacency.indices
75
    indptr = adjacency.indptr
76
    ncells = adjacency.shape[0]
77
78
    tic = time.perf_counter()
79
    for i in range(ncells):
80
        for j in range(indptr[i], indptr[i+1]):
81
            col = indices[j]
82
            a = (((embed1[i,:] - embed1[col,:])**2).sum()**.5) - nndist1[i]
83
            if a == 0: dist1[i,col] = np.nan
84
            else: dist1[i,col] = a
85
            b = (((embed2[i,:] - embed2[col,:])**2).sum()**.5) - nndist2[i]
86
            if b == 0: dist2[i,col] = np.nan
87
            else: dist2[i,col] = b
88
89
        if (i % 2000) == 0:
90
            toc = time.perf_counter()
91
            print('%d out of %d %.2f seconds elapsed' % (i, ncells, toc-tic))
92
93
    return(csr_matrix(dist1), csr_matrix(dist2))
94
95
96
def select_topK(dist,  n_neighbors=20):
97
    indices = dist.indices
98
    indptr = dist.indptr
99
    data = dist.data
100
    nrows = dist.shape[0]
101
102
    final_data = []
103
    final_col_ind = []
104
105
    tic = time.perf_counter()
106
    for i in range(nrows):
107
        cols = indices[indptr[i]:indptr[i+1]]
108
        rowvals = data[indptr[i]:indptr[i+1]]
109
        idx = np.argsort(rowvals)
110
        final_data.append(rowvals[idx[(-1*n_neighbors):]])
111
        final_col_ind.append(cols[idx[(-1*n_neighbors):]])
112
113
    final_data = np.concatenate(final_data)
114
    final_col_ind = np.concatenate(final_col_ind)
115
    final_row_ind = np.tile(np.arange(nrows), (n_neighbors, 1)).reshape(-1, order='F')
116
117
    result = csr_matrix((final_data, (final_row_ind, final_col_ind)), shape=(nrows, dist.shape[1]))
118
119
    return(result)
120
121
122
class pyWNN():
123
124
    def __init__(self, adata, reps=['X_pca', 'X_apca'], n_neighbors=20, npcs=[20, 20], seed=14, distances=None):
125
        """\
126
        Class for running weighted nearest neighbors analysis as described in Hao
127
        et al 2021.
128
        """
129
130
        self.seed = seed
131
        np.random.seed(seed)
132
133
        if len(reps)>2:
134
            sys.exit('WNN currently only implemented for 2 modalities')
135
136
        self.adata = adata.copy()
137
        self.reps = [r+'_norm' for r in reps]
138
        self.npcs = npcs
139
        for (i,r) in enumerate(reps):
140
            self.adata.obsm[self.reps[i]] = preprocessing.normalize(adata.obsm[r][:,0:npcs[i]])
141
142
        self.n_neighbors = n_neighbors
143
        if distances is None:
144
            print('Computing KNN distance matrices using default Scanpy implementation')
145
            sc.pp.neighbors(self.adata, n_neighbors=n_neighbors, n_pcs=npcs[0], use_rep=self.reps[0], metric='euclidean', key_added='1')
146
            sc.pp.neighbors(self.adata, n_neighbors=n_neighbors, n_pcs=npcs[1], use_rep=self.reps[1], metric='euclidean', key_added='2')
147
            sc.pp.neighbors(self.adata, n_neighbors=200, n_pcs=npcs[0], use_rep=self.reps[0], metric='euclidean', key_added='1_200')
148
            sc.pp.neighbors(self.adata, n_neighbors=200, n_pcs=npcs[1], use_rep=self.reps[1], metric='euclidean', key_added='2_200')
149
            self.distances = ['1_distances', '2_distances', '1_200_distances', '2_200_distances']
150
        else:
151
            self.distances = distances
152
153
        for d in self.distances:
154
            if type(self.adata.obsp[d]) is not csr_matrix:
155
                self.adata.obsp[d] = csr_matrix(self.adata.obsp[d])
156
157
        self.NNdist = []
158
        self.NNidx = []
159
        self.NNadjacency = []
160
        self.BWs = []
161
162
        for (i,r) in enumerate(self.reps):
163
            nn = get_nearestneighbor(self.adata.obsp[self.distances[i]])
164
            dist_to_nn = ((self.adata.obsm[r]-self.adata.obsm[r][nn, :])**2).sum(axis=1)**.5
165
            nn_adj = (self.adata.obsp[self.distances[i]]>0).astype(int)
166
            nn_adj_wdiag = nn_adj.copy()
167
            nn_adj_wdiag.setdiag(1)
168
            bw = compute_bw(nn_adj_wdiag, self.adata.obsm[r], n_neighbors=self.n_neighbors)
169
            self.NNidx.append(nn)
170
            self.NNdist.append(dist_to_nn)
171
            self.NNadjacency.append(nn_adj)
172
            self.BWs.append(bw)
173
174
        self.weights = []
175
        self.WNN = None
176
177
    def compute_weights(self):
178
        cmap = {0:1, 1:0}
179
        affinity_ratios = []
180
        self.within = []
181
        self.cross = []
182
        for (i,r) in enumerate(self.reps):
183
            within_predict = self.NNadjacency[i].dot(self.adata.obsm[r]) / (self.n_neighbors-1)
184
            cross_predict = self.NNadjacency[cmap[i]].dot(self.adata.obsm[r]) / (self.n_neighbors-1)
185
186
            within_predict_dist = ((self.adata.obsm[r] - within_predict)**2).sum(axis=1)**.5
187
            cross_predict_dist = ((self.adata.obsm[r] - cross_predict)**2).sum(axis=1)**.5
188
            within_affinity = compute_affinity(within_predict_dist, self.NNdist[i], self.BWs[i])
189
            cross_affinity = compute_affinity(cross_predict_dist, self.NNdist[i], self.BWs[i])
190
            affinity_ratios.append(within_affinity / (cross_affinity + 0.0001))
191
            self.within.append(within_predict_dist)
192
            self.cross.append(cross_predict_dist)
193
194
        self.weights.append( 1 / (1+ np.exp(affinity_ratios[1]-affinity_ratios[0])) )
195
        self.weights.append( 1 - self.weights[0] )
196
197
198
    def compute_wnn(self, adata):
199
        print('Computing modality weights')
200
        self.compute_weights()
201
        union_adj_mat = ((self.adata.obsp[self.distances[2]]+self.adata.obsp[self.distances[3]]) > 0).astype(int)
202
203
        print('Computing weighted distances for union of 200 nearest neighbors between modalities')
204
        full_dists = dist_from_adj(union_adj_mat, self.adata.obsm[self.reps[0]], self.adata.obsm[self.reps[1]],
205
                                   self.NNdist[0], self.NNdist[1])
206
        weighted_dist = csr_matrix(union_adj_mat.shape)
207
        for (i,dist) in enumerate(full_dists):
208
            dist = diags(-1 / (self.BWs[i] - self.NNdist[i]), format='csr').dot(dist)
209
            dist.data = np.exp(dist.data)
210
            ind = np.isnan(dist.data)
211
            dist.data[ind] = 1
212
            dist = diags(self.weights[i]).dot(dist)
213
            weighted_dist += dist
214
215
        print('Selecting top K neighbors')
216
        self.WNN = select_topK(weighted_dist,  n_neighbors=self.n_neighbors)
217
        WNNdist = self.WNN.copy()
218
        x = (1-WNNdist.data) / 2
219
        x[x<0]=0
220
        x[x>1]=1
221
        WNNdist.data = np.sqrt(x)
222
        self.WNNdist = WNNdist
223
224
225
        adata.obsp['WNN'] = self.WNN
226
        adata.obsp['WNN_distance'] = self.WNNdist
227
        adata.obsm[self.reps[0]] = self.adata.obsm[self.reps[0]]
228
        adata.obsm[self.reps[1]] = self.adata.obsm[self.reps[1]]
229
        adata.uns['WNN'] = {'connectivities_key': 'WNN',
230
                                     'distances_key': 'WNN_distance',
231
                                     'params': {'n_neighbors': self.n_neighbors,
232
                                      'method': 'WNN',
233
                                      'random_state': self.seed,
234
                                      'metric': 'euclidean',
235
                                      'use_rep': self.reps[0],
236
                                      'n_pcs': self.npcs[0]}}
237
        return(adata)