a b/src/scotv1.py
1
"""
2
Authors: Pinar Demetci, Rebecca Santorella
3
Principal Investigator: Ritambhara Singh, Ph.D. from Brown University
4
12 February 2020
5
Updated: 27 November 2020
6
SCOT algorithm (version 1): Single Cell alignment using Optimal Transport
7
Correspondence: pinar_demetci@brown.edu, rebecca_santorella@brown.edu, ritambhara@brown.edu
8
"""
9
10
### Import python packages we depend on:
11
# For regular matrix operations:
12
import numpy as np
13
# For optimal transport operations:
14
import ot
15
# For computing graph distances:
16
from scipy.sparse.csgraph import dijkstra
17
from scipy.sparse import csr_matrix
18
from sklearn.neighbors import kneighbors_graph
19
20
# For pre-processing, normalization
21
from sklearn.preprocessing import StandardScaler, normalize
22
23
24
class SCOT(object):
25
    """
26
    SCOT algorithm for unsupervised alignment of single-cell multi-omic data.
27
    https://www.biorxiv.org/content/10.1101/2020.04.28.066787v2 (original preprint)
28
    https://www.liebertpub.com/doi/full/10.1089/cmb.2021.0446 (Journal of Computational Biology publication through RECOMB 2021 conference)
29
    Input: domain1, domain2 in form of numpy arrays/matrices, where the rows correspond to samples and columns correspond to features.
30
    Returns: aligned domain 1, aligned domain 2 in form of numpy arrays/matrices projected on domain 1
31
    Example use:
32
    # Given two numpy matrices, domain1 and domain2, where the rows are cells and columns are different genomic features:
33
    scot= SCOT(domain1, domain2)
34
    aligned_domain1, aligned_domain2 = scot.align(k=20, e=1e-3)
35
    #If you can't pick the parameters k and e, you can try out our unsupervised self-tuning heuristic by running:
36
    scot= SCOT(domain1, domain2)
37
    aligned_domain1, aligned_domain2 = scot.align(selfTune=True)
38
    Required parameters:
39
    - k: Number of neighbors to be used when constructing kNN graphs. Default= min(min(n_1, n_2), 50), where n_i, for i=1,2 corresponds to the number of samples in the i^th domain.
40
    - e: Regularization constant for the entropic regularization term in entropic Gromov-Wasserstein optimal transport formulation. Default= 1e-3 
41
   
42
    Optional parameters:
43
    - normalize= Determines whether to normalize input data ahead of alignment. True or False (boolean parameter). Default = True.
44
    - norm= Determines what sort of normalization to run, "l2", "l1", "max", "zscore". Default="l2" 
45
    - mode: "connectivity" or "distance". Determines whether to use a connectivity graph (adjacency matrix of 1s/0s based on whether nodes are connected) or a distance graph (adjacency matrix entries weighted by distances between nodes). Default="connectivity"  
46
    - metric: Sets the metric to use while constructing nearest neighbor graphs. some possible choices are "correlation", "minkowski".  "correlation" is Pearson's correlation and "minkowski" is equivalent to Euclidean distance in its default form (). Default= "correlation". 
47
    - verbose: Prints loss while optimizing the optimal transport formulation. Default=True
48
    - XontoY: Determines the direction of barycentric projection. True or False (boolean parameter). If True, projects domain1 onto domain2. If False, projects domain2 onto domain1. Default=True.
49
    Note: If you want to specify the marginal distributions of the input domains and not use uniform distribution, please set the attributes p and q to the distributions of your choice (for domain 1, and 2, respectively) 
50
            after initializing a SCOT class instance and before running alignment and set init_marginals=False in .align() parameters
51
    """
52
53
    def __init__(self, domain1, domain2):
54
55
        self.X=domain1
56
        self.y=domain2
57
58
        self.p= None #empirical probability distribution for domain 1 (X)
59
        self.q= None #empirical probability distribution for domain 2 (y)
60
61
        self.Cx=None #intra-domain graph distances for domain 1 (X)
62
        self.Cy=None #intra-domain graph distances for domain 2 (y)
63
64
        self.coupling=None # Coupling matrix that relates domain 1 and domain 2, ..., m
65
        self.gwdist=None # Gromov-Wasserstein distance between domains after alignment
66
        self.flag = None # convergence flag
67
68
        self.X_aligned=None #aligned datasets to return: domain1
69
        self.y_aligned=None #aligned datasets to return: domain2
70
71
    def init_marginals(self):
72
        # Without any prior information, we set the probabilities to what we observe empirically: uniform over all observed sample
73
        self.p= ot.unif(self.X.shape[0])
74
        self.q = ot.unif(self.y.shape[0])
75
76
    def normalize(self, norm="l2", bySample=True):
77
        assert (norm in ["l1","l2","max", "zscore"]), "Norm argument has to be either one of 'max', 'l1', 'l2' or 'zscore'. If you would like to perform another type of normalization, please give SCOT the normalize data and set the argument normalize=False when running the algorithm."
78
79
        if (bySample==True or bySample==None):
80
            axis=1
81
        else:
82
            axis=0
83
84
        if norm=="zscore":
85
            scaler=StandardScaler()
86
            self.X, self.y=scaler.fit_transform(self.X), scaler.fit_transform(self.y)
87
88
        else:
89
            self.X, self.y =normalize(self.X, norm=norm, axis=axis), normalize(self.y, norm=norm, axis=axis)
90
91
    def construct_graph(self, k, mode= "connectivity", metric="correlation"):
92
        assert (mode in ["connectivity", "distance"]), "Norm argument has to be either one of 'connectivity', or 'distance'. "
93
        if mode=="connectivity":
94
            include_self=True
95
        else:
96
            include_self=False
97
98
        self.Xgraph=kneighbors_graph(self.X, k, mode=mode, metric=metric, include_self=include_self)
99
        self.ygraph=kneighbors_graph(self.y, k, mode=mode, metric=metric, include_self=include_self)
100
101
        return self.Xgraph, self.ygraph
102
103
    def init_distances(self):
104
        # Compute shortest distances
105
        X_shortestPath=dijkstra(csgraph= csr_matrix(self.Xgraph), directed=False, return_predecessors=False)
106
        y_shortestPath=dijkstra(csgraph= csr_matrix(self.ygraph), directed=False, return_predecessors=False)
107
108
        # Deal with unconnected stuff (infinities):
109
        X_max=np.nanmax(X_shortestPath[X_shortestPath != np.inf])
110
        y_max=np.nanmax(y_shortestPath[y_shortestPath != np.inf])
111
        X_shortestPath[X_shortestPath > X_max] = X_max
112
        y_shortestPath[y_shortestPath > y_max] = y_max
113
114
        # Finnally, normalize the distance matrix:
115
        self.Cx=X_shortestPath/X_shortestPath.max()
116
        self.Cy=y_shortestPath/y_shortestPath.max()
117
118
        return self.Cx, self.Cy
119
120
    def find_correspondences(self, e, verbose=True):
121
        self.coupling, log= ot.gromov.entropic_gromov_wasserstein(self.Cx, self.Cy, self.p, self.q, loss_fun='square_loss', epsilon=e, log=True, verbose=verbose)
122
        self.gwdist=log['gw_dist']
123
124
        # Check convergence:
125
        if (np.isnan(self.coupling).any() or np.any(~self.coupling.any(axis=1)) or np.any(~self.coupling.any(axis=0)) or sum(sum(self.coupling)) < .95):
126
            self.flag=False
127
        else:
128
            self.flag=True
129
130
        return self.gwdist
131
132
    def barycentric_projection(self, XontoY=True):
133
        if XontoY:
134
            #Projecting the first domain onto the second domain
135
            self.y_aligned=self.y
136
            weights=np.sum(self.coupling, axis = 1)
137
            self.X_aligned=np.matmul(self.coupling, self.y) / weights[:, None]
138
        else:
139
            #Projecting the second domain onto the first domain
140
            self.X_aligned=self.X
141
            weights=np.sum(self.coupling, axis = 0)
142
            self.y_aligned=np.matmul(np.transpose(self.coupling), self.X) / weights[:, None]
143
        return self.X_aligned, self.y_aligned
144
145
    def align(self, k=None, e=1e-3, mode="connectivity", metric="correlation", verbose=True, normalize=True, norm="l2", XontoY=True, selfTune=False, init_marginals=True):
146
        if normalize:
147
            self.normalize(norm=norm)
148
        if init_marginals:
149
            self.init_marginals()
150
151
        if selfTune:
152
            X_aligned, y_aligned= self.unsupervised_scot()
153
        else:
154
            if k==None:
155
                k=min((int(self.X.shape[0]*0.2), int(self.y.shape[0]*0.2)),50)
156
157
            self.construct_graph(k, mode= "connectivity", metric="correlation")
158
            self.init_distances()
159
            self.find_correspondences(e=e, verbose=verbose)
160
161
            if self.flag==False:
162
                print("CONVERGENCE ERROR: Optimization procedure runs into numerical errors with the hyperparameters specified. Please try aligning with higher values of epsilon.")
163
                return
164
            
165
            else:
166
                X_aligned, y_aligned = self.barycentric_projection(XontoY=XontoY)
167
168
        self.X_aligned, self.y_aligned=X_aligned, y_aligned
169
        return self.X_aligned, self.y_aligned
170
171
    def search_scot(self, ks, es, all_values = False,  mode= "connectivity", metric="correlation", normalize=True, norm="l2", init_marginals=True):
172
        '''
173
        Performs a hyperparameter sweep for given values of k and epsilon
174
        Default: return the parameters corresponding to the lowest GW distance
175
        (Optional): return all k, epsilon, and GW values
176
        '''
177
178
        # initialize alignment
179
        if normalize:
180
            self.normalize(norm=norm)
181
        if init_marginals:
182
            self.init_marginals()
183
184
        # Note to self: Incorporate multiprocessing here to speed things up
185
        # store values of k, epsilon, and gw distance
186
        total=len(es)*len(ks)
187
        k_sweep=np.zeros(total)
188
        e_sweep=np.zeros(total)
189
        gw_sweep=np.zeros(total)
190
191
        gmin = 1
192
        counter=1
193
194
        X_aligned,y_aligned=None, None
195
        e_best,k_best=None, None
196
        # search in k first to reduce graph computation
197
        for k in ks:
198
            self.construct_graph(k, mode= mode, metric=metric)
199
            self.init_distances()
200
            for e in es:
201
                print(counter, "/", total)
202
                print("Aligning k: ",k, " and e: ",e)
203
                # run alignment / optimize correspondence matrix:
204
                self.find_correspondences(e=e, verbose=False)
205
                # save values
206
                if self.flag:
207
                    if all_values:
208
                        k_sweep[counter]=k
209
                        e_sweep[counter]=e
210
                        gw_sweep[counter] = self.gwdist
211
212
                        print(self.gwdist)
213
                    # save the alignment if it is lower
214
                    if self.gwdist < gmin:
215
                        X_aligned, y_aligned = self.barycentric_projection()
216
                        gmin =self.gwdist
217
                        e_best, k_best= e, k
218
                    counter = counter + 1
219
           
220
        if all_values:
221
            # return alignment and all values
222
            return X_aligned, y_aligned, gw_sweep, k_sweep, e_sweep
223
        else:
224
            # return  alignment and the parameters corresponding to the lowest GW distance
225
            return X_aligned, y_aligned, gmin, k_best, e_best
226
227
228
    def unsupervised_scot(self, normalize=False, norm='l2'):
229
        '''
230
        Unsupervised hyperparameter tuning algorithm to find an alignment
231
        by using the GW distance as a measure of alignment
232
        '''
233
234
        # use k = 20% of # sample or k = 50 if dataset is large
235
        n = min(self.X.shape[0], self.y.shape[0])
236
        k_start = min(n // 5, 50)
237
238
        num_eps = 12
239
        num_k = 5
240
241
        # define search space
242
        es = np.logspace(-1, -3, num_eps)
243
        if ( n > 250):
244
            ks = np.linspace(20, 100, num_k)
245
        else:
246
            ks = np.linspace(n//20, n//6, num_k)
247
        ks = ks.astype(int)
248
        
249
        # search parameter space
250
        X_aligned, y_aligned, g_best, k_best, e_best = self.search_scot(ks, es, all_values=False, normalize=normalize, norm=norm, init_marginals=False)
251
252
        print("Alignment completed. Hyperparameters selected from the unsupervised hyperparameter sweep are: %d for number of neighbors k and %f for epsilon" %(k_best, e_best))
253
254
        return X_aligned,