--- a +++ b/train_msk_seg.py @@ -0,0 +1,149 @@ +# Authors: +# Akshay Chaudhari and Zhongnan Fang +# May 2018 +# akshaysc@stanford.edu + +from __future__ import print_function, division + +import numpy as np +import pickle +import math +import os + +from keras.optimizers import Adam +from keras import backend as K +import keras.callbacks as kc + +from keras.callbacks import ModelCheckpoint, History +from keras.callbacks import LambdaCallback as lcb +from keras.callbacks import LearningRateScheduler as lrs +from keras.callbacks import TensorBoard as tfb + +from utils.generator_msk_seg import calc_generator_info, img_generator_oai +from utils.models import unet_2d_model +from utils.losses import dice_loss + +# Training and validation data locations +train_path = '/bmrNAS/people/akshay/dl/oai_data/unet_2d/train_aug/' +valid_path = '/bmrNAS/people/akshay/dl/oai_data/unet_2d/valid/' +test_path = '/bmrNAS/people/akshay/dl/oai_data/unet_2d/test' +train_batch_size = 35 +valid_batch_size = 35 + +# Locations and names for saving training checkpoints +cp_save_path = '/bmrNAS/people/akshay/dl/oai_data/unet_2d/weights' +cp_save_tag = 'unet_2d_men' +pik_save_path = './checkpoint/' + cp_save_tag + '.dat' + +# Model parameters +n_epochs = 20 +file_types = ['im'] +# Tissues are in the following order +# 0. Femoral 1. Lat Tib 2. Med Tib. 3. Pat 4. Lat Men 5. Med Men +tissue = np.arange(0,1) +# Load pre-trained model +model_weights = '/bmrNAS/people/akshay/dl/oai_data/unet_2d/weights/unet_2d_men_weights.009--0.7682.h5' + +# training and validation image size +img_size = (288,288,len(file_types)) +# What dataset are we training on? 'dess' or 'oai' +tag = 'oai_aug' + +# Restrict number of files learned. Default is all [] +learn_files = [] +# Freeze layers in transfer learning +layers_to_freeze = [] + +# learning rate schedule +# Implementing a step decay for now +def step_decay(epoch): + initial_lrate = 1e-4 + drop = 0.8 + epochs_drop = 1.0 + lrate = initial_lrate * math.pow(drop, math.floor((1+epoch)/epochs_drop)) + return lrate + +def train_seg(img_size, train_path, valid_path, train_batch_size, valid_batch_size, + cp_save_path, cp_save_tag, n_epochs, file_types, pik_save_path, + tag, tissue, learn_files, layers_to_freeze): + + # set image format to be (N, dim1, dim2, dim3, ch) + K.set_image_data_format('channels_last') + train_files, train_nbatches = calc_generator_info(train_path, train_batch_size, learn_files) + valid_files, valid_nbatches = calc_generator_info(valid_path, valid_batch_size) + + # Print some useful debugging information + print('INFO: Train size: %d, batch size: %d' % (len(train_files), train_batch_size)) + print('INFO: Valid size: %d, batch size: %d' % (len(valid_files), valid_batch_size)) + print('INFO: Image size: %s' % (img_size,)) + print('INFO: Image types included in training: %s' % (file_types,)) + print('INFO: Number of tissues being segmented: %d' % len(tissue)) + print('INFO: Number of frozen layers: %s' % len(layers_to_freeze)) + + # create the unet model + model = unet_2d_model(img_size) + if model_weights is not None: + model.load_weights(model_weights,by_name=True) + + # Set up the optimizer + model.compile(optimizer=Adam(lr=1e-9, beta_1=0.99, beta_2=0.995, epsilon=1e-08, decay=0.0), + loss=dice_loss) + + # Optinal, but this allows you to freeze layers if you want for transfer learning + for lyr in layers_to_freeze: + model.layers[lyr].trainable = False + + # model callbacks per epoch + cp_cb = ModelCheckpoint(cp_save_path + '/' + cp_save_tag + '_weights.{epoch:03d}-{val_loss:.4f}.h5',save_best_only=True) + tfb_cb = tfb('./tf_log', + histogram_freq=1, + write_grads=False, + write_images=False) + lr_cb = lrs(step_decay) + hist_cb = LossHistory() + + callbacks_list = [tfb_cb, cp_cb, hist_cb, lr_cb] + + # Start the training + model.fit_generator( + img_generator_oai(train_path, train_batch_size, img_size, tissue, tag), + train_nbatches, + epochs=n_epochs, + validation_data=img_generator_oai(valid_path, valid_batch_size, img_size, tissue, tag), + validation_steps=valid_nbatches, + callbacks=callbacks_list) + + # Save files to write as output + data = [hist_cb.epoch, hist_cb.lr, hist_cb.losses, hist_cb.val_losses] + with open(pik_save_path, "wb") as f: + pickle.dump(data, f) + + return hist_cb + + +# Print and asve the training history +class LossHistory(kc.Callback): + def on_train_begin(self, logs={}): + self.val_losses = [] + self.losses = [] + self.lr = [] + self.epoch = [] + + def on_epoch_end(self, batch, logs={}): + self.val_losses.append(logs.get('val_loss')) + self.losses.append(logs.get('loss')) + self.lr.append(step_decay(len(self.losses))) + self.epoch.append(len(self.losses)) + +if __name__ == '__main__': + + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + + model = unet_2d_model(img_size) + # print(model.summary()) + train_seg(img_size, train_path, valid_path, train_batch_size, valid_batch_size, + cp_save_path, cp_save_tag, n_epochs, file_types, pik_save_path, + tag, tissue, learn_files, layers_to_freeze) +