--- a +++ b/medgan.py @@ -0,0 +1,409 @@ +import sys, time, argparse +import tensorflow as tf +import numpy as np +from sklearn.model_selection import train_test_split +from sklearn.metrics import roc_auc_score +from tensorflow.contrib.layers import l2_regularizer +from tensorflow.contrib.layers import batch_norm + +_VALIDATION_RATIO = 0.1 + + +class Medgan(object): + def __init__(self, + dataType='binary', + inputDim=615, + embeddingDim=128, + randomDim=128, + generatorDims=(128, 128), + discriminatorDims=(256, 128, 1), + compressDims=(), + decompressDims=(), + bnDecay=0.99, + l2scale=0.001): + self.inputDim = inputDim + self.embeddingDim = embeddingDim + self.generatorDims = list(generatorDims) + [embeddingDim] + self.randomDim = randomDim + self.dataType = dataType + + if dataType == 'binary': + self.aeActivation = tf.nn.tanh + else: + self.aeActivation = tf.nn.relu + + self.generatorActivation = tf.nn.relu + self.discriminatorActivation = tf.nn.relu + self.discriminatorDims = discriminatorDims + self.compressDims = list(compressDims) + [embeddingDim] + self.decompressDims = list(decompressDims) + [inputDim] + self.bnDecay = bnDecay + self.l2scale = l2scale + + def loadData(self, dataPath=''): + data = np.load(dataPath, allow_pickle=True) + + if self.dataType == 'binary': + data = np.clip(data, 0, 1) + + trainX, validX = train_test_split(data, test_size=_VALIDATION_RATIO, random_state=0) + return trainX, validX + + def buildAutoencoder(self, x_input): + decodeVariables = {} + with tf.variable_scope('autoencoder', regularizer=l2_regularizer(self.l2scale)): + tempVec = x_input + tempDim = self.inputDim + i = 0 + for compressDim in self.compressDims: + W = tf.get_variable('aee_W_'+str(i), shape=[tempDim, compressDim]) + b = tf.get_variable('aee_b_'+str(i), shape=[compressDim]) + tempVec = self.aeActivation(tf.add(tf.matmul(tempVec, W), b)) + tempDim = compressDim + i += 1 + + i = 0 + for decompressDim in self.decompressDims[:-1]: + W = tf.get_variable('aed_W_'+str(i), shape=[tempDim, decompressDim]) + b = tf.get_variable('aed_b_'+str(i), shape=[decompressDim]) + tempVec = self.aeActivation(tf.add(tf.matmul(tempVec, W), b)) + tempDim = decompressDim + decodeVariables['aed_W_'+str(i)] = W + decodeVariables['aed_b_'+str(i)] = b + i += 1 + W = tf.get_variable('aed_W_'+str(i), shape=[tempDim, self.decompressDims[-1]]) + b = tf.get_variable('aed_b_'+str(i), shape=[self.decompressDims[-1]]) + decodeVariables['aed_W_'+str(i)] = W + decodeVariables['aed_b_'+str(i)] = b + + if self.dataType == 'binary': + x_reconst = tf.nn.sigmoid(tf.add(tf.matmul(tempVec,W),b)) + loss = tf.reduce_mean(-tf.reduce_sum(x_input * tf.log(x_reconst + 1e-12) + (1. - x_input) * tf.log(1. - x_reconst + 1e-12), 1), 0) + else: + x_reconst = tf.nn.relu(tf.add(tf.matmul(tempVec,W),b)) + loss = tf.reduce_mean((x_input - x_reconst)**2) + + return loss, decodeVariables + + def buildGenerator(self, x_input, bn_train): + tempVec = x_input + tempDim = self.randomDim + with tf.variable_scope('generator', regularizer=l2_regularizer(self.l2scale)): + for i, genDim in enumerate(self.generatorDims[:-1]): + W = tf.get_variable('W_'+str(i), shape=[tempDim, genDim]) + h = tf.matmul(tempVec,W) + h2 = batch_norm(h, decay=self.bnDecay, scale=True, is_training=bn_train, updates_collections=None) + h3 = self.generatorActivation(h2) + tempVec = h3 + tempVec + tempDim = genDim + W = tf.get_variable('W'+str(i), shape=[tempDim, self.generatorDims[-1]]) + h = tf.matmul(tempVec,W) + h2 = batch_norm(h, decay=self.bnDecay, scale=True, is_training=bn_train, updates_collections=None) + + if self.dataType == 'binary': + h3 = tf.nn.tanh(h2) + else: + h3 = tf.nn.relu(h2) + + output = h3 + tempVec + return output + + def buildGeneratorTest(self, x_input, bn_train): + tempVec = x_input + tempDim = self.randomDim + with tf.variable_scope('generator', regularizer=l2_regularizer(self.l2scale)): + for i, genDim in enumerate(self.generatorDims[:-1]): + W = tf.get_variable('W_'+str(i), shape=[tempDim, genDim]) + h = tf.matmul(tempVec,W) + h2 = batch_norm(h, decay=self.bnDecay, scale=True, is_training=bn_train, updates_collections=None, trainable=False) + h3 = self.generatorActivation(h2) + tempVec = h3 + tempVec + tempDim = genDim + W = tf.get_variable('W'+str(i), shape=[tempDim, self.generatorDims[-1]]) + h = tf.matmul(tempVec,W) + h2 = batch_norm(h, decay=self.bnDecay, scale=True, is_training=bn_train, updates_collections=None, trainable=False) + + if self.dataType == 'binary': + h3 = tf.nn.tanh(h2) + else: + h3 = tf.nn.relu(h2) + + output = h3 + tempVec + return output + + def getDiscriminatorResults(self, x_input, keepRate, reuse=False): + batchSize = tf.shape(x_input)[0] + inputMean = tf.reshape(tf.tile(tf.reduce_mean(x_input,0), [batchSize]), (batchSize, self.inputDim)) + tempVec = tf.concat([x_input, inputMean], 1) + tempDim = self.inputDim * 2 + with tf.variable_scope('discriminator', reuse=reuse, regularizer=l2_regularizer(self.l2scale)): + for i, discDim in enumerate(self.discriminatorDims[:-1]): + W = tf.get_variable('W_'+str(i), shape=[tempDim, discDim]) + b = tf.get_variable('b_'+str(i), shape=[discDim]) + h = self.discriminatorActivation(tf.add(tf.matmul(tempVec,W),b)) + h = tf.nn.dropout(h, keepRate) + tempVec = h + tempDim = discDim + W = tf.get_variable('W', shape=[tempDim, 1]) + b = tf.get_variable('b', shape=[1]) + y_hat = tf.squeeze(tf.nn.sigmoid(tf.add(tf.matmul(tempVec, W), b))) + return y_hat + + def buildDiscriminator(self, x_real, x_fake, keepRate, decodeVariables, bn_train): + #Discriminate for real samples + y_hat_real = self.getDiscriminatorResults(x_real, keepRate, reuse=False) + + #Decompress, then discriminate for real samples + tempVec = x_fake + i = 0 + for _ in self.decompressDims[:-1]: + tempVec = self.aeActivation(tf.add(tf.matmul(tempVec, decodeVariables['aed_W_'+str(i)]), decodeVariables['aed_b_'+str(i)])) + i += 1 + + if self.dataType == 'binary': + x_decoded = tf.nn.sigmoid(tf.add(tf.matmul(tempVec, decodeVariables['aed_W_'+str(i)]), decodeVariables['aed_b_'+str(i)])) + else: + x_decoded = tf.nn.relu(tf.add(tf.matmul(tempVec, decodeVariables['aed_W_'+str(i)]), decodeVariables['aed_b_'+str(i)])) + + y_hat_fake = self.getDiscriminatorResults(x_decoded, keepRate, reuse=True) + + loss_d = -tf.reduce_mean(tf.log(y_hat_real + 1e-12)) - tf.reduce_mean(tf.log(1. - y_hat_fake + 1e-12)) + loss_g = -tf.reduce_mean(tf.log(y_hat_fake + 1e-12)) + + return loss_d, loss_g, y_hat_real, y_hat_fake + + def print2file(self, buf, outFile): + outfd = open(outFile, 'a') + outfd.write(buf + '\n') + outfd.close() + + def generateData(self, + nSamples=100, + modelFile='model', + batchSize=100, + outFile='out'): + x_dummy = tf.placeholder('float', [None, self.inputDim]) + _, decodeVariables = self.buildAutoencoder(x_dummy) + x_random = tf.placeholder('float', [None, self.randomDim]) + bn_train = tf.placeholder('bool') + x_emb = self.buildGeneratorTest(x_random, bn_train) + tempVec = x_emb + i = 0 + for _ in self.decompressDims[:-1]: + tempVec = self.aeActivation(tf.add(tf.matmul(tempVec, decodeVariables['aed_W_'+str(i)]), decodeVariables['aed_b_'+str(i)])) + i += 1 + + if self.dataType == 'binary': + x_reconst = tf.nn.sigmoid(tf.add(tf.matmul(tempVec, decodeVariables['aed_W_'+str(i)]), decodeVariables['aed_b_'+str(i)])) + else: + x_reconst = tf.nn.relu(tf.add(tf.matmul(tempVec, decodeVariables['aed_W_'+str(i)]), decodeVariables['aed_b_'+str(i)])) + + np.random.seed(1234) + saver = tf.train.Saver() + outputVec = [] + burn_in = 1000 + with tf.Session() as sess: + saver.restore(sess, modelFile) + print('burning in') + for i in range(burn_in): + randomX = np.random.normal(size=(batchSize, self.randomDim)) + output = sess.run(x_reconst, feed_dict={x_random:randomX, bn_train:True}) + + print('generating') + nBatches = int(np.ceil(float(nSamples)) / float(batchSize)) + for i in range(nBatches): + randomX = np.random.normal(size=(batchSize, self.randomDim)) + output = sess.run(x_reconst, feed_dict={x_random:randomX, bn_train:False}) + outputVec.extend(output) + + outputMat = np.array(outputVec) + np.save(outFile, outputMat) + + def calculateDiscAuc(self, preds_real, preds_fake): + preds = np.concatenate([preds_real, preds_fake], axis=0) + labels = np.concatenate([np.ones((len(preds_real))), np.zeros((len(preds_fake)))], axis=0) + auc = roc_auc_score(labels, preds) + return auc + + def calculateDiscAccuracy(self, preds_real, preds_fake): + total = len(preds_real) + len(preds_fake) + hit = 0 + for pred in preds_real: + if pred > 0.5: hit += 1 + for pred in preds_fake: + if pred < 0.5: hit += 1 + acc = float(hit) / float(total) + return acc + + def train(self, + dataPath='data', + modelPath='', + outPath='out', + nEpochs=500, + discriminatorTrainPeriod=2, + generatorTrainPeriod=1, + pretrainBatchSize=100, + batchSize=1000, + pretrainEpochs=100, + saveMaxKeep=0): + x_raw = tf.placeholder('float', [None, self.inputDim]) + x_random= tf.placeholder('float', [None, self.randomDim]) + keep_prob = tf.placeholder('float') + bn_train = tf.placeholder('bool') + + loss_ae, decodeVariables = self.buildAutoencoder(x_raw) + x_fake = self.buildGenerator(x_random, bn_train) + loss_d, loss_g, y_hat_real, y_hat_fake = self.buildDiscriminator(x_raw, x_fake, keep_prob, decodeVariables, bn_train) + trainX, validX = self.loadData(dataPath) + + t_vars = tf.trainable_variables() + ae_vars = [var for var in t_vars if 'autoencoder' in var.name] + d_vars = [var for var in t_vars if 'discriminator' in var.name] + g_vars = [var for var in t_vars if 'generator' in var.name] + + all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) + + optimize_ae = tf.train.AdamOptimizer().minimize(loss_ae + sum(all_regs), var_list=ae_vars) + optimize_d = tf.train.AdamOptimizer().minimize(loss_d + sum(all_regs), var_list=d_vars) + decodeVariablesValues = list(decodeVariables.values()) + optimize_g = tf.train.AdamOptimizer().minimize(loss_g + sum(all_regs), var_list=g_vars+decodeVariablesValues) + + initOp = tf.global_variables_initializer() + + nBatches = int(np.ceil(float(trainX.shape[0]) / float(batchSize))) + saver = tf.train.Saver(max_to_keep=saveMaxKeep) + logFile = outPath + '.log' + + with tf.Session() as sess: + if modelPath == '': sess.run(initOp) + else: saver.restore(sess, modelPath) + nTrainBatches = int(np.ceil(float(trainX.shape[0])) / float(pretrainBatchSize)) + nValidBatches = int(np.ceil(float(validX.shape[0])) / float(pretrainBatchSize)) + + if modelPath== '': + for epoch in range(pretrainEpochs): + idx = np.random.permutation(trainX.shape[0]) + trainLossVec = [] + for i in range(nTrainBatches): + batchX = trainX[idx[i*pretrainBatchSize:(i+1)*pretrainBatchSize]] + _, loss = sess.run([optimize_ae, loss_ae], feed_dict={x_raw:batchX}) + trainLossVec.append(loss) + idx = np.random.permutation(validX.shape[0]) + validLossVec = [] + for i in range(nValidBatches): + batchX = validX[idx[i*pretrainBatchSize:(i+1)*pretrainBatchSize]] + loss = sess.run(loss_ae, feed_dict={x_raw:batchX}) + validLossVec.append(loss) + validReverseLoss = 0. + buf = 'Pretrain_Epoch:%d, trainLoss:%f, validLoss:%f, validReverseLoss:%f' % (epoch, np.mean(trainLossVec), np.mean(validLossVec), validReverseLoss) + print(buf) + self.print2file(buf, logFile) + + idx = np.arange(trainX.shape[0]) + for epoch in range(nEpochs): + d_loss_vec= [] + g_loss_vec = [] + for i in range(nBatches): + for _ in range(discriminatorTrainPeriod): + batchIdx = np.random.choice(idx, size=batchSize, replace=False) + batchX = trainX[batchIdx] + randomX = np.random.normal(size=(batchSize, self.randomDim)) + _, discLoss = sess.run([optimize_d, loss_d], feed_dict={x_raw:batchX, x_random:randomX, keep_prob:1.0, bn_train:False}) + d_loss_vec.append(discLoss) + for _ in range(generatorTrainPeriod): + randomX = np.random.normal(size=(batchSize, self.randomDim)) + _, generatorLoss = sess.run([optimize_g, loss_g], feed_dict={x_raw:batchX, x_random:randomX, keep_prob:1.0, bn_train:True}) + g_loss_vec.append(generatorLoss) + + idx = np.arange(len(validX)) + nValidBatches = int(np.ceil(float(len(validX)) / float(batchSize))) + validAccVec = [] + validAucVec = [] + for i in range(nBatches): + batchIdx = np.random.choice(idx, size=batchSize, replace=False) + batchX = validX[batchIdx] + randomX = np.random.normal(size=(batchSize, self.randomDim)) + preds_real, preds_fake, = sess.run([y_hat_real, y_hat_fake], feed_dict={x_raw:batchX, x_random:randomX, keep_prob:1.0, bn_train:False}) + validAcc = self.calculateDiscAccuracy(preds_real, preds_fake) + validAuc = self.calculateDiscAuc(preds_real, preds_fake) + validAccVec.append(validAcc) + validAucVec.append(validAuc) + buf = 'Epoch:%d, d_loss:%f, g_loss:%f, accuracy:%f, AUC:%f' % (epoch, np.mean(d_loss_vec), np.mean(g_loss_vec), np.mean(validAccVec), np.mean(validAucVec)) + print(buf) + self.print2file(buf, logFile) + savePath = saver.save(sess, outPath, global_step=epoch) + print(savePath) + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + +def parse_arguments(parser): + parser.add_argument('--embed_size', type=int, default=128, help='The dimension size of the embedding, which will be generated by the generator. (default value: 128)') + parser.add_argument('--noise_size', type=int, default=128, help='The dimension size of the random noise, on which the generator is conditioned. (default value: 128)') + parser.add_argument('--generator_size', type=tuple, default=(128, 128), help='The dimension size of the generator. Note that another layer of size "--embed_size" is always added. (default value: (128, 128))') + parser.add_argument('--discriminator_size', type=tuple, default=(256, 128, 1), help='The dimension size of the discriminator. (default value: (256, 128, 1))') + parser.add_argument('--compressor_size', type=tuple, default=(), help='The dimension size of the encoder of the autoencoder. Note that another layer of size "--embed_size" is always added. Therefore this can be a blank tuple. (default value: ())') + parser.add_argument('--decompressor_size', type=tuple, default=(), help='The dimension size of the decoder of the autoencoder. Note that another layer, whose size is equal to the dimension of the <patient_matrix>, is always added. Therefore this can be a blank tuple. (default value: ())') + parser.add_argument('--data_type', type=str, default='binary', choices=['binary', 'count'], help='The input data type. The <patient matrix> could either contain binary values or count values. (default value: "binary")') + parser.add_argument('--batchnorm_decay', type=float, default=0.99, help='Decay value for the moving average used in Batch Normalization. (default value: 0.99)') + parser.add_argument('--L2', type=float, default=0.001, help='L2 regularization coefficient for all weights. (default value: 0.001)') + + parser.add_argument('data_file', type=str, metavar='<patient_matrix>', help='The path to the numpy matrix containing aggregated patient records.') + parser.add_argument('out_file', type=str, metavar='<out_file>', help='The path to the output models.') + parser.add_argument('--model_file', type=str, metavar='<model_file>', default='', help='The path to the model file, in case you want to continue training. (default value: '')') + parser.add_argument('--n_pretrain_epoch', type=int, default=100, help='The number of epochs to pre-train the autoencoder. (default value: 100)') + parser.add_argument('--n_epoch', type=int, default=1000, help='The number of epochs to train medGAN. (default value: 1000)') + parser.add_argument('--n_discriminator_update', type=int, default=2, help='The number of times to update the discriminator per epoch. (default value: 2)') + parser.add_argument('--n_generator_update', type=int, default=1, help='The number of times to update the generator per epoch. (default value: 1)') + parser.add_argument('--pretrain_batch_size', type=int, default=100, help='The size of a single mini-batch for pre-training the autoencoder. (default value: 100)') + parser.add_argument('--batch_size', type=int, default=1000, help='The size of a single mini-batch for training medGAN. (default value: 1000)') + parser.add_argument('--save_max_keep', type=int, default=0, help='The number of models to keep. Setting this to 0 will save models for every epoch. (default value: 0)') + parser.add_argument('--generate_data', type=str2bool, default=False, help='If True the model generates data, if False the model is trained (default value: False)') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + args = parse_arguments(parser) + + data = np.load(args.data_file, allow_pickle=True) + inputDim = data.shape[1] + + mg = Medgan(dataType=args.data_type, + inputDim=inputDim, + embeddingDim=args.embed_size, + randomDim=args.noise_size, + generatorDims=args.generator_size, + discriminatorDims=args.discriminator_size, + compressDims=args.compressor_size, + decompressDims=args.decompressor_size, + bnDecay=args.batchnorm_decay, + l2scale=args.L2) + + # True for generation, False for training + if not args.generate_data: + # Training + mg.train(dataPath=args.data_file, + modelPath=args.model_file, + outPath=args.out_file, + pretrainEpochs=args.n_pretrain_epoch, + nEpochs=args.n_epoch, + discriminatorTrainPeriod=args.n_discriminator_update, + generatorTrainPeriod=args.n_generator_update, + pretrainBatchSize=args.pretrain_batch_size, + batchSize=args.batch_size, + saveMaxKeep=args.save_max_keep) + else: + # Generate synthetic data using a trained model + # You must specify "--model_file" and "<out_file>" to generate synthetic data. + mg.generateData(nSamples=10000, + modelFile=args.model_file, + batchSize=args.batch_size, + outFile=args.out_file)