Diff of /source/run_experiments.py [000000] .. [8af014]

Switch to unified view

a b/source/run_experiments.py
1
from __future__ import print_function
2
#import matplotlib
3
#matplotlib.use('Agg')
4
import numpy as np
5
import tensorflow as tf
6
import random as rn
7
8
### We modified Pahikkala et al. (2014) source code for cross-val process ###
9
10
import os
11
os.environ['PYTHONHASHSEED'] = '0'
12
13
np.random.seed(1)
14
rn.seed(1)
15
16
session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
17
import keras
18
from keras import backend as K
19
tf.set_random_seed(0)
20
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
21
K.set_session(sess)
22
23
24
from datahelper import *
25
#import logging
26
from itertools import product
27
from arguments import argparser, logging
28
29
import keras
30
from keras.models import Model
31
from keras.preprocessing import sequence
32
from keras.models import Sequential, load_model
33
from keras.layers import Dense, Dropout, Activation
34
from keras.layers import Embedding
35
from keras.layers import Conv1D, GlobalMaxPooling1D, MaxPooling1D
36
from keras.layers.normalization import BatchNormalization
37
from keras.layers import Conv2D, GRU
38
from keras.layers import Input, Embedding, LSTM, Dense, TimeDistributed, Masking, RepeatVector, merge, Flatten
39
from keras.models import Model
40
from keras.utils import plot_model
41
from keras.layers import Bidirectional
42
from keras.callbacks import ModelCheckpoint, EarlyStopping
43
from keras import optimizers, layers
44
45
46
import sys, pickle, os
47
import math, json, time
48
import decimal
49
import matplotlib.pyplot as plt
50
import matplotlib.mlab as mlab
51
from random import shuffle
52
from copy import deepcopy
53
from sklearn import preprocessing
54
from emetrics import get_aupr, get_cindex, get_rm2
55
56
57
58
TABSY = "\t"
59
figdir = "figures/"
60
61
def build_combined_onehot(FLAGS, NUM_FILTERS, FILTER_LENGTH1, FILTER_LENGTH2):
62
    XDinput = Input(shape=(FLAGS.max_smi_len, FLAGS.charsmiset_size))
63
    XTinput = Input(shape=(FLAGS.max_seq_len, FLAGS.charseqset_size))
64
65
66
    encode_smiles= Conv1D(filters=NUM_FILTERS, kernel_size=FILTER_LENGTH1,  activation='relu', padding='valid',  strides=1)(XDinput)
67
    encode_smiles = Conv1D(filters=NUM_FILTERS*2, kernel_size=FILTER_LENGTH1,  activation='relu', padding='valid',  strides=1)(encode_smiles)
68
    encode_smiles = Conv1D(filters=NUM_FILTERS*3, kernel_size=FILTER_LENGTH1,  activation='relu', padding='valid',  strides=1)(encode_smiles)
69
    encode_smiles = GlobalMaxPooling1D()(encode_smiles) #pool_size=pool_length[i]
70
71
72
    encode_protein = Conv1D(filters=NUM_FILTERS, kernel_size=FILTER_LENGTH2,  activation='relu', padding='valid',  strides=1)(XTinput)
73
    encode_protein = Conv1D(filters=NUM_FILTERS*2, kernel_size=FILTER_LENGTH2,  activation='relu', padding='valid',  strides=1)(encode_protein)
74
    encode_protein = Conv1D(filters=NUM_FILTERS*3, kernel_size=FILTER_LENGTH2,  activation='relu', padding='valid',  strides=1)(encode_protein)
75
    encode_protein = GlobalMaxPooling1D()(encode_protein)
76
77
78
79
    encode_interaction = keras.layers.concatenate([encode_smiles, encode_protein])
80
    #encode_interaction = keras.layers.concatenate([encode_smiles, encode_protein], axis=-1) #merge.Add()([encode_smiles, encode_protein])
81
82
    # Fully connected 
83
    FC1 = Dense(1024, activation='relu')(encode_interaction)
84
    FC2 = Dropout(0.1)(FC1)
85
    FC2 = Dense(1024, activation='relu')(FC2)
86
    FC2 = Dropout(0.1)(FC2)
87
    FC2 = Dense(512, activation='relu')(FC2)
88
89
90
    predictions = Dense(1, kernel_initializer='normal')(FC2) 
91
92
    interactionModel = Model(inputs=[XDinput, XTinput], outputs=[predictions])
93
    interactionModel.compile(optimizer='adam', loss='mean_squared_error', metrics=[cindex_score]) #, metrics=['cindex_score']
94
    
95
96
    print(interactionModel.summary())
97
    plot_model(interactionModel, to_file='figures/build_combined_onehot.png')
98
99
    return interactionModel
100
101
102
103
104
105
def build_combined_categorical(FLAGS, NUM_FILTERS, FILTER_LENGTH1, FILTER_LENGTH2):
106
   
107
    XDinput = Input(shape=(FLAGS.max_smi_len,), dtype='int32') ### Buralar flagdan gelmeliii
108
    XTinput = Input(shape=(FLAGS.max_seq_len,), dtype='int32')
109
110
    ### SMI_EMB_DINMS  FLAGS GELMELII 
111
    encode_smiles = Embedding(input_dim=FLAGS.charsmiset_size+1, output_dim=128, input_length=FLAGS.max_smi_len)(XDinput) 
112
    encode_smiles = Conv1D(filters=NUM_FILTERS, kernel_size=FILTER_LENGTH1,  activation='relu', padding='valid',  strides=1)(encode_smiles)
113
    encode_smiles = Conv1D(filters=NUM_FILTERS*2, kernel_size=FILTER_LENGTH1,  activation='relu', padding='valid',  strides=1)(encode_smiles)
114
    encode_smiles = Conv1D(filters=NUM_FILTERS*3, kernel_size=FILTER_LENGTH1,  activation='relu', padding='valid',  strides=1)(encode_smiles)
115
    encode_smiles = GlobalMaxPooling1D()(encode_smiles)
116
117
118
    encode_protein = Embedding(input_dim=FLAGS.charseqset_size+1, output_dim=128, input_length=FLAGS.max_seq_len)(XTinput)
119
    encode_protein = Conv1D(filters=NUM_FILTERS, kernel_size=FILTER_LENGTH2,  activation='relu', padding='valid',  strides=1)(encode_protein)
120
    encode_protein = Conv1D(filters=NUM_FILTERS*2, kernel_size=FILTER_LENGTH2,  activation='relu', padding='valid',  strides=1)(encode_protein)
121
    encode_protein = Conv1D(filters=NUM_FILTERS*3, kernel_size=FILTER_LENGTH2,  activation='relu', padding='valid',  strides=1)(encode_protein)
122
    encode_protein = GlobalMaxPooling1D()(encode_protein)
123
124
125
    encode_interaction = keras.layers.concatenate([encode_smiles, encode_protein], axis=-1) #merge.Add()([encode_smiles, encode_protein])
126
127
    # Fully connected 
128
    FC1 = Dense(1024, activation='relu')(encode_interaction)
129
    FC2 = Dropout(0.1)(FC1)
130
    FC2 = Dense(1024, activation='relu')(FC2)
131
    FC2 = Dropout(0.1)(FC2)
132
    FC2 = Dense(512, activation='relu')(FC2)
133
134
135
    # And add a logistic regression on top
136
    predictions = Dense(1, kernel_initializer='normal')(FC2) #OR no activation, rght now it's between 0-1, do I want this??? activation='sigmoid'
137
138
    interactionModel = Model(inputs=[XDinput, XTinput], outputs=[predictions])
139
140
    interactionModel.compile(optimizer='adam', loss='mean_squared_error', metrics=[cindex_score]) #, metrics=['cindex_score']
141
    print(interactionModel.summary())
142
    plot_model(interactionModel, to_file='figures/build_combined_categorical.png')
143
144
    return interactionModel
145
146
147
148
def build_single_drug(FLAGS, NUM_FILTERS, FILTER_LENGTH1, FILTER_LENGTH2):
149
   
150
    interactionModel = Sequential()
151
    XTmodel = Sequential()
152
    XTmodel.add(Activation('linear', input_shape=(FLAGS.target_count,)))
153
154
155
    encode_smiles = Sequential()
156
    encode_smiles.add(Embedding(input_dim=FLAGS.charsmiset_size+1, output_dim=128, input_length=FLAGS.max_smi_len)) 
157
    encode_smiles.add(Conv1D(filters=NUM_FILTERS, kernel_size=FILTER_LENGTH1,  activation='relu', padding='valid',  strides=1)) #input_shape=(MAX_SMI_LEN, SMI_EMBEDDING_DIMS)
158
    encode_smiles.add(Conv1D(filters=NUM_FILTERS*2, kernel_size=FILTER_LENGTH1,  activation='relu', padding='valid',  strides=1))
159
    encode_smiles.add(Conv1D(filters=NUM_FILTERS*3, kernel_size=FILTER_LENGTH1,  activation='relu', padding='valid',  strides=1))
160
    encode_smiles.add(GlobalMaxPooling1D())
161
162
163
    interactionModel.add(Merge([encode_smiles, XTmodel], mode='concat', concat_axis=1))
164
    #interactionModel.add(layers.merge.Concatenate([XDmodel, XTmodel]))
165
166
    # Fully connected 
167
    interactionModel.add(Dense(1024, activation='relu')) #1024
168
    interactionModel.add(Dropout(0.1))
169
    interactionModel.add(Dense(1024, activation='relu')) #1024
170
    interactionModel.add(Dropout(0.1))
171
    interactionModel.add(Dense(512, activation='relu')) 
172
173
174
    interactionModel.add(Dense(1, kernel_initializer='normal'))
175
    interactionModel.compile(optimizer='adam', loss='mean_squared_error', metrics=[cindex_score])
176
177
    print(interactionModel.summary())
178
    plot_model(interactionModel, to_file='figures/build_single_drug.png')
179
180
    return interactionModel
181
182
183
def build_single_prot(FLAGS, NUM_FILTERS, FILTER_LENGTH1, FILTER_LENGTH2):
184
   
185
    interactionModel = Sequential()
186
    XDmodel = Sequential()
187
    XDmodel.add(Activation('linear', input_shape=(FLAGS.drugcount,)))
188
189
190
    XTmodel1 = Sequential()
191
    XTmodel1.add(Embedding(input_dim=FLAGS.charseqset_size+1, output_dim=128,  input_length=FLAGS.max_seq_len))
192
    XTmodel1.add(Conv1D(filters=NUM_FILTERS, kernel_size=FILTER_LENGTH2,  activation='relu', padding='valid',  strides=1)) #input_shape=(MAX_SEQ_LEN, SEQ_EMBEDDING_DIMS)
193
    XTmodel1.add(Conv1D(filters=NUM_FILTERS*2, kernel_size=FILTER_LENGTH2,  activation='relu', padding='valid',  strides=1))
194
    XTmodel1.add(Conv1D(filters=NUM_FILTERS*3, kernel_size=FILTER_LENGTH2,  activation='relu', padding='valid',  strides=1))
195
    XTmodel1.add(GlobalMaxPooling1D())
196
197
198
    interactionModel.add(Merge([XDmodel, XTmodel1], mode='concat', concat_axis=1))
199
200
    # Fully connected 
201
    interactionModel.add(Dense(1024, activation='relu'))
202
    interactionModel.add(Dropout(0.1))
203
    interactionModel.add(Dense(1024, activation='relu'))
204
    interactionModel.add(Dropout(0.1))
205
    interactionModel.add(Dense(512, activation='relu'))
206
207
    interactionModel.add(Dense(1, kernel_initializer='normal'))
208
    interactionModel.compile(optimizer='adam', loss='mean_squared_error', metrics=[cindex_score])
209
210
    print(interactionModel.summary())
211
    plot_model(interactionModel, to_file='figures/build_single_protein.png')
212
213
    return interactionModel
214
215
def build_baseline(FLAGS, NUM_FILTERS, FILTER_LENGTH1, FILTER_LENGTH2):
216
    interactionModel = Sequential()
217
218
    XDmodel = Sequential()
219
    XDmodel.add(Dense(1, activation='linear', input_shape=(FLAGS.drug_count, )))
220
221
    XTmodel = Sequential()
222
    XTmodel.add(Dense(1, activation='linear', input_shape=(FLAGS.target_count,)))
223
224
225
    interactionModel.add(Merge([XDmodel, XTmodel], mode='concat', concat_axis=1))
226
227
    # Fully connected 
228
    interactionModel.add(Dense(1024, activation='relu'))
229
    interactionModel.add(Dropout(0.1))
230
    interactionModel.add(Dense(1024, activation='relu'))
231
    interactionModel.add(Dropout(0.1))
232
    interactionModel.add(Dense(512, activation='relu'))
233
234
    interactionModel.add(Dense(1, kernel_initializer='normal'))
235
    interactionModel.compile(optimizer='adam', loss='mean_squared_error', metrics=[cindex_score])
236
237
    print(interactionModel.summary())
238
    plot_model(interactionModel, to_file='figures/build_baseline.png')
239
240
    return interactionModel
241
242
def nfold_1_2_3_setting_sample(XD, XT,  Y, label_row_inds, label_col_inds, measure, runmethod,  FLAGS, dataset):
243
244
    bestparamlist = []
245
    test_set, outer_train_sets = dataset.read_sets(FLAGS) 
246
    
247
    foldinds = len(outer_train_sets)
248
249
    test_sets = []
250
    ## TRAIN AND VAL
251
    val_sets = []
252
    train_sets = []
253
254
    #logger.info('Start training')
255
    for val_foldind in range(foldinds):
256
        val_fold = outer_train_sets[val_foldind]
257
        val_sets.append(val_fold)
258
        otherfolds = deepcopy(outer_train_sets)
259
        otherfolds.pop(val_foldind)
260
        otherfoldsinds = [item for sublist in otherfolds for item in sublist]
261
        train_sets.append(otherfoldsinds)
262
        test_sets.append(test_set)
263
        print("val set", str(len(val_fold)))
264
        print("train set", str(len(otherfoldsinds)))
265
266
267
268
    bestparamind, best_param_list, bestperf, all_predictions_not_need, losses_not_need = general_nfold_cv(XD, XT,  Y, label_row_inds, label_col_inds, 
269
                                                                                                measure, runmethod, FLAGS, train_sets, val_sets)
270
   
271
    #print("Test Set len", str(len(test_set)))
272
    #print("Outer Train Set len", str(len(outer_train_sets)))
273
    bestparam, best_param_list, bestperf, all_predictions, all_losses = general_nfold_cv(XD, XT,  Y, label_row_inds, label_col_inds, 
274
                                                                                                measure, runmethod, FLAGS, train_sets, test_sets)
275
    
276
    testperf = all_predictions[bestparamind]##pointer pos 
277
278
    logging("---FINAL RESULTS-----", FLAGS)
279
    logging("best param index = %s,  best param = %.5f" % 
280
            (bestparamind, bestparam), FLAGS)
281
282
283
    testperfs = []
284
    testloss= []
285
286
    avgperf = 0.
287
288
    for test_foldind in range(len(test_sets)):
289
        foldperf = all_predictions[bestparamind][test_foldind]
290
        foldloss = all_losses[bestparamind][test_foldind]
291
        testperfs.append(foldperf)
292
        testloss.append(foldloss)
293
        avgperf += foldperf
294
295
    avgperf = avgperf / len(test_sets)
296
    avgloss = np.mean(testloss)
297
    teststd = np.std(testperfs)
298
299
    logging("Test Performance CI", FLAGS)
300
    logging(testperfs, FLAGS)
301
    logging("Test Performance MSE", FLAGS)
302
    logging(testloss, FLAGS)
303
304
    return avgperf, avgloss, teststd
305
306
307
308
309
def general_nfold_cv(XD, XT,  Y, label_row_inds, label_col_inds, prfmeasure, runmethod, FLAGS, labeled_sets, val_sets): ## BURAYA DA FLAGS LAZIM????
310
    
311
    paramset1 = FLAGS.num_windows                              #[32]#[32,  512] #[32, 128]  # filter numbers
312
    paramset2 = FLAGS.smi_window_lengths                               #[4, 8]#[4,  32] #[4,  8] #filter length smi
313
    paramset3 = FLAGS.seq_window_lengths                               #[8, 12]#[64,  256] #[64, 192]#[8, 192, 384]
314
    epoch = FLAGS.num_epoch                                 #100
315
    batchsz = FLAGS.batch_size                             #256
316
317
    logging("---Parameter Search-----", FLAGS)
318
319
    w = len(val_sets)
320
    h = len(paramset1) * len(paramset2) * len(paramset3)
321
322
    all_predictions = [[0 for x in range(w)] for y in range(h)] 
323
    all_losses = [[0 for x in range(w)] for y in range(h)] 
324
    print(all_predictions)
325
326
    for foldind in range(len(val_sets)):
327
        valinds = val_sets[foldind]
328
        labeledinds = labeled_sets[foldind]
329
330
        Y_train = np.mat(np.copy(Y))
331
332
        params = {}
333
        XD_train = XD
334
        XT_train = XT
335
        trrows = label_row_inds[labeledinds]
336
        trcols = label_col_inds[labeledinds]
337
338
        XD_train = XD[trrows]
339
        XT_train = XT[trcols]
340
341
        train_drugs, train_prots,  train_Y = prepare_interaction_pairs(XD, XT, Y, trrows, trcols)
342
        
343
        terows = label_row_inds[valinds]
344
        tecols = label_col_inds[valinds]
345
        #print("terows", str(terows), str(len(terows)))
346
        #print("tecols", str(tecols), str(len(tecols)))
347
348
        val_drugs, val_prots,  val_Y = prepare_interaction_pairs(XD, XT,  Y, terows, tecols)
349
350
351
        pointer = 0
352
       
353
        for param1ind in range(len(paramset1)): #hidden neurons
354
            param1value = paramset1[param1ind]
355
            for param2ind in range(len(paramset2)): #learning rate
356
                param2value = paramset2[param2ind]
357
358
                for param3ind in range(len(paramset3)):
359
                    param3value = paramset3[param3ind]
360
361
                    gridmodel = runmethod(FLAGS, param1value, param2value, param3value)
362
                    es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=15)
363
                    gridres = gridmodel.fit(([np.array(train_drugs),np.array(train_prots) ]), np.array(train_Y), batch_size=batchsz, epochs=epoch, 
364
                            validation_data=( ([np.array(val_drugs), np.array(val_prots) ]), np.array(val_Y)),  shuffle=False, callbacks=[es] ) 
365
366
367
                    predicted_labels = gridmodel.predict([np.array(val_drugs), np.array(val_prots) ])
368
                    loss, rperf2 = gridmodel.evaluate(([np.array(val_drugs),np.array(val_prots) ]), np.array(val_Y), verbose=0)
369
                    rperf = prfmeasure(val_Y, predicted_labels)
370
                    rperf = rperf[0]
371
372
373
                    logging("P1 = %d,  P2 = %d, P3 = %d, Fold = %d, CI-i = %f, CI-ii = %f, MSE = %f" % 
374
                    (param1ind, param2ind, param3ind, foldind, rperf, rperf2, loss), FLAGS)
375
376
                    plotLoss(gridres, param1ind, param2ind, param3ind, foldind)
377
378
                    all_predictions[pointer][foldind] =rperf #TODO FOR EACH VAL SET allpredictions[pointer][foldind]
379
                    all_losses[pointer][foldind]= loss
380
381
                    pointer +=1
382
383
    bestperf = -float('Inf')
384
    bestpointer = None
385
386
387
    best_param_list = []
388
    ##Take average according to folds, then chooose best params
389
    pointer = 0
390
    for param1ind in range(len(paramset1)):
391
            for param2ind in range(len(paramset2)):
392
                for param3ind in range(len(paramset3)):
393
                
394
                    avgperf = 0.
395
                    for foldind in range(len(val_sets)):
396
                        foldperf = all_predictions[pointer][foldind]
397
                        avgperf += foldperf
398
                    avgperf /= len(val_sets)
399
                    #print(epoch, batchsz, avgperf)
400
                    if avgperf > bestperf:
401
                        bestperf = avgperf
402
                        bestpointer = pointer
403
                        best_param_list = [param1ind, param2ind, param3ind]
404
405
                    pointer +=1
406
        
407
    return  bestpointer, best_param_list, bestperf, all_predictions, all_losses
408
409
410
411
def cindex_score(y_true, y_pred):
412
413
    g = tf.subtract(tf.expand_dims(y_pred, -1), y_pred)
414
    g = tf.cast(g == 0.0, tf.float32) * 0.5 + tf.cast(g > 0.0, tf.float32)
415
416
    f = tf.subtract(tf.expand_dims(y_true, -1), y_true) > 0.0
417
    f = tf.matrix_band_part(tf.cast(f, tf.float32), -1, 0)
418
419
    g = tf.reduce_sum(tf.multiply(g, f))
420
    f = tf.reduce_sum(f)
421
422
    return tf.where(tf.equal(g, 0), 0.0, g/f) #select
423
424
425
   
426
def plotLoss(history, batchind, epochind, param3ind, foldind):
427
428
    figname = "b"+str(batchind) + "_e" + str(epochind) + "_" + str(param3ind) + "_"  + str( foldind) + "_" + str(time.time()) 
429
    plt.figure()
430
    plt.plot(history.history['loss'])
431
    plt.plot(history.history['val_loss'])
432
    plt.title('model loss')
433
    plt.ylabel('loss')
434
    plt.xlabel('epoch')
435
    #plt.legend(['trainloss', 'valloss', 'cindex', 'valcindex'], loc='upper left')
436
    plt.legend(['trainloss', 'valloss'], loc='upper left')
437
    plt.savefig("figures/"+figname +".png" , dpi=None, facecolor='w', edgecolor='w', orientation='portrait', 
438
                    papertype=None, format=None,transparent=False, bbox_inches=None, pad_inches=0.1,frameon=None)
439
    plt.close()
440
441
442
    ## PLOT CINDEX
443
    plt.figure()
444
    plt.title('model concordance index')
445
    plt.ylabel('cindex')
446
    plt.xlabel('epoch')
447
    plt.plot(history.history['cindex_score'])
448
    plt.plot(history.history['val_cindex_score'])
449
    plt.legend(['traincindex', 'valcindex'], loc='upper left')
450
    plt.savefig("figures/"+figname + "_acc.png" , dpi=None, facecolor='w', edgecolor='w', orientation='portrait', 
451
                            papertype=None, format=None,transparent=False, bbox_inches=None, pad_inches=0.1,frameon=None)
452
    plt.close()
453
454
455
456
def prepare_interaction_pairs(XD, XT,  Y, rows, cols):
457
    drugs = []
458
    targets = []
459
    targetscls = []
460
    affinity=[] 
461
        
462
    for pair_ind in range(len(rows)):
463
        drug = XD[rows[pair_ind]]
464
        drugs.append(drug)
465
466
        target=XT[cols[pair_ind]]
467
        targets.append(target)
468
469
        affinity.append(Y[rows[pair_ind],cols[pair_ind]])
470
471
    drug_data = np.stack(drugs)
472
    target_data = np.stack(targets)
473
474
    return drug_data,target_data,  affinity
475
476
477
       
478
def experiment(FLAGS, perfmeasure, deepmethod, foldcount=6): #5-fold cross validation + test
479
480
    #Input
481
    #XD: [drugs, features] sized array (features may also be similarities with other drugs
482
    #XT: [targets, features] sized array (features may also be similarities with other targets
483
    #Y: interaction values, can be real values or binary (+1, -1), insert value float("nan") for unknown entries
484
    #perfmeasure: function that takes as input a list of correct and predicted outputs, and returns performance
485
    #higher values should be better, so if using error measures use instead e.g. the inverse -error(Y, P)
486
    #foldcount: number of cross-validation folds for settings 1-3, setting 4 always runs 3x3 cross-validation
487
488
489
    dataset = DataSet( fpath = FLAGS.dataset_path, ### BUNU ARGS DA GUNCELLE
490
                      setting_no = FLAGS.problem_type, ##BUNU ARGS A EKLE
491
                      seqlen = FLAGS.max_seq_len,
492
                      smilen = FLAGS.max_smi_len,
493
                      need_shuffle = False )
494
    # set character set size
495
    FLAGS.charseqset_size = dataset.charseqset_size 
496
    FLAGS.charsmiset_size = dataset.charsmiset_size 
497
498
    XD, XT, Y = dataset.parse_data(FLAGS)
499
500
    XD = np.asarray(XD)
501
    XT = np.asarray(XT)
502
    Y = np.asarray(Y)
503
504
    drugcount = XD.shape[0]
505
    print(drugcount)
506
    targetcount = XT.shape[0]
507
    print(targetcount)
508
509
    FLAGS.drug_count = drugcount
510
    FLAGS.target_count = targetcount
511
512
    label_row_inds, label_col_inds = np.where(np.isnan(Y)==False)  #basically finds the point address of affinity [x,y]
513
514
    if not os.path.exists(figdir):
515
        os.makedirs(figdir)
516
517
    print(FLAGS.log_dir)
518
    S1_avgperf, S1_avgloss, S1_teststd = nfold_1_2_3_setting_sample(XD, XT, Y, label_row_inds, label_col_inds,
519
                                                                     perfmeasure, deepmethod, FLAGS, dataset)
520
521
    logging("Setting " + str(FLAGS.problem_type), FLAGS)
522
    logging("avg_perf = %.5f,  avg_mse = %.5f, std = %.5f" % 
523
            (S1_avgperf, S1_avgloss, S1_teststd), FLAGS)
524
525
526
527
528
def run_regression( FLAGS ): 
529
530
    perfmeasure = get_cindex
531
    deepmethod = build_combined_categorical
532
533
    experiment(FLAGS, perfmeasure, deepmethod)
534
535
536
537
538
if __name__=="__main__":
539
    FLAGS = argparser()
540
    FLAGS.log_dir = FLAGS.log_dir + str(time.time()) + "/"
541
542
    if not os.path.exists(FLAGS.log_dir):
543
        os.makedirs(FLAGS.log_dir)
544
545
    logging(str(FLAGS), FLAGS)
546
    run_regression( FLAGS )