[ccb1dd]: / fetal_net / training.py

Download this file

125 lines (113 with data), 7.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import itertools
import math
import os
from functools import partial
import keras
from keras import backend as K
from keras.callbacks import ModelCheckpoint, CSVLogger, LearningRateScheduler, ReduceLROnPlateau, EarlyStopping, \
LambdaCallback
from keras.models import load_model, Model
import fetal_net.model
from fetal_net.metrics import (dice_coefficient, dice_coefficient_loss, dice_coef, dice_coef_loss,
weighted_dice_coefficient_loss, weighted_dice_coefficient,
vod_coefficient, vod_coefficient_loss, focal_loss, dice_and_xent, double_dice_loss)
K.set_image_dim_ordering('th')
from multiprocessing import cpu_count
# learning rate schedule
def step_decay(epoch, initial_lrate, drop, epochs_drop):
return initial_lrate * math.pow(drop, math.floor((1 + epoch) / float(epochs_drop)))
def get_callbacks(model_file, initial_learning_rate=0.0001, learning_rate_drop=0.5, learning_rate_epochs=None,
learning_rate_patience=50, logging_file="training.log", verbosity=1,
early_stopping_patience=None):
callbacks = list()
callbacks.append(
ModelCheckpoint(model_file + '-epoch{epoch:02d}-loss{val_loss:.3f}-acc{val_binary_accuracy:.3f}.h5',
save_best_only=True, verbose=verbosity, monitor='val_loss'))
callbacks.append(CSVLogger(logging_file, append=True))
if learning_rate_epochs:
callbacks.append(LearningRateScheduler(partial(step_decay, initial_lrate=initial_learning_rate,
drop=learning_rate_drop, epochs_drop=learning_rate_epochs)))
else:
callbacks.append(ReduceLROnPlateau(factor=learning_rate_drop, patience=learning_rate_patience,
verbose=verbosity))
if early_stopping_patience:
callbacks.append(EarlyStopping(verbose=verbosity, patience=early_stopping_patience))
return callbacks
def load_old_model(model_file, verbose=True, config=None) -> Model:
print("Loading pre-trained model")
custom_objects = {'dice_coefficient_loss': dice_coefficient_loss, 'dice_coefficient': dice_coefficient,
'dice_coef': dice_coef, 'dice_coef_loss': dice_coef_loss,
'weighted_dice_coefficient': weighted_dice_coefficient,
'weighted_dice_coefficient_loss': weighted_dice_coefficient_loss,
'vod_coefficient': vod_coefficient,
'vod_coefficient_loss': vod_coefficient_loss,
'focal_loss': focal_loss,
'focal_loss_fixed': focal_loss,
'dice_and_xent': dice_and_xent,
'double_dice_loss': double_dice_loss }
try:
from keras_contrib.layers import InstanceNormalization
custom_objects["InstanceNormalization"] = InstanceNormalization
except ImportError:
pass
try:
if verbose:
print('Loading model from {}...'.format(model_file))
return load_model(model_file, custom_objects=custom_objects)
except ValueError as error:
print(error)
if 'InstanceNormalization' in str(error):
raise ValueError(str(error) + "\n\nPlease install keras-contrib to use InstanceNormalization:\n"
"'pip install git+https://www.github.com/keras-team/keras-contrib.git'")
else:
if config is not None:
print('Trying to build model manually...')
loss_func = getattr(fetal_net.metrics, config['loss'])
model_func = getattr(fetal_net.model, config['model_name'])
model = model_func(input_shape=config["input_shape"],
initial_learning_rate=config["initial_learning_rate"],
**{'dropout_rate': config['dropout_rate'],
'loss_function': loss_func,
'mask_shape': None if config["weight_mask"] is None else config["input_shape"],
# TODO: change to output shape
'old_model_path': config['old_model']})
model.load_weights(model_file)
return model
else:
raise
def train_model(model, model_file, training_generator, validation_generator, steps_per_epoch, validation_steps,
initial_learning_rate=0.001, learning_rate_drop=0.5, learning_rate_epochs=None, n_epochs=500,
learning_rate_patience=20, early_stopping_patience=None, output_folder='.'):
"""
Train a Keras model.
:param early_stopping_patience: If set, training will end early if the validation loss does not improve after the
specified number of epochs.
:param learning_rate_patience: If learning_rate_epochs is not set, the learning rate will decrease if the validation
loss does not improve after the specified number of epochs. (default is 20)
:param model: Keras model that will be trained.
:param model_file: Where to save the Keras model.
:param training_generator: Generator that iterates through the training data.
:param validation_generator: Generator that iterates through the validation data.
:param steps_per_epoch: Number of batches that the training generator will provide during a given epoch.
:param validation_steps: Number of batches that the validation generator will provide during a given epoch.
:param initial_learning_rate: Learning rate at the beginning of training.
:param learning_rate_drop: How much at which to the learning rate will decay.
:param learning_rate_epochs: Number of epochs after which the learning rate will drop.
:param n_epochs: Total number of epochs to train the model.
:return:
"""
model.fit_generator(generator=training_generator,
steps_per_epoch=steps_per_epoch,
epochs=n_epochs,
validation_data=validation_generator,
validation_steps=validation_steps,
max_queue_size=15,
workers=1,
use_multiprocessing=False,
callbacks=get_callbacks(model_file,
initial_learning_rate=initial_learning_rate,
learning_rate_drop=learning_rate_drop,
learning_rate_epochs=learning_rate_epochs,
learning_rate_patience=learning_rate_patience,
early_stopping_patience=early_stopping_patience,
logging_file=os.path.join(output_folder, 'training')))