Diff of /src/scotv2.py [000000] .. [090c8c]

Switch to unified view

a b/src/scotv2.py
1
"""
2
Author: Pinar Demetci
3
Principal Investigator: Ritambhara Singh, Ph.D. from Brown University
4
08 August 2021
5
Updated: 23 February 2023
6
SCOTv2 algorithm: Single Cell alignment using Optimal Transport version 2
7
Correspondence: pinar_demetci@brown.edu, ritambhara@brown.edu
8
"""
9
10
### Import python packages we depend on:
11
import numpy as np
12
import torch
13
import ot
14
import scipy
15
# For computing graph distances:
16
from scipy.sparse.csgraph import dijkstra
17
from scipy.sparse import csr_matrix
18
from sklearn.neighbors import kneighbors_graph
19
20
# For pre-processing, normalization
21
from sklearn.preprocessing import StandardScaler, normalize
22
23
24
class SCOTv2(object):
25
    """
26
    SCOT algorithm for unsupervised alignment of single-cell multi-omic data.
27
    https://www.biorxiv.org/content/10.1101/2020.04.28.066787v2 (original preprint)
28
    https://www.liebertpub.com/doi/full/10.1089/cmb.2021.0446 (Journal of Computational Biology publication through RECOMB 2021 conference)
29
30
    Input: domain1, domain2 in form of numpy arrays/matrices, where the rows correspond to samples and columns correspond to features.
31
    Returns: aligned domain 1, aligned domain 2 in form of numpy arrays/matrices projected on domain 1
32
33
    Example use:
34
    # Given two numpy matrices, domain1 and domain2, where the rows are cells and columns are different genomic features:
35
    scot= SCOT(domain1, domain2)
36
    aligned_domain1, aligned_domain2 = scot.align(k=20, e=1e-3)
37
38
    #If you can't pick the parameters k and e, you can try out our unsupervised self-tuning heuristic by running:
39
    scot= SCOT(domain1, domain2)
40
    aligned_domain1, aligned_domain2 = scot.align(selfTune=True)
41
42
    Required parameters:
43
    - k: Number of neighbors to be used when constructing kNN graphs. Default= min(min(n_1, n_2), 50), where n_i, for i=1,2 corresponds to the number of samples in the i^th domain.
44
    - e: Regularization constant for the entropic regularization term in entropic Gromov-Wasserstein optimal transport formulation. Default= 1e-3 
45
   
46
    Optional parameters:
47
48
    - normalize= Determines whether to normalize input data ahead of alignment. True or False (boolean parameter). Default = True.
49
    - norm= Determines what sort of normalization to run, "l2", "l1", "max", "zscore". Default="l2" 
50
    - mode: "connectivity" or "distance". Determines whether to use a connectivity graph (adjacency matrix of 1s/0s based on whether nodes are connected) or a distance graph (adjacency matrix entries weighted by distances between nodes). Default="connectivity"  
51
    - metric: Sets the metric to use while constructing nearest neighbor graphs. some possible choices are "correlation", "minkowski".  "correlation" is Pearson's correlation and "minkowski" is equivalent to Euclidean distance in its default form (). Default= "correlation". 
52
    - verbose: Prints loss while optimizing the optimal transport formulation. Default=True
53
    - XontoY: Determines the direction of barycentric projection. True or False (boolean parameter). If True, projects domain1 onto domain2. If False, projects domain2 onto domain1. Default=True.
54
55
    Note: If you want to specify the marginal distributions of the input domains and not use uniform distribution, please set the attributes p and q to the distributions of your choice (for domain 1, and 2, respectively) 
56
            after initializing a SCOT class instance and before running alignment and set init_marginals=False in .align() parameters
57
    """
58
59
    def __init__(self, data):
60
61
        assert type(data)==list and len(data)>=2, "As input, SCOTv2 requires a list, containing at least two numpy arrays to be aligned.  \
62
                Each numpy array/matrix corresponds to a dataset, with samples (cells) in rows and features (latent representations or genomic features) in columns. \
63
                We recommend using latent representations (e.g. principal components for RNA-seq and topics - via cisTopic- for ATAC-seq/Methyl-seq)."
64
        self.data=data
65
        self.marginals=[] # Holds the empirical probability distributions over samples in each dataset
66
        self.graphs=[] # Holds graphs per dataset
67
        self.graphDists=[] # Holds intra-domain graph distances for each input dataset
68
        self.couplings=[] # Holds coupling matrices
69
        self.gwdists=[] # Gromov-Wasserstein distances between domains after alignment
70
        self.flags = [] # Holds alignment convergence flags (booleans: True/False)
71
72
        self.aligned_data=[]
73
74
    def _init_marginals(self):
75
        # Without any prior information, we set the probabilities to what we observe empirically: uniform over all observed sample
76
        for i in range(len(self.data)):
77
            num_cells=self.data[i].shape[0]
78
            marginalDist=torch.ones(num_cells)/num_cells
79
            self.marginals.append(marginalDist)
80
        return self.marginals
81
82
    def _normalize(self, norm="l2", bySample=True):
83
        assert (norm in ["l1","l2","max", "zscore"]), "Norm argument has to be either one of 'max', 'l1', 'l2' or 'zscore'.\
84
         If you would like to perform another type of normalization, please give SCOT the normalized data and set the argument 'normalize=False' when running the algorithm. \
85
         We have found l2 normalization to empirically perform better with single-cell sequencing datasets, including when using latent representations. "
86
87
        for i in range(len(self.data)):
88
            if norm=="zscore":
89
                scaler=StandardScaler()
90
                self.data[i]=scaler.fit_transform(self.data[i])
91
            else:
92
                if (bySample==True or bySample==None):
93
                    axis=1
94
                else:
95
                    axis=0
96
                self.data[i] =normalize(self.data[i], norm=norm, axis=axis)
97
        return self.data # Normalized data
98
99
    def construct_graph(self, k=20, mode= "connectivity", metric="correlation"):
100
        assert (mode in ["connectivity", "distance"]), "Norm argument has to be either one of 'connectivity', or 'distance'. "
101
        if mode=="connectivity":
102
            include_self=True
103
        else:
104
            include_self=False
105
106
        for i in range(len(self.data)):
107
            self.graphs.append(kneighbors_graph(self.data[i], n_neighbors=k, mode=mode, metric=metric, include_self=include_self))
108
109
        return self.graphs
110
111
    def init_graph_distances(self):
112
        for i in range(len(self.data)):
113
            # Compute shortest distances
114
            shortestPath=dijkstra(csgraph= csr_matrix(self.graphs[i]), directed=False, return_predecessors=False)
115
            # Deal with unconnected stuff (infinities):
116
            Max_dist=np.nanmax(shortestPath[shortestPath != np.inf])
117
            shortestPath[shortestPath > Max_dist] = Max_dist
118
            # Finnally, normalize the distance matrix:
119
            self.graphDists.append(shortestPath/shortestPath.max())
120
121
        return self.graphDists
122
123
    def _exp_sinkhorn_solver(self, ecost, u, v,a,b, mass, eps, rho, rho2, nits_sinkhorn, tol_sinkhorn):
124
            """
125
            Parameters
126
            ----------
127
            - ecost: torch.Tensor of size [size_X, size_Y]
128
                     Exponential kernel generated from the local cost based on the current coupling.  
129
            - u: torch.Tensor of size [size_X[0]].
130
                 First dual potential defined on X.
131
            - v: torch.Tensor of size [size_Y[0]].
132
                 Second dual potential defined on Y. 
133
            - mass: torch.Tensor of size [1]. 
134
                    Mass of the current coupling.
135
            - nits_sinkhorn: int. 
136
                             Maximum number of iterations to update Sinkhorn potentials in inner loop.
137
            - tol_sinkhorn: float
138
                            Tolerance on convergence of Sinkhorn potentials.
139
140
            Returns
141
            ----------
142
            u: torch.Tensor of size [size_X[0]]
143
               First dual potential of Sinkhorn algorithm
144
            v: torch.Tensor of size [size_Y[0]]
145
               Second dual potential of Sinkhorn algorithm
146
            logpi: torch.Tensor of size [size_X, size_Y]
147
                   Optimal transport plan in log-space.
148
            """
149
            # Initialize potentials by finding best translation
150
            if u is None or v is None:
151
                u, v = torch.ones_like(a), torch.ones_like(b)
152
            k = (a * u ** (-eps / rho)).sum()+ (b * v ** (-eps / rho)).sum()
153
            k = k / (2 * (u[:, None] * v[None, :] * ecost *a[:, None] * b[None, :]).sum())
154
            z = (0.5 * mass * eps) / (2.0 + 0.5 * (eps / rho) + 0.5 * (eps / rho2))
155
            k = k ** z
156
            u,v= u * k, v * k
157
158
            # perform Sinkhorn updates in LSE form
159
            for j in range(nits_sinkhorn):
160
                u_prev = u.clone()
161
                v = torch.einsum("ij,i->j", ecost, a * u) ** (-1.0 / (1.0 + eps / rho))
162
                u = torch.einsum("ij,j->i", ecost, b * v) ** (-1.0 / (1.0 + eps / rho2))
163
                if (u.log() - u_prev.log()).abs().max().item() * eps < tol_sinkhorn:
164
                    break
165
            pi = u[:, None] * v[None, :] * ecost * a[:, None] * b[None, :]
166
            return u, v, pi
167
168
    def exp_unbalanced_gw(self,a, dx, b, dy, eps=0.01, rho=1.0, rho2=None, nits_plan=3000, tol_plan=1e-6, nits_sinkhorn=3000, tol_sinkhorn=1e-6):
169
        if rho2 is None:
170
            rho2 = rho #KL divergence coefficient doesn't have to be the same for both couplings. 
171
                       #But, to keep #hyperparameters low, we default to using the same coefficient. 
172
                       #Someone else playing with our code could assign a rho2 different than rho, though.
173
174
        # Initialize the coupling and local costs
175
        pi= a[:, None]* b[None, :] / (a.sum() * b.sum()).sqrt()
176
        pi_prev = torch.zeros_like(pi)
177
        up, vp = None, None
178
179
        for i in range(nits_plan):
180
            pi_prev = pi.clone()
181
            mp = pi.sum()
182
183
            #Compute the current local cost:
184
            distxy = torch.einsum("ij,kj->ik", dx, torch.einsum("kl,jl->kj", dy, pi))
185
            kl_pi = torch.sum(pi * (pi / (a[:, None] * b[None, :]) + 1e-10).log())
186
            mu, nu = torch.sum(pi, dim=1), torch.sum(pi, dim=0)
187
            distxx = torch.einsum("ij,j->i", dx ** 2, mu)
188
            distyy = torch.einsum("kl,l->k", dy ** 2, nu)
189
            lcost = (distxx[:, None] + distyy[None, :] - 2 * distxy) + eps * kl_pi
190
            if rho < float("Inf"):
191
                lcost = (lcost+ rho* torch.sum(mu * (mu / a + 1e-10).log()))
192
            if rho2 < float("Inf"):
193
                lcost = (lcost+ rho2* torch.sum(nu * (nu / b + 1e-10).log()))
194
            ecost = (-lcost /(mp * eps)).exp()
195
196
            if (i%10)==0:
197
                print("Unbalanced GW step:", i)
198
            #compute the coupling via sinkhorn
199
            up, vp, pi = self._exp_sinkhorn_solver(ecost, up, vp, a, b, mp, eps, rho, rho2,nits_sinkhorn, tol_sinkhorn)
200
            
201
            flag=True
202
            if torch.any(torch.isnan(pi)):
203
                flag=False
204
205
            pi = (mp / pi.sum()).sqrt() * pi
206
            if (pi - pi_prev).abs().max().item() < tol_plan:
207
                break
208
        return pi, flag
209
210
    def find_correspondences(self, normalize=True, norm="l2", bySample=True, k=20, mode= "connectivity", metric="correlation",  eps=0.01, rho=1.0, rho2=None):
211
        # Normalize 
212
        if normalize:
213
            self._normalize(norm=norm, bySample=bySample)
214
        # Initialize inputs for (unbalanced) Gromov-Wasserstein optimal transport:
215
        self._init_marginals()
216
        print("computing intra-domain graph distances")
217
        self.construct_graph(k=k, mode=mode, metric=metric)
218
        self.init_graph_distances()
219
        # Run pairwise dataset alignments:
220
        for i in range(len(self.data)-1):
221
            print("running pairwise dataset alignments")
222
            a,b =torch.Tensor(self.marginals[0]), torch.Tensor(self.marginals[i+1])
223
            dx, dy= torch.Tensor(self.graphDists[0]), torch.Tensor(self.graphDists[i+1])
224
            coupling, flag=self.exp_unbalanced_gw(a, dx, b, dy, eps=eps, rho=rho, rho2=rho2, nits_plan=3000, tol_plan=1e-6, nits_sinkhorn=3000, tol_sinkhorn=1e-6)
225
            self.couplings.append(coupling)
226
            self.flags.append(flag)
227
            if flag==False:
228
                    raise Exception(
229
                    f"Solver got NaN plan with params (eps, rho, rho2) "
230
                    f" = {eps, rho, rho2}. Try increasing argument eps")
231
        return self.couplings
232
233
    def barycentric_projection(self):
234
        aligned_datasets=[self.data[0]]
235
        for i in range(0,len(self.couplings)):
236
            coupling=np.transpose(self.couplings[i].numpy())
237
            weights=np.sum(coupling, axis = 1)
238
            projected_data=np.matmul((coupling/ weights[:, None]), self.data[0])
239
            aligned_datasets.append(projected_data)
240
        return aligned_datasets
241
242
    def coembed_datasets(self, Lambda=1.0, out_dim=10):
243
        """
244
        Co-embeds datasets in a shared space.
245
        Implementation is based on Cao et al 2022 (Pamona)
246
        """
247
        n_datasets = len(self.data)
248
        H0 = []
249
        L = []
250
        for i in range(n_datasets-1):
251
            self.couplings[i] = self.couplings[i]*np.shape(self.data[i])[0]
252
253
        for i in range(n_datasets):    
254
            graph_data = self.graphs[i] + self.graphs[i].T.multiply(self.graphs[i].T > self.graphs[i]) - \
255
                self.graphs[i].multiply(self.graphs[i].T > self.graphs[i])
256
            W = np.array(graph_data.todense())
257
            index_pos = np.where(W>0)
258
            W[index_pos] = 1/W[index_pos] 
259
            D = np.diag(np.dot(W, np.ones(np.shape(W)[1])))
260
            L.append(D - W)
261
262
        Sigma_x = []
263
        Sigma_y = []
264
        for i in range(n_datasets-1):
265
            Sigma_y.append(np.diag(np.dot(np.transpose(np.ones(np.shape(self.couplings[i])[0])), self.couplings[i])))
266
            Sigma_x.append(np.diag(np.dot(self.couplings[i], np.ones(np.shape(self.couplings[i])[1]))))
267
268
        S_xy = self.couplings[0]
269
        S_xx = L[0] + Lambda*Sigma_x[0]
270
        S_yy = L[-1] +Lambda*Sigma_y[0]
271
        for i in range(1, n_datasets-1):
272
            S_xy = np.vstack((S_xy, self.couplings[i]))
273
            S_xx = block_diag(S_xx, L[i] + Lambda*Sigma_x[i])
274
            S_yy = S_yy + Lambda*Sigma_y[i]
275
276
        v, Q = np.linalg.eig(S_xx)
277
        v = v + 1e-12   
278
        V = np.diag(v**(-0.5))
279
        H_x = np.dot(Q, np.dot(V, np.transpose(Q)))
280
281
        v, Q = np.linalg.eig(S_yy)
282
        v = v + 1e-12      
283
        V = np.diag(v**(-0.5))
284
        H_y = np.dot(Q, np.dot(V, np.transpose(Q)))
285
286
        H = np.dot(H_x, np.dot(S_xy, H_y))
287
        U, sigma, V = np.linalg.svd(H)
288
289
        num = [0]
290
        for i in range(n_datasets-1):
291
            num.append(num[i]+len(self.data[i]))
292
293
        U, V = U[:,:out_dim], np.transpose(V)[:,:out_dim]
294
295
        fx = np.dot(H_x, U)
296
        fy = np.dot(H_y, V)
297
298
        integrated_data = []
299
        for i in range(n_datasets-1):
300
            integrated_data.append(fx[num[i]:num[i+1]])
301
302
        integrated_data.append(fy)
303
304
        return integrated_data
305
306
    def align(self,normalize=True, norm="l2", bySample=True, k=20, mode= "connectivity", metric="correlation",  eps=0.01, rho=1.0, rho2=None, projMethod="embedding", Lambda=1.0, out_dim=10):
307
        assert projMethod in ["embedding", "barycentric"], "The input to the parameter 'projMethod' needs to be one of \
308
                                'embedding' (if co-embedding them in a new shared space) or 'barycentric' (if using barycentric projection)"
309
        self.find_correspondences(normalize=normalize, norm=norm, bySample=bySample, k=k, mode=mode, metric=metric,  eps=eps, rho=rho, rho2=rho2)
310
        print("FLAGS", self.flags)
311
        if projMethod=="embedding":
312
            integrated_data=self.coembed_datasets(Lambda=Lambda, out_dim=out_dim)
313
        else:
314
            integrated_data=self.barycentric_projection()
315
        self.integrated_data=integrated_data
316
        return integrated_data
317
    
318
# X=np.load("../data/SNARE/SNAREseq_atac_feat.npy")[0:1000,:]
319
# Y=np.load("../data/SNARE/SNAREseq_rna_feat.npy")
320
# print(X.shape, Y.shape)
321
# SCOT=SCOTv2([Y,X])
322
# aligned_datasets=SCOT.align(normalize=True, k=50, eps=0.005, rho=0.1, projMethod="barycentric")
323
# print(len(aligned_datasets))
324
# print(aligned_datasets[0].shape)
325
# print(aligned_datasets[1].shape)
326
# # np.save("aligned_atac.npy", aligned_datasets[1])
327
# np.save("aligned_rna.npy", aligned_datasets[0])
328
329