a b/train.py
1
import numpy as np
2
import random
3
import json
4
from glob import glob
5
from keras.models import model_from_json,load_model
6
from keras.preprocessing.image import ImageDataGenerator
7
from keras.callbacks import  ModelCheckpoint,Callback,LearningRateScheduler
8
import keras.backend as K
9
from model import Unet_model
10
from losses import *
11
#from keras.utils.visualize_util import plot
12
13
14
15
class SGDLearningRateTracker(Callback):
16
    def on_epoch_begin(self, epoch, logs={}):
17
        optimizer = self.model.optimizer
18
        lr = K.get_value(optimizer.lr)
19
        decay = K.get_value(optimizer.decay)
20
        lr=lr/10
21
        decay=decay*10
22
        K.set_value(optimizer.lr, lr)
23
        K.set_value(optimizer.decay, decay)
24
        print('LR changed to:',lr)
25
        print('Decay changed to:',decay)
26
27
28
29
class Training(object):
30
    
31
32
    def __init__(self, batch_size,nb_epoch,load_model_resume_training=None):
33
34
        self.batch_size = batch_size
35
        self.nb_epoch = nb_epoch
36
37
        #loading model from path to resume previous training without recompiling the whole model
38
        if load_model_resume_training is not None:
39
            self.model =load_model(load_model_resume_training,custom_objects={'gen_dice_loss': gen_dice_loss,'dice_whole_metric':dice_whole_metric,'dice_core_metric':dice_core_metric,'dice_en_metric':dice_en_metric})
40
            print("pre-trained model loaded!")
41
        else:
42
            unet =Unet_model(img_shape=(128,128,4))
43
            self.model=unet.model
44
            print("U-net CNN compiled!")
45
46
                    
47
    def fit_unet(self,X33_train,Y_train,X_patches_valid=None,Y_labels_valid=None):
48
49
        train_generator=self.img_msk_gen(X33_train,Y_train,9999)
50
        checkpointer = ModelCheckpoint(filepath='brain_segmentation/ResUnet.{epoch:02d}_{val_loss:.3f}.hdf5', verbose=1)
51
        self.model.fit_generator(train_generator,steps_per_epoch=len(X33_train)//self.batch_size,epochs=self.nb_epoch, validation_data=(X_patches_valid,Y_labels_valid),verbose=1, callbacks = [checkpointer,SGDLearningRateTracker()])
52
        #self.model.fit(X33_train,Y_train, epochs=self.nb_epoch,batch_size=self.batch_size,validation_data=(X_patches_valid,Y_labels_valid),verbose=1, callbacks = [checkpointer,SGDLearningRateTracker()])
53
54
    def img_msk_gen(self,X33_train,Y_train,seed):
55
56
        '''
57
        a custom generator that performs data augmentation on both patches and their corresponding targets (masks)
58
        '''
59
        datagen = ImageDataGenerator(horizontal_flip=True,data_format="channels_last")
60
        datagen_msk = ImageDataGenerator(horizontal_flip=True,data_format="channels_last")
61
        image_generator = datagen.flow(X33_train,batch_size=4,seed=seed)
62
        y_generator = datagen_msk.flow(Y_train,batch_size=4,seed=seed)
63
        while True:
64
            yield(image_generator.next(), y_generator.next())
65
66
67
    def save_model(self, model_name):
68
        '''
69
        INPUT string 'model_name': path where to save model and weights, without extension
70
        Saves current model as json and weights as h5df file
71
        '''
72
73
        model_tosave = '{}.json'.format(model_name)
74
        weights = '{}.hdf5'.format(model_name)
75
        json_string = self.model.to_json()
76
        self.model.save_weights(weights)
77
        with open(model_tosave, 'w') as f:
78
            json.dump(json_string, f)
79
        print ('Model saved.')
80
81
    def load_model(self, model_name):
82
        '''
83
        Load a model
84
        INPUT  (1) string 'model_name': filepath to model and weights, not including extension
85
        OUTPUT: Model with loaded weights. can fit on model using loaded_model=True in fit_model method
86
        '''
87
        print ('Loading model {}'.format(model_name))
88
        model_toload = '{}.json'.format(model_name)
89
        weights = '{}.hdf5'.format(model_name)
90
        with open(model_toload) as f:
91
            m = next(f)
92
        model_comp = model_from_json(json.loads(m))
93
        model_comp.load_weights(weights)
94
        print ('Model loaded.')
95
        self.model = model_comp
96
        return model_comp
97
98
99
100
if __name__ == "__main__":
101
    #set arguments
102
103
    #reload already trained model to resume training
104
    model_to_load="Models/ResUnet.04_0.646.hdf5" 
105
    #save=None
106
107
    #compile the model
108
    brain_seg = Training(batch_size=4,nb_epoch=3,load_model_resume_training=model_to_load)
109
110
    print("number of trainabale parameters:",brain_seg.model.count_params())
111
    #print(brain_seg.model.summary())
112
    #plot(brain_seg.model, to_file='model_architecture.png', show_shapes=True)
113
114
    #load data from disk
115
    Y_labels=np.load("y_training.npy").astype(np.uint8)
116
    X_patches=np.load("x_training.npy").astype(np.float32)
117
    Y_labels_valid=np.load("y_valid.npy").astype(np.uint8)
118
    X_patches_valid=np.load("x_valid.npy").astype(np.float32)
119
    print("loading patches done\n")
120
121
    # fit model
122
    brain_seg.fit_unet(X_patches,Y_labels,X_patches_valid,Y_labels_valid)#*
123
124
    #if save is not None:
125
    #    brain_seg.save_model('models/' + save)
126
127
128
129