a b/3D/resume.py
1
2
from __future__ import print_function
3
4
# import packages
5
from model import unet_model_3d
6
from keras.utils import plot_model
7
from keras import callbacks
8
from keras.callbacks import ModelCheckpoint, CSVLogger, LearningRateScheduler, ReduceLROnPlateau, EarlyStopping
9
10
# import load data
11
from data_handling import load_train_data, load_validatation_data
12
13
# import configurations
14
import configs
15
16
# init configs
17
patch_size = configs.PATCH_SIZE
18
batch_size = configs.BATCH_SIZE
19
20
config = dict()
21
config["pool_size"] = (2, 2, 2)  # pool size for the max pooling operations
22
config["image_shape"] = (256, 128, 256)  # This determines what shape the images will be cropped/resampled to.
23
config["input_shape"] = (patch_size, patch_size, patch_size, 1)  # switch to None to train on the whole image (64, 64, 64) (64, 64, 64)
24
config["n_labels"] = 4
25
config["all_modalities"] = ['t1']#]["t1", "t1Gd", "flair", "t2"]
26
config["training_modalities"] = config["all_modalities"]  # change this if you want to only use some of the modalities
27
config["nb_channels"] = len(config["training_modalities"])
28
config["deconvolution"] = False  # if False, will use upsampling instead of deconvolution
29
config["batch_size"] = batch_size
30
config["n_epochs"] = 500  # cutoff the training after this many epochs
31
config["patience"] = 10  # learning rate will be reduced after this many epochs if the validation loss is not improving
32
config["early_stop"] = 30  # training will be stopped after this many epochs without the validation loss improving
33
config["initial_learning_rate"] = 0.00005
34
config["depth"] = configs.DEPTH
35
config["learning_rate_drop"] = 0.5
36
37
image_type = '3d_patches'
38
39
# resume training
40
def resume():
41
    print('-'*30)
42
    print('Loading and preprocessing train data...')
43
    print('-'*30)
44
    imgs_train, imgs_gtruth_train = load_train_data()
45
    
46
    print('-'*30)
47
    print('Loading and preprocessing validation data...')
48
    print('-'*30)
49
    imgs_val, imgs_gtruth_val  = load_validatation_data()
50
    
51
    print('-'*30)
52
    print('Creating and compiling model...')
53
    print('-'*30)
54
55
   # create a model
56
    model = unet_model_3d(input_shape=config["input_shape"],
57
                                depth=config["depth"],
58
                                pool_size=config["pool_size"],
59
                                n_labels=config["n_labels"],
60
                                initial_learning_rate=config["initial_learning_rate"],
61
                                deconvolution=config["deconvolution"])
62
63
    model.summary()
64
    
65
    checkpoint_filepath_best = 'outputs/' + 'best_weights_125extract_depth5_patch32_88_943_935.h5'
66
    checkpoint_filepath_best = 'outputs/' + 'best_weights_10extract_depth5_patch32_855_945_935.h5'
67
    checkpoint_filepath_best = 'outputs/' + 'best_weights_12extract_depth5_patch32_85_946_931_norm_tuned10.h5'
68
    checkpoint_filepath_best = 'outputs/' + 'best_weights.h5'
69
    #checkpoint_filepath_best = 'outputs/' + 'best_weights_11extract_depth5_patch32_855_945_935_tunedfrom10extract.h5'
70
    
71
    #checkpoint_filepath_best = 'outputs/' + 'best_weights_125extract_depth4_patch32_864_941_932.h5'
72
    model.load_weights(checkpoint_filepath_best)
73
    
74
    print('*'*50)
75
    print('Load model: ', checkpoint_filepath_best)
76
    print('*'*50)
77
    
78
    #summarize layers
79
    #print(model.summary())
80
    # plot graph
81
    #plot_model(model, to_file='3d_unet.png')
82
    
83
    print('-'*30)
84
    print('Fitting model...')
85
    print('-'*30)
86
    
87
    #============================================================================
88
    print('training starting..')
89
    log_filename = 'outputs/' + image_type +'_model_train.csv' 
90
    
91
    
92
    csv_log = callbacks.CSVLogger(log_filename, separator=',', append=True)
93
    
94
#    early_stopping = callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=0, mode='min')
95
    
96
    #checkpoint_filepath = 'outputs/' + image_type +"_best_weight_model_{epoch:03d}_{val_loss:.4f}.hdf5"
97
    checkpoint_filepath = 'outputs/' + 'weights.h5'
98
    
99
    checkpoint = callbacks.ModelCheckpoint(checkpoint_filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
100
    
101
    callbacks_list = [csv_log, checkpoint]
102
    callbacks_list.append(ReduceLROnPlateau(factor=config["learning_rate_drop"], patience=config["patience"],
103
                                           verbose=True))
104
    callbacks_list.append(EarlyStopping(verbose=True, patience=config["early_stop"]))
105
106
    #============================================================================
107
    hist = model.fit(imgs_train, imgs_gtruth_train, batch_size=config["batch_size"], nb_epoch=config["n_epochs"], verbose=1, validation_data=(imgs_val,imgs_gtruth_val), shuffle=True, callbacks=callbacks_list) #              validation_split=0.2,
108
        
109
     
110
    model_name = 'outputs/' + image_type + '_model_last'
111
    model.save(model_name)  # creates a HDF5 file 'my_model.h5'
112
113
    
114
if __name__ == '__main__':
115
    resume()