--- a
+++ b/main_clustering.py
@@ -0,0 +1,381 @@
+from __future__ import division
+import os,sys
+import time
+import dateutil.tz
+import datetime
+import argparse
+import importlib
+import tensorflow as tf
+import numpy as np
+import random
+import copy
+import math
+import util
+import metric
+from sklearn.cluster import KMeans
+from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score,homogeneity_score
+import pandas as pd 
+
+
+tf.set_random_seed(0)
+tf.reset_default_graph()
+
+
+'''
+Instructions: scDEC model
+    x,y - data drawn from base density (e.g., Gaussian) and observation data
+    x_onehot - data drawn from caltegrory distribution
+    y_  - Generated data where y_=G(x,x_onehot)
+    x_latent_,x_onehot_  -  Embedding and inferred clustering label where x_latent_, x_onehot_=H(y)
+    y__ - reconstructed distribution, y__ = G(H(y))
+    x__ - reconstructed distribution, x__ = H(G(y))
+    G(.)  - generator network for mapping latent space to data space
+    H(.)  - generator network for mapping data space to latent space (embedding) and clustering, simultaneously
+    Dx(.) - discriminator network in x space (latent space)
+    Dy(.) - discriminator network in y space (observation space)
+'''
+class scDEC(object):
+    def __init__(self, g_net, h_net, dx_net, dy_net, x_sampler, y_sampler, nb_classes, data, pool, batch_size, alpha, beta, is_train):
+        self.data = data
+        self.g_net = g_net
+        self.h_net = h_net
+        self.dx_net = dx_net
+        self.dy_net = dy_net
+        self.x_sampler = x_sampler
+        self.y_sampler = y_sampler
+        self.nb_classes = nb_classes
+        self.batch_size = batch_size
+        self.alpha = alpha
+        self.beta = beta
+        self.pool = pool
+        self.x_dim = self.dx_net.input_dim
+        self.y_dim = self.dy_net.input_dim
+
+
+        self.x = tf.placeholder(tf.float32, [None, self.x_dim], name='x')
+        self.x_onehot = tf.placeholder(tf.float32, [None, self.nb_classes], name='x_onehot')
+        self.x_combine = tf.concat([self.x,self.x_onehot],axis=1,name='x_combine')
+
+        self.y = tf.placeholder(tf.float32, [None, self.y_dim], name='y')
+
+        self.y_ = self.g_net(self.x_combine,reuse=False)
+
+        self.x_latent_, self.x_onehot_ = self.h_net(self.y,reuse=False)#continuous + softmax + before_softmax
+        self.x_ = self.x_latent_[:,:self.x_dim]
+        self.x_logits_ = self.x_latent_[:,self.x_dim:]
+        
+        self.x_latent__, self.x_onehot__ = self.h_net(self.y_)
+        self.x__ = self.x_latent__[:,:self.x_dim]
+        self.x_logits__ = self.x_latent__[:,self.x_dim:]
+
+        self.x_combine_ = tf.concat([self.x_, self.x_onehot_],axis=1)
+        self.y__ = self.g_net(self.x_combine_)
+
+        self.dy_ = self.dy_net(self.y_, reuse=False)
+        self.dx_ = self.dx_net(self.x_, reuse=False)
+
+        self.l2_loss_x = tf.reduce_mean((self.x - self.x__)**2)
+        self.l2_loss_y = tf.reduce_mean((self.y - self.y__)**2)
+
+        #self.CE_loss_x = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.x_onehot, logits=self.x_logits__))
+        self.CE_loss_x = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.x_logits__,labels=self.x_onehot))
+        
+        self.g_loss_adv = -tf.reduce_mean(self.dy_)
+        self.h_loss_adv = -tf.reduce_mean(self.dx_)
+        
+
+        self.g_loss = self.g_loss_adv + self.alpha*self.l2_loss_x + self.beta*self.l2_loss_y
+        self.h_loss = self.h_loss_adv + self.alpha*self.l2_loss_x + self.beta*self.l2_loss_y
+        self.g_h_loss = self.g_loss_adv + self.h_loss_adv + self.alpha*(self.l2_loss_x + self.l2_loss_y) + self.beta*self.CE_loss_x
+       
+        self.dx = self.dx_net(self.x)
+        self.dy = self.dy_net(self.y)
+
+        self.dx_loss = -tf.reduce_mean(self.dx) + tf.reduce_mean(self.dx_)
+        self.dy_loss = -tf.reduce_mean(self.dy) + tf.reduce_mean(self.dy_)
+
+        #gradient penalty for x
+        epsilon_x = tf.random_uniform([], 0.0, 1.0)
+        x_hat = epsilon_x * self.x + (1 - epsilon_x) * self.x_
+        dx_hat = self.dx_net(x_hat)
+        grad_x = tf.gradients(dx_hat, x_hat)[0] #(bs,x_dim)
+        grad_norm_x = tf.sqrt(tf.reduce_sum(tf.square(grad_x), axis=1))#(bs,)
+        self.gpx_loss = tf.reduce_mean(tf.square(grad_norm_x - 1.0))
+
+        #gradient penalty for y
+        epsilon_y = tf.random_uniform([], 0.0, 1.0)
+        y_hat = epsilon_y * self.y + (1 - epsilon_y) * self.y_
+        dy_hat = self.dy_net(y_hat)
+        grad_y = tf.gradients(dy_hat, y_hat)[0] #(bs,x_dim)
+        grad_norm_y = tf.sqrt(tf.reduce_sum(tf.square(grad_y), axis=1))#(bs,)
+        self.gpy_loss = tf.reduce_mean(tf.square(grad_norm_y - 1.0))
+
+        self.d_loss = self.dx_loss + self.dy_loss + 10*(self.gpx_loss + self.gpy_loss)
+
+        self.lr = tf.placeholder(tf.float32, None, name='learning_rate')
+        self.g_h_optim = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5, beta2=0.9) \
+                .minimize(self.g_h_loss, var_list=self.g_net.vars+self.h_net.vars)
+        #self.d_optim = tf.train.GradientDescentOptimizer(learning_rate=self.lr) \
+        #        .minimize(self.d_loss, var_list=self.dx_net.vars+self.dy_net.vars)
+        self.d_optim = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5, beta2=0.9) \
+                .minimize(self.d_loss, var_list=self.dx_net.vars+self.dy_net.vars)
+
+        now = datetime.datetime.now(dateutil.tz.tzlocal())
+        self.timestamp = now.strftime('%Y%m%d_%H%M%S')
+
+        self.g_loss_adv_summary = tf.summary.scalar('g_loss_adv',self.g_loss_adv)
+        self.h_loss_adv_summary = tf.summary.scalar('h_loss_adv',self.h_loss_adv)
+        self.l2_loss_x_summary = tf.summary.scalar('l2_loss_x',self.l2_loss_x)
+        self.l2_loss_y_summary = tf.summary.scalar('l2_loss_y',self.l2_loss_y)
+        self.dx_loss_summary = tf.summary.scalar('dx_loss',self.dx_loss)
+        self.dy_loss_summary = tf.summary.scalar('dy_loss',self.dy_loss)
+        self.gpx_loss_summary = tf.summary.scalar('gpx_loss',self.gpx_loss)
+        self.gpy_loss_summary = tf.summary.scalar('gpy_loss',self.gpy_loss)
+        self.g_merged_summary = tf.summary.merge([self.g_loss_adv_summary, self.h_loss_adv_summary,\
+            self.l2_loss_x_summary,self.l2_loss_y_summary,self.gpx_loss_summary,self.gpy_loss_summary])
+        self.d_merged_summary = tf.summary.merge([self.dx_loss_summary,self.dy_loss_summary])
+
+        #graph path for tensorboard visualization
+        self.graph_dir = 'graph/{}/{}_x_dim={}_y_dim={}_alpha={}_beta={}_ratio={}'.format(self.data,self.timestamp,self.x_dim, self.y_dim, self.alpha, self.beta, ratio)
+        if not os.path.exists(self.graph_dir) and is_train:
+            os.makedirs(self.graph_dir)
+        
+        #save path for saving predicted data
+        self.save_dir = 'results/{}/{}_x_dim={}_y_dim={}_alpha={}_beta={}_ratio={}'.format(self.data,self.timestamp,self.x_dim, self.y_dim, self.alpha, self.beta, ratio)
+        if not os.path.exists(self.save_dir) and is_train:
+            os.makedirs(self.save_dir)
+
+        self.saver = tf.train.Saver(max_to_keep=5000)
+
+        #run_config = tf.ConfigProto(intra_op_parallelism_threads=1,inter_op_parallelism_threads=1)
+        run_config = tf.ConfigProto()
+        run_config.gpu_options.per_process_gpu_memory_fraction = 1.0
+        run_config.gpu_options.allow_growth = True
+
+        self.sess = tf.Session(config=run_config)
+
+
+    def train(self, nb_batches):
+        data_y = self.y_sampler.load_all()[0] if has_label else self.y_sampler.load_all()
+        self.sess.run(tf.global_variables_initializer())
+        self.summary_writer=tf.summary.FileWriter(self.graph_dir,graph=tf.get_default_graph())
+        batches_per_eval = 100
+        start_time = time.time()
+        weights = np.ones(self.nb_classes, dtype=np.float64) / float(self.nb_classes)
+        last_weights = np.ones(self.nb_classes, dtype=np.float64) / float(self.nb_classes)
+        diff_history=[]
+        for batch_idx in range(nb_batches):
+            lr = 2e-4
+            #update D
+            for _ in range(5):
+                bx, bx_onehot = self.x_sampler.train(self.batch_size,weights)
+                by = random.sample(data_y,self.batch_size)
+                d_summary,_ = self.sess.run([self.d_merged_summary, self.d_optim], feed_dict={self.x: bx, self.x_onehot: bx_onehot, self.y: by, self.lr:lr})
+            self.summary_writer.add_summary(d_summary,batch_idx)
+
+            bx, bx_onehot = self.x_sampler.train(self.batch_size,weights)
+            by = random.sample(data_y,self.batch_size)
+
+            #update G
+            g_summary, _ = self.sess.run([self.g_merged_summary ,self.g_h_optim], feed_dict={self.x: bx, self.x_onehot: bx_onehot, self.y: by, self.lr:lr})
+            self.summary_writer.add_summary(g_summary,batch_idx)
+            #quick test on a random batch data
+            if batch_idx % batches_per_eval == 0:
+                g_loss_adv, h_loss_adv, CE_loss, l2_loss_x, l2_loss_y, g_loss, \
+                    h_loss, g_h_loss, gpx_loss, gpy_loss = self.sess.run(
+                    [self.g_loss_adv, self.h_loss_adv, self.CE_loss_x, self.l2_loss_x, self.l2_loss_y, \
+                    self.g_loss, self.h_loss, self.g_h_loss, self.gpx_loss, self.gpy_loss],
+                    feed_dict={self.x: bx, self.x_onehot: bx_onehot, self.y: by}
+                )
+                dx_loss, dy_loss, d_loss = self.sess.run([self.dx_loss, self.dy_loss, self.d_loss], \
+                    feed_dict={self.x: bx, self.x_onehot: bx_onehot, self.y: by})
+
+                print('Batch_idx [%d] Time [%.4f] g_loss_adv [%.4f] h_loss_adv [%.4f] CE_loss [%.4f] gpx_loss [%.4f] gpy_loss [%.4f] \
+                    l2_loss_x [%.4f] l2_loss_y [%.4f] g_loss [%.4f] h_loss [%.4f] g_h_loss [%.4f] dx_loss [%.4f] dy_loss [%.4f] d_loss [%.4f]' %
+                    (batch_idx, time.time() - start_time, g_loss_adv, h_loss_adv, CE_loss, gpx_loss, gpy_loss, l2_loss_x, l2_loss_y, \
+                    g_loss, h_loss, g_h_loss, dx_loss, dy_loss, d_loss))                 
+
+            if (batch_idx+1) % batches_per_eval == 0:
+                self.evaluate(timestamp,batch_idx)
+                self.save(batch_idx)
+
+                tol = 0.02
+                estimated_weights = self.estimate_weights(use_kmeans=False)
+                weights = ratio*weights + (1-ratio)*estimated_weights
+                weights = weights/np.sum(weights)
+                diff_weights = np.mean(np.abs(last_weights-weights))
+                diff_history.append(diff_weights)
+                if np.min(weights)<tol:
+                    weights = self.adjust_tiny_weights(weights,tol)
+                last_weights = copy.copy(weights)
+            
+            if len(diff_history)>100 and np.mean(diff_history[-10:]) < 5e-3 and batch_idx>30000:
+                print('Reach a stable cluster')
+                self.evaluate(timestamp,batch_idx)
+                sys.exit()
+
+    def adjust_tiny_weights(self,weights,tol):
+        idx_less = np.where(weights<tol)[0]
+        idx_greater = np.where(weights>=tol)[0]
+        weights[idx_less] = np.array([np.random.uniform(2*tol,1./self.nb_classes) for item in idx_less])
+        weights[idx_greater] = weights[idx_greater]*(1-np.sum(weights[idx_less]))/np.sum(weights[idx_greater])
+        return weights    
+
+    def estimate_weights(self,use_kmeans=False):
+        data_y = self.y_sampler.load_all()[0] if has_label else self.y_sampler.load_all()
+        data_x_, data_x_onehot_ = self.predict_x(data_y)
+        if use_kmeans:
+            km = KMeans(n_clusters=nb_classes, random_state=0).fit(np.concatenate([data_x_,data_x_onehot_],axis=1))
+            label_infer = km.labels_
+        else:
+            label_infer = np.argmax(data_x_onehot_, axis=1)
+        weights = np.empty(self.nb_classes, dtype=np.float32)
+        for i in range(self.nb_classes):
+            weights[i] = list(label_infer).count(i)  
+        return weights/float(np.sum(weights)) 
+
+    def evaluate(self,timestamp,batch_idx):
+        if has_label:
+            data_y, label_y = self.y_sampler.load_all()
+        else:
+            data_y = self.y_sampler.load_all()
+        data_x_, data_x_onehot_ = self.predict_x(data_y)
+        label_infer = np.argmax(data_x_onehot_, axis=1)
+        if has_label:
+            purity = metric.compute_purity(label_infer, label_y)
+            nmi = normalized_mutual_info_score(label_y, label_infer)
+            ari = adjusted_rand_score(label_y, label_infer)
+            homo = homogeneity_score(label_y,label_infer)
+            print('scDEC: NMI = {}, ARI = {}, Homogeneity = {}'.format(nmi,ari,homo))
+            if is_train:
+                f = open('%s/log.txt'%self.save_dir,'a+')
+                f.write('NMI = {}\tARI = {}\tHomogeneity = {}\t batch_idx = {}\n'.format(nmi,ari,homo,batch_idx))
+                f.close()
+                np.savez('{}/data_at_{}.npz'.format(self.save_dir, batch_idx+1),data_x_,data_x_onehot_,label_y)
+            else:
+                np.savez('results/{}/data_pre.npz'.format(self.data),data_x_,data_x_onehot_,label_y)    
+        else:
+            if is_train:
+                np.savez('{}/data_at_{}.npz'.format(self.save_dir, batch_idx+1),data_x_,data_x_onehot_)
+            else:
+                np.savez('results/{}/data_pre.npz'.format(self.data),data_x_,data_x_onehot_)
+
+    #predict with y_=G(x)
+    def predict_y(self, x, x_onehot, bs=256):
+        assert x.shape[-1] == self.x_dim
+        N = x.shape[0]
+        y_pred = np.zeros(shape=(N, self.y_dim)) 
+        for b in range(int(np.ceil(N*1.0 / bs))):
+            if (b+1)*bs > N:
+               ind = np.arange(b*bs, N)
+            else:
+               ind = np.arange(b*bs, (b+1)*bs)
+            batch_x = x[ind, :]
+            batch_x_onehot = x_onehot[ind, :]
+            batch_y_ = self.sess.run(self.y_, feed_dict={self.x:batch_x, self.x_onehot:batch_x_onehot})
+            y_pred[ind, :] = batch_y_
+        return y_pred
+    
+    #predict with x_=H(y)
+    def predict_x(self,y,bs=256):
+        assert y.shape[-1] == self.y_dim
+        N = y.shape[0]
+        x_pred = np.zeros(shape=(N, self.x_dim+self.nb_classes)) 
+        x_onehot = np.zeros(shape=(N, self.nb_classes)) 
+        for b in range(int(np.ceil(N*1.0 / bs))):
+            if (b+1)*bs > N:
+               ind = np.arange(b*bs, N)
+            else:
+               ind = np.arange(b*bs, (b+1)*bs)
+            batch_y = y[ind, :]
+            batch_x_,batch_x_onehot_ = self.sess.run([self.x_latent_, self.x_onehot_], feed_dict={self.y:batch_y})
+            x_pred[ind, :] = batch_x_
+            x_onehot[ind, :] = batch_x_onehot_
+        return x_pred, x_onehot
+
+
+    def save(self,batch_idx):
+
+        checkpoint_dir = 'checkpoint/{}_{}_x_dim={}_y_dim={}_alpha={}_beta={}'.format(self.timestamp,self.data,self.x_dim, self.y_dim, self.alpha, self.beta)
+        if not os.path.exists(checkpoint_dir):
+            os.makedirs(checkpoint_dir)
+
+        self.saver.save(self.sess, os.path.join(checkpoint_dir, 'model.ckpt'),global_step=batch_idx)
+
+    def load(self, pre_trained = False, timestamp='',batch_idx=999):
+
+        if pre_trained == True:
+            print('Loading Pre-trained Model...')
+            checkpoint_dir = 'pre_trained_models/{}'.format(self.data)
+            self.saver.restore(self.sess, os.path.join(checkpoint_dir, 'model.ckpt-best'))
+        else:
+            if timestamp == '':
+                print('Best Timestamp not provided.')
+                checkpoint_dir = ''
+            else:
+                checkpoint_dir = 'checkpoint/{}_{}_x_dim={}_y_dim={}_alpha={}_beta={}'.format(timestamp,self.data,self.x_dim, self.y_dim, self.alpha, self.beta)
+                self.saver.restore(self.sess, os.path.join(checkpoint_dir, 'model.ckpt-%d'%batch_idx))
+                print('Restored model weights.')
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser('')
+    parser.add_argument('--data', type=str, default='Splenocyte',help='name of dataset')
+    parser.add_argument('--model', type=str, default='model',help='model definition')
+    parser.add_argument('--K', type=int, default=11,help='number of clusters')
+    parser.add_argument('--dx', type=int, default=10,help='dimension of Gaussian distribution')
+    parser.add_argument('--dy', type=int, default=20,help='dimension of preprocessed data')
+    parser.add_argument('--bs', type=int, default=64,help='batch size')
+    parser.add_argument('--nb_batches', type=int, default=50000,help='total number of training batches or the batch idx for loading pretrain model')
+    parser.add_argument('--alpha', type=float, default=10.0,help='coefficient of loss term')
+    parser.add_argument('--beta', type=float, default=10.0,help='coefficient of loss term')
+    parser.add_argument('--ratio', type=float, default=0.2,help='parameter in updating Caltegory distribution')
+    parser.add_argument('--low', type=float, default=0.03,help='low ratio for filtering peaks')
+    parser.add_argument('--timestamp', type=str, default='')
+    parser.add_argument('--train', type=bool, default=False,help='whether train from scratch')
+    parser.add_argument('--no_label', action='store_true',help='whether the dataset has label')
+    parser.add_argument('--mode', type=int, default=1,help='mode for 10x paired data')
+    
+    args = parser.parse_args()
+    data = args.data
+    model = importlib.import_module(args.model)
+    nb_classes = args.K
+    x_dim = args.dx
+    y_dim = args.dy
+    batch_size = args.bs
+    nb_batches = args.nb_batches
+    alpha = args.alpha
+    beta = args.beta
+    ratio = args.ratio
+    low = args.low
+    timestamp = args.timestamp
+    is_train = args.train
+    has_label = not args.no_label
+
+    g_net = model.Generator(input_dim=x_dim,output_dim = y_dim,name='g_net',nb_layers=10,nb_units=512,concat_every_fcl=False)
+    h_net = model.Encoder(input_dim=y_dim,output_dim = x_dim+nb_classes,feat_dim=x_dim,name='h_net',nb_layers=10,nb_units=256)
+    dx_net = model.Discriminator(input_dim=x_dim,name='dx_net',nb_layers=2,nb_units=256)
+    dy_net = model.Discriminator(input_dim=y_dim,name='dy_net',nb_layers=2,nb_units=256)
+    pool = util.DataPool(10)
+
+    xs = util.Mixture_sampler(nb_classes=nb_classes,N=10000,dim=x_dim,sd=1)
+    if data == "PBMC10k":
+        mode = args.mode
+        ys = util.ARC_Sampler(name=data,n_components=int(y_dim/2),mode=mode)
+    else:
+        ys = util.scATAC_Sampler(data,y_dim,low,has_label)
+
+    model = scDEC(g_net, h_net, dx_net, dy_net, xs, ys, nb_classes, data, pool, batch_size, alpha, beta, is_train)
+
+    if args.train:
+        model.train(nb_batches=nb_batches)
+    else:
+        print('Attempting to Restore Model ...')
+        if timestamp == '':
+            model.load(pre_trained=True)
+            timestamp = 'pre-trained'
+        else:
+            model.load(pre_trained=False, timestamp = timestamp, batch_idx = nb_batches-1)
+        model.evaluate(timestamp,nb_batches-1)
+        
+            
+