a b/flair-segmentation/train.py
1
from __future__ import print_function
2
3
import os
4
import sys
5
6
import numpy as np
7
import tensorflow as tf
8
from keras import backend as K
9
from keras.callbacks import TensorBoard
10
from keras.optimizers import Adam
11
12
from data import load_data
13
from data import oversample
14
from net import dice_coef
15
from net import dice_coef_loss
16
from net import unet
17
18
train_images_path = "./data/train/"
19
valid_images_path = "./data/valid/"
20
init_weights_path = "./weights_128.h5"
21
weights_path = "."
22
log_path = "."
23
24
gpu = "0"
25
26
epochs = 128
27
batch_size = 32
28
base_lr = 1e-5
29
30
31
def train():
32
    imgs_train, imgs_mask_train, _ = load_data(train_images_path)
33
34
    mean = np.mean(imgs_train)
35
    std = np.std(imgs_train)
36
37
    imgs_train -= mean
38
    imgs_train /= std
39
40
    imgs_valid, imgs_mask_valid, _ = load_data(valid_images_path)
41
42
    imgs_valid -= mean
43
    imgs_valid /= std
44
45
    imgs_train, imgs_mask_train = oversample(imgs_train, imgs_mask_train)
46
47
    model = unet()
48
    if os.path.exists(init_weights_path):
49
        model.load_weights(init_weights_path)
50
51
    optimizer = Adam(lr=base_lr)
52
    model.compile(optimizer=optimizer, loss=dice_coef_loss, metrics=[dice_coef])
53
54
    if not os.path.exists(log_path):
55
        os.mkdir(log_path)
56
57
    training_log = TensorBoard(log_dir=log_path)
58
59
    model.fit(
60
        imgs_train,
61
        imgs_mask_train,
62
        validation_data=(imgs_valid, imgs_mask_valid),
63
        batch_size=batch_size,
64
        epochs=epochs,
65
        shuffle=True,
66
        callbacks=[training_log],
67
    )
68
69
    if not os.path.exists(weights_path):
70
        os.mkdir(weights_path)
71
    model.save_weights(os.path.join(weights_path, "weights_{}.h5".format(epochs)))
72
73
74
if __name__ == "__main__":
75
    config = tf.ConfigProto()
76
    config.gpu_options.allow_growth = True
77
    config.allow_soft_placement = True
78
    sess = tf.Session(config=config)
79
    K.set_session(sess)
80
81
    if len(sys.argv) > 1:
82
        gpu = sys.argv[1]
83
    device = "/gpu:" + gpu
84
85
    with tf.device(device):
86
        train()