Diff of /eval.py [000000] .. [9e0229]

Switch to unified view

a b/eval.py
1
import argparse
2
import metric
3
from sklearn.cluster import KMeans
4
from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score
5
from sklearn.metrics.cluster import homogeneity_score, adjusted_mutual_info_score
6
import numpy as np
7
import random
8
import sys,os
9
from scipy.io import loadmat
10
from sklearn.metrics import confusion_matrix
11
import pandas as pd 
12
import matplotlib
13
matplotlib.use('agg')
14
import matplotlib.pyplot as plt
15
import seaborn as sns
16
sns.set_style("whitegrid", {'axes.grid' : False})
17
18
def plot_embedding(X, labels, classes=None, method='tSNE', cmap='tab20', figsize=(8, 8), markersize=15, dpi=300,marker=None,
19
                   return_emb=False, save=False, save_emb=False, show_legend=True, show_axis_label=True, **legend_params):
20
    if marker is not None:
21
        X = np.concatenate([X, marker], axis=0)
22
    N = len(labels)
23
    matplotlib.rc('xtick', labelsize=20) 
24
    matplotlib.rc('ytick', labelsize=20) 
25
    matplotlib.rcParams.update({'font.size': 22})
26
    if X.shape[1] != 2:
27
        if method == 'tSNE':
28
            from sklearn.manifold import TSNE
29
            X = TSNE(n_components=2, random_state=124).fit_transform(X)
30
        if method == 'PCA':
31
            from sklearn.decomposition import PCA
32
            X = PCA(n_components=2, random_state=124).fit_transform(X)
33
        if method == 'UMAP':
34
            from umap import UMAP
35
            X = UMAP(n_neighbors=15, min_dist=0.1, metric='correlation').fit_transform(X)
36
    labels = np.array(labels)
37
    plt.figure(figsize=figsize)
38
    if classes is None:
39
        classes = np.unique(labels)
40
    #tab10, tab20, husl, hls
41
    if cmap is not None:
42
        cmap = cmap
43
    elif len(classes) <= 10:
44
        cmap = 'tab10'
45
    elif len(classes) <= 20:
46
        cmap = 'tab20'
47
    else:
48
        cmap = 'husl'
49
    colors = sns.husl_palette(len(classes), s=.8)
50
    #markersize = 80
51
    for i, c in enumerate(classes):
52
        plt.scatter(X[:N][labels==c, 0], X[:N][labels==c, 1], s=markersize, color=colors[i], label=c)
53
    if marker is not None:
54
        plt.scatter(X[N:, 0], X[N:, 1], s=10*markersize, color='black', marker='*')
55
    
56
    legend_params_ = {'loc': 'center left',
57
                     'bbox_to_anchor':(1.0, 0.45),
58
                     'fontsize': 20,
59
                     'ncol': 1,
60
                     'frameon': False,
61
                     'markerscale': 1.5
62
                    }
63
    legend_params_.update(**legend_params)
64
    if show_legend:
65
        plt.legend(**legend_params_)
66
    sns.despine(offset=10, trim=True)
67
    if show_axis_label:
68
        plt.xlabel(method+' dim 1', fontsize=12)
69
        plt.ylabel(method+' dim 2', fontsize=12)
70
71
    if save:
72
        plt.savefig(save, format='png', bbox_inches='tight',dpi=dpi)
73
74
def cluster_eval(labels_true,labels_infer):
75
    purity = metric.compute_purity(labels_infer, labels_true)
76
    nmi = normalized_mutual_info_score(labels_true, labels_infer)
77
    ari = adjusted_rand_score(labels_true, labels_infer)
78
    homogeneity = homogeneity_score(labels_true, labels_infer)
79
    ami = adjusted_mutual_info_score(labels_true, labels_infer)
80
    print('NMI = {}, ARI = {}, Purity = {},AMI = {}, Homogeneity = {}'.format(nmi,ari,purity,ami,homogeneity))
81
    return nmi,ari,homogeneity
82
83
def get_best_epoch(exp_dir, dataset, measurement='NMI'):
84
    results = []
85
    for each in os.listdir('results/%s/%s'%(dataset,exp_dir)):
86
        if each.startswith('data'):
87
            #print('results/%s/%s/%s'%(dataset,exp_dir,each))
88
            data = np.load('results/%s/%s/%s'%(dataset,exp_dir,each))
89
            data_x_onehot_,label_y = data['arr_1'],data['arr_2']
90
            label_infer = np.argmax(data_x_onehot_, axis=1)
91
            nmi,ari,homo = cluster_eval(label_y,label_infer)
92
            results.append([each,nmi,ari,homo])
93
    if measurement == 'NMI':
94
        results.sort(key=lambda a:-a[1])
95
    elif measurement == 'ARI':
96
        results.sort(key=lambda a:-a[2])
97
    elif measurement == 'HOMO':
98
        results.sort(key=lambda a:-a[3])
99
    else:
100
        print('Wrong indicated metric')
101
        sys.exit()
102
    print('NMI = {}\tARI = {}\tHomogeneity = {}'.format(results[0][1],results[0][2],results[0][3]))
103
    return results[0][0]
104
105
def save_embedding(emb_feat,save,sep='\t'):
106
    index = ['cell%d'%(i+1) for i in range(emb_feat.shape[0])]
107
    columns = ['feat%d'%(i+1) for i in range(emb_feat.shape[1])]
108
    data_pd = pd.DataFrame(emb_feat,index = index,columns=columns)
109
    data_pd.to_csv(save,sep=sep)
110
111
def save_clustering(label,save):
112
    f = open(save,'w')
113
    res_list = ['cell%d\t%s'%(i,str(item)) for i,item in enumerate(label)]
114
    f.write('\n'.join(res_list))
115
    f.close()
116
117
if __name__ == '__main__':
118
        parser = argparse.ArgumentParser(description='Simultaneous deep generative modeling and clustering of single cell genomic data')
119
        parser.add_argument('--data', '-d', type=str, help='which dataset')
120
        parser.add_argument('--timestamp', '-t', type=str, help='timestamp')
121
        parser.add_argument('--epoch', '-e', type=int, help='epoch or batch index')
122
        parser.add_argument('--train', type=bool, default=False)
123
        parser.add_argument('--save', '-s', type=str, help='save latent visualization plot (e.g., t-SNE)')
124
        parser.add_argument('--no_label', action='store_true',help='whether the dataset has label')
125
        args = parser.parse_args()
126
        has_label = not args.no_label
127
        if has_label:
128
            if args.train:
129
                exp_dir = [item for item in os.listdir('results/%s'%args.data) if item.startswith(args.timestamp)][0]
130
                if args.epoch is None:
131
                    epoch = get_best_epoch(exp_dir,args.data,'ARI')
132
                else:
133
                    epoch = args.epoch
134
                data = np.load('results/%s/%s/%s'%(args.data,exp_dir,epoch))
135
                embedding, label_infered_onehot = data['arr_0'],data['arr_1']
136
                embedding_before_softmax = embedding[:,-label_infered_onehot.shape[1]:]
137
                label_infered = np.argmax(label_infered_onehot, axis=1)
138
                label_true = [item.strip() for item  in open('datasets/%s/label.txt'%args.data).readlines()]
139
                save_clustering(label_infered,save='results/%s/%s/scDEC_cluster.txt'%(args.data,exp_dir))
140
                save_embedding(embedding,save='results/%s/%s/scDEC_embedding.csv'%(args.data,exp_dir),sep='\t')
141
                plot_embedding(embedding,label_true,save='results/%s/%s/scDEC_embedding.png'%(args.data,exp_dir))
142
            else:
143
                if args.data == 'PBMC10k':
144
                    data = np.load('results/%s/data_pre.npz'%args.data)
145
                    embedding, label_infered_onehot = data['arr_0'],data['arr_1']
146
                    embedding_before_softmax = embedding[:,-label_infered_onehot.shape[1]:]
147
                    label_infered = np.argmax(label_infered_onehot, axis=1)
148
                    barcode2label = {item.split('\t')[0]:item.split('\t')[1].strip() for item in open('datasets/%s/labels_annot.txt'%args.data).readlines()[1:]}
149
                    barcodes = [item.strip() for item in open('datasets/%s/barcodes.tsv'%args.data).readlines()]
150
                    labels_annot = [barcode2label[item] for i,item in enumerate(barcodes) if item in barcode2label.keys()]
151
                    select_idx = [i for i,item in enumerate(barcodes) if item in barcode2label.keys()]
152
                    embedding = embedding[select_idx,:] # only evaluated on cells with annotation labels
153
                    label_infered = label_infered[select_idx]
154
                    uniq_label = list(np.unique(labels_annot))
155
                    Y = np.array([uniq_label.index(item) for item in labels_annot])
156
                    cluster_eval(Y,label_infered)
157
                    save_clustering(label_infered,save='results/%s/scDEC_cluster.txt'%args.data)
158
                    save_embedding(embedding,save='results/%s/scDEC_embedding.csv'%args.data,sep='\t')
159
                    plot_embedding(embedding,labels_annot,save='results/%s/scDEC_embedding.png'%args.data)
160
                else:
161
                    data = np.load('results/%s/data_pre.npz'%args.data)
162
                    embedding, label_infered_onehot = data['arr_0'],data['arr_1']
163
                    embedding_before_softmax = embedding[:,-label_infered_onehot.shape[1]:]
164
                    label_infered = np.argmax(label_infered_onehot, axis=1)
165
                    label_true = [item.strip() for item  in open('datasets/%s/label.txt'%args.data).readlines()]
166
                    save_clustering(label_infered,save='results/%s/scDEC_cluster.txt'%args.data)
167
                    save_embedding(embedding,save='results/%s/scDEC_embedding.csv'%args.data,sep='\t')
168
                    plot_embedding(embedding,label_true,save='results/%s/scDEC_embedding.png'%args.data)
169
        else:
170
            if args.epoch is None:
171
                print('Provide the epoch or batch index to analyze')
172
                sys.exit()
173
            else:
174
                exp_dir = [item for item in os.listdir('results/%s'%args.data) if item.startswith(args.timestamp)][0]
175
                data = np.load('results/%s/%s/data_at_%s.npz'%(args.data,exp_dir,args.epoch))
176
                embedding, label_infered_onehot = data['arr_0'],data['arr_1']
177
                label_infered = np.argmax(label_infered_onehot, axis=1)
178
                save_clustering(label_infered,save='results/%s/%s/scDEC_cluster.txt'%(args.data,exp_dir))
179
                
180
181
182
    
183