--- a +++ b/test_msk_seg.py @@ -0,0 +1,114 @@ +# Authors: +# Akshay Chaudhari and Zhongnan Fang +# May 2018 +# akshaysc@stanford.edu + +from __future__ import print_function, division + +import numpy as np +import h5py +import time +import os +import tensorflow as tf +from keras import backend as K + +from utils.generator_msk_seg import calc_generator_info, img_generator_oai +from utils.models import unet_2d_model +from utils.losses import dice_loss_test +import utils.utils_msk_seg as segutils + +# Specify directories +test_result_path = './results/' +test_path = '/bmrNAS/people/akshay/dl/oai_data/unet_2d/test' + +# Tissue Type +tissue = np.arange(4,6) + +# Parameters for the model testing +img_size = (288,288,1) +file_types = ['im'] +test_batch_size = 1 +save_file = False +tag = 'oai_aug' + +# Test with pre-trained weights +model_weights = '/bmrNAS/people/akshay/dl/oai_data/unet_2d/weights/unet_2d_men_weights.012--0.7692.h5' + +def test_seg(test_result_path, test_path, tissue, img_size, + file_types, test_batch_size, save_file, model_weights): + + img_cnt = 0 + + # set image format to be (N, dim1, dim2, dim3, ch) + K.set_image_data_format('channels_last') + + # create the unet model + model = unet_2d_model(img_size) + model.load_weights(model_weights); + + # All of the testing currently assumes that there is a reference to test against. + # Comment out these lines if testing on reference-les data + dice_losses = np.array([]) + cv_values = np.array([]) + voe_values = np.array([]) + vd_values = np.array([]) + + start = time.time() + + # Read the files that will be segmented + test_files,ntest = calc_generator_info(test_path, test_batch_size) + print('INFO: Test size: %d, Number of batches: %d' % (len(test_files), ntest)) + + # Iterature through the files to be segmented + for x_test, y_test, fname in img_generator_oai(test_path, test_batch_size, + img_size, tissue, tag, + testing= True, shuffle_epoch=False): + + # Perform the actual segmentation using pre-loaded model + recon = model.predict(x_test, batch_size = test_batch_size) + + # Calculate real time metrics + dl = np.mean(segutils.calc_dice(recon,y_test)) + dice_losses = np.append(dice_losses,dl) + + cv = np.mean(segutils.calc_cv(recon,y_test)) + cv_values = np.append(cv_values,cv) + + voe = np.mean(segutils.calc_voe(y_test, recon)) + voe_values = np.append(voe_values,voe) + + vd = np.mean(segutils.calc_vd(y_test, recon)) + vd_values = np.append(vd_values,vd) + + # print('Image #%0.2d (%s). Dice = %0.3f CV = %2.1f VOE = %2.1f VD = %2.1f' % ( img_cnt, fname[0:11], dl, cv, voe, vd) ) + + # Write output file per batch + if save_file is True: + save_name = '%s/%s.pred' %(test_result_path,fname) + with h5py.File(save_name,'w') as h5f: + h5f.create_dataset('recon',data=recon) + + img_cnt += 1 + if img_cnt == ntest: + break + + end = time.time() + + # Print some summary statistics + print('--'*20) + print('Overall Summary:') + print('Dice Mean= %0.4f Std = %0.3f' % (np.mean(dice_losses) , np.std(dice_losses) )) + print('CV Mean= %0.4f Std = %0.3f' % (np.mean(cv_values) , np.std(cv_values) )) + print('VOE Mean= %0.4f Std = %0.3f' % (np.mean(voe_values) , np.std(voe_values) )) + print('VD Mean= %0.4f Std = %0.3f' % (np.mean(vd_values) , np.std(vd_values) )) + print('Time required = %0.1f seconds.' % (end-start)) + print('--'*20) + +if __name__ == '__main__': + + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + + test_seg(test_result_path, test_path, tissue, img_size, + file_types, test_batch_size, save_file, model_weights) \ No newline at end of file