Diff of /medgan.py [000000] .. [bab239]

Switch to unified view

a b/medgan.py
1
import sys, time, argparse
2
import tensorflow as tf
3
import numpy as np
4
from sklearn.model_selection import train_test_split
5
from sklearn.metrics import roc_auc_score
6
from tensorflow.contrib.layers import l2_regularizer
7
from tensorflow.contrib.layers import batch_norm
8
9
_VALIDATION_RATIO = 0.1
10
11
12
class Medgan(object):
13
    def __init__(self,
14
                 dataType='binary',
15
                 inputDim=615,
16
                 embeddingDim=128,
17
                 randomDim=128,
18
                 generatorDims=(128, 128),
19
                 discriminatorDims=(256, 128, 1),
20
                 compressDims=(),
21
                 decompressDims=(),
22
                 bnDecay=0.99,
23
                 l2scale=0.001):
24
        self.inputDim = inputDim
25
        self.embeddingDim = embeddingDim
26
        self.generatorDims = list(generatorDims) + [embeddingDim]
27
        self.randomDim = randomDim
28
        self.dataType = dataType
29
30
        if dataType == 'binary':
31
            self.aeActivation = tf.nn.tanh
32
        else:
33
            self.aeActivation = tf.nn.relu
34
35
        self.generatorActivation = tf.nn.relu
36
        self.discriminatorActivation = tf.nn.relu
37
        self.discriminatorDims = discriminatorDims
38
        self.compressDims = list(compressDims) + [embeddingDim]
39
        self.decompressDims = list(decompressDims) + [inputDim]
40
        self.bnDecay = bnDecay
41
        self.l2scale = l2scale
42
43
    def loadData(self, dataPath=''):
44
        data = np.load(dataPath, allow_pickle=True)
45
46
        if self.dataType == 'binary':
47
            data = np.clip(data, 0, 1)
48
49
        trainX, validX = train_test_split(data, test_size=_VALIDATION_RATIO, random_state=0)
50
        return trainX, validX
51
52
    def buildAutoencoder(self, x_input):
53
        decodeVariables = {}
54
        with tf.variable_scope('autoencoder', regularizer=l2_regularizer(self.l2scale)):
55
            tempVec = x_input
56
            tempDim = self.inputDim
57
            i = 0
58
            for compressDim in self.compressDims:
59
                W = tf.get_variable('aee_W_'+str(i), shape=[tempDim, compressDim])
60
                b = tf.get_variable('aee_b_'+str(i), shape=[compressDim])
61
                tempVec = self.aeActivation(tf.add(tf.matmul(tempVec, W), b))
62
                tempDim = compressDim
63
                i += 1
64
    
65
            i = 0
66
            for decompressDim in self.decompressDims[:-1]:
67
                W = tf.get_variable('aed_W_'+str(i), shape=[tempDim, decompressDim])
68
                b = tf.get_variable('aed_b_'+str(i), shape=[decompressDim])
69
                tempVec = self.aeActivation(tf.add(tf.matmul(tempVec, W), b))
70
                tempDim = decompressDim
71
                decodeVariables['aed_W_'+str(i)] = W
72
                decodeVariables['aed_b_'+str(i)] = b
73
                i += 1
74
            W = tf.get_variable('aed_W_'+str(i), shape=[tempDim, self.decompressDims[-1]])
75
            b = tf.get_variable('aed_b_'+str(i), shape=[self.decompressDims[-1]])
76
            decodeVariables['aed_W_'+str(i)] = W
77
            decodeVariables['aed_b_'+str(i)] = b
78
79
            if self.dataType == 'binary':
80
                x_reconst = tf.nn.sigmoid(tf.add(tf.matmul(tempVec,W),b))
81
                loss = tf.reduce_mean(-tf.reduce_sum(x_input * tf.log(x_reconst + 1e-12) + (1. - x_input) * tf.log(1. - x_reconst + 1e-12), 1), 0)
82
            else:
83
                x_reconst = tf.nn.relu(tf.add(tf.matmul(tempVec,W),b))
84
                loss = tf.reduce_mean((x_input - x_reconst)**2)
85
            
86
        return loss, decodeVariables
87
88
    def buildGenerator(self, x_input, bn_train):
89
        tempVec = x_input
90
        tempDim = self.randomDim
91
        with tf.variable_scope('generator', regularizer=l2_regularizer(self.l2scale)):
92
            for i, genDim in enumerate(self.generatorDims[:-1]):
93
                W = tf.get_variable('W_'+str(i), shape=[tempDim, genDim])
94
                h = tf.matmul(tempVec,W)
95
                h2 = batch_norm(h, decay=self.bnDecay, scale=True, is_training=bn_train, updates_collections=None)
96
                h3 = self.generatorActivation(h2)
97
                tempVec = h3 + tempVec
98
                tempDim = genDim
99
            W = tf.get_variable('W'+str(i), shape=[tempDim, self.generatorDims[-1]])
100
            h = tf.matmul(tempVec,W)
101
            h2 = batch_norm(h, decay=self.bnDecay, scale=True, is_training=bn_train, updates_collections=None)
102
103
            if self.dataType == 'binary':
104
                h3 = tf.nn.tanh(h2)
105
            else:
106
                h3 = tf.nn.relu(h2)
107
108
            output = h3 + tempVec
109
        return output
110
    
111
    def buildGeneratorTest(self, x_input, bn_train):
112
        tempVec = x_input
113
        tempDim = self.randomDim
114
        with tf.variable_scope('generator', regularizer=l2_regularizer(self.l2scale)):
115
            for i, genDim in enumerate(self.generatorDims[:-1]):
116
                W = tf.get_variable('W_'+str(i), shape=[tempDim, genDim])
117
                h = tf.matmul(tempVec,W)
118
                h2 = batch_norm(h, decay=self.bnDecay, scale=True, is_training=bn_train, updates_collections=None, trainable=False)
119
                h3 = self.generatorActivation(h2)
120
                tempVec = h3 + tempVec
121
                tempDim = genDim
122
            W = tf.get_variable('W'+str(i), shape=[tempDim, self.generatorDims[-1]])
123
            h = tf.matmul(tempVec,W)
124
            h2 = batch_norm(h, decay=self.bnDecay, scale=True, is_training=bn_train, updates_collections=None, trainable=False)
125
126
            if self.dataType == 'binary':
127
                h3 = tf.nn.tanh(h2)
128
            else:
129
                h3 = tf.nn.relu(h2)
130
131
            output = h3 + tempVec
132
        return output
133
    
134
    def getDiscriminatorResults(self, x_input, keepRate, reuse=False):
135
        batchSize = tf.shape(x_input)[0]
136
        inputMean = tf.reshape(tf.tile(tf.reduce_mean(x_input,0), [batchSize]), (batchSize, self.inputDim))
137
        tempVec = tf.concat([x_input, inputMean], 1)
138
        tempDim = self.inputDim * 2
139
        with tf.variable_scope('discriminator', reuse=reuse, regularizer=l2_regularizer(self.l2scale)):
140
            for i, discDim in enumerate(self.discriminatorDims[:-1]):
141
                W = tf.get_variable('W_'+str(i), shape=[tempDim, discDim])
142
                b = tf.get_variable('b_'+str(i), shape=[discDim])
143
                h = self.discriminatorActivation(tf.add(tf.matmul(tempVec,W),b))
144
                h = tf.nn.dropout(h, keepRate)
145
                tempVec = h
146
                tempDim = discDim
147
            W = tf.get_variable('W', shape=[tempDim, 1])
148
            b = tf.get_variable('b', shape=[1])
149
            y_hat = tf.squeeze(tf.nn.sigmoid(tf.add(tf.matmul(tempVec, W), b)))
150
        return y_hat
151
    
152
    def buildDiscriminator(self, x_real, x_fake, keepRate, decodeVariables, bn_train):
153
        #Discriminate for real samples
154
        y_hat_real = self.getDiscriminatorResults(x_real, keepRate, reuse=False)
155
156
        #Decompress, then discriminate for real samples
157
        tempVec = x_fake
158
        i = 0
159
        for _ in self.decompressDims[:-1]:
160
            tempVec = self.aeActivation(tf.add(tf.matmul(tempVec, decodeVariables['aed_W_'+str(i)]), decodeVariables['aed_b_'+str(i)]))
161
            i += 1
162
163
        if self.dataType == 'binary':
164
            x_decoded = tf.nn.sigmoid(tf.add(tf.matmul(tempVec, decodeVariables['aed_W_'+str(i)]), decodeVariables['aed_b_'+str(i)]))
165
        else:
166
            x_decoded = tf.nn.relu(tf.add(tf.matmul(tempVec, decodeVariables['aed_W_'+str(i)]), decodeVariables['aed_b_'+str(i)]))
167
168
        y_hat_fake = self.getDiscriminatorResults(x_decoded, keepRate, reuse=True)
169
170
        loss_d = -tf.reduce_mean(tf.log(y_hat_real + 1e-12)) - tf.reduce_mean(tf.log(1. - y_hat_fake + 1e-12))
171
        loss_g = -tf.reduce_mean(tf.log(y_hat_fake + 1e-12))
172
173
        return loss_d, loss_g, y_hat_real, y_hat_fake
174
175
    def print2file(self, buf, outFile):
176
        outfd = open(outFile, 'a')
177
        outfd.write(buf + '\n')
178
        outfd.close()
179
    
180
    def generateData(self,
181
                     nSamples=100,
182
                     modelFile='model',
183
                     batchSize=100,
184
                     outFile='out'):
185
        x_dummy = tf.placeholder('float', [None, self.inputDim])
186
        _, decodeVariables = self.buildAutoencoder(x_dummy)
187
        x_random = tf.placeholder('float', [None, self.randomDim])
188
        bn_train = tf.placeholder('bool')
189
        x_emb = self.buildGeneratorTest(x_random, bn_train)
190
        tempVec = x_emb
191
        i = 0
192
        for _ in self.decompressDims[:-1]:
193
            tempVec = self.aeActivation(tf.add(tf.matmul(tempVec, decodeVariables['aed_W_'+str(i)]), decodeVariables['aed_b_'+str(i)]))
194
            i += 1
195
196
        if self.dataType == 'binary':
197
            x_reconst = tf.nn.sigmoid(tf.add(tf.matmul(tempVec, decodeVariables['aed_W_'+str(i)]), decodeVariables['aed_b_'+str(i)]))
198
        else:
199
            x_reconst = tf.nn.relu(tf.add(tf.matmul(tempVec, decodeVariables['aed_W_'+str(i)]), decodeVariables['aed_b_'+str(i)]))
200
201
        np.random.seed(1234)
202
        saver = tf.train.Saver()
203
        outputVec = []
204
        burn_in = 1000
205
        with tf.Session() as sess:
206
            saver.restore(sess, modelFile)
207
            print('burning in')
208
            for i in range(burn_in):
209
                randomX = np.random.normal(size=(batchSize, self.randomDim))
210
                output = sess.run(x_reconst, feed_dict={x_random:randomX, bn_train:True})
211
212
            print('generating')
213
            nBatches = int(np.ceil(float(nSamples)) / float(batchSize))
214
            for i in range(nBatches):
215
                randomX = np.random.normal(size=(batchSize, self.randomDim))
216
                output = sess.run(x_reconst, feed_dict={x_random:randomX, bn_train:False})
217
                outputVec.extend(output)
218
219
        outputMat = np.array(outputVec)
220
        np.save(outFile, outputMat)
221
    
222
    def calculateDiscAuc(self, preds_real, preds_fake):
223
        preds = np.concatenate([preds_real, preds_fake], axis=0)
224
        labels = np.concatenate([np.ones((len(preds_real))), np.zeros((len(preds_fake)))], axis=0)
225
        auc = roc_auc_score(labels, preds)
226
        return auc
227
    
228
    def calculateDiscAccuracy(self, preds_real, preds_fake):
229
        total = len(preds_real) + len(preds_fake)
230
        hit = 0
231
        for pred in preds_real: 
232
            if pred > 0.5: hit += 1
233
        for pred in preds_fake: 
234
            if pred < 0.5: hit += 1
235
        acc = float(hit) / float(total)
236
        return acc
237
238
    def train(self,
239
              dataPath='data',
240
              modelPath='',
241
              outPath='out',
242
              nEpochs=500,
243
              discriminatorTrainPeriod=2,
244
              generatorTrainPeriod=1,
245
              pretrainBatchSize=100,
246
              batchSize=1000,
247
              pretrainEpochs=100,
248
              saveMaxKeep=0):
249
        x_raw = tf.placeholder('float', [None, self.inputDim])
250
        x_random= tf.placeholder('float', [None, self.randomDim])
251
        keep_prob = tf.placeholder('float')
252
        bn_train = tf.placeholder('bool')
253
254
        loss_ae, decodeVariables = self.buildAutoencoder(x_raw)
255
        x_fake = self.buildGenerator(x_random, bn_train)
256
        loss_d, loss_g, y_hat_real, y_hat_fake = self.buildDiscriminator(x_raw, x_fake, keep_prob, decodeVariables, bn_train)
257
        trainX, validX = self.loadData(dataPath)
258
259
        t_vars = tf.trainable_variables()
260
        ae_vars = [var for var in t_vars if 'autoencoder' in var.name]
261
        d_vars = [var for var in t_vars if 'discriminator' in var.name]
262
        g_vars = [var for var in t_vars if 'generator' in var.name]
263
264
        all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
265
266
        optimize_ae = tf.train.AdamOptimizer().minimize(loss_ae + sum(all_regs), var_list=ae_vars)
267
        optimize_d = tf.train.AdamOptimizer().minimize(loss_d + sum(all_regs), var_list=d_vars)
268
        decodeVariablesValues = list(decodeVariables.values())
269
        optimize_g = tf.train.AdamOptimizer().minimize(loss_g + sum(all_regs), var_list=g_vars+decodeVariablesValues)
270
271
        initOp = tf.global_variables_initializer()
272
273
        nBatches = int(np.ceil(float(trainX.shape[0]) / float(batchSize)))
274
        saver = tf.train.Saver(max_to_keep=saveMaxKeep)
275
        logFile = outPath + '.log'
276
277
        with tf.Session() as sess:
278
            if modelPath == '': sess.run(initOp)
279
            else: saver.restore(sess, modelPath)
280
            nTrainBatches = int(np.ceil(float(trainX.shape[0])) / float(pretrainBatchSize))
281
            nValidBatches = int(np.ceil(float(validX.shape[0])) / float(pretrainBatchSize))
282
283
            if modelPath== '':
284
                for epoch in range(pretrainEpochs):
285
                    idx = np.random.permutation(trainX.shape[0])
286
                    trainLossVec = []
287
                    for i in range(nTrainBatches):
288
                        batchX = trainX[idx[i*pretrainBatchSize:(i+1)*pretrainBatchSize]]
289
                        _, loss = sess.run([optimize_ae, loss_ae], feed_dict={x_raw:batchX})
290
                        trainLossVec.append(loss)
291
                    idx = np.random.permutation(validX.shape[0])
292
                    validLossVec = []
293
                    for i in range(nValidBatches):
294
                        batchX = validX[idx[i*pretrainBatchSize:(i+1)*pretrainBatchSize]]
295
                        loss = sess.run(loss_ae, feed_dict={x_raw:batchX})
296
                        validLossVec.append(loss)
297
                    validReverseLoss = 0.
298
                    buf = 'Pretrain_Epoch:%d, trainLoss:%f, validLoss:%f, validReverseLoss:%f' % (epoch, np.mean(trainLossVec), np.mean(validLossVec), validReverseLoss)
299
                    print(buf)
300
                    self.print2file(buf, logFile)
301
302
            idx = np.arange(trainX.shape[0])
303
            for epoch in range(nEpochs):
304
                d_loss_vec= []
305
                g_loss_vec = []
306
                for i in range(nBatches):
307
                    for _ in range(discriminatorTrainPeriod):
308
                        batchIdx = np.random.choice(idx, size=batchSize, replace=False)
309
                        batchX = trainX[batchIdx]
310
                        randomX = np.random.normal(size=(batchSize, self.randomDim))
311
                        _, discLoss = sess.run([optimize_d, loss_d], feed_dict={x_raw:batchX, x_random:randomX, keep_prob:1.0, bn_train:False})
312
                        d_loss_vec.append(discLoss)
313
                    for _ in range(generatorTrainPeriod):
314
                        randomX = np.random.normal(size=(batchSize, self.randomDim))
315
                        _, generatorLoss = sess.run([optimize_g, loss_g], feed_dict={x_raw:batchX, x_random:randomX, keep_prob:1.0, bn_train:True})
316
                        g_loss_vec.append(generatorLoss)
317
318
                idx = np.arange(len(validX))
319
                nValidBatches = int(np.ceil(float(len(validX)) / float(batchSize)))
320
                validAccVec = []
321
                validAucVec = []
322
                for i in range(nBatches):
323
                    batchIdx = np.random.choice(idx, size=batchSize, replace=False)
324
                    batchX = validX[batchIdx]
325
                    randomX = np.random.normal(size=(batchSize, self.randomDim))
326
                    preds_real, preds_fake, = sess.run([y_hat_real, y_hat_fake], feed_dict={x_raw:batchX, x_random:randomX, keep_prob:1.0, bn_train:False})
327
                    validAcc = self.calculateDiscAccuracy(preds_real, preds_fake)
328
                    validAuc = self.calculateDiscAuc(preds_real, preds_fake)
329
                    validAccVec.append(validAcc)
330
                    validAucVec.append(validAuc)
331
                buf = 'Epoch:%d, d_loss:%f, g_loss:%f, accuracy:%f, AUC:%f' % (epoch, np.mean(d_loss_vec), np.mean(g_loss_vec), np.mean(validAccVec), np.mean(validAucVec))
332
                print(buf)
333
                self.print2file(buf, logFile)
334
                savePath = saver.save(sess, outPath, global_step=epoch)
335
        print(savePath)
336
337
def str2bool(v):
338
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
339
        return True
340
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
341
        return False
342
    else:
343
        raise argparse.ArgumentTypeError('Boolean value expected.')
344
345
def parse_arguments(parser):
346
    parser.add_argument('--embed_size', type=int, default=128, help='The dimension size of the embedding, which will be generated by the generator. (default value: 128)')
347
    parser.add_argument('--noise_size', type=int, default=128, help='The dimension size of the random noise, on which the generator is conditioned. (default value: 128)')
348
    parser.add_argument('--generator_size', type=tuple, default=(128, 128), help='The dimension size of the generator. Note that another layer of size "--embed_size" is always added. (default value: (128, 128))')
349
    parser.add_argument('--discriminator_size', type=tuple, default=(256, 128, 1), help='The dimension size of the discriminator. (default value: (256, 128, 1))')
350
    parser.add_argument('--compressor_size', type=tuple, default=(), help='The dimension size of the encoder of the autoencoder. Note that another layer of size "--embed_size" is always added. Therefore this can be a blank tuple. (default value: ())')
351
    parser.add_argument('--decompressor_size', type=tuple, default=(), help='The dimension size of the decoder of the autoencoder. Note that another layer, whose size is equal to the dimension of the <patient_matrix>, is always added. Therefore this can be a blank tuple. (default value: ())')
352
    parser.add_argument('--data_type', type=str, default='binary', choices=['binary', 'count'], help='The input data type. The <patient matrix> could either contain binary values or count values. (default value: "binary")')
353
    parser.add_argument('--batchnorm_decay', type=float, default=0.99, help='Decay value for the moving average used in Batch Normalization. (default value: 0.99)')
354
    parser.add_argument('--L2', type=float, default=0.001, help='L2 regularization coefficient for all weights. (default value: 0.001)')
355
356
    parser.add_argument('data_file', type=str, metavar='<patient_matrix>', help='The path to the numpy matrix containing aggregated patient records.')
357
    parser.add_argument('out_file', type=str, metavar='<out_file>', help='The path to the output models.')
358
    parser.add_argument('--model_file', type=str, metavar='<model_file>', default='', help='The path to the model file, in case you want to continue training. (default value: '')')
359
    parser.add_argument('--n_pretrain_epoch', type=int, default=100, help='The number of epochs to pre-train the autoencoder. (default value: 100)')
360
    parser.add_argument('--n_epoch', type=int, default=1000, help='The number of epochs to train medGAN. (default value: 1000)')
361
    parser.add_argument('--n_discriminator_update', type=int, default=2, help='The number of times to update the discriminator per epoch. (default value: 2)')
362
    parser.add_argument('--n_generator_update', type=int, default=1, help='The number of times to update the generator per epoch. (default value: 1)')
363
    parser.add_argument('--pretrain_batch_size', type=int, default=100, help='The size of a single mini-batch for pre-training the autoencoder. (default value: 100)')
364
    parser.add_argument('--batch_size', type=int, default=1000, help='The size of a single mini-batch for training medGAN. (default value: 1000)')
365
    parser.add_argument('--save_max_keep', type=int, default=0, help='The number of models to keep. Setting this to 0 will save models for every epoch. (default value: 0)')
366
    parser.add_argument('--generate_data', type=str2bool, default=False, help='If True the model generates data, if False the model is trained (default value: False)')
367
    args = parser.parse_args()
368
    return args
369
370
371
if __name__ == '__main__':
372
373
    parser = argparse.ArgumentParser()
374
    args = parse_arguments(parser)
375
376
    data = np.load(args.data_file, allow_pickle=True)
377
    inputDim = data.shape[1]
378
379
    mg = Medgan(dataType=args.data_type,
380
                inputDim=inputDim,
381
                embeddingDim=args.embed_size,
382
                randomDim=args.noise_size,
383
                generatorDims=args.generator_size,
384
                discriminatorDims=args.discriminator_size,
385
                compressDims=args.compressor_size,
386
                decompressDims=args.decompressor_size,
387
                bnDecay=args.batchnorm_decay,
388
                l2scale=args.L2)
389
390
    # True for generation, False for training
391
    if not args.generate_data:
392
    # Training
393
        mg.train(dataPath=args.data_file,
394
                 modelPath=args.model_file,
395
                 outPath=args.out_file,
396
                 pretrainEpochs=args.n_pretrain_epoch,
397
                 nEpochs=args.n_epoch,
398
                 discriminatorTrainPeriod=args.n_discriminator_update,
399
                 generatorTrainPeriod=args.n_generator_update,
400
                 pretrainBatchSize=args.pretrain_batch_size,
401
                 batchSize=args.batch_size,
402
                 saveMaxKeep=args.save_max_keep)
403
    else:
404
    # Generate synthetic data using a trained model
405
    # You must specify "--model_file" and "<out_file>" to generate synthetic data.
406
        mg.generateData(nSamples=10000,
407
                        modelFile=args.model_file,
408
                        batchSize=args.batch_size,
409
                        outFile=args.out_file)