a b/scripts/train.py
1
#!/usr/bin/env python
2
3
from __future__ import division, print_function
4
5
import os
6
import argparse
7
import logging
8
9
from keras import losses, optimizers, utils
10
from keras.optimizers import SGD, RMSprop, Adagrad, Adadelta, Adam, Adamax, Nadam
11
from keras.callbacks import ModelCheckpoint
12
from keras import backend as K
13
14
from rvseg import dataset, models, loss, opts
15
16
17
def select_optimizer(optimizer_name, optimizer_args):
18
    optimizers = {
19
        'sgd': SGD,
20
        'rmsprop': RMSprop,
21
        'adagrad': Adagrad,
22
        'adadelta': Adadelta,
23
        'adam': Adam,
24
        'adamax': Adamax,
25
        'nadam': Nadam,
26
    }
27
    if optimizer_name not in optimizers:
28
        raise Exception("Unknown optimizer ({}).".format(name))
29
30
    return optimizers[optimizer_name](**optimizer_args)
31
32
def train():
33
    logging.basicConfig(level=logging.INFO)
34
35
    args = opts.parse_arguments()
36
37
    logging.info("Loading dataset...")
38
    augmentation_args = {
39
        'rotation_range': args.rotation_range,
40
        'width_shift_range': args.width_shift_range,
41
        'height_shift_range': args.height_shift_range,
42
        'shear_range': args.shear_range,
43
        'zoom_range': args.zoom_range,
44
        'fill_mode' : args.fill_mode,
45
        'alpha': args.alpha,
46
        'sigma': args.sigma,
47
    }
48
    train_generator, train_steps_per_epoch, \
49
        val_generator, val_steps_per_epoch = dataset.create_generators(
50
            args.datadir, args.batch_size,
51
            validation_split=args.validation_split,
52
            mask=args.classes,
53
            shuffle_train_val=args.shuffle_train_val,
54
            shuffle=args.shuffle,
55
            seed=args.seed,
56
            normalize_images=args.normalize,
57
            augment_training=args.augment_training,
58
            augment_validation=args.augment_validation,
59
            augmentation_args=augmentation_args)
60
61
    # get image dimensions from first batch
62
    images, masks = next(train_generator)
63
    _, height, width, channels = images.shape
64
    _, _, _, classes = masks.shape
65
66
    logging.info("Building model...")
67
    string_to_model = {
68
        "unet": models.unet,
69
        "dilated-unet": models.dilated_unet,
70
        "dilated-densenet": models.dilated_densenet,
71
        "dilated-densenet2": models.dilated_densenet2,
72
        "dilated-densenet3": models.dilated_densenet3,
73
    }
74
    model = string_to_model[args.model]
75
    m = model(height=height, width=width, channels=channels, classes=classes,
76
              features=args.features, depth=args.depth, padding=args.padding,
77
              temperature=args.temperature, batchnorm=args.batchnorm,
78
              dropout=args.dropout)
79
80
    m.summary()
81
82
    if args.load_weights:
83
        logging.info("Loading saved weights from file: {}".format(args.load_weights))
84
        m.load_weights(args.load_weights)
85
86
    # instantiate optimizer, and only keep args that have been set
87
    # (not all optimizers have args like `momentum' or `decay')
88
    optimizer_args = {
89
        'lr':       args.learning_rate,
90
        'momentum': args.momentum,
91
        'decay':    args.decay
92
    }
93
    for k in list(optimizer_args):
94
        if optimizer_args[k] is None:
95
            del optimizer_args[k]
96
    optimizer = select_optimizer(args.optimizer, optimizer_args)
97
98
    # select loss function: pixel-wise crossentropy, soft dice or soft
99
    # jaccard coefficient
100
    if args.loss == 'pixel':
101
        def lossfunc(y_true, y_pred):
102
            return loss.weighted_categorical_crossentropy(
103
                y_true, y_pred, args.loss_weights)
104
    elif args.loss == 'dice':
105
        def lossfunc(y_true, y_pred):
106
            return loss.sorensen_dice_loss(y_true, y_pred, args.loss_weights)
107
    elif args.loss == 'jaccard':
108
        def lossfunc(y_true, y_pred):
109
            return loss.jaccard_loss(y_true, y_pred, args.loss_weights)
110
    else:
111
        raise Exception("Unknown loss ({})".format(args.loss))
112
113
    def dice(y_true, y_pred):
114
        batch_dice_coefs = loss.sorensen_dice(y_true, y_pred, axis=[1, 2])
115
        dice_coefs = K.mean(batch_dice_coefs, axis=0)
116
        return dice_coefs[1]    # HACK for 2-class case
117
118
    def jaccard(y_true, y_pred):
119
        batch_jaccard_coefs = loss.jaccard(y_true, y_pred, axis=[1, 2])
120
        jaccard_coefs = K.mean(batch_jaccard_coefs, axis=0)
121
        return jaccard_coefs[1] # HACK for 2-class case
122
123
    metrics = ['accuracy', dice, jaccard]
124
125
    m.compile(optimizer=optimizer, loss=lossfunc, metrics=metrics)
126
127
    # automatic saving of model during training
128
    if args.checkpoint:
129
        if args.loss == 'pixel':
130
            filepath = os.path.join(
131
                args.outdir, "weights-{epoch:02d}-{val_acc:.4f}.hdf5")
132
            monitor = 'val_acc'
133
            mode = 'max'
134
        elif args.loss == 'dice':
135
            filepath = os.path.join(
136
                args.outdir, "weights-{epoch:02d}-{val_dice:.4f}.hdf5")
137
            monitor='val_dice'
138
            mode = 'max'
139
        elif args.loss == 'jaccard':
140
            filepath = os.path.join(
141
                args.outdir, "weights-{epoch:02d}-{val_jaccard:.4f}.hdf5")
142
            monitor='val_jaccard'
143
            mode = 'max'
144
        checkpoint = ModelCheckpoint(
145
            filepath, monitor=monitor, verbose=1,
146
            save_best_only=True, mode=mode)
147
        callbacks = [checkpoint]
148
    else:
149
        callbacks = []
150
151
    # train
152
    logging.info("Begin training.")
153
    m.fit_generator(train_generator,
154
                    epochs=args.epochs,
155
                    steps_per_epoch=train_steps_per_epoch,
156
                    validation_data=val_generator,
157
                    validation_steps=val_steps_per_epoch,
158
                    callbacks=callbacks,
159
                    verbose=2)
160
161
    m.save(os.path.join(args.outdir, args.outfile))
162
163
if __name__ == '__main__':
164
    train()