Diff of /fetal_net/training.py [000000] .. [ccb1dd]

Switch to unified view

a b/fetal_net/training.py
1
import itertools
2
import math
3
import os
4
from functools import partial
5
6
import keras
7
from keras import backend as K
8
from keras.callbacks import ModelCheckpoint, CSVLogger, LearningRateScheduler, ReduceLROnPlateau, EarlyStopping, \
9
    LambdaCallback
10
from keras.models import load_model, Model
11
12
import fetal_net.model
13
from fetal_net.metrics import (dice_coefficient, dice_coefficient_loss, dice_coef, dice_coef_loss,
14
                               weighted_dice_coefficient_loss, weighted_dice_coefficient,
15
                               vod_coefficient, vod_coefficient_loss, focal_loss, dice_and_xent, double_dice_loss)
16
17
K.set_image_dim_ordering('th')
18
from multiprocessing import cpu_count
19
20
21
# learning rate schedule
22
def step_decay(epoch, initial_lrate, drop, epochs_drop):
23
    return initial_lrate * math.pow(drop, math.floor((1 + epoch) / float(epochs_drop)))
24
25
26
def get_callbacks(model_file, initial_learning_rate=0.0001, learning_rate_drop=0.5, learning_rate_epochs=None,
27
                  learning_rate_patience=50, logging_file="training.log", verbosity=1,
28
                  early_stopping_patience=None):
29
    callbacks = list()
30
    callbacks.append(
31
        ModelCheckpoint(model_file + '-epoch{epoch:02d}-loss{val_loss:.3f}-acc{val_binary_accuracy:.3f}.h5',
32
                        save_best_only=True, verbose=verbosity, monitor='val_loss'))
33
    callbacks.append(CSVLogger(logging_file, append=True))
34
    if learning_rate_epochs:
35
        callbacks.append(LearningRateScheduler(partial(step_decay, initial_lrate=initial_learning_rate,
36
                                                       drop=learning_rate_drop, epochs_drop=learning_rate_epochs)))
37
    else:
38
        callbacks.append(ReduceLROnPlateau(factor=learning_rate_drop, patience=learning_rate_patience,
39
                                           verbose=verbosity))
40
    if early_stopping_patience:
41
        callbacks.append(EarlyStopping(verbose=verbosity, patience=early_stopping_patience))
42
    return callbacks
43
44
45
def load_old_model(model_file, verbose=True, config=None) -> Model:
46
    print("Loading pre-trained model")
47
    custom_objects = {'dice_coefficient_loss': dice_coefficient_loss, 'dice_coefficient': dice_coefficient,
48
                      'dice_coef': dice_coef, 'dice_coef_loss': dice_coef_loss,
49
                      'weighted_dice_coefficient': weighted_dice_coefficient,
50
                      'weighted_dice_coefficient_loss': weighted_dice_coefficient_loss,
51
                      'vod_coefficient': vod_coefficient,
52
                      'vod_coefficient_loss': vod_coefficient_loss,
53
                      'focal_loss': focal_loss,
54
                      'focal_loss_fixed': focal_loss,
55
                      'dice_and_xent': dice_and_xent,
56
                      'double_dice_loss': double_dice_loss }
57
    try:
58
        from keras_contrib.layers import InstanceNormalization
59
        custom_objects["InstanceNormalization"] = InstanceNormalization
60
    except ImportError:
61
        pass
62
    try:
63
        if verbose:
64
            print('Loading model from {}...'.format(model_file))
65
        return load_model(model_file, custom_objects=custom_objects)
66
    except ValueError as error:
67
        print(error)
68
        if 'InstanceNormalization' in str(error):
69
            raise ValueError(str(error) + "\n\nPlease install keras-contrib to use InstanceNormalization:\n"
70
                                          "'pip install git+https://www.github.com/keras-team/keras-contrib.git'")
71
        else:
72
            if config is not None:
73
                print('Trying to build model manually...')
74
                loss_func = getattr(fetal_net.metrics, config['loss'])
75
                model_func = getattr(fetal_net.model, config['model_name'])
76
                model = model_func(input_shape=config["input_shape"],
77
                                   initial_learning_rate=config["initial_learning_rate"],
78
                                   **{'dropout_rate': config['dropout_rate'],
79
                                      'loss_function': loss_func,
80
                                      'mask_shape': None if config["weight_mask"] is None else config["input_shape"],
81
                                      # TODO: change to output shape
82
                                      'old_model_path': config['old_model']})
83
                model.load_weights(model_file)
84
                return model
85
            else:
86
                raise
87
88
89
def train_model(model, model_file, training_generator, validation_generator, steps_per_epoch, validation_steps,
90
                initial_learning_rate=0.001, learning_rate_drop=0.5, learning_rate_epochs=None, n_epochs=500,
91
                learning_rate_patience=20, early_stopping_patience=None, output_folder='.'):
92
    """
93
    Train a Keras model.
94
    :param early_stopping_patience: If set, training will end early if the validation loss does not improve after the
95
    specified number of epochs.
96
    :param learning_rate_patience: If learning_rate_epochs is not set, the learning rate will decrease if the validation
97
    loss does not improve after the specified number of epochs. (default is 20)
98
    :param model: Keras model that will be trained.
99
    :param model_file: Where to save the Keras model.
100
    :param training_generator: Generator that iterates through the training data.
101
    :param validation_generator: Generator that iterates through the validation data.
102
    :param steps_per_epoch: Number of batches that the training generator will provide during a given epoch.
103
    :param validation_steps: Number of batches that the validation generator will provide during a given epoch.
104
    :param initial_learning_rate: Learning rate at the beginning of training.
105
    :param learning_rate_drop: How much at which to the learning rate will decay.
106
    :param learning_rate_epochs: Number of epochs after which the learning rate will drop.
107
    :param n_epochs: Total number of epochs to train the model.
108
    :return: 
109
    """
110
    model.fit_generator(generator=training_generator,
111
                        steps_per_epoch=steps_per_epoch,
112
                        epochs=n_epochs,
113
                        validation_data=validation_generator,
114
                        validation_steps=validation_steps,
115
                        max_queue_size=15,
116
                        workers=1,
117
                        use_multiprocessing=False,
118
                        callbacks=get_callbacks(model_file,
119
                                                initial_learning_rate=initial_learning_rate,
120
                                                learning_rate_drop=learning_rate_drop,
121
                                                learning_rate_epochs=learning_rate_epochs,
122
                                                learning_rate_patience=learning_rate_patience,
123
                                                early_stopping_patience=early_stopping_patience,
124
                                                logging_file=os.path.join(output_folder, 'training')))