|
a |
|
b/3D/train.py |
|
|
1 |
|
|
|
2 |
from __future__ import print_function |
|
|
3 |
|
|
|
4 |
# import packages |
|
|
5 |
from model import unet_model_3d |
|
|
6 |
from keras.utils import plot_model |
|
|
7 |
from keras import callbacks |
|
|
8 |
from keras.callbacks import ModelCheckpoint, CSVLogger, LearningRateScheduler, ReduceLROnPlateau, EarlyStopping |
|
|
9 |
|
|
|
10 |
# import load data |
|
|
11 |
from data_handling import load_train_data, load_validatation_data |
|
|
12 |
|
|
|
13 |
# import configurations |
|
|
14 |
import configs |
|
|
15 |
|
|
|
16 |
# init configs |
|
|
17 |
patch_size = configs.PATCH_SIZE |
|
|
18 |
batch_size = configs.BATCH_SIZE |
|
|
19 |
|
|
|
20 |
config = dict() |
|
|
21 |
config["pool_size"] = (2, 2, 2) # pool size for the max pooling operations |
|
|
22 |
config["image_shape"] = (256, 128, 256) # This determines what shape the images will be cropped/resampled to. |
|
|
23 |
config["input_shape"] = (patch_size, patch_size, patch_size, 1) # switch to None to train on the whole image (64, 64, 64) (64, 64, 64) |
|
|
24 |
config["n_labels"] = 4 |
|
|
25 |
config["all_modalities"] = ['t1']#]["t1", "t1Gd", "flair", "t2"] |
|
|
26 |
config["training_modalities"] = config["all_modalities"] # change this if you want to only use some of the modalities |
|
|
27 |
config["nb_channels"] = len(config["training_modalities"]) |
|
|
28 |
config["deconvolution"] = False # if False, will use upsampling instead of deconvolution |
|
|
29 |
config["batch_size"] = batch_size |
|
|
30 |
config["n_epochs"] = 500 # cutoff the training after this many epochs |
|
|
31 |
config["patience"] = 10 # learning rate will be reduced after this many epochs if the validation loss is not improving |
|
|
32 |
config["early_stop"] = 31 # training will be stopped after this many epochs without the validation loss improving |
|
|
33 |
config["initial_learning_rate"] = 0.0001 |
|
|
34 |
config["depth"] = configs.DEPTH |
|
|
35 |
config["learning_rate_drop"] = 0.5 |
|
|
36 |
|
|
|
37 |
image_type = '3d_patches' |
|
|
38 |
|
|
|
39 |
# 3D U-net depth=5 |
|
|
40 |
def generate_model(num_classes=4) : |
|
|
41 |
init_input = Input((1, 32, 32, 32)) |
|
|
42 |
|
|
|
43 |
x = Conv3D(25, kernel_size=(3, 3, 3))(init_input) |
|
|
44 |
x = PReLU()(x) |
|
|
45 |
x = Conv3D(25, kernel_size=(3, 3, 3))(x) |
|
|
46 |
x = PReLU()(x) |
|
|
47 |
x = Conv3D(25, kernel_size=(3, 3, 3))(x) |
|
|
48 |
x = PReLU()(x) |
|
|
49 |
|
|
|
50 |
y = Conv3D(50, kernel_size=(3, 3, 3))(x) |
|
|
51 |
y = PReLU()(y) |
|
|
52 |
y = Conv3D(50, kernel_size=(3, 3, 3))(y) |
|
|
53 |
y = PReLU()(y) |
|
|
54 |
y = Conv3D(50, kernel_size=(3, 3, 3))(y) |
|
|
55 |
y = PReLU()(y) |
|
|
56 |
|
|
|
57 |
z = Conv3D(75, kernel_size=(3, 3, 3))(y) |
|
|
58 |
z = PReLU()(z) |
|
|
59 |
z = Conv3D(75, kernel_size=(3, 3, 3))(z) |
|
|
60 |
z = PReLU()(z) |
|
|
61 |
z = Conv3D(75, kernel_size=(3, 3, 3))(z) |
|
|
62 |
z = PReLU()(z) |
|
|
63 |
|
|
|
64 |
x_crop = Cropping3D(cropping=((6, 6), (6, 6), (6, 6)))(x) |
|
|
65 |
y_crop = Cropping3D(cropping=((3, 3), (3, 3), (3, 3)))(y) |
|
|
66 |
|
|
|
67 |
concat = concatenate([x_crop, y_crop, z], axis=1) |
|
|
68 |
|
|
|
69 |
fc = Conv3D(400, kernel_size=(1, 1, 1))(concat) |
|
|
70 |
fc = PReLU()(fc) |
|
|
71 |
fc = Conv3D(200, kernel_size=(1, 1, 1))(fc) |
|
|
72 |
fc = PReLU()(fc) |
|
|
73 |
fc = Conv3D(150, kernel_size=(1, 1, 1))(fc) |
|
|
74 |
fc = PReLU()(fc) |
|
|
75 |
|
|
|
76 |
pred = Conv3D(num_classes, kernel_size=(1, 1, 1))(fc) |
|
|
77 |
pred = PReLU()(pred) |
|
|
78 |
pred = Reshape((num_classes, 9 * 9 * 9))(pred) |
|
|
79 |
pred = Permute((2, 1))(pred) |
|
|
80 |
pred = Activation('softmax')(pred) |
|
|
81 |
|
|
|
82 |
model = Model(inputs=init_input, outputs=pred) |
|
|
83 |
model.compile( |
|
|
84 |
loss='categorical_crossentropy', |
|
|
85 |
optimizer='adam', |
|
|
86 |
metrics=['categorical_accuracy']) |
|
|
87 |
return model |
|
|
88 |
|
|
|
89 |
# train |
|
|
90 |
def train(): |
|
|
91 |
print('-'*30) |
|
|
92 |
print('Loading and preprocessing train data...') |
|
|
93 |
print('-'*30) |
|
|
94 |
imgs_train, imgs_gtruth_train = load_train_data() |
|
|
95 |
|
|
|
96 |
print('-'*30) |
|
|
97 |
print('Loading and preprocessing validation data...') |
|
|
98 |
print('-'*30) |
|
|
99 |
imgs_val, imgs_gtruth_val = load_validatation_data() |
|
|
100 |
|
|
|
101 |
print('-'*30) |
|
|
102 |
print('Creating and compiling model...') |
|
|
103 |
print('-'*30) |
|
|
104 |
|
|
|
105 |
# create a model |
|
|
106 |
model = unet_model_3d(input_shape=config["input_shape"], |
|
|
107 |
depth=config["depth"], |
|
|
108 |
pool_size=config["pool_size"], |
|
|
109 |
n_labels=config["n_labels"], |
|
|
110 |
initial_learning_rate=config["initial_learning_rate"], |
|
|
111 |
deconvolution=config["deconvolution"]) |
|
|
112 |
|
|
|
113 |
model.summary() |
|
|
114 |
|
|
|
115 |
print('-'*30) |
|
|
116 |
print('Fitting model...') |
|
|
117 |
print('-'*30) |
|
|
118 |
|
|
|
119 |
#============================================================================ |
|
|
120 |
print('training starting..') |
|
|
121 |
log_filename = 'outputs/' + image_type +'_model_train.csv' |
|
|
122 |
|
|
|
123 |
|
|
|
124 |
csv_log = callbacks.CSVLogger(log_filename, separator=',', append=True) |
|
|
125 |
|
|
|
126 |
# early_stopping = callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=0, mode='min') |
|
|
127 |
|
|
|
128 |
#checkpoint_filepath = 'outputs/' + image_type +"_best_weight_model_{epoch:03d}_{val_loss:.4f}.hdf5" |
|
|
129 |
checkpoint_filepath = 'outputs/' + 'weights.h5' |
|
|
130 |
|
|
|
131 |
checkpoint = callbacks.ModelCheckpoint(checkpoint_filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min') |
|
|
132 |
|
|
|
133 |
callbacks_list = [csv_log, checkpoint] |
|
|
134 |
callbacks_list.append(ReduceLROnPlateau(factor=config["learning_rate_drop"], patience=config["patience"], |
|
|
135 |
verbose=True)) |
|
|
136 |
callbacks_list.append(EarlyStopping(verbose=True, patience=config["early_stop"])) |
|
|
137 |
|
|
|
138 |
#============================================================================ |
|
|
139 |
hist = model.fit(imgs_train, imgs_gtruth_train, batch_size=config["batch_size"], nb_epoch=config["n_epochs"], verbose=1, validation_data=(imgs_val,imgs_gtruth_val), shuffle=True, callbacks=callbacks_list) # validation_split=0.2, |
|
|
140 |
|
|
|
141 |
|
|
|
142 |
model_name = 'outputs/' + image_type + '_model_last' |
|
|
143 |
model.save(model_name) # creates a HDF5 file 'my_model.h5' |
|
|
144 |
|
|
|
145 |
# main |
|
|
146 |
if __name__ == '__main__': |
|
|
147 |
train() |