a b/train.py
1
import os
2
import tables
3
import numpy as np
4
from config import cfg
5
from model import unet_model
6
from data_generator import CustomDataGenerator
7
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint, TensorBoard
8
9
def train_model(hdf5_dir, brains_idx_dir, view, modified_unet=True, batch_size=16, val_batch_size=32,
10
                lr=0.01, epochs=100, hor_flip=False, ver_flip=False, zoom_range=0.0, save_dir='./save/',
11
                start_chs=64, levels=3, multiprocessing=False, load_model_dir=None):
12
    """
13
14
    The function that builds/loads UNet model, initializes the data generators for training and validation, and finally 
15
    trains the model.
16
17
    """
18
    # preparing generators
19
    hdf5_file        = tables.open_file(hdf5_dir, mode='r+')
20
    brain_idx        = np.load(brains_idx_dir)
21
    datagen_train    = CustomDataGenerator(hdf5_file, brain_idx, batch_size, view, 'train',
22
                                    hor_flip, ver_flip, zoom_range, shuffle=True)
23
    datagen_val      = CustomDataGenerator(hdf5_file, brain_idx, val_batch_size, view, 'validation', shuffle=False)
24
    
25
    # add callbacks    
26
    save_dir     = os.path.join(save_dir, '{}_{}'.format(view, os.path.basename(brains_idx_dir)[:5]))
27
    if not os.path.isdir(save_dir):
28
        os.mkdir(save_dir)
29
    logger       = CSVLogger(os.path.join(save_dir, 'log.txt'))
30
    checkpointer = ModelCheckpoint(filepath = os.path.join(save_dir, 'model.hdf5'), verbose=1, save_best_only=True)
31
    tensorboard  = TensorBoard(os.path.join(save_dir, 'tensorboard'))
32
    callbacks    = [logger, checkpointer, tensorboard]        
33
    
34
    # building the model
35
    model_input_shape = datagen_train.data_shape[1:]
36
    model             = unet_model(model_input_shape, modified_unet, lr, start_chs, levels)
37
    # training the model
38
    model.fit_generator(datagen_train, epochs=epochs, use_multiprocessing=multiprocessing, 
39
                        callbacks=callbacks, validation_data = datagen_val)
40
41
42
   
43
if __name__ == '__main__':
44
    
45
    
46
    train_model(cfg['hdf5_dir'], cfg['brains_idx_dir'], cfg['view'], cfg['modified_unet'], cfg['batch_size'], 
47
                cfg['val_batch_size'], cfg['lr'], cfg['epochs'], cfg['hor_flip'], cfg['ver_flip'], cfg['zoom_range'], 
48
                cfg['save_dir'], cfg['start_chs'], cfg['levels'], cfg['multiprocessing'], 
49
                cfg['load_model_dir'])
50
    
51
    
52
    
53
    
54