import sys, time, argparse
import tensorflow as tf
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from tensorflow.contrib.layers import l2_regularizer
from tensorflow.contrib.layers import batch_norm
_VALIDATION_RATIO = 0.1
class Medgan(object):
def __init__(self,
dataType='binary',
inputDim=615,
embeddingDim=128,
randomDim=128,
generatorDims=(128, 128),
discriminatorDims=(256, 128, 1),
compressDims=(),
decompressDims=(),
bnDecay=0.99,
l2scale=0.001):
self.inputDim = inputDim
self.embeddingDim = embeddingDim
self.generatorDims = list(generatorDims) + [embeddingDim]
self.randomDim = randomDim
self.dataType = dataType
if dataType == 'binary':
self.aeActivation = tf.nn.tanh
else:
self.aeActivation = tf.nn.relu
self.generatorActivation = tf.nn.relu
self.discriminatorActivation = tf.nn.relu
self.discriminatorDims = discriminatorDims
self.compressDims = list(compressDims) + [embeddingDim]
self.decompressDims = list(decompressDims) + [inputDim]
self.bnDecay = bnDecay
self.l2scale = l2scale
def loadData(self, dataPath=''):
data = np.load(dataPath, allow_pickle=True)
if self.dataType == 'binary':
data = np.clip(data, 0, 1)
trainX, validX = train_test_split(data, test_size=_VALIDATION_RATIO, random_state=0)
return trainX, validX
def buildAutoencoder(self, x_input):
decodeVariables = {}
with tf.variable_scope('autoencoder', regularizer=l2_regularizer(self.l2scale)):
tempVec = x_input
tempDim = self.inputDim
i = 0
for compressDim in self.compressDims:
W = tf.get_variable('aee_W_'+str(i), shape=[tempDim, compressDim])
b = tf.get_variable('aee_b_'+str(i), shape=[compressDim])
tempVec = self.aeActivation(tf.add(tf.matmul(tempVec, W), b))
tempDim = compressDim
i += 1
i = 0
for decompressDim in self.decompressDims[:-1]:
W = tf.get_variable('aed_W_'+str(i), shape=[tempDim, decompressDim])
b = tf.get_variable('aed_b_'+str(i), shape=[decompressDim])
tempVec = self.aeActivation(tf.add(tf.matmul(tempVec, W), b))
tempDim = decompressDim
decodeVariables['aed_W_'+str(i)] = W
decodeVariables['aed_b_'+str(i)] = b
i += 1
W = tf.get_variable('aed_W_'+str(i), shape=[tempDim, self.decompressDims[-1]])
b = tf.get_variable('aed_b_'+str(i), shape=[self.decompressDims[-1]])
decodeVariables['aed_W_'+str(i)] = W
decodeVariables['aed_b_'+str(i)] = b
if self.dataType == 'binary':
x_reconst = tf.nn.sigmoid(tf.add(tf.matmul(tempVec,W),b))
loss = tf.reduce_mean(-tf.reduce_sum(x_input * tf.log(x_reconst + 1e-12) + (1. - x_input) * tf.log(1. - x_reconst + 1e-12), 1), 0)
else:
x_reconst = tf.nn.relu(tf.add(tf.matmul(tempVec,W),b))
loss = tf.reduce_mean((x_input - x_reconst)**2)
return loss, decodeVariables
def buildGenerator(self, x_input, bn_train):
tempVec = x_input
tempDim = self.randomDim
with tf.variable_scope('generator', regularizer=l2_regularizer(self.l2scale)):
for i, genDim in enumerate(self.generatorDims[:-1]):
W = tf.get_variable('W_'+str(i), shape=[tempDim, genDim])
h = tf.matmul(tempVec,W)
h2 = batch_norm(h, decay=self.bnDecay, scale=True, is_training=bn_train, updates_collections=None)
h3 = self.generatorActivation(h2)
tempVec = h3 + tempVec
tempDim = genDim
W = tf.get_variable('W'+str(i), shape=[tempDim, self.generatorDims[-1]])
h = tf.matmul(tempVec,W)
h2 = batch_norm(h, decay=self.bnDecay, scale=True, is_training=bn_train, updates_collections=None)
if self.dataType == 'binary':
h3 = tf.nn.tanh(h2)
else:
h3 = tf.nn.relu(h2)
output = h3 + tempVec
return output
def buildGeneratorTest(self, x_input, bn_train):
tempVec = x_input
tempDim = self.randomDim
with tf.variable_scope('generator', regularizer=l2_regularizer(self.l2scale)):
for i, genDim in enumerate(self.generatorDims[:-1]):
W = tf.get_variable('W_'+str(i), shape=[tempDim, genDim])
h = tf.matmul(tempVec,W)
h2 = batch_norm(h, decay=self.bnDecay, scale=True, is_training=bn_train, updates_collections=None, trainable=False)
h3 = self.generatorActivation(h2)
tempVec = h3 + tempVec
tempDim = genDim
W = tf.get_variable('W'+str(i), shape=[tempDim, self.generatorDims[-1]])
h = tf.matmul(tempVec,W)
h2 = batch_norm(h, decay=self.bnDecay, scale=True, is_training=bn_train, updates_collections=None, trainable=False)
if self.dataType == 'binary':
h3 = tf.nn.tanh(h2)
else:
h3 = tf.nn.relu(h2)
output = h3 + tempVec
return output
def getDiscriminatorResults(self, x_input, keepRate, reuse=False):
batchSize = tf.shape(x_input)[0]
inputMean = tf.reshape(tf.tile(tf.reduce_mean(x_input,0), [batchSize]), (batchSize, self.inputDim))
tempVec = tf.concat([x_input, inputMean], 1)
tempDim = self.inputDim * 2
with tf.variable_scope('discriminator', reuse=reuse, regularizer=l2_regularizer(self.l2scale)):
for i, discDim in enumerate(self.discriminatorDims[:-1]):
W = tf.get_variable('W_'+str(i), shape=[tempDim, discDim])
b = tf.get_variable('b_'+str(i), shape=[discDim])
h = self.discriminatorActivation(tf.add(tf.matmul(tempVec,W),b))
h = tf.nn.dropout(h, keepRate)
tempVec = h
tempDim = discDim
W = tf.get_variable('W', shape=[tempDim, 1])
b = tf.get_variable('b', shape=[1])
y_hat = tf.squeeze(tf.nn.sigmoid(tf.add(tf.matmul(tempVec, W), b)))
return y_hat
def buildDiscriminator(self, x_real, x_fake, keepRate, decodeVariables, bn_train):
#Discriminate for real samples
y_hat_real = self.getDiscriminatorResults(x_real, keepRate, reuse=False)
#Decompress, then discriminate for real samples
tempVec = x_fake
i = 0
for _ in self.decompressDims[:-1]:
tempVec = self.aeActivation(tf.add(tf.matmul(tempVec, decodeVariables['aed_W_'+str(i)]), decodeVariables['aed_b_'+str(i)]))
i += 1
if self.dataType == 'binary':
x_decoded = tf.nn.sigmoid(tf.add(tf.matmul(tempVec, decodeVariables['aed_W_'+str(i)]), decodeVariables['aed_b_'+str(i)]))
else:
x_decoded = tf.nn.relu(tf.add(tf.matmul(tempVec, decodeVariables['aed_W_'+str(i)]), decodeVariables['aed_b_'+str(i)]))
y_hat_fake = self.getDiscriminatorResults(x_decoded, keepRate, reuse=True)
loss_d = -tf.reduce_mean(tf.log(y_hat_real + 1e-12)) - tf.reduce_mean(tf.log(1. - y_hat_fake + 1e-12))
loss_g = -tf.reduce_mean(tf.log(y_hat_fake + 1e-12))
return loss_d, loss_g, y_hat_real, y_hat_fake
def print2file(self, buf, outFile):
outfd = open(outFile, 'a')
outfd.write(buf + '\n')
outfd.close()
def generateData(self,
nSamples=100,
modelFile='model',
batchSize=100,
outFile='out'):
x_dummy = tf.placeholder('float', [None, self.inputDim])
_, decodeVariables = self.buildAutoencoder(x_dummy)
x_random = tf.placeholder('float', [None, self.randomDim])
bn_train = tf.placeholder('bool')
x_emb = self.buildGeneratorTest(x_random, bn_train)
tempVec = x_emb
i = 0
for _ in self.decompressDims[:-1]:
tempVec = self.aeActivation(tf.add(tf.matmul(tempVec, decodeVariables['aed_W_'+str(i)]), decodeVariables['aed_b_'+str(i)]))
i += 1
if self.dataType == 'binary':
x_reconst = tf.nn.sigmoid(tf.add(tf.matmul(tempVec, decodeVariables['aed_W_'+str(i)]), decodeVariables['aed_b_'+str(i)]))
else:
x_reconst = tf.nn.relu(tf.add(tf.matmul(tempVec, decodeVariables['aed_W_'+str(i)]), decodeVariables['aed_b_'+str(i)]))
np.random.seed(1234)
saver = tf.train.Saver()
outputVec = []
burn_in = 1000
with tf.Session() as sess:
saver.restore(sess, modelFile)
print('burning in')
for i in range(burn_in):
randomX = np.random.normal(size=(batchSize, self.randomDim))
output = sess.run(x_reconst, feed_dict={x_random:randomX, bn_train:True})
print('generating')
nBatches = int(np.ceil(float(nSamples)) / float(batchSize))
for i in range(nBatches):
randomX = np.random.normal(size=(batchSize, self.randomDim))
output = sess.run(x_reconst, feed_dict={x_random:randomX, bn_train:False})
outputVec.extend(output)
outputMat = np.array(outputVec)
np.save(outFile, outputMat)
def calculateDiscAuc(self, preds_real, preds_fake):
preds = np.concatenate([preds_real, preds_fake], axis=0)
labels = np.concatenate([np.ones((len(preds_real))), np.zeros((len(preds_fake)))], axis=0)
auc = roc_auc_score(labels, preds)
return auc
def calculateDiscAccuracy(self, preds_real, preds_fake):
total = len(preds_real) + len(preds_fake)
hit = 0
for pred in preds_real:
if pred > 0.5: hit += 1
for pred in preds_fake:
if pred < 0.5: hit += 1
acc = float(hit) / float(total)
return acc
def train(self,
dataPath='data',
modelPath='',
outPath='out',
nEpochs=500,
discriminatorTrainPeriod=2,
generatorTrainPeriod=1,
pretrainBatchSize=100,
batchSize=1000,
pretrainEpochs=100,
saveMaxKeep=0):
x_raw = tf.placeholder('float', [None, self.inputDim])
x_random= tf.placeholder('float', [None, self.randomDim])
keep_prob = tf.placeholder('float')
bn_train = tf.placeholder('bool')
loss_ae, decodeVariables = self.buildAutoencoder(x_raw)
x_fake = self.buildGenerator(x_random, bn_train)
loss_d, loss_g, y_hat_real, y_hat_fake = self.buildDiscriminator(x_raw, x_fake, keep_prob, decodeVariables, bn_train)
trainX, validX = self.loadData(dataPath)
t_vars = tf.trainable_variables()
ae_vars = [var for var in t_vars if 'autoencoder' in var.name]
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]
all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
optimize_ae = tf.train.AdamOptimizer().minimize(loss_ae + sum(all_regs), var_list=ae_vars)
optimize_d = tf.train.AdamOptimizer().minimize(loss_d + sum(all_regs), var_list=d_vars)
decodeVariablesValues = list(decodeVariables.values())
optimize_g = tf.train.AdamOptimizer().minimize(loss_g + sum(all_regs), var_list=g_vars+decodeVariablesValues)
initOp = tf.global_variables_initializer()
nBatches = int(np.ceil(float(trainX.shape[0]) / float(batchSize)))
saver = tf.train.Saver(max_to_keep=saveMaxKeep)
logFile = outPath + '.log'
with tf.Session() as sess:
if modelPath == '': sess.run(initOp)
else: saver.restore(sess, modelPath)
nTrainBatches = int(np.ceil(float(trainX.shape[0])) / float(pretrainBatchSize))
nValidBatches = int(np.ceil(float(validX.shape[0])) / float(pretrainBatchSize))
if modelPath== '':
for epoch in range(pretrainEpochs):
idx = np.random.permutation(trainX.shape[0])
trainLossVec = []
for i in range(nTrainBatches):
batchX = trainX[idx[i*pretrainBatchSize:(i+1)*pretrainBatchSize]]
_, loss = sess.run([optimize_ae, loss_ae], feed_dict={x_raw:batchX})
trainLossVec.append(loss)
idx = np.random.permutation(validX.shape[0])
validLossVec = []
for i in range(nValidBatches):
batchX = validX[idx[i*pretrainBatchSize:(i+1)*pretrainBatchSize]]
loss = sess.run(loss_ae, feed_dict={x_raw:batchX})
validLossVec.append(loss)
validReverseLoss = 0.
buf = 'Pretrain_Epoch:%d, trainLoss:%f, validLoss:%f, validReverseLoss:%f' % (epoch, np.mean(trainLossVec), np.mean(validLossVec), validReverseLoss)
print(buf)
self.print2file(buf, logFile)
idx = np.arange(trainX.shape[0])
for epoch in range(nEpochs):
d_loss_vec= []
g_loss_vec = []
for i in range(nBatches):
for _ in range(discriminatorTrainPeriod):
batchIdx = np.random.choice(idx, size=batchSize, replace=False)
batchX = trainX[batchIdx]
randomX = np.random.normal(size=(batchSize, self.randomDim))
_, discLoss = sess.run([optimize_d, loss_d], feed_dict={x_raw:batchX, x_random:randomX, keep_prob:1.0, bn_train:False})
d_loss_vec.append(discLoss)
for _ in range(generatorTrainPeriod):
randomX = np.random.normal(size=(batchSize, self.randomDim))
_, generatorLoss = sess.run([optimize_g, loss_g], feed_dict={x_raw:batchX, x_random:randomX, keep_prob:1.0, bn_train:True})
g_loss_vec.append(generatorLoss)
idx = np.arange(len(validX))
nValidBatches = int(np.ceil(float(len(validX)) / float(batchSize)))
validAccVec = []
validAucVec = []
for i in range(nBatches):
batchIdx = np.random.choice(idx, size=batchSize, replace=False)
batchX = validX[batchIdx]
randomX = np.random.normal(size=(batchSize, self.randomDim))
preds_real, preds_fake, = sess.run([y_hat_real, y_hat_fake], feed_dict={x_raw:batchX, x_random:randomX, keep_prob:1.0, bn_train:False})
validAcc = self.calculateDiscAccuracy(preds_real, preds_fake)
validAuc = self.calculateDiscAuc(preds_real, preds_fake)
validAccVec.append(validAcc)
validAucVec.append(validAuc)
buf = 'Epoch:%d, d_loss:%f, g_loss:%f, accuracy:%f, AUC:%f' % (epoch, np.mean(d_loss_vec), np.mean(g_loss_vec), np.mean(validAccVec), np.mean(validAucVec))
print(buf)
self.print2file(buf, logFile)
savePath = saver.save(sess, outPath, global_step=epoch)
print(savePath)
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def parse_arguments(parser):
parser.add_argument('--embed_size', type=int, default=128, help='The dimension size of the embedding, which will be generated by the generator. (default value: 128)')
parser.add_argument('--noise_size', type=int, default=128, help='The dimension size of the random noise, on which the generator is conditioned. (default value: 128)')
parser.add_argument('--generator_size', type=tuple, default=(128, 128), help='The dimension size of the generator. Note that another layer of size "--embed_size" is always added. (default value: (128, 128))')
parser.add_argument('--discriminator_size', type=tuple, default=(256, 128, 1), help='The dimension size of the discriminator. (default value: (256, 128, 1))')
parser.add_argument('--compressor_size', type=tuple, default=(), help='The dimension size of the encoder of the autoencoder. Note that another layer of size "--embed_size" is always added. Therefore this can be a blank tuple. (default value: ())')
parser.add_argument('--decompressor_size', type=tuple, default=(), help='The dimension size of the decoder of the autoencoder. Note that another layer, whose size is equal to the dimension of the <patient_matrix>, is always added. Therefore this can be a blank tuple. (default value: ())')
parser.add_argument('--data_type', type=str, default='binary', choices=['binary', 'count'], help='The input data type. The <patient matrix> could either contain binary values or count values. (default value: "binary")')
parser.add_argument('--batchnorm_decay', type=float, default=0.99, help='Decay value for the moving average used in Batch Normalization. (default value: 0.99)')
parser.add_argument('--L2', type=float, default=0.001, help='L2 regularization coefficient for all weights. (default value: 0.001)')
parser.add_argument('data_file', type=str, metavar='<patient_matrix>', help='The path to the numpy matrix containing aggregated patient records.')
parser.add_argument('out_file', type=str, metavar='<out_file>', help='The path to the output models.')
parser.add_argument('--model_file', type=str, metavar='<model_file>', default='', help='The path to the model file, in case you want to continue training. (default value: '')')
parser.add_argument('--n_pretrain_epoch', type=int, default=100, help='The number of epochs to pre-train the autoencoder. (default value: 100)')
parser.add_argument('--n_epoch', type=int, default=1000, help='The number of epochs to train medGAN. (default value: 1000)')
parser.add_argument('--n_discriminator_update', type=int, default=2, help='The number of times to update the discriminator per epoch. (default value: 2)')
parser.add_argument('--n_generator_update', type=int, default=1, help='The number of times to update the generator per epoch. (default value: 1)')
parser.add_argument('--pretrain_batch_size', type=int, default=100, help='The size of a single mini-batch for pre-training the autoencoder. (default value: 100)')
parser.add_argument('--batch_size', type=int, default=1000, help='The size of a single mini-batch for training medGAN. (default value: 1000)')
parser.add_argument('--save_max_keep', type=int, default=0, help='The number of models to keep. Setting this to 0 will save models for every epoch. (default value: 0)')
parser.add_argument('--generate_data', type=str2bool, default=False, help='If True the model generates data, if False the model is trained (default value: False)')
args = parser.parse_args()
return args
if __name__ == '__main__':
parser = argparse.ArgumentParser()
args = parse_arguments(parser)
data = np.load(args.data_file, allow_pickle=True)
inputDim = data.shape[1]
mg = Medgan(dataType=args.data_type,
inputDim=inputDim,
embeddingDim=args.embed_size,
randomDim=args.noise_size,
generatorDims=args.generator_size,
discriminatorDims=args.discriminator_size,
compressDims=args.compressor_size,
decompressDims=args.decompressor_size,
bnDecay=args.batchnorm_decay,
l2scale=args.L2)
# True for generation, False for training
if not args.generate_data:
# Training
mg.train(dataPath=args.data_file,
modelPath=args.model_file,
outPath=args.out_file,
pretrainEpochs=args.n_pretrain_epoch,
nEpochs=args.n_epoch,
discriminatorTrainPeriod=args.n_discriminator_update,
generatorTrainPeriod=args.n_generator_update,
pretrainBatchSize=args.pretrain_batch_size,
batchSize=args.batch_size,
saveMaxKeep=args.save_max_keep)
else:
# Generate synthetic data using a trained model
# You must specify "--model_file" and "<out_file>" to generate synthetic data.
mg.generateData(nSamples=10000,
modelFile=args.model_file,
batchSize=args.batch_size,
outFile=args.out_file)