|
a |
|
b/fetal_net/training.py |
|
|
1 |
import itertools |
|
|
2 |
import math |
|
|
3 |
import os |
|
|
4 |
from functools import partial |
|
|
5 |
|
|
|
6 |
import keras |
|
|
7 |
from keras import backend as K |
|
|
8 |
from keras.callbacks import ModelCheckpoint, CSVLogger, LearningRateScheduler, ReduceLROnPlateau, EarlyStopping, \ |
|
|
9 |
LambdaCallback |
|
|
10 |
from keras.models import load_model, Model |
|
|
11 |
|
|
|
12 |
import fetal_net.model |
|
|
13 |
from fetal_net.metrics import (dice_coefficient, dice_coefficient_loss, dice_coef, dice_coef_loss, |
|
|
14 |
weighted_dice_coefficient_loss, weighted_dice_coefficient, |
|
|
15 |
vod_coefficient, vod_coefficient_loss, focal_loss, dice_and_xent, double_dice_loss) |
|
|
16 |
|
|
|
17 |
K.set_image_dim_ordering('th') |
|
|
18 |
from multiprocessing import cpu_count |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
# learning rate schedule |
|
|
22 |
def step_decay(epoch, initial_lrate, drop, epochs_drop): |
|
|
23 |
return initial_lrate * math.pow(drop, math.floor((1 + epoch) / float(epochs_drop))) |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
def get_callbacks(model_file, initial_learning_rate=0.0001, learning_rate_drop=0.5, learning_rate_epochs=None, |
|
|
27 |
learning_rate_patience=50, logging_file="training.log", verbosity=1, |
|
|
28 |
early_stopping_patience=None): |
|
|
29 |
callbacks = list() |
|
|
30 |
callbacks.append( |
|
|
31 |
ModelCheckpoint(model_file + '-epoch{epoch:02d}-loss{val_loss:.3f}-acc{val_binary_accuracy:.3f}.h5', |
|
|
32 |
save_best_only=True, verbose=verbosity, monitor='val_loss')) |
|
|
33 |
callbacks.append(CSVLogger(logging_file, append=True)) |
|
|
34 |
if learning_rate_epochs: |
|
|
35 |
callbacks.append(LearningRateScheduler(partial(step_decay, initial_lrate=initial_learning_rate, |
|
|
36 |
drop=learning_rate_drop, epochs_drop=learning_rate_epochs))) |
|
|
37 |
else: |
|
|
38 |
callbacks.append(ReduceLROnPlateau(factor=learning_rate_drop, patience=learning_rate_patience, |
|
|
39 |
verbose=verbosity)) |
|
|
40 |
if early_stopping_patience: |
|
|
41 |
callbacks.append(EarlyStopping(verbose=verbosity, patience=early_stopping_patience)) |
|
|
42 |
return callbacks |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
def load_old_model(model_file, verbose=True, config=None) -> Model: |
|
|
46 |
print("Loading pre-trained model") |
|
|
47 |
custom_objects = {'dice_coefficient_loss': dice_coefficient_loss, 'dice_coefficient': dice_coefficient, |
|
|
48 |
'dice_coef': dice_coef, 'dice_coef_loss': dice_coef_loss, |
|
|
49 |
'weighted_dice_coefficient': weighted_dice_coefficient, |
|
|
50 |
'weighted_dice_coefficient_loss': weighted_dice_coefficient_loss, |
|
|
51 |
'vod_coefficient': vod_coefficient, |
|
|
52 |
'vod_coefficient_loss': vod_coefficient_loss, |
|
|
53 |
'focal_loss': focal_loss, |
|
|
54 |
'focal_loss_fixed': focal_loss, |
|
|
55 |
'dice_and_xent': dice_and_xent, |
|
|
56 |
'double_dice_loss': double_dice_loss } |
|
|
57 |
try: |
|
|
58 |
from keras_contrib.layers import InstanceNormalization |
|
|
59 |
custom_objects["InstanceNormalization"] = InstanceNormalization |
|
|
60 |
except ImportError: |
|
|
61 |
pass |
|
|
62 |
try: |
|
|
63 |
if verbose: |
|
|
64 |
print('Loading model from {}...'.format(model_file)) |
|
|
65 |
return load_model(model_file, custom_objects=custom_objects) |
|
|
66 |
except ValueError as error: |
|
|
67 |
print(error) |
|
|
68 |
if 'InstanceNormalization' in str(error): |
|
|
69 |
raise ValueError(str(error) + "\n\nPlease install keras-contrib to use InstanceNormalization:\n" |
|
|
70 |
"'pip install git+https://www.github.com/keras-team/keras-contrib.git'") |
|
|
71 |
else: |
|
|
72 |
if config is not None: |
|
|
73 |
print('Trying to build model manually...') |
|
|
74 |
loss_func = getattr(fetal_net.metrics, config['loss']) |
|
|
75 |
model_func = getattr(fetal_net.model, config['model_name']) |
|
|
76 |
model = model_func(input_shape=config["input_shape"], |
|
|
77 |
initial_learning_rate=config["initial_learning_rate"], |
|
|
78 |
**{'dropout_rate': config['dropout_rate'], |
|
|
79 |
'loss_function': loss_func, |
|
|
80 |
'mask_shape': None if config["weight_mask"] is None else config["input_shape"], |
|
|
81 |
# TODO: change to output shape |
|
|
82 |
'old_model_path': config['old_model']}) |
|
|
83 |
model.load_weights(model_file) |
|
|
84 |
return model |
|
|
85 |
else: |
|
|
86 |
raise |
|
|
87 |
|
|
|
88 |
|
|
|
89 |
def train_model(model, model_file, training_generator, validation_generator, steps_per_epoch, validation_steps, |
|
|
90 |
initial_learning_rate=0.001, learning_rate_drop=0.5, learning_rate_epochs=None, n_epochs=500, |
|
|
91 |
learning_rate_patience=20, early_stopping_patience=None, output_folder='.'): |
|
|
92 |
""" |
|
|
93 |
Train a Keras model. |
|
|
94 |
:param early_stopping_patience: If set, training will end early if the validation loss does not improve after the |
|
|
95 |
specified number of epochs. |
|
|
96 |
:param learning_rate_patience: If learning_rate_epochs is not set, the learning rate will decrease if the validation |
|
|
97 |
loss does not improve after the specified number of epochs. (default is 20) |
|
|
98 |
:param model: Keras model that will be trained. |
|
|
99 |
:param model_file: Where to save the Keras model. |
|
|
100 |
:param training_generator: Generator that iterates through the training data. |
|
|
101 |
:param validation_generator: Generator that iterates through the validation data. |
|
|
102 |
:param steps_per_epoch: Number of batches that the training generator will provide during a given epoch. |
|
|
103 |
:param validation_steps: Number of batches that the validation generator will provide during a given epoch. |
|
|
104 |
:param initial_learning_rate: Learning rate at the beginning of training. |
|
|
105 |
:param learning_rate_drop: How much at which to the learning rate will decay. |
|
|
106 |
:param learning_rate_epochs: Number of epochs after which the learning rate will drop. |
|
|
107 |
:param n_epochs: Total number of epochs to train the model. |
|
|
108 |
:return: |
|
|
109 |
""" |
|
|
110 |
model.fit_generator(generator=training_generator, |
|
|
111 |
steps_per_epoch=steps_per_epoch, |
|
|
112 |
epochs=n_epochs, |
|
|
113 |
validation_data=validation_generator, |
|
|
114 |
validation_steps=validation_steps, |
|
|
115 |
max_queue_size=15, |
|
|
116 |
workers=1, |
|
|
117 |
use_multiprocessing=False, |
|
|
118 |
callbacks=get_callbacks(model_file, |
|
|
119 |
initial_learning_rate=initial_learning_rate, |
|
|
120 |
learning_rate_drop=learning_rate_drop, |
|
|
121 |
learning_rate_epochs=learning_rate_epochs, |
|
|
122 |
learning_rate_patience=learning_rate_patience, |
|
|
123 |
early_stopping_patience=early_stopping_patience, |
|
|
124 |
logging_file=os.path.join(output_folder, 'training'))) |