Diff of /src/utils.py [000000] .. [ac720d]

Switch to unified view

a b/src/utils.py
1
import os
2
import sys
3
import math
4
import scanpy as sc
5
from scipy import stats, spatial, sparse
6
from scipy.linalg import norm
7
from sklearn.metrics.pairwise import euclidean_distances
8
import numpy as np
9
import random
10
import torch
11
import torch.nn as nn
12
import torch.nn.init as init
13
import torch.utils.data as data
14
from sklearn.neighbors import kneighbors_graph
15
16
def cluster_acc(y_true, y_pred):
17
    """
18
    Calculate clustering accuracy. Require scikit-learn installed
19
    # Arguments
20
        y: true labels, numpy.array with shape `(n_samples,)`
21
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`
22
    # Return
23
        accuracy, in [0,1]
24
    """
25
    y_true = y_true.astype(np.int64)
26
    assert y_pred.size == y_true.size
27
    D = max(y_pred.max(), y_true.max()) + 1
28
    w = np.zeros((D, D), dtype=np.int64)
29
    for i in range(y_pred.size):
30
        w[y_pred[i], y_true[i]] += 1
31
    from sklearn.utils.linear_assignment_ import linear_assignment
32
    ind = linear_assignment(w.max() - w)
33
    return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size
34
35
def GetCluster(X, res, n):
36
    adata0=sc.AnnData(X)
37
    if adata0.shape[0]>200000:
38
       np.random.seed(adata0.shape[0])#set seed 
39
       adata0=adata0[np.random.choice(adata0.shape[0],200000,replace=False)] 
40
    sc.pp.neighbors(adata0, n_neighbors=n, use_rep="X")
41
    sc.tl.louvain(adata0,resolution=res)
42
    Y_pred_init=adata0.obs['louvain']
43
    Y_pred_init=np.asarray(Y_pred_init,dtype=int)
44
    if np.unique(Y_pred_init).shape[0]<=1:
45
        #avoid only a cluster
46
        exit("Error: There is only a cluster detected. The resolution:"+str(res)+"is too small, please choose a larger resolution!!")
47
    else: 
48
        print("Estimated n_clusters is: ", np.shape(np.unique(Y_pred_init))[0]) 
49
    return(np.shape(np.unique(Y_pred_init))[0])
50
51
def torch_PCA(X, k, center=True, scale=False):
52
    X = X.t()
53
    n,p = X.size()
54
    ones = torch.ones(n).cuda().view([n,1])
55
    h = ((1/n) * torch.mm(ones, ones.t())) if center else torch.zeros(n*n).view([n,n])
56
    H = torch.eye(n).cuda() - h
57
    X_center =  torch.mm(H.double(), X.double())
58
    covariance = 1/(n-1) * torch.mm(X_center.t(), X_center).view(p,p)
59
    scaling = torch.sqrt(1/torch.diag(covariance)).double() if scale else torch.ones(p).cuda().double()
60
    scaled_covariance = torch.mm(torch.diag(scaling).view(p,p), covariance)
61
    eigenvalues, eigenvectors = torch.eig(scaled_covariance, True)
62
    components = (eigenvectors[:, :k])
63
    #explained_variance = eigenvalues[:k, 0]
64
    return components
65
    
66
def best_map(L1,L2):
67
            #L1 should be the groundtruth labels and L2 should be the clustering labels we got
68
            Label1 = np.unique(L1)
69
            nClass1 = len(Label1)
70
            Label2 = np.unique(L2)
71
            nClass2 = len(Label2)
72
            nClass = np.maximum(nClass1,nClass2)
73
            G = np.zeros((nClass,nClass))
74
            for i in range(nClass1):
75
                ind_cla1 = L1 == Label1[i]
76
                ind_cla1 = ind_cla1.astype(float)
77
                for j in range(nClass2):
78
                    ind_cla2 = L2 == Label2[j]
79
                    ind_cla2 = ind_cla2.astype(float)
80
                    G[i,j] = np.sum(ind_cla2 * ind_cla1)
81
            m = Munkres()
82
            index = m.compute(-G.T)
83
            index = np.array(index)
84
            c = index[:,1]
85
            newL2 = np.zeros(L2.shape)
86
            for i in range(nClass2):
87
                newL2[L2 == Label2[i]] = Label1[c[i]]
88
            return newL2
89
 
90
def geneSelection(data, threshold=0, atleast=10, 
91
                  yoffset=.02, xoffset=5, decay=1.5, n=None, 
92
                  plot=True, markers=None, genes=None, figsize=(6,3.5),
93
                  markeroffsets=None, labelsize=10, alpha=1, verbose=1):
94
    
95
    if sparse.issparse(data):
96
        zeroRate = 1 - np.squeeze(np.array((data>threshold).mean(axis=0)))
97
        A = data.multiply(data>threshold)
98
        A.data = np.log2(A.data)
99
        meanExpr = np.zeros_like(zeroRate) * np.nan
100
        detected = zeroRate < 1
101
        meanExpr[detected] = np.squeeze(np.array(A[:,detected].mean(axis=0))) / (1-zeroRate[detected])
102
    else:
103
        zeroRate = 1 - np.mean(data>threshold, axis=0)
104
        meanExpr = np.zeros_like(zeroRate) * np.nan
105
        detected = zeroRate < 1
106
        mask = data[:,detected]>threshold
107
        logs = np.zeros_like(data[:,detected]) * np.nan
108
        logs[mask] = np.log2(data[:,detected][mask])
109
        meanExpr[detected] = np.nanmean(logs, axis=0)
110
111
    lowDetection = np.array(np.sum(data>threshold, axis=0)).squeeze() < atleast
112
    zeroRate[lowDetection] = np.nan
113
    meanExpr[lowDetection] = np.nan
114
            
115
    if n is not None:
116
        up = 10
117
        low = 0
118
        for t in range(100):
119
            nonan = ~np.isnan(zeroRate)
120
            selected = np.zeros_like(zeroRate).astype(bool)
121
            selected[nonan] = zeroRate[nonan] > np.exp(-decay*(meanExpr[nonan] - xoffset)) + yoffset
122
            if np.sum(selected) == n:
123
                break
124
            elif np.sum(selected) < n:
125
                up = xoffset
126
                xoffset = (xoffset + low)/2
127
            else:
128
                low = xoffset
129
                xoffset = (xoffset + up)/2
130
        if verbose>0:
131
            print('Chosen offset: {:.2f}'.format(xoffset))
132
    else:
133
        nonan = ~np.isnan(zeroRate)
134
        selected = np.zeros_like(zeroRate).astype(bool)
135
        selected[nonan] = zeroRate[nonan] > np.exp(-decay*(meanExpr[nonan] - xoffset)) + yoffset
136
                
137
    if plot:
138
        if figsize is not None:
139
            plt.figure(figsize=figsize)
140
        plt.ylim([0, 1])
141
        if threshold>0:
142
            plt.xlim([np.log2(threshold), np.ceil(np.nanmax(meanExpr))])
143
        else:
144
            plt.xlim([0, np.ceil(np.nanmax(meanExpr))])
145
        x = np.arange(plt.xlim()[0], plt.xlim()[1]+.1,.1)
146
        y = np.exp(-decay*(x - xoffset)) + yoffset
147
        if decay==1:
148
            plt.text(.4, 0.2, '{} genes selected\ny = exp(-x+{:.2f})+{:.2f}'.format(np.sum(selected),xoffset, yoffset), 
149
                     color='k', fontsize=labelsize, transform=plt.gca().transAxes)
150
        else:
151
            plt.text(.4, 0.2, '{} genes selected\ny = exp(-{:.1f}*(x-{:.2f}))+{:.2f}'.format(np.sum(selected),decay,xoffset, yoffset), 
152
                     color='k', fontsize=labelsize, transform=plt.gca().transAxes)
153
154
        plt.plot(x, y, color=sns.color_palette()[1], linewidth=2)
155
        xy = np.concatenate((np.concatenate((x[:,None],y[:,None]),axis=1), np.array([[plt.xlim()[1], 1]])))
156
        t = plt.matplotlib.patches.Polygon(xy, color=sns.color_palette()[1], alpha=.4)
157
        plt.gca().add_patch(t)
158
        
159
        plt.scatter(meanExpr, zeroRate, s=1, alpha=alpha, rasterized=True)
160
        if threshold==0:
161
            plt.xlabel('Mean log2 nonzero expression')
162
            plt.ylabel('Frequency of zero expression')
163
        else:
164
            plt.xlabel('Mean log2 nonzero expression')
165
            plt.ylabel('Frequency of near-zero expression')
166
        plt.tight_layout()
167
        
168
        if markers is not None and genes is not None:
169
            if markeroffsets is None:
170
                markeroffsets = [(0, 0) for g in markers]
171
            for num,g in enumerate(markers):
172
                i = np.where(genes==g)[0]
173
                plt.scatter(meanExpr[i], zeroRate[i], s=10, color='k')
174
                dx, dy = markeroffsets[num]
175
                plt.text(meanExpr[i]+dx+.1, zeroRate[i]+dy, g, color='k', fontsize=labelsize)
176
    
177
    return selected