|
a |
|
b/eval.py |
|
|
1 |
import argparse |
|
|
2 |
import metric |
|
|
3 |
from sklearn.cluster import KMeans |
|
|
4 |
from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score |
|
|
5 |
from sklearn.metrics.cluster import homogeneity_score, adjusted_mutual_info_score |
|
|
6 |
import numpy as np |
|
|
7 |
import random |
|
|
8 |
import sys,os |
|
|
9 |
from scipy.io import loadmat |
|
|
10 |
from sklearn.metrics import confusion_matrix |
|
|
11 |
import pandas as pd |
|
|
12 |
import matplotlib |
|
|
13 |
matplotlib.use('agg') |
|
|
14 |
import matplotlib.pyplot as plt |
|
|
15 |
import seaborn as sns |
|
|
16 |
sns.set_style("whitegrid", {'axes.grid' : False}) |
|
|
17 |
|
|
|
18 |
def plot_embedding(X, labels, classes=None, method='tSNE', cmap='tab20', figsize=(8, 8), markersize=15, dpi=300,marker=None, |
|
|
19 |
return_emb=False, save=False, save_emb=False, show_legend=True, show_axis_label=True, **legend_params): |
|
|
20 |
if marker is not None: |
|
|
21 |
X = np.concatenate([X, marker], axis=0) |
|
|
22 |
N = len(labels) |
|
|
23 |
matplotlib.rc('xtick', labelsize=20) |
|
|
24 |
matplotlib.rc('ytick', labelsize=20) |
|
|
25 |
matplotlib.rcParams.update({'font.size': 22}) |
|
|
26 |
if X.shape[1] != 2: |
|
|
27 |
if method == 'tSNE': |
|
|
28 |
from sklearn.manifold import TSNE |
|
|
29 |
X = TSNE(n_components=2, random_state=124).fit_transform(X) |
|
|
30 |
if method == 'PCA': |
|
|
31 |
from sklearn.decomposition import PCA |
|
|
32 |
X = PCA(n_components=2, random_state=124).fit_transform(X) |
|
|
33 |
if method == 'UMAP': |
|
|
34 |
from umap import UMAP |
|
|
35 |
X = UMAP(n_neighbors=15, min_dist=0.1, metric='correlation').fit_transform(X) |
|
|
36 |
labels = np.array(labels) |
|
|
37 |
plt.figure(figsize=figsize) |
|
|
38 |
if classes is None: |
|
|
39 |
classes = np.unique(labels) |
|
|
40 |
#tab10, tab20, husl, hls |
|
|
41 |
if cmap is not None: |
|
|
42 |
cmap = cmap |
|
|
43 |
elif len(classes) <= 10: |
|
|
44 |
cmap = 'tab10' |
|
|
45 |
elif len(classes) <= 20: |
|
|
46 |
cmap = 'tab20' |
|
|
47 |
else: |
|
|
48 |
cmap = 'husl' |
|
|
49 |
colors = sns.husl_palette(len(classes), s=.8) |
|
|
50 |
#markersize = 80 |
|
|
51 |
for i, c in enumerate(classes): |
|
|
52 |
plt.scatter(X[:N][labels==c, 0], X[:N][labels==c, 1], s=markersize, color=colors[i], label=c) |
|
|
53 |
if marker is not None: |
|
|
54 |
plt.scatter(X[N:, 0], X[N:, 1], s=10*markersize, color='black', marker='*') |
|
|
55 |
|
|
|
56 |
legend_params_ = {'loc': 'center left', |
|
|
57 |
'bbox_to_anchor':(1.0, 0.45), |
|
|
58 |
'fontsize': 20, |
|
|
59 |
'ncol': 1, |
|
|
60 |
'frameon': False, |
|
|
61 |
'markerscale': 1.5 |
|
|
62 |
} |
|
|
63 |
legend_params_.update(**legend_params) |
|
|
64 |
if show_legend: |
|
|
65 |
plt.legend(**legend_params_) |
|
|
66 |
sns.despine(offset=10, trim=True) |
|
|
67 |
if show_axis_label: |
|
|
68 |
plt.xlabel(method+' dim 1', fontsize=12) |
|
|
69 |
plt.ylabel(method+' dim 2', fontsize=12) |
|
|
70 |
|
|
|
71 |
if save: |
|
|
72 |
plt.savefig(save, format='png', bbox_inches='tight',dpi=dpi) |
|
|
73 |
|
|
|
74 |
def cluster_eval(labels_true,labels_infer): |
|
|
75 |
purity = metric.compute_purity(labels_infer, labels_true) |
|
|
76 |
nmi = normalized_mutual_info_score(labels_true, labels_infer) |
|
|
77 |
ari = adjusted_rand_score(labels_true, labels_infer) |
|
|
78 |
homogeneity = homogeneity_score(labels_true, labels_infer) |
|
|
79 |
ami = adjusted_mutual_info_score(labels_true, labels_infer) |
|
|
80 |
print('NMI = {}, ARI = {}, Purity = {},AMI = {}, Homogeneity = {}'.format(nmi,ari,purity,ami,homogeneity)) |
|
|
81 |
return nmi,ari,homogeneity |
|
|
82 |
|
|
|
83 |
def get_best_epoch(exp_dir, dataset, measurement='NMI'): |
|
|
84 |
results = [] |
|
|
85 |
for each in os.listdir('results/%s/%s'%(dataset,exp_dir)): |
|
|
86 |
if each.startswith('data'): |
|
|
87 |
#print('results/%s/%s/%s'%(dataset,exp_dir,each)) |
|
|
88 |
data = np.load('results/%s/%s/%s'%(dataset,exp_dir,each)) |
|
|
89 |
data_x_onehot_,label_y = data['arr_1'],data['arr_2'] |
|
|
90 |
label_infer = np.argmax(data_x_onehot_, axis=1) |
|
|
91 |
nmi,ari,homo = cluster_eval(label_y,label_infer) |
|
|
92 |
results.append([each,nmi,ari,homo]) |
|
|
93 |
if measurement == 'NMI': |
|
|
94 |
results.sort(key=lambda a:-a[1]) |
|
|
95 |
elif measurement == 'ARI': |
|
|
96 |
results.sort(key=lambda a:-a[2]) |
|
|
97 |
elif measurement == 'HOMO': |
|
|
98 |
results.sort(key=lambda a:-a[3]) |
|
|
99 |
else: |
|
|
100 |
print('Wrong indicated metric') |
|
|
101 |
sys.exit() |
|
|
102 |
print('NMI = {}\tARI = {}\tHomogeneity = {}'.format(results[0][1],results[0][2],results[0][3])) |
|
|
103 |
return results[0][0] |
|
|
104 |
|
|
|
105 |
def save_embedding(emb_feat,save,sep='\t'): |
|
|
106 |
index = ['cell%d'%(i+1) for i in range(emb_feat.shape[0])] |
|
|
107 |
columns = ['feat%d'%(i+1) for i in range(emb_feat.shape[1])] |
|
|
108 |
data_pd = pd.DataFrame(emb_feat,index = index,columns=columns) |
|
|
109 |
data_pd.to_csv(save,sep=sep) |
|
|
110 |
|
|
|
111 |
def save_clustering(label,save): |
|
|
112 |
f = open(save,'w') |
|
|
113 |
res_list = ['cell%d\t%s'%(i,str(item)) for i,item in enumerate(label)] |
|
|
114 |
f.write('\n'.join(res_list)) |
|
|
115 |
f.close() |
|
|
116 |
|
|
|
117 |
if __name__ == '__main__': |
|
|
118 |
parser = argparse.ArgumentParser(description='Simultaneous deep generative modeling and clustering of single cell genomic data') |
|
|
119 |
parser.add_argument('--data', '-d', type=str, help='which dataset') |
|
|
120 |
parser.add_argument('--timestamp', '-t', type=str, help='timestamp') |
|
|
121 |
parser.add_argument('--epoch', '-e', type=int, help='epoch or batch index') |
|
|
122 |
parser.add_argument('--train', type=bool, default=False) |
|
|
123 |
parser.add_argument('--save', '-s', type=str, help='save latent visualization plot (e.g., t-SNE)') |
|
|
124 |
parser.add_argument('--no_label', action='store_true',help='whether the dataset has label') |
|
|
125 |
args = parser.parse_args() |
|
|
126 |
has_label = not args.no_label |
|
|
127 |
if has_label: |
|
|
128 |
if args.train: |
|
|
129 |
exp_dir = [item for item in os.listdir('results/%s'%args.data) if item.startswith(args.timestamp)][0] |
|
|
130 |
if args.epoch is None: |
|
|
131 |
epoch = get_best_epoch(exp_dir,args.data,'ARI') |
|
|
132 |
else: |
|
|
133 |
epoch = args.epoch |
|
|
134 |
data = np.load('results/%s/%s/%s'%(args.data,exp_dir,epoch)) |
|
|
135 |
embedding, label_infered_onehot = data['arr_0'],data['arr_1'] |
|
|
136 |
embedding_before_softmax = embedding[:,-label_infered_onehot.shape[1]:] |
|
|
137 |
label_infered = np.argmax(label_infered_onehot, axis=1) |
|
|
138 |
label_true = [item.strip() for item in open('datasets/%s/label.txt'%args.data).readlines()] |
|
|
139 |
save_clustering(label_infered,save='results/%s/%s/scDEC_cluster.txt'%(args.data,exp_dir)) |
|
|
140 |
save_embedding(embedding,save='results/%s/%s/scDEC_embedding.csv'%(args.data,exp_dir),sep='\t') |
|
|
141 |
plot_embedding(embedding,label_true,save='results/%s/%s/scDEC_embedding.png'%(args.data,exp_dir)) |
|
|
142 |
else: |
|
|
143 |
if args.data == 'PBMC10k': |
|
|
144 |
data = np.load('results/%s/data_pre.npz'%args.data) |
|
|
145 |
embedding, label_infered_onehot = data['arr_0'],data['arr_1'] |
|
|
146 |
embedding_before_softmax = embedding[:,-label_infered_onehot.shape[1]:] |
|
|
147 |
label_infered = np.argmax(label_infered_onehot, axis=1) |
|
|
148 |
barcode2label = {item.split('\t')[0]:item.split('\t')[1].strip() for item in open('datasets/%s/labels_annot.txt'%args.data).readlines()[1:]} |
|
|
149 |
barcodes = [item.strip() for item in open('datasets/%s/barcodes.tsv'%args.data).readlines()] |
|
|
150 |
labels_annot = [barcode2label[item] for i,item in enumerate(barcodes) if item in barcode2label.keys()] |
|
|
151 |
select_idx = [i for i,item in enumerate(barcodes) if item in barcode2label.keys()] |
|
|
152 |
embedding = embedding[select_idx,:] # only evaluated on cells with annotation labels |
|
|
153 |
label_infered = label_infered[select_idx] |
|
|
154 |
uniq_label = list(np.unique(labels_annot)) |
|
|
155 |
Y = np.array([uniq_label.index(item) for item in labels_annot]) |
|
|
156 |
cluster_eval(Y,label_infered) |
|
|
157 |
save_clustering(label_infered,save='results/%s/scDEC_cluster.txt'%args.data) |
|
|
158 |
save_embedding(embedding,save='results/%s/scDEC_embedding.csv'%args.data,sep='\t') |
|
|
159 |
plot_embedding(embedding,labels_annot,save='results/%s/scDEC_embedding.png'%args.data) |
|
|
160 |
else: |
|
|
161 |
data = np.load('results/%s/data_pre.npz'%args.data) |
|
|
162 |
embedding, label_infered_onehot = data['arr_0'],data['arr_1'] |
|
|
163 |
embedding_before_softmax = embedding[:,-label_infered_onehot.shape[1]:] |
|
|
164 |
label_infered = np.argmax(label_infered_onehot, axis=1) |
|
|
165 |
label_true = [item.strip() for item in open('datasets/%s/label.txt'%args.data).readlines()] |
|
|
166 |
save_clustering(label_infered,save='results/%s/scDEC_cluster.txt'%args.data) |
|
|
167 |
save_embedding(embedding,save='results/%s/scDEC_embedding.csv'%args.data,sep='\t') |
|
|
168 |
plot_embedding(embedding,label_true,save='results/%s/scDEC_embedding.png'%args.data) |
|
|
169 |
else: |
|
|
170 |
if args.epoch is None: |
|
|
171 |
print('Provide the epoch or batch index to analyze') |
|
|
172 |
sys.exit() |
|
|
173 |
else: |
|
|
174 |
exp_dir = [item for item in os.listdir('results/%s'%args.data) if item.startswith(args.timestamp)][0] |
|
|
175 |
data = np.load('results/%s/%s/data_at_%s.npz'%(args.data,exp_dir,args.epoch)) |
|
|
176 |
embedding, label_infered_onehot = data['arr_0'],data['arr_1'] |
|
|
177 |
label_infered = np.argmax(label_infered_onehot, axis=1) |
|
|
178 |
save_clustering(label_infered,save='results/%s/%s/scDEC_cluster.txt'%(args.data,exp_dir)) |
|
|
179 |
|
|
|
180 |
|
|
|
181 |
|
|
|
182 |
|
|
|
183 |
|