import argparse
import metric
from sklearn.cluster import KMeans
from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score
from sklearn.metrics.cluster import homogeneity_score, adjusted_mutual_info_score
import numpy as np
import random
import sys,os
from scipy.io import loadmat
from sklearn.metrics import confusion_matrix
import pandas as pd
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("whitegrid", {'axes.grid' : False})
def plot_embedding(X, labels, classes=None, method='tSNE', cmap='tab20', figsize=(8, 8), markersize=15, dpi=300,marker=None,
return_emb=False, save=False, save_emb=False, show_legend=True, show_axis_label=True, **legend_params):
if marker is not None:
X = np.concatenate([X, marker], axis=0)
N = len(labels)
matplotlib.rc('xtick', labelsize=20)
matplotlib.rc('ytick', labelsize=20)
matplotlib.rcParams.update({'font.size': 22})
if X.shape[1] != 2:
if method == 'tSNE':
from sklearn.manifold import TSNE
X = TSNE(n_components=2, random_state=124).fit_transform(X)
if method == 'PCA':
from sklearn.decomposition import PCA
X = PCA(n_components=2, random_state=124).fit_transform(X)
if method == 'UMAP':
from umap import UMAP
X = UMAP(n_neighbors=15, min_dist=0.1, metric='correlation').fit_transform(X)
labels = np.array(labels)
plt.figure(figsize=figsize)
if classes is None:
classes = np.unique(labels)
#tab10, tab20, husl, hls
if cmap is not None:
cmap = cmap
elif len(classes) <= 10:
cmap = 'tab10'
elif len(classes) <= 20:
cmap = 'tab20'
else:
cmap = 'husl'
colors = sns.husl_palette(len(classes), s=.8)
#markersize = 80
for i, c in enumerate(classes):
plt.scatter(X[:N][labels==c, 0], X[:N][labels==c, 1], s=markersize, color=colors[i], label=c)
if marker is not None:
plt.scatter(X[N:, 0], X[N:, 1], s=10*markersize, color='black', marker='*')
legend_params_ = {'loc': 'center left',
'bbox_to_anchor':(1.0, 0.45),
'fontsize': 20,
'ncol': 1,
'frameon': False,
'markerscale': 1.5
}
legend_params_.update(**legend_params)
if show_legend:
plt.legend(**legend_params_)
sns.despine(offset=10, trim=True)
if show_axis_label:
plt.xlabel(method+' dim 1', fontsize=12)
plt.ylabel(method+' dim 2', fontsize=12)
if save:
plt.savefig(save, format='png', bbox_inches='tight',dpi=dpi)
def cluster_eval(labels_true,labels_infer):
purity = metric.compute_purity(labels_infer, labels_true)
nmi = normalized_mutual_info_score(labels_true, labels_infer)
ari = adjusted_rand_score(labels_true, labels_infer)
homogeneity = homogeneity_score(labels_true, labels_infer)
ami = adjusted_mutual_info_score(labels_true, labels_infer)
print('NMI = {}, ARI = {}, Purity = {},AMI = {}, Homogeneity = {}'.format(nmi,ari,purity,ami,homogeneity))
return nmi,ari,homogeneity
def get_best_epoch(exp_dir, dataset, measurement='NMI'):
results = []
for each in os.listdir('results/%s/%s'%(dataset,exp_dir)):
if each.startswith('data'):
#print('results/%s/%s/%s'%(dataset,exp_dir,each))
data = np.load('results/%s/%s/%s'%(dataset,exp_dir,each))
data_x_onehot_,label_y = data['arr_1'],data['arr_2']
label_infer = np.argmax(data_x_onehot_, axis=1)
nmi,ari,homo = cluster_eval(label_y,label_infer)
results.append([each,nmi,ari,homo])
if measurement == 'NMI':
results.sort(key=lambda a:-a[1])
elif measurement == 'ARI':
results.sort(key=lambda a:-a[2])
elif measurement == 'HOMO':
results.sort(key=lambda a:-a[3])
else:
print('Wrong indicated metric')
sys.exit()
print('NMI = {}\tARI = {}\tHomogeneity = {}'.format(results[0][1],results[0][2],results[0][3]))
return results[0][0]
def save_embedding(emb_feat,save,sep='\t'):
index = ['cell%d'%(i+1) for i in range(emb_feat.shape[0])]
columns = ['feat%d'%(i+1) for i in range(emb_feat.shape[1])]
data_pd = pd.DataFrame(emb_feat,index = index,columns=columns)
data_pd.to_csv(save,sep=sep)
def save_clustering(label,save):
f = open(save,'w')
res_list = ['cell%d\t%s'%(i,str(item)) for i,item in enumerate(label)]
f.write('\n'.join(res_list))
f.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Simultaneous deep generative modeling and clustering of single cell genomic data')
parser.add_argument('--data', '-d', type=str, help='which dataset')
parser.add_argument('--timestamp', '-t', type=str, help='timestamp')
parser.add_argument('--epoch', '-e', type=int, help='epoch or batch index')
parser.add_argument('--train', type=bool, default=False)
parser.add_argument('--save', '-s', type=str, help='save latent visualization plot (e.g., t-SNE)')
parser.add_argument('--no_label', action='store_true',help='whether the dataset has label')
args = parser.parse_args()
has_label = not args.no_label
if has_label:
if args.train:
exp_dir = [item for item in os.listdir('results/%s'%args.data) if item.startswith(args.timestamp)][0]
if args.epoch is None:
epoch = get_best_epoch(exp_dir,args.data,'ARI')
else:
epoch = args.epoch
data = np.load('results/%s/%s/%s'%(args.data,exp_dir,epoch))
embedding, label_infered_onehot = data['arr_0'],data['arr_1']
embedding_before_softmax = embedding[:,-label_infered_onehot.shape[1]:]
label_infered = np.argmax(label_infered_onehot, axis=1)
label_true = [item.strip() for item in open('datasets/%s/label.txt'%args.data).readlines()]
save_clustering(label_infered,save='results/%s/%s/scDEC_cluster.txt'%(args.data,exp_dir))
save_embedding(embedding,save='results/%s/%s/scDEC_embedding.csv'%(args.data,exp_dir),sep='\t')
plot_embedding(embedding,label_true,save='results/%s/%s/scDEC_embedding.png'%(args.data,exp_dir))
else:
if args.data == 'PBMC10k':
data = np.load('results/%s/data_pre.npz'%args.data)
embedding, label_infered_onehot = data['arr_0'],data['arr_1']
embedding_before_softmax = embedding[:,-label_infered_onehot.shape[1]:]
label_infered = np.argmax(label_infered_onehot, axis=1)
barcode2label = {item.split('\t')[0]:item.split('\t')[1].strip() for item in open('datasets/%s/labels_annot.txt'%args.data).readlines()[1:]}
barcodes = [item.strip() for item in open('datasets/%s/barcodes.tsv'%args.data).readlines()]
labels_annot = [barcode2label[item] for i,item in enumerate(barcodes) if item in barcode2label.keys()]
select_idx = [i for i,item in enumerate(barcodes) if item in barcode2label.keys()]
embedding = embedding[select_idx,:] # only evaluated on cells with annotation labels
label_infered = label_infered[select_idx]
uniq_label = list(np.unique(labels_annot))
Y = np.array([uniq_label.index(item) for item in labels_annot])
cluster_eval(Y,label_infered)
save_clustering(label_infered,save='results/%s/scDEC_cluster.txt'%args.data)
save_embedding(embedding,save='results/%s/scDEC_embedding.csv'%args.data,sep='\t')
plot_embedding(embedding,labels_annot,save='results/%s/scDEC_embedding.png'%args.data)
else:
data = np.load('results/%s/data_pre.npz'%args.data)
embedding, label_infered_onehot = data['arr_0'],data['arr_1']
embedding_before_softmax = embedding[:,-label_infered_onehot.shape[1]:]
label_infered = np.argmax(label_infered_onehot, axis=1)
label_true = [item.strip() for item in open('datasets/%s/label.txt'%args.data).readlines()]
save_clustering(label_infered,save='results/%s/scDEC_cluster.txt'%args.data)
save_embedding(embedding,save='results/%s/scDEC_embedding.csv'%args.data,sep='\t')
plot_embedding(embedding,label_true,save='results/%s/scDEC_embedding.png'%args.data)
else:
if args.epoch is None:
print('Provide the epoch or batch index to analyze')
sys.exit()
else:
exp_dir = [item for item in os.listdir('results/%s'%args.data) if item.startswith(args.timestamp)][0]
data = np.load('results/%s/%s/data_at_%s.npz'%(args.data,exp_dir,args.epoch))
embedding, label_infered_onehot = data['arr_0'],data['arr_1']
label_infered = np.argmax(label_infered_onehot, axis=1)
save_clustering(label_infered,save='results/%s/%s/scDEC_cluster.txt'%(args.data,exp_dir))