|
a |
|
b/train.py |
|
|
1 |
import os |
|
|
2 |
import tables |
|
|
3 |
import numpy as np |
|
|
4 |
from config import cfg |
|
|
5 |
from model import unet_model |
|
|
6 |
from data_generator import CustomDataGenerator |
|
|
7 |
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint, TensorBoard |
|
|
8 |
|
|
|
9 |
def train_model(hdf5_dir, brains_idx_dir, view, modified_unet=True, batch_size=16, val_batch_size=32, |
|
|
10 |
lr=0.01, epochs=100, hor_flip=False, ver_flip=False, zoom_range=0.0, save_dir='./save/', |
|
|
11 |
start_chs=64, levels=3, multiprocessing=False, load_model_dir=None): |
|
|
12 |
""" |
|
|
13 |
|
|
|
14 |
The function that builds/loads UNet model, initializes the data generators for training and validation, and finally |
|
|
15 |
trains the model. |
|
|
16 |
|
|
|
17 |
""" |
|
|
18 |
# preparing generators |
|
|
19 |
hdf5_file = tables.open_file(hdf5_dir, mode='r+') |
|
|
20 |
brain_idx = np.load(brains_idx_dir) |
|
|
21 |
datagen_train = CustomDataGenerator(hdf5_file, brain_idx, batch_size, view, 'train', |
|
|
22 |
hor_flip, ver_flip, zoom_range, shuffle=True) |
|
|
23 |
datagen_val = CustomDataGenerator(hdf5_file, brain_idx, val_batch_size, view, 'validation', shuffle=False) |
|
|
24 |
|
|
|
25 |
# add callbacks |
|
|
26 |
save_dir = os.path.join(save_dir, '{}_{}'.format(view, os.path.basename(brains_idx_dir)[:5])) |
|
|
27 |
if not os.path.isdir(save_dir): |
|
|
28 |
os.mkdir(save_dir) |
|
|
29 |
logger = CSVLogger(os.path.join(save_dir, 'log.txt')) |
|
|
30 |
checkpointer = ModelCheckpoint(filepath = os.path.join(save_dir, 'model.hdf5'), verbose=1, save_best_only=True) |
|
|
31 |
tensorboard = TensorBoard(os.path.join(save_dir, 'tensorboard')) |
|
|
32 |
callbacks = [logger, checkpointer, tensorboard] |
|
|
33 |
|
|
|
34 |
# building the model |
|
|
35 |
model_input_shape = datagen_train.data_shape[1:] |
|
|
36 |
model = unet_model(model_input_shape, modified_unet, lr, start_chs, levels) |
|
|
37 |
# training the model |
|
|
38 |
model.fit_generator(datagen_train, epochs=epochs, use_multiprocessing=multiprocessing, |
|
|
39 |
callbacks=callbacks, validation_data = datagen_val) |
|
|
40 |
|
|
|
41 |
|
|
|
42 |
|
|
|
43 |
if __name__ == '__main__': |
|
|
44 |
|
|
|
45 |
|
|
|
46 |
train_model(cfg['hdf5_dir'], cfg['brains_idx_dir'], cfg['view'], cfg['modified_unet'], cfg['batch_size'], |
|
|
47 |
cfg['val_batch_size'], cfg['lr'], cfg['epochs'], cfg['hor_flip'], cfg['ver_flip'], cfg['zoom_range'], |
|
|
48 |
cfg['save_dir'], cfg['start_chs'], cfg['levels'], cfg['multiprocessing'], |
|
|
49 |
cfg['load_model_dir']) |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
|
|
|
53 |
|
|
|
54 |
|