Switch to unified view

a b/deepdta-toy/run_experiments.py
1
from __future__ import print_function
2
#import matplotlib
3
#matplotlib.use('Agg')
4
import numpy as np
5
import tensorflow as tf
6
import random as rn
7
8
### We modified Pahikkala et al. (2014) source code for cross-val process ###
9
10
import os
11
os.environ['PYTHONHASHSEED'] = '0'
12
13
np.random.seed(1)
14
rn.seed(1)
15
16
session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
17
import keras
18
from keras import backend as K
19
tf.set_random_seed(0)
20
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
21
K.set_session(sess)
22
23
24
from datahelper import *
25
#import logging
26
from itertools import product
27
from arguments import argparser, logging
28
29
import keras
30
from keras.models import Model
31
from keras.preprocessing import sequence
32
from keras.models import Sequential, load_model
33
from keras.layers import Dense, Dropout, Activation, Merge
34
from keras.layers import Embedding
35
from keras.layers import Conv1D, GlobalMaxPooling1D, MaxPooling1D
36
from keras.layers.normalization import BatchNormalization
37
from keras.layers import Conv2D, GRU
38
from keras.layers import Input, Embedding, LSTM, Dense, TimeDistributed, Masking, RepeatVector, merge, Flatten
39
from keras.models import Model
40
from keras.utils import plot_model
41
from keras.layers import Bidirectional
42
from keras.callbacks import ModelCheckpoint, EarlyStopping
43
from keras import optimizers, layers
44
45
46
import sys, pickle, os
47
import math, json, time
48
import decimal
49
import matplotlib.pyplot as plt
50
import matplotlib.mlab as mlab
51
from random import shuffle
52
from copy import deepcopy
53
from sklearn import preprocessing
54
from emetrics import get_aupr, get_cindex, get_rm2
55
import pandas as pd
56
from testdatahelper import *
57
58
59
60
TABSY = "\t"
61
figdir = "figures/"
62
63
def build_combined_onehot(FLAGS, NUM_FILTERS, FILTER_LENGTH1, FILTER_LENGTH2):
64
    XDinput = Input(shape=(FLAGS.max_smi_len, FLAGS.charsmiset_size))
65
    XTinput = Input(shape=(FLAGS.max_seq_len, FLAGS.charseqset_size))
66
67
68
    encode_smiles= Conv1D(filters=NUM_FILTERS, kernel_size=FILTER_LENGTH1,  activation='relu', padding='valid',  strides=1)(XDinput)
69
    encode_smiles = Conv1D(filters=NUM_FILTERS*2, kernel_size=FILTER_LENGTH1,  activation='relu', padding='valid',  strides=1)(encode_smiles)
70
    encode_smiles = Conv1D(filters=NUM_FILTERS*3, kernel_size=FILTER_LENGTH1,  activation='relu', padding='valid',  strides=1)(encode_smiles)
71
    encode_smiles = GlobalMaxPooling1D()(encode_smiles) #pool_size=pool_length[i]
72
73
74
    encode_protein = Conv1D(filters=NUM_FILTERS, kernel_size=FILTER_LENGTH2,  activation='relu', padding='valid',  strides=1)(XTinput)
75
    encode_protein = Conv1D(filters=NUM_FILTERS*2, kernel_size=FILTER_LENGTH2,  activation='relu', padding='valid',  strides=1)(encode_protein)
76
    encode_protein = Conv1D(filters=NUM_FILTERS*3, kernel_size=FILTER_LENGTH2,  activation='relu', padding='valid',  strides=1)(encode_protein)
77
    encode_protein = GlobalMaxPooling1D()(encode_protein)
78
79
80
81
    encode_interaction = keras.layers.concatenate([encode_smiles, encode_protein])
82
    #encode_interaction = keras.layers.concatenate([encode_smiles, encode_protein], axis=-1) #merge.Add()([encode_smiles, encode_protein])
83
84
    # Fully connected 
85
    FC1 = Dense(1024, activation='relu')(encode_interaction)
86
    FC2 = Dropout(0.1)(FC1)
87
    FC2 = Dense(1024, activation='relu')(FC2)
88
    FC2 = Dropout(0.1)(FC2)
89
    FC2 = Dense(512, activation='relu')(FC2)
90
91
92
    predictions = Dense(1, kernel_initializer='normal')(FC2) 
93
94
    interactionModel = Model(inputs=[XDinput, XTinput], outputs=[predictions])
95
    interactionModel.compile(optimizer='adam', loss='mean_squared_error', metrics=[cindex_score]) #, metrics=['cindex_score']
96
    
97
98
    print(interactionModel.summary())
99
    plot_model(interactionModel, to_file='figures/build_combined_onehot.png')
100
101
    return interactionModel
102
103
104
105
106
107
def build_combined_categorical(FLAGS, NUM_FILTERS, FILTER_LENGTH1, FILTER_LENGTH2):
108
   
109
    XDinput = Input(shape=(FLAGS.max_smi_len,), dtype='int32') ### Buralar flagdan gelmeliii
110
    XTinput = Input(shape=(FLAGS.max_seq_len,), dtype='int32')
111
112
    ### SMI_EMB_DINMS  FLAGS GELMELII 
113
    encode_smiles = Embedding(input_dim=FLAGS.charsmiset_size+1, output_dim=128, input_length=FLAGS.max_smi_len)(XDinput) 
114
    encode_smiles = Conv1D(filters=NUM_FILTERS, kernel_size=FILTER_LENGTH1,  activation='relu', padding='valid',  strides=1)(encode_smiles)
115
    encode_smiles = Conv1D(filters=NUM_FILTERS*2, kernel_size=FILTER_LENGTH1,  activation='relu', padding='valid',  strides=1)(encode_smiles)
116
    encode_smiles = Conv1D(filters=NUM_FILTERS*3, kernel_size=FILTER_LENGTH1,  activation='relu', padding='valid',  strides=1)(encode_smiles)
117
    encode_smiles = GlobalMaxPooling1D()(encode_smiles)
118
119
120
    encode_protein = Embedding(input_dim=FLAGS.charseqset_size+1, output_dim=128, input_length=FLAGS.max_seq_len)(XTinput)
121
    encode_protein = Conv1D(filters=NUM_FILTERS, kernel_size=FILTER_LENGTH2,  activation='relu', padding='valid',  strides=1)(encode_protein)
122
    encode_protein = Conv1D(filters=NUM_FILTERS*2, kernel_size=FILTER_LENGTH2,  activation='relu', padding='valid',  strides=1)(encode_protein)
123
    encode_protein = Conv1D(filters=NUM_FILTERS*3, kernel_size=FILTER_LENGTH2,  activation='relu', padding='valid',  strides=1)(encode_protein)
124
    encode_protein = GlobalMaxPooling1D()(encode_protein)
125
126
127
    encode_interaction = keras.layers.concatenate([encode_smiles, encode_protein], axis=-1) #merge.Add()([encode_smiles, encode_protein])
128
129
    # Fully connected 
130
    FC1 = Dense(1024, activation='relu')(encode_interaction)
131
    FC2 = Dropout(0.1)(FC1)
132
    FC2 = Dense(1024, activation='relu')(FC2)
133
    FC2 = Dropout(0.1)(FC2)
134
    FC2 = Dense(512, activation='relu')(FC2)
135
136
137
    # And add a logistic regression on top
138
    predictions = Dense(1, kernel_initializer='normal')(FC2) #OR no activation, rght now it's between 0-1, do I want this??? activation='sigmoid'
139
140
    interactionModel = Model(inputs=[XDinput, XTinput], outputs=[predictions])
141
142
    interactionModel.compile(optimizer='adam', loss='mean_squared_error', metrics=[cindex_score]) #, metrics=['cindex_score']
143
    print(interactionModel.summary())
144
    plot_model(interactionModel, to_file='figures/build_combined_categorical.png')
145
146
    return interactionModel
147
148
149
150
def build_single_drug(FLAGS, NUM_FILTERS, FILTER_LENGTH1, FILTER_LENGTH2):
151
   
152
    interactionModel = Sequential()
153
    XTmodel = Sequential()
154
    XTmodel.add(Activation('linear', input_shape=(FLAGS.target_count,)))
155
156
157
    encode_smiles = Sequential()
158
    encode_smiles.add(Embedding(input_dim=FLAGS.charsmiset_size+1, output_dim=128, input_length=FLAGS.max_smi_len)) 
159
    encode_smiles.add(Conv1D(filters=NUM_FILTERS, kernel_size=FILTER_LENGTH1,  activation='relu', padding='valid',  strides=1)) #input_shape=(MAX_SMI_LEN, SMI_EMBEDDING_DIMS)
160
    encode_smiles.add(Conv1D(filters=NUM_FILTERS*2, kernel_size=FILTER_LENGTH1,  activation='relu', padding='valid',  strides=1))
161
    encode_smiles.add(Conv1D(filters=NUM_FILTERS*3, kernel_size=FILTER_LENGTH1,  activation='relu', padding='valid',  strides=1))
162
    encode_smiles.add(GlobalMaxPooling1D())
163
164
165
    interactionModel.add(Merge([encode_smiles, XTmodel], mode='concat', concat_axis=1))
166
    #interactionModel.add(layers.merge.Concatenate([XDmodel, XTmodel]))
167
168
    # Fully connected 
169
    interactionModel.add(Dense(1024, activation='relu')) #1024
170
    interactionModel.add(Dropout(0.1))
171
    interactionModel.add(Dense(1024, activation='relu')) #1024
172
    interactionModel.add(Dropout(0.1))
173
    interactionModel.add(Dense(512, activation='relu')) 
174
175
176
    interactionModel.add(Dense(1, kernel_initializer='normal'))
177
    interactionModel.compile(optimizer='adam', loss='mean_squared_error', metrics=[cindex_score])
178
179
    print(interactionModel.summary())
180
    plot_model(interactionModel, to_file='figures/build_single_drug.png')
181
182
    return interactionModel
183
184
185
def build_single_prot(FLAGS, NUM_FILTERS, FILTER_LENGTH1, FILTER_LENGTH2):
186
   
187
    interactionModel = Sequential()
188
    XDmodel = Sequential()
189
    XDmodel.add(Activation('linear', input_shape=(FLAGS.drugcount,)))
190
191
192
    XTmodel1 = Sequential()
193
    XTmodel1.add(Embedding(input_dim=FLAGS.charseqset_size+1, output_dim=128,  input_length=FLAGS.max_seq_len))
194
    XTmodel1.add(Conv1D(filters=NUM_FILTERS, kernel_size=FILTER_LENGTH2,  activation='relu', padding='valid',  strides=1)) #input_shape=(MAX_SEQ_LEN, SEQ_EMBEDDING_DIMS)
195
    XTmodel1.add(Conv1D(filters=NUM_FILTERS*2, kernel_size=FILTER_LENGTH2,  activation='relu', padding='valid',  strides=1))
196
    XTmodel1.add(Conv1D(filters=NUM_FILTERS*3, kernel_size=FILTER_LENGTH2,  activation='relu', padding='valid',  strides=1))
197
    XTmodel1.add(GlobalMaxPooling1D())
198
199
200
    interactionModel.add(Merge([XDmodel, XTmodel1], mode='concat', concat_axis=1))
201
202
    # Fully connected 
203
    interactionModel.add(Dense(1024, activation='relu'))
204
    interactionModel.add(Dropout(0.1))
205
    interactionModel.add(Dense(1024, activation='relu'))
206
    interactionModel.add(Dropout(0.1))
207
    interactionModel.add(Dense(512, activation='relu'))
208
209
    interactionModel.add(Dense(1, kernel_initializer='normal'))
210
    interactionModel.compile(optimizer='adam', loss='mean_squared_error', metrics=[cindex_score])
211
212
    print(interactionModel.summary())
213
    plot_model(interactionModel, to_file='figures/build_single_protein.png')
214
215
    return interactionModel
216
217
def build_baseline(FLAGS, NUM_FILTERS, FILTER_LENGTH1, FILTER_LENGTH2):
218
    interactionModel = Sequential()
219
220
    XDmodel = Sequential()
221
    XDmodel.add(Dense(1, activation='linear', input_shape=(FLAGS.drug_count, )))
222
223
    XTmodel = Sequential()
224
    XTmodel.add(Dense(1, activation='linear', input_shape=(FLAGS.target_count,)))
225
226
227
    interactionModel.add(Merge([XDmodel, XTmodel], mode='concat', concat_axis=1))
228
229
    # Fully connected 
230
    interactionModel.add(Dense(1024, activation='relu'))
231
    interactionModel.add(Dropout(0.1))
232
    interactionModel.add(Dense(1024, activation='relu'))
233
    interactionModel.add(Dropout(0.1))
234
    interactionModel.add(Dense(512, activation='relu'))
235
236
    interactionModel.add(Dense(1, kernel_initializer='normal'))
237
    interactionModel.compile(optimizer='adam', loss='mean_squared_error', metrics=[cindex_score])
238
239
    print(interactionModel.summary())
240
    plot_model(interactionModel, to_file='figures/build_baseline.png')
241
242
    return interactionModel
243
244
def nfold_1_2_3_setting_sample(tr_XD, tr_XT,  tr_Y, te_XD, te_XT, te_Y,  measure, runmethod,  FLAGS, dataset):
245
246
    bestparamlist = []
247
    test_set, outer_train_sets = dataset.read_sets(FLAGS) 
248
    
249
    ### MODIFIED FOR SINGLE TRAIN AND TEST #####
250
    train_set = outer_train_sets
251
    #train_set = [item for sublist in outer_train_sets for item in sublist]
252
253
    bestparamind, best_param_list, bestperf, all_predictions, all_losses = general_nfold_cv(tr_XD, tr_XT,  tr_Y, te_XD, te_XT, te_Y,  
254
                                                                                                measure, runmethod, FLAGS, train_set, test_set)
255
    
256
    testperf = all_predictions[bestparamind]##pointer pos 
257
258
    logging("---FINAL RESULTS-----", FLAGS)
259
    logging("best param index = %s" % bestparamind, FLAGS)
260
261
262
    testperfs = []
263
    testloss= []
264
265
    avgperf = 0.
266
267
268
    foldperf = all_predictions[bestparamind]
269
    foldloss = all_losses[bestparamind]
270
    testperfs.append(foldperf)
271
    testloss.append(foldloss)
272
    avgperf += foldperf
273
274
    avgperf = avgperf / 1
275
    avgloss = np.mean(testloss)
276
    teststd = np.std(testperfs)
277
278
    logging("Test Performance CI", FLAGS)
279
    logging(testperfs, FLAGS)
280
    logging("Test Performance MSE", FLAGS)
281
    logging(testloss, FLAGS)
282
283
    return avgperf, avgloss, teststd
284
285
286
287
288
def general_nfold_cv(tr_XD, tr_XT,  tr_Y, te_XD, te_XT, te_Y,  prfmeasure, runmethod, FLAGS, labeled_sets, val_sets): ## BURAYA DA FLAGS LAZIM????
289
    
290
    paramset1 = FLAGS.num_windows                              #[32]#[32,  512] #[32, 128]  # filter numbers
291
    paramset2 = FLAGS.smi_window_lengths                               #[4, 8]#[4,  32] #[4,  8] #filter length smi
292
    paramset3 = FLAGS.seq_window_lengths                               #[8, 12]#[64,  256] #[64, 192]#[8, 192, 384]
293
    epoch = FLAGS.num_epoch                                 #100
294
    batchsz = FLAGS.batch_size                             #256
295
296
    logging("---Parameter Search-----", FLAGS)
297
298
299
    ### MODIFIED FOR SINGLE TRAIN
300
301
302
303
    h = len(paramset1) * len(paramset2) * len(paramset3)
304
305
    all_predictions = [0 for y in range(h)] 
306
    all_losses = [0 for y in range(h)] 
307
308
309
    valinds = val_sets
310
    labeledinds = labeled_sets
311
312
    tr_label_row_inds, tr_label_col_inds = np.where(np.isnan(tr_Y)==False)  #basically finds the point address of affinity [x,y]
313
    te_label_row_inds, te_label_col_inds = np.where(np.isnan(te_Y)==False)  #basically finds the point address of affinity [x,y]
314
315
    Y_train = np.mat(np.copy(tr_Y))
316
317
    params = {}
318
    XD_train = tr_XD
319
    XT_train = tr_XT
320
    trrows = tr_label_row_inds[labeledinds]
321
    trcols = tr_label_col_inds[labeledinds]
322
323
        #print("trrows", str(trrows), str(len(trrows)))
324
        #print("trcols", str(trcols), str(len(trcols)))
325
326
    XD_train = tr_XD[trrows]
327
    XT_train = tr_XT[trcols]
328
329
330
    train_drugs, train_prots,  train_Y = prepare_interaction_pairs(tr_XD, tr_XT, tr_Y, trrows, trcols)
331
        
332
    terows = te_label_row_inds[valinds]
333
    tecols = te_label_col_inds[valinds]
334
        #print("terows", str(terows), str(len(terows)))
335
        #print("tecols", str(tecols), str(len(tecols)))
336
337
    val_drugs, val_prots,  val_Y = prepare_interaction_pairs(te_XD, te_XT,  te_Y, terows, tecols)
338
339
340
    pointer = 0
341
       
342
    for param1ind in range(len(paramset1)): #hidden neurons
343
        param1value = paramset1[param1ind]
344
        for param2ind in range(len(paramset2)): #learning rate
345
            param2value = paramset2[param2ind]
346
347
            for param3ind in range(len(paramset3)):
348
                param3value = paramset3[param3ind]
349
350
                gridmodel = runmethod(FLAGS, param1value, param2value, param3value)
351
352
                # Set callback functions to early stop training and save the best model so far
353
                callbacks = [EarlyStopping(monitor='val_loss', patience=15)]
354
355
                gridres = gridmodel.fit(([np.array(train_drugs),np.array(train_prots) ]), np.array(train_Y), batch_size=batchsz, epochs=epoch,  
356
                    shuffle=False ) 
357
                #validation_data=( ([np.array(val_drugs), np.array(val_prots) ]), np.array(val_Y)), 
358
359
                predicted_labels = gridmodel.predict([np.array(val_drugs), np.array(val_prots) ])
360
                json.dump(predicted_labels.tolist(), open("predicted_labels_"+str(pointer)+ ".txt", "w"))
361
                loss, rperf2 = gridmodel.evaluate(([np.array(val_drugs),np.array(val_prots) ]), np.array(val_Y), verbose=0)
362
                rperf = prfmeasure(val_Y, predicted_labels)
363
                #rperf = rperf[0]
364
365
366
                logging("P1 = %d,  P2 = %d, P3 = %d,  CI-i = %f, CI-ii = %f, MSE = %f" % 
367
                (param1ind, param2ind, param3ind,  rperf, rperf2, loss), FLAGS)
368
369
                #plotLoss(gridres, param1ind, param2ind, param3ind, "1")
370
371
                all_predictions[pointer] =rperf #TODO FOR EACH VAL SET allpredictions[pointer][foldind]
372
                all_losses[pointer]= loss
373
374
                pointer +=1
375
376
    bestperf = -float('Inf')
377
    bestpointer = None
378
379
380
    best_param_list = []
381
    ##Take average according to folds, then chooose best params
382
    pointer = 0
383
    for param1ind in range(len(paramset1)):
384
            for param2ind in range(len(paramset2)):
385
                for param3ind in range(len(paramset3)):
386
                
387
                    avgperf = 0.
388
389
                    foldperf = all_predictions[pointer]
390
                    avgperf += foldperf
391
                    #avgperf /= len(val_sets)
392
                    #print(epoch, batchsz, avgperf)
393
                    if avgperf > bestperf:
394
                        bestperf = avgperf
395
                        bestpointer = pointer
396
                        best_param_list = [param1ind, param2ind, param3ind]
397
398
                    pointer +=1
399
        
400
    return  bestpointer, best_param_list, bestperf, all_predictions, all_losses
401
402
403
404
def cindex_score(y_true, y_pred):
405
406
    g = tf.subtract(tf.expand_dims(y_pred, -1), y_pred)
407
    g = tf.cast(g == 0.0, tf.float32) * 0.5 + tf.cast(g > 0.0, tf.float32)
408
409
    f = tf.subtract(tf.expand_dims(y_true, -1), y_true) > 0.0
410
    f = tf.matrix_band_part(tf.cast(f, tf.float32), -1, 0)
411
412
    g = tf.reduce_sum(tf.multiply(g, f))
413
    f = tf.reduce_sum(f)
414
415
    return tf.where(tf.equal(g, 0), 0.0, g/f) #select
416
417
418
   
419
def plotLoss(history, batchind, epochind, param3ind, foldind):
420
421
    figname = "b"+str(batchind) + "_e" + str(epochind) + "_" + str(param3ind) + "_"  + str( foldind) + "_" + str(time.time()) 
422
    plt.figure()
423
    plt.plot(history.history['loss'])
424
    plt.plot(history.history['val_loss'])
425
    plt.title('model loss')
426
    plt.ylabel('loss')
427
    plt.xlabel('epoch')
428
    #plt.legend(['trainloss', 'valloss', 'cindex', 'valcindex'], loc='upper left')
429
    plt.legend(['trainloss', 'valloss'], loc='upper left')
430
    plt.savefig("figures/"+figname +".png" , dpi=None, facecolor='w', edgecolor='w', orientation='portrait', 
431
                    papertype=None, format=None,transparent=False, bbox_inches=None, pad_inches=0.1,frameon=None)
432
    plt.close()
433
434
435
    ## PLOT CINDEX
436
    plt.figure()
437
    plt.title('model concordance index')
438
    plt.ylabel('cindex')
439
    plt.xlabel('epoch')
440
    plt.plot(history.history['cindex_score'])
441
    plt.plot(history.history['val_cindex_score'])
442
    plt.legend(['traincindex', 'valcindex'], loc='upper left')
443
    plt.savefig("figures/"+figname + "_acc.png" , dpi=None, facecolor='w', edgecolor='w', orientation='portrait', 
444
                            papertype=None, format=None,transparent=False, bbox_inches=None, pad_inches=0.1,frameon=None)
445
    plt.close()
446
447
448
449
def prepare_interaction_pairs(XD, XT,  Y, rows, cols):
450
    drugs = []
451
    targets = []
452
    targetscls = []
453
    affinity=[] 
454
        
455
    for pair_ind in range(len(rows)):
456
457
            drug = XD[rows[pair_ind]]
458
            drugs.append(drug)
459
460
            target=XT[cols[pair_ind]]
461
            targets.append(target)
462
463
            affinity.append(Y[rows[pair_ind],cols[pair_ind]])
464
465
    drug_data = np.stack(drugs)
466
    target_data = np.stack(targets)
467
468
    return drug_data,target_data,  affinity
469
470
471
       
472
def experiment(FLAGS, perfmeasure, deepmethod, foldcount=6): #5-fold cross validation + test
473
474
    #Input
475
    #XD: [drugs, features] sized array (features may also be similarities with other drugs
476
    #XT: [targets, features] sized array (features may also be similarities with other targets
477
    #Y: interaction values, can be real values or binary (+1, -1), insert value float("nan") for unknown entries
478
    #perfmeasure: function that takes as input a list of correct and predicted outputs, and returns performance
479
    #higher values should be better, so if using error measures use instead e.g. the inverse -error(Y, P)
480
    #foldcount: number of cross-validation folds for settings 1-3, setting 4 always runs 3x3 cross-validation
481
482
483
    dataset = DataSet( fpath = FLAGS.train_path,
484
                       fpath_test = FLAGS.test_path,
485
                      setting_no = FLAGS.problem_type, 
486
                      seqlen = FLAGS.max_seq_len,
487
                      smilen = FLAGS.max_smi_len,
488
                      need_shuffle = False )
489
    # set character set size
490
    FLAGS.charseqset_size = dataset.charseqset_size 
491
    FLAGS.charsmiset_size = dataset.charsmiset_size 
492
493
    #XD, XT, Y = dataset.parse_data(fpath = FLAGS.dataset_path)
494
    tr_XD, tr_XT, tr_Y, te_XD, te_XT, te_Y = dataset.parse_train_test_data(FLAGS)
495
496
    tr_XD = np.asarray(tr_XD)
497
    tr_XT = np.asarray(tr_XT)
498
    tr_Y = np.asarray(tr_Y)
499
500
    te_XD = np.asarray(te_XD)
501
    te_XT = np.asarray(te_XT)
502
    te_Y = np.asarray(te_Y)
503
504
    tr_drugcount = tr_XD.shape[0]
505
    print("train drugs: ", tr_drugcount)
506
    tr_targetcount = tr_XT.shape[0]
507
    print("train targets: ", tr_targetcount)
508
509
    te_drugcount = te_XD.shape[0]
510
    print("test drugs: ", te_drugcount)
511
    te_targetcount = te_XT.shape[0]
512
    print("test targets: ", te_targetcount)
513
514
    FLAGS.drug_count = tr_drugcount
515
    FLAGS.target_count = tr_targetcount
516
517
518
519
    if not os.path.exists(figdir):
520
        os.makedirs(figdir)
521
522
    print(FLAGS.log_dir)
523
    S1_avgperf, S1_avgloss, S1_teststd = nfold_1_2_3_setting_sample(tr_XD, tr_XT,  tr_Y, te_XD, te_XT, te_Y,
524
                                                                     perfmeasure, deepmethod, FLAGS, dataset)
525
526
    logging("Setting " + str(FLAGS.problem_type), FLAGS)
527
    logging("avg_perf = %.5f,  avg_mse = %.5f, std = %.5f" % 
528
            (S1_avgperf, S1_avgloss, S1_teststd), FLAGS)
529
530
531
532
533
def run_regression( FLAGS ): 
534
535
    perfmeasure = get_cindex
536
    deepmethod = build_combined_categorical
537
538
    experiment(FLAGS, perfmeasure, deepmethod)
539
540
541
542
543
if __name__=="__main__":
544
545
    FLAGS = argparser()
546
    FLAGS.log_dir = FLAGS.log_dir + str(time.time()) + "/"
547
548
    if not os.path.exists(FLAGS.log_dir):
549
        os.makedirs(FLAGS.log_dir)
550
551
    prepare_new_data(FLAGS.test_path, test=True)
552
    #prepare_new_data(FLAGS.train_path, test=False) #Uncomment this if you also have a new training data
553
554
    logging(str(FLAGS), FLAGS)
555
    run_regression( FLAGS )