|
a |
|
b/train.py |
|
|
1 |
import numpy as np |
|
|
2 |
import random |
|
|
3 |
import json |
|
|
4 |
from glob import glob |
|
|
5 |
from keras.models import model_from_json,load_model |
|
|
6 |
from keras.preprocessing.image import ImageDataGenerator |
|
|
7 |
from keras.callbacks import ModelCheckpoint,Callback,LearningRateScheduler |
|
|
8 |
import keras.backend as K |
|
|
9 |
from model import Unet_model |
|
|
10 |
from losses import * |
|
|
11 |
#from keras.utils.visualize_util import plot |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
|
|
|
15 |
class SGDLearningRateTracker(Callback): |
|
|
16 |
def on_epoch_begin(self, epoch, logs={}): |
|
|
17 |
optimizer = self.model.optimizer |
|
|
18 |
lr = K.get_value(optimizer.lr) |
|
|
19 |
decay = K.get_value(optimizer.decay) |
|
|
20 |
lr=lr/10 |
|
|
21 |
decay=decay*10 |
|
|
22 |
K.set_value(optimizer.lr, lr) |
|
|
23 |
K.set_value(optimizer.decay, decay) |
|
|
24 |
print('LR changed to:',lr) |
|
|
25 |
print('Decay changed to:',decay) |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
|
|
|
29 |
class Training(object): |
|
|
30 |
|
|
|
31 |
|
|
|
32 |
def __init__(self, batch_size,nb_epoch,load_model_resume_training=None): |
|
|
33 |
|
|
|
34 |
self.batch_size = batch_size |
|
|
35 |
self.nb_epoch = nb_epoch |
|
|
36 |
|
|
|
37 |
#loading model from path to resume previous training without recompiling the whole model |
|
|
38 |
if load_model_resume_training is not None: |
|
|
39 |
self.model =load_model(load_model_resume_training,custom_objects={'gen_dice_loss': gen_dice_loss,'dice_whole_metric':dice_whole_metric,'dice_core_metric':dice_core_metric,'dice_en_metric':dice_en_metric}) |
|
|
40 |
print("pre-trained model loaded!") |
|
|
41 |
else: |
|
|
42 |
unet =Unet_model(img_shape=(128,128,4)) |
|
|
43 |
self.model=unet.model |
|
|
44 |
print("U-net CNN compiled!") |
|
|
45 |
|
|
|
46 |
|
|
|
47 |
def fit_unet(self,X33_train,Y_train,X_patches_valid=None,Y_labels_valid=None): |
|
|
48 |
|
|
|
49 |
train_generator=self.img_msk_gen(X33_train,Y_train,9999) |
|
|
50 |
checkpointer = ModelCheckpoint(filepath='brain_segmentation/ResUnet.{epoch:02d}_{val_loss:.3f}.hdf5', verbose=1) |
|
|
51 |
self.model.fit_generator(train_generator,steps_per_epoch=len(X33_train)//self.batch_size,epochs=self.nb_epoch, validation_data=(X_patches_valid,Y_labels_valid),verbose=1, callbacks = [checkpointer,SGDLearningRateTracker()]) |
|
|
52 |
#self.model.fit(X33_train,Y_train, epochs=self.nb_epoch,batch_size=self.batch_size,validation_data=(X_patches_valid,Y_labels_valid),verbose=1, callbacks = [checkpointer,SGDLearningRateTracker()]) |
|
|
53 |
|
|
|
54 |
def img_msk_gen(self,X33_train,Y_train,seed): |
|
|
55 |
|
|
|
56 |
''' |
|
|
57 |
a custom generator that performs data augmentation on both patches and their corresponding targets (masks) |
|
|
58 |
''' |
|
|
59 |
datagen = ImageDataGenerator(horizontal_flip=True,data_format="channels_last") |
|
|
60 |
datagen_msk = ImageDataGenerator(horizontal_flip=True,data_format="channels_last") |
|
|
61 |
image_generator = datagen.flow(X33_train,batch_size=4,seed=seed) |
|
|
62 |
y_generator = datagen_msk.flow(Y_train,batch_size=4,seed=seed) |
|
|
63 |
while True: |
|
|
64 |
yield(image_generator.next(), y_generator.next()) |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
def save_model(self, model_name): |
|
|
68 |
''' |
|
|
69 |
INPUT string 'model_name': path where to save model and weights, without extension |
|
|
70 |
Saves current model as json and weights as h5df file |
|
|
71 |
''' |
|
|
72 |
|
|
|
73 |
model_tosave = '{}.json'.format(model_name) |
|
|
74 |
weights = '{}.hdf5'.format(model_name) |
|
|
75 |
json_string = self.model.to_json() |
|
|
76 |
self.model.save_weights(weights) |
|
|
77 |
with open(model_tosave, 'w') as f: |
|
|
78 |
json.dump(json_string, f) |
|
|
79 |
print ('Model saved.') |
|
|
80 |
|
|
|
81 |
def load_model(self, model_name): |
|
|
82 |
''' |
|
|
83 |
Load a model |
|
|
84 |
INPUT (1) string 'model_name': filepath to model and weights, not including extension |
|
|
85 |
OUTPUT: Model with loaded weights. can fit on model using loaded_model=True in fit_model method |
|
|
86 |
''' |
|
|
87 |
print ('Loading model {}'.format(model_name)) |
|
|
88 |
model_toload = '{}.json'.format(model_name) |
|
|
89 |
weights = '{}.hdf5'.format(model_name) |
|
|
90 |
with open(model_toload) as f: |
|
|
91 |
m = next(f) |
|
|
92 |
model_comp = model_from_json(json.loads(m)) |
|
|
93 |
model_comp.load_weights(weights) |
|
|
94 |
print ('Model loaded.') |
|
|
95 |
self.model = model_comp |
|
|
96 |
return model_comp |
|
|
97 |
|
|
|
98 |
|
|
|
99 |
|
|
|
100 |
if __name__ == "__main__": |
|
|
101 |
#set arguments |
|
|
102 |
|
|
|
103 |
#reload already trained model to resume training |
|
|
104 |
model_to_load="Models/ResUnet.04_0.646.hdf5" |
|
|
105 |
#save=None |
|
|
106 |
|
|
|
107 |
#compile the model |
|
|
108 |
brain_seg = Training(batch_size=4,nb_epoch=3,load_model_resume_training=model_to_load) |
|
|
109 |
|
|
|
110 |
print("number of trainabale parameters:",brain_seg.model.count_params()) |
|
|
111 |
#print(brain_seg.model.summary()) |
|
|
112 |
#plot(brain_seg.model, to_file='model_architecture.png', show_shapes=True) |
|
|
113 |
|
|
|
114 |
#load data from disk |
|
|
115 |
Y_labels=np.load("y_training.npy").astype(np.uint8) |
|
|
116 |
X_patches=np.load("x_training.npy").astype(np.float32) |
|
|
117 |
Y_labels_valid=np.load("y_valid.npy").astype(np.uint8) |
|
|
118 |
X_patches_valid=np.load("x_valid.npy").astype(np.float32) |
|
|
119 |
print("loading patches done\n") |
|
|
120 |
|
|
|
121 |
# fit model |
|
|
122 |
brain_seg.fit_unet(X_patches,Y_labels,X_patches_valid,Y_labels_valid)#* |
|
|
123 |
|
|
|
124 |
#if save is not None: |
|
|
125 |
# brain_seg.save_model('models/' + save) |
|
|
126 |
|
|
|
127 |
|
|
|
128 |
|
|
|
129 |
|