Diff of /src/compnet.py [000000] .. [e918fa]

Switch to unified view

a b/src/compnet.py
1
import argparse
2
import numpy as np
3
import os
4
import sys
5
import warnings
6
with warnings.catch_warnings():
7
    warnings.filterwarnings("ignore", category=FutureWarning)
8
    import tensorflow as tf
9
    
10
import keras
11
from keras.models import Model
12
from keras.layers import Input,merge, concatenate, Conv2D, MaxPooling2D, Activation, UpSampling2D,Dropout,Conv2DTranspose,add,multiply
13
from keras.layers.normalization import BatchNormalization as bn
14
from keras.optimizers import RMSprop, Adam
15
from keras import regularizers, losses, backend as K
16
from keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger, ModelCheckpoint, TensorBoard
17
os.environ['CUDA_VISIBLE_DEVICES']="0"
18
19
smooth = 1.
20
def dice_coef(y_true, y_pred):
21
22
    y_true_f = K.flatten(y_true)
23
    y_pred_f = K.flatten(y_pred)
24
    intersection = K.sum(y_true_f * y_pred_f)
25
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
26
27
def dice_coef_test(y_true, y_pred):
28
29
    y_true_f = np.array(y_true).flatten()
30
    y_pred_f =np.array(y_pred).flatten()
31
    intersection = np.sum(y_true_f * y_pred_f)
32
    return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)
33
34
35
def dice_coef_loss(y_true, y_pred):
36
    return -dice_coef(y_true, y_pred)
37
38
def neg_dice_coef_loss(y_true, y_pred):
39
    return dice_coef(y_true, y_pred)
40
41
42
#define the model
43
def Comp_U_Net(input_shape,learn_rate=1e-3):
44
45
    l2_lambda = 0.0002
46
    DropP = 0.3
47
    kernel_size=3
48
49
    inputs = Input(input_shape,name='ip0')
50
    
51
52
    conv0a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same', 
53
                   kernel_regularizer=regularizers.l2(l2_lambda) )(inputs)
54
    
55
    
56
    conv0a = bn()(conv0a)
57
    
58
    conv0b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
59
                   kernel_regularizer=regularizers.l2(l2_lambda) )(conv0a)
60
61
    conv0b = bn()(conv0b)
62
63
    
64
    pool0 = MaxPooling2D(pool_size=(2, 2))(conv0b)
65
66
    pool0 = Dropout(DropP)(pool0)
67
68
69
    conv1a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same', 
70
                   kernel_regularizer=regularizers.l2(l2_lambda) )(pool0)
71
    
72
    
73
    conv1a = bn()(conv1a)
74
    
75
    conv1b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
76
                   kernel_regularizer=regularizers.l2(l2_lambda) )(conv1a)
77
78
    conv1b = bn()(conv1b)
79
80
81
    
82
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1b)
83
84
    pool1 = Dropout(DropP)(pool1)
85
86
87
88
    
89
90
    conv2a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
91
                   kernel_regularizer=regularizers.l2(l2_lambda) )(pool1)
92
    
93
    conv2a = bn()(conv2a)
94
95
    conv2b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
96
                   kernel_regularizer=regularizers.l2(l2_lambda) )(conv2a)
97
98
    conv2b = bn()(conv2b)
99
100
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2b)
101
102
    pool2 = Dropout(DropP)(pool2)
103
104
105
106
107
108
109
110
    conv3a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
111
                   kernel_regularizer=regularizers.l2(l2_lambda) )(pool2)
112
    
113
    conv3a = bn()(conv3a)
114
115
    conv3b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
116
                   kernel_regularizer=regularizers.l2(l2_lambda) )(conv3a)
117
118
    conv3b = bn()(conv3b)
119
120
121
122
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3b)
123
124
    pool3 = Dropout(DropP)(pool3)
125
126
    
127
    conv4a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
128
                   kernel_regularizer=regularizers.l2(l2_lambda) )(pool3)
129
    
130
    conv4a = bn()(conv4a)
131
132
    conv4b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
133
                   kernel_regularizer=regularizers.l2(l2_lambda) )(conv4a)
134
135
    conv4b = bn()(conv4b)
136
137
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4b)
138
139
    pool4 = Dropout(DropP)(pool4)
140
141
142
143
144
145
    conv5a = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same',
146
                   kernel_regularizer=regularizers.l2(l2_lambda) )(pool4)
147
    
148
    conv5a = bn()(conv5a)
149
150
    conv5b = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same',
151
                   kernel_regularizer=regularizers.l2(l2_lambda) )(conv5a)
152
153
    conv5b = bn()(conv5b)
154
155
    
156
157
158
159
    up6 = concatenate([Conv2DTranspose(256,(2, 2), strides=(2, 2), padding='same')(conv5b), (conv4b)],name='up6', axis=3)
160
161
162
    up6 = Dropout(DropP)(up6)
163
164
    conv6a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
165
                   kernel_regularizer=regularizers.l2(l2_lambda) )(up6)
166
    
167
    conv6a = bn()(conv6a)
168
169
    conv6b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
170
                   kernel_regularizer=regularizers.l2(l2_lambda) )(conv6a)
171
172
    conv6b = bn()(conv6b)
173
174
175
176
177
178
    up7 = concatenate([Conv2DTranspose(128,(2, 2), strides=(2, 2), padding='same')(conv6b),(conv3b)],name='up7', axis=3)
179
180
    up7 = Dropout(DropP)(up7)
181
    #add second output here
182
183
    conv7a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
184
                   kernel_regularizer=regularizers.l2(l2_lambda) )(up7)
185
    
186
    conv7a = bn()(conv7a)
187
188
 
189
190
    conv7b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
191
                   kernel_regularizer=regularizers.l2(l2_lambda) )(conv7a)
192
193
    conv7b = bn()(conv7b)
194
195
   
196
197
198
199
200
    up8 = concatenate([Conv2DTranspose(64,(2, 2), strides=(2, 2), padding='same')(conv7b), (conv2b)],name='up8', axis=3)
201
202
    up8 = Dropout(DropP)(up8)
203
 
204
    conv8a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
205
                   kernel_regularizer=regularizers.l2(l2_lambda) )(up8)
206
    
207
    conv8a = bn()(conv8a)
208
209
    
210
    conv8b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
211
                   kernel_regularizer=regularizers.l2(l2_lambda) )(conv8a)
212
213
    conv8b = bn()(conv8b)
214
215
216
    
217
    up9 = concatenate([Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same')(conv8b),(conv1b)],name='up9',axis=3)
218
219
220
    conv9a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
221
                   kernel_regularizer=regularizers.l2(l2_lambda) )(up9)
222
    
223
    conv9a = bn()(conv9a)
224
225
    conv9b = Conv2D(12, (kernel_size, kernel_size), activation='relu', padding='same',
226
                   kernel_regularizer=regularizers.l2(l2_lambda) )(conv9a)
227
228
    conv9b = bn()(conv9b)
229
230
231
232
233
    up10 = concatenate([Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same')(conv9b),(conv0b)],name='up10',axis=3)
234
235
    conv10a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
236
                   kernel_regularizer=regularizers.l2(l2_lambda) )(up10)
237
    
238
    conv10a = bn()(conv10a)
239
240
   
241
242
    conv10b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
243
                   kernel_regularizer=regularizers.l2(l2_lambda) )(conv10a)
244
245
    conv10b = bn()(conv10b)
246
247
248
    
249
    final_op=Conv2D(1, (1, 1), activation='sigmoid',name='final_op')(conv10b)
250
    
251
252
253
    #----------------------------------------------------------------------------------------------------------------------------------
254
255
    #second branch - brain
256
    xup6 = concatenate([Conv2DTranspose(256,(2, 2), strides=(2, 2), padding='same')(conv5b), (conv4b)],name='xup6', axis=3)
257
258
    
259
260
    xup6 = Dropout(DropP)(xup6)
261
262
    xconv6a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
263
                   kernel_regularizer=regularizers.l2(l2_lambda) )(xup6)
264
    
265
    xconv6a = bn()(xconv6a)
266
267
    
268
269
    xconv6b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
270
                   kernel_regularizer=regularizers.l2(l2_lambda) )(xconv6a)
271
272
    xconv6b = bn()(xconv6b)
273
274
275
276
277
278
    xup7 = concatenate([Conv2DTranspose(128,(2, 2), strides=(2, 2), padding='same')(xconv6b),(conv3b)],name='xup7', axis=3)
279
280
    xup7 = Dropout(DropP)(xup7)
281
    
282
    xconv7a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
283
                   kernel_regularizer=regularizers.l2(l2_lambda) )(xup7)
284
    
285
    xconv7a = bn()(xconv7a)
286
287
288
    xconv7b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
289
                   kernel_regularizer=regularizers.l2(l2_lambda) )(xconv7a)
290
291
    xconv7b = bn()(xconv7b)
292
293
294
    xup8 = concatenate([Conv2DTranspose(64,(2, 2), strides=(2, 2), padding='same')(xconv7b),(conv2b)],name='xup8', axis=3)
295
296
    xup8 = Dropout(DropP)(xup8)
297
    #add third xoutxout here
298
    
299
    xconv8a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
300
                   kernel_regularizer=regularizers.l2(l2_lambda) )(xup8)
301
    
302
    xconv8a = bn()(xconv8a)
303
304
305
    xconv8b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
306
                   kernel_regularizer=regularizers.l2(l2_lambda) )(xconv8a)
307
308
    xconv8b = bn()(xconv8b)
309
310
311
312
    
313
    xup9 = concatenate([Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same')(xconv8b), (conv1b)],name='xup9',axis=3)
314
315
    xup9 = Dropout(DropP)(xup9)
316
    
317
318
    xconv9a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
319
                   kernel_regularizer=regularizers.l2(l2_lambda) )(xup9)
320
    
321
    xconv9a = bn()(xconv9a)
322
323
    
324
    xconv9b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
325
                   kernel_regularizer=regularizers.l2(l2_lambda) )(xconv9a)
326
327
    xconv9b = bn()(xconv9b)
328
329
 
330
    
331
    xup10 = concatenate([Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same')(xconv9b), (conv0b)],name='xup10',axis=3)
332
333
    xup10 = Dropout(DropP)(xup10)
334
    
335
336
    xconv10a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
337
                   kernel_regularizer=regularizers.l2(l2_lambda) )(xup10)
338
    
339
    xconv10a = bn()(xconv10a)
340
341
342
    xconv10b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
343
                   kernel_regularizer=regularizers.l2(l2_lambda) )(xconv10a)
344
345
    xconv10b = bn()(xconv10b)
346
347
    
348
349
350
   
351
    
352
    xfinal_op=Conv2D(1, (1, 1), activation='sigmoid',name='xfinal_op')(xconv10b)
353
354
355
    #-----------------------------third branch
356
357
358
359
    #Concatenation fed to the reconstruction layer of all 3
360
   
361
    x_u_net_op0=keras.layers.concatenate([final_op,xfinal_op,keras.layers.add([final_op,xfinal_op])],name='res_a')
362
    
363
364
    
365
366
367
368
369
370
371
    res_1_conv0a = Conv2D( 32, (kernel_size, kernel_size), activation='relu', padding='same', 
372
                   kernel_regularizer=regularizers.l2(l2_lambda) )(x_u_net_op0)
373
    
374
    
375
    res_1_conv0a = bn()(res_1_conv0a)
376
    
377
    res_1_conv0b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
378
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv0a)
379
380
    res_1_conv0b = bn()(res_1_conv0b)
381
382
    res_1_pool0 = MaxPooling2D(pool_size=(2, 2))(res_1_conv0b)
383
384
    res_1_pool0 = Dropout(DropP)(res_1_pool0)
385
386
387
388
389
    res_1_conv1a = Conv2D( 32, (kernel_size, kernel_size), activation='relu', padding='same', 
390
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_pool0)
391
    
392
    
393
    res_1_conv1a = bn()(res_1_conv1a)
394
    
395
    res_1_conv1b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
396
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv1a)
397
398
    res_1_conv1b = bn()(res_1_conv1b)
399
400
    res_1_pool1 = MaxPooling2D(pool_size=(2, 2))(res_1_conv1b)
401
402
    res_1_pool1 = Dropout(DropP)(res_1_pool1)
403
404
405
406
    
407
408
    res_1_conv2a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
409
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_pool1)
410
    
411
    res_1_conv2a = bn()(res_1_conv2a)
412
413
    res_1_conv2b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
414
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv2a)
415
416
    res_1_conv2b = bn()(res_1_conv2b)
417
418
    
419
    res_1_pool2 = MaxPooling2D(pool_size=(2, 2))(res_1_conv2b)
420
421
    res_1_pool2 = Dropout(DropP)(res_1_pool2)
422
423
424
425
426
427
428
429
    res_1_conv3a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
430
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_pool2)
431
    
432
    res_1_conv3a = bn()(res_1_conv3a)
433
434
    res_1_conv3b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
435
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv3a)
436
437
    res_1_conv3b = bn()(res_1_conv3b)
438
439
    res_1_pool3 = MaxPooling2D(pool_size=(2, 2))(res_1_conv3b)
440
441
    res_1_pool3 = Dropout(DropP)(res_1_pool3)
442
443
    
444
    res_1_conv4a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
445
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_pool3)
446
    
447
    res_1_conv4a = bn()(res_1_conv4a)
448
449
    res_1_conv4b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
450
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv4a)
451
452
    res_1_conv4b = bn()(res_1_conv4b)
453
454
    
455
    res_1_pool4 = MaxPooling2D(pool_size=(2, 2))(res_1_conv4b)
456
457
    res_1_pool4 = Dropout(DropP)(res_1_pool4)
458
459
460
461
462
463
    res_1_conv5a = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same',
464
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_pool4)
465
    
466
    res_1_conv5a = bn()(res_1_conv5a)
467
468
    res_1_conv5b = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same',
469
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv5a)
470
471
    res_1_conv5b = bn()(res_1_conv5b)
472
473
474
475
476
    res_1_up6 = concatenate([Conv2DTranspose(256,(2, 2), strides=(2, 2), padding='same')(res_1_conv5b), (res_1_conv4b)],name='res_1_up6', axis=3)
477
478
479
    res_1_up6 = Dropout(DropP)(res_1_up6)
480
481
    res_1_conv6a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
482
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_up6)
483
    
484
    res_1_conv6a = bn()(res_1_conv6a)
485
486
487
    res_1_conv6b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
488
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv6a)
489
490
    res_1_conv6b = bn()(res_1_conv6b)
491
492
493
494
    res_1_up7 = concatenate([Conv2DTranspose(128,(2, 2), strides=(2, 2), padding='same')(res_1_conv6b),(res_1_conv3b)],name='res_1_up7', axis=3)
495
496
    res_1_up7 = Dropout(DropP)(res_1_up7)
497
    #add second res_1_output here
498
    res_1_conv7a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
499
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_up7)
500
    
501
    res_1_conv7a = bn()(res_1_conv7a)
502
503
    
504
    res_1_conv7b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
505
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv7a)
506
507
    res_1_conv7b = bn()(res_1_conv7b)
508
509
510
511
    res_1_up8 = concatenate([Conv2DTranspose(64,(2, 2), strides=(2, 2), padding='same')(res_1_conv7b),(res_1_conv2b)],name='res_1_up8', axis=3)
512
513
    res_1_up8 = Dropout(DropP)(res_1_up8)
514
    #add third outout here
515
    res_1_conv8a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
516
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_up8)
517
    
518
    res_1_conv8a = bn()(res_1_conv8a)
519
520
521
    res_1_conv8b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
522
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv8a)
523
524
    res_1_conv8b = bn()(res_1_conv8b)
525
526
527
    res_1_up9 = concatenate([Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same')(res_1_conv8b), (res_1_conv1b)],name='res_1_up9',axis=3)
528
529
    res_1_up9 = Dropout(DropP)(res_1_up9)
530
531
    res_1_conv9a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
532
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_up9)
533
    
534
    res_1_conv9a = bn()(res_1_conv9a)
535
536
537
    res_1_conv9b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
538
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv9a)
539
540
    res_1_conv9b = bn()(res_1_conv9b)
541
542
543
544
545
    res_1_up10 = concatenate([Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same')(res_1_conv9b),(res_1_conv0b)],name='res_1_up10',axis=3)
546
547
    res_1_up10 = Dropout(DropP)(res_1_up10)
548
    
549
550
    res_1_conv10a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
551
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_up10)
552
    
553
    res_1_conv10a = bn()(res_1_conv10a)
554
555
   
556
    res_1_conv10b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
557
                   kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv10a)
558
559
    res_1_conv10b = bn()(res_1_conv10b)
560
561
562
    res_1_final_op=Conv2D(1, (1, 1), activation='sigmoid',name='res_1_final_op')(res_1_conv10b)
563
564
565
    model=Model(inputs=[inputs],outputs=[final_op,
566
                                         xfinal_op,
567
                                         res_1_final_op,
568
                                     ])
569
570
    model.compile(optimizer=keras.optimizers.Adam(lr=1e-5),loss={'final_op':dice_coef_loss,
571
                                                'xfinal_op':neg_dice_coef_loss,
572
                                                'res_1_final_op':'mse'})
573
574
    return model
575
576
#----------------------------------------------------Main--------------------------------------------------#
577
578
579
def train_model(data_params, train_params, common_params):
580
581
582
    training_data_folder = data_params['data_dir'].rstrip('/')
583
584
    train_x = training_data_folder + '/' + data_params['train_data_file']
585
    train_y = training_data_folder + '/' + data_params['train_label_file']
586
587
    model = Comp_U_Net(input_shape=(256,256,1), learn_rate=train_params['learning_rate'])
588
    # print(model.summary())
589
590
    x_train = np.load(train_x)
591
    y_train = np.load(train_y)
592
593
    x_train=x_train.reshape(x_train.shape+(1,))
594
    y_train=y_train.reshape(y_train.shape+(1,))
595
596
    # Log output
597
    print ("Training dwi volume shape: ", x_train.shape)
598
    print ("Training dwi mask volume shape: ", y_train.shape)
599
600
    view = train_params['principal_axis']
601
602
    os.makedirs(common_params['log_dir'], exist_ok= True)
603
    csv_logger = CSVLogger(common_params['log_dir'] + '/' + view + '.csv', append=True, separator=';')
604
605
    # checkpoint
606
    os.makedirs(common_params['save_model_dir'], exist_ok= True)
607
    filepath = common_params['save_model_dir'] + "/weights-" + view + "-improvement-{epoch:02d}.h5"
608
    checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=False, save_weights_only=True)
609
610
    # Trains the model for a given number of epochs (iterations on a dataset).
611
    history_callback = model.fit([x_train],
612
                                 [y_train,y_train,y_train],
613
                                 validation_split=train_params['validation_split'],
614
                                 batch_size=train_params['train_batch_size'],
615
                                 epochs=train_params['num_epochs'],
616
                                 shuffle=train_params['shuffle_data'],
617
                                 verbose=1,
618
                                 callbacks=[csv_logger, checkpoint])
619
620
    import h5py
621
    # serialize model to JSON
622
    model_json = model.to_json()
623
    with open(common_params['save_model_dir'] + "/CompNetBasicModel.json", "w") as json_file:
624
        json_file.write(model_json)
625
    # serialize weights to HDF5
626
    model.save_weights(common_params['save_model_dir'] + "/" + view + "-compnet_final_weight.h5")
627
    print("Saved model to disk location: ", common_params['save_model_dir'])