|
a |
|
b/flair-segmentation/train.py |
|
|
1 |
from __future__ import print_function |
|
|
2 |
|
|
|
3 |
import os |
|
|
4 |
import sys |
|
|
5 |
|
|
|
6 |
import numpy as np |
|
|
7 |
import tensorflow as tf |
|
|
8 |
from keras import backend as K |
|
|
9 |
from keras.callbacks import TensorBoard |
|
|
10 |
from keras.optimizers import Adam |
|
|
11 |
|
|
|
12 |
from data import load_data |
|
|
13 |
from data import oversample |
|
|
14 |
from net import dice_coef |
|
|
15 |
from net import dice_coef_loss |
|
|
16 |
from net import unet |
|
|
17 |
|
|
|
18 |
train_images_path = "./data/train/" |
|
|
19 |
valid_images_path = "./data/valid/" |
|
|
20 |
init_weights_path = "./weights_128.h5" |
|
|
21 |
weights_path = "." |
|
|
22 |
log_path = "." |
|
|
23 |
|
|
|
24 |
gpu = "0" |
|
|
25 |
|
|
|
26 |
epochs = 128 |
|
|
27 |
batch_size = 32 |
|
|
28 |
base_lr = 1e-5 |
|
|
29 |
|
|
|
30 |
|
|
|
31 |
def train(): |
|
|
32 |
imgs_train, imgs_mask_train, _ = load_data(train_images_path) |
|
|
33 |
|
|
|
34 |
mean = np.mean(imgs_train) |
|
|
35 |
std = np.std(imgs_train) |
|
|
36 |
|
|
|
37 |
imgs_train -= mean |
|
|
38 |
imgs_train /= std |
|
|
39 |
|
|
|
40 |
imgs_valid, imgs_mask_valid, _ = load_data(valid_images_path) |
|
|
41 |
|
|
|
42 |
imgs_valid -= mean |
|
|
43 |
imgs_valid /= std |
|
|
44 |
|
|
|
45 |
imgs_train, imgs_mask_train = oversample(imgs_train, imgs_mask_train) |
|
|
46 |
|
|
|
47 |
model = unet() |
|
|
48 |
if os.path.exists(init_weights_path): |
|
|
49 |
model.load_weights(init_weights_path) |
|
|
50 |
|
|
|
51 |
optimizer = Adam(lr=base_lr) |
|
|
52 |
model.compile(optimizer=optimizer, loss=dice_coef_loss, metrics=[dice_coef]) |
|
|
53 |
|
|
|
54 |
if not os.path.exists(log_path): |
|
|
55 |
os.mkdir(log_path) |
|
|
56 |
|
|
|
57 |
training_log = TensorBoard(log_dir=log_path) |
|
|
58 |
|
|
|
59 |
model.fit( |
|
|
60 |
imgs_train, |
|
|
61 |
imgs_mask_train, |
|
|
62 |
validation_data=(imgs_valid, imgs_mask_valid), |
|
|
63 |
batch_size=batch_size, |
|
|
64 |
epochs=epochs, |
|
|
65 |
shuffle=True, |
|
|
66 |
callbacks=[training_log], |
|
|
67 |
) |
|
|
68 |
|
|
|
69 |
if not os.path.exists(weights_path): |
|
|
70 |
os.mkdir(weights_path) |
|
|
71 |
model.save_weights(os.path.join(weights_path, "weights_{}.h5".format(epochs))) |
|
|
72 |
|
|
|
73 |
|
|
|
74 |
if __name__ == "__main__": |
|
|
75 |
config = tf.ConfigProto() |
|
|
76 |
config.gpu_options.allow_growth = True |
|
|
77 |
config.allow_soft_placement = True |
|
|
78 |
sess = tf.Session(config=config) |
|
|
79 |
K.set_session(sess) |
|
|
80 |
|
|
|
81 |
if len(sys.argv) > 1: |
|
|
82 |
gpu = sys.argv[1] |
|
|
83 |
device = "/gpu:" + gpu |
|
|
84 |
|
|
|
85 |
with tf.device(device): |
|
|
86 |
train() |