Diff of /train_msk_seg.py [000000] .. [06a92b]

Switch to side-by-side view

--- a
+++ b/train_msk_seg.py
@@ -0,0 +1,149 @@
+# Authors:
+# Akshay Chaudhari and Zhongnan Fang
+# May 2018
+# akshaysc@stanford.edu
+
+from __future__ import print_function, division
+
+import numpy as np
+import pickle
+import math
+import os
+
+from keras.optimizers import Adam
+from keras import backend as K
+import keras.callbacks as kc  
+
+from keras.callbacks import ModelCheckpoint, History
+from keras.callbacks import LambdaCallback as lcb
+from keras.callbacks import LearningRateScheduler as lrs
+from keras.callbacks import TensorBoard as tfb
+
+from utils.generator_msk_seg import calc_generator_info, img_generator_oai
+from utils.models import unet_2d_model
+from utils.losses import dice_loss
+
+# Training and validation data locations
+train_path = '/bmrNAS/people/akshay/dl/oai_data/unet_2d/train_aug/'
+valid_path = '/bmrNAS/people/akshay/dl/oai_data/unet_2d/valid/'
+test_path  = '/bmrNAS/people/akshay/dl/oai_data/unet_2d/test'
+train_batch_size = 35
+valid_batch_size = 35
+
+# Locations and names for saving training checkpoints
+cp_save_path = '/bmrNAS/people/akshay/dl/oai_data/unet_2d/weights'
+cp_save_tag = 'unet_2d_men'
+pik_save_path = './checkpoint/' + cp_save_tag + '.dat'
+
+# Model parameters
+n_epochs = 20
+file_types = ['im']
+# Tissues are in the following order
+# 0. Femoral 1. Lat Tib 2. Med Tib. 3. Pat 4. Lat Men 5. Med Men
+tissue = np.arange(0,1)
+# Load pre-trained model
+model_weights = '/bmrNAS/people/akshay/dl/oai_data/unet_2d/weights/unet_2d_men_weights.009--0.7682.h5'
+
+# training and validation image size
+img_size = (288,288,len(file_types))
+# What dataset are we training on? 'dess' or 'oai'
+tag = 'oai_aug'
+
+# Restrict number of files learned. Default is all []
+learn_files = []
+# Freeze layers in transfer learning
+layers_to_freeze = []
+
+# learning rate schedule
+# Implementing a step decay for now
+def step_decay(epoch):
+    initial_lrate = 1e-4
+    drop = 0.8
+    epochs_drop = 1.0
+    lrate = initial_lrate * math.pow(drop, math.floor((1+epoch)/epochs_drop))
+    return lrate
+
+def train_seg(img_size, train_path, valid_path, train_batch_size, valid_batch_size, 
+                cp_save_path, cp_save_tag, n_epochs, file_types, pik_save_path, 
+                tag, tissue, learn_files, layers_to_freeze):
+
+    # set image format to be (N, dim1, dim2, dim3, ch)
+    K.set_image_data_format('channels_last')
+    train_files, train_nbatches = calc_generator_info(train_path, train_batch_size, learn_files)
+    valid_files, valid_nbatches = calc_generator_info(valid_path, valid_batch_size)
+
+    # Print some useful debugging information
+    print('INFO: Train size: %d, batch size: %d' % (len(train_files), train_batch_size))
+    print('INFO: Valid size: %d, batch size: %d' % (len(valid_files), valid_batch_size))    
+    print('INFO: Image size: %s' % (img_size,))
+    print('INFO: Image types included in training: %s' % (file_types,))    
+    print('INFO: Number of tissues being segmented: %d' % len(tissue))
+    print('INFO: Number of frozen layers: %s' % len(layers_to_freeze))
+
+    # create the unet model
+    model = unet_2d_model(img_size)
+    if model_weights is not None:
+        model.load_weights(model_weights,by_name=True)
+
+    # Set up the optimizer 
+    model.compile(optimizer=Adam(lr=1e-9, beta_1=0.99, beta_2=0.995, epsilon=1e-08, decay=0.0), 
+                  loss=dice_loss)
+
+    # Optinal, but this allows you to freeze layers if you want for transfer learning
+    for lyr in layers_to_freeze:
+        model.layers[lyr].trainable = False
+
+    # model callbacks per epoch
+    cp_cb   = ModelCheckpoint(cp_save_path + '/' + cp_save_tag + '_weights.{epoch:03d}-{val_loss:.4f}.h5',save_best_only=True)
+    tfb_cb  = tfb('./tf_log',
+                  histogram_freq=1,
+                  write_grads=False,
+                  write_images=False)
+    lr_cb   = lrs(step_decay)
+    hist_cb = LossHistory()
+
+    callbacks_list = [tfb_cb, cp_cb, hist_cb, lr_cb]
+
+    # Start the training    
+    model.fit_generator(
+            img_generator_oai(train_path, train_batch_size, img_size, tissue, tag),
+            train_nbatches,
+            epochs=n_epochs,
+            validation_data=img_generator_oai(valid_path, valid_batch_size, img_size, tissue, tag),
+            validation_steps=valid_nbatches,
+            callbacks=callbacks_list)
+
+    # Save files to write as output
+    data = [hist_cb.epoch, hist_cb.lr, hist_cb.losses, hist_cb.val_losses]
+    with open(pik_save_path, "wb") as f:
+        pickle.dump(data, f)
+
+    return hist_cb
+
+
+# Print and asve the training history
+class LossHistory(kc.Callback):
+    def on_train_begin(self, logs={}):
+       self.val_losses = []
+       self.losses = []
+       self.lr = []
+       self.epoch = []
+ 
+    def on_epoch_end(self, batch, logs={}):
+       self.val_losses.append(logs.get('val_loss'))
+       self.losses.append(logs.get('loss'))
+       self.lr.append(step_decay(len(self.losses)))
+       self.epoch.append(len(self.losses))
+
+if __name__ == '__main__':
+  
+    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
+    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
+    
+    model = unet_2d_model(img_size)
+    # print(model.summary())
+    train_seg(img_size, train_path, valid_path, train_batch_size, valid_batch_size, 
+                cp_save_path, cp_save_tag, n_epochs, file_types, pik_save_path, 
+                tag, tissue, learn_files, layers_to_freeze)
+