--- a +++ b/testvis.py @@ -0,0 +1,183 @@ +""" +This code is to test NN model and visualize output +""" +import numpy as np +import sys +import time +import matplotlib.pyplot as plt + +from keras.models import Model, load_model +from keras.layers import Input, Activation, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, UpSampling2D, ZeroPadding2D, BatchNormalization +from keras.optimizers import Adam, SGD +from keras.callbacks import ModelCheckpoint +from keras import backend as K +import tensorflow as tf + +from data import load_train_data, load_test_data +from utils import * + +K.set_image_data_format('channels_last') # Tensorflow dimension ordering + +data_path = sys.argv[1] + "/" +model_path = data_path + "models/" + +# dir for storing results that contains +rst_path = data_path + "test-records/" +if not os.path.exists(rst_path): + os.makedirs(rst_path) + +model_to_test = sys.argv[2] +cur_fold = sys.argv[3] +plane = sys.argv[4] +im_z = int(sys.argv[5]) +im_y = int(sys.argv[6]) +im_x = int(sys.argv[7]) +high_range = float(sys.argv[8]) +low_range = float(sys.argv[9]) +margin = int(sys.argv[10]) +vis = sys.argv[11] + +# prediction of trained model +pred_path = os.path.join(rst_path, "pred-%s/"%cur_fold) +if not os.path.exists(pred_path): + os.makedirs(pred_path) + +""" +Dice Ceofficient and Cost functions for training +""" +smooth = 1. + +def dice_coef(y_true, y_pred): + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + intersection = K.sum(y_true_f * y_pred_f) + return (2.0 * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) + +def dice_coef_loss(y_true, y_pred): + return -dice_coef(y_true, y_pred) + + +def test(model_to_test, current_fold, plane, rst_dir, vis): + print "-"*50 + print "loading model ", model_to_test + print "-"*50 + + model = load_model(model_path + model_to_test + '.h5', custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef':dice_coef}) + volume_list = open(testing_set_filename(current_fold), 'r').read().splitlines() + total = len(volume_list) + + dsc = np.zeros((total, 2)) + + # iterate all test cases + for i in range(total): + s = volume_list[i].split(' ') + image = np.load(s[1]) + label = np.load(s[2]) + + case_num = s[1].split("00")[1].split(".")[0] + print "testing case: ", case_num + + image_ = np.transpose(image, (2, 0, 1)) + label_ = np.transpose(label, (2, 0, 1)) + + # standardize test data + image_[image_ < low_range] = low_range + image_[image_ > high_range] = high_range + image_ = (image_ - low_range) / float(high_range - low_range) + + # for creating final prediction visualization + pred = np.zeros_like(image_) + + for sli in range(label_.shape[0]): + try: + # crop each slice according to smallest bounding box of each slice + width = label_[sli].shape[0] + height = label_[sli].shape[1] + + arr = np.nonzero(label_[sli]) + + if len(arr[0]) == 0: + continue + + minA = min(arr[0]) + maxA = max(arr[0]) + minB = min(arr[1]) + maxB = max(arr[1]) + + minAdiff = margin + maxAdiff = margin + minBdiff = margin + maxBdiff = margin + + cropped = image_[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), \ + max(minB - minBdiff, 0): min(maxB + maxBdiff + 1, height)] + cropped_mask = label_[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), \ + max(minB - minBdiff, 0): min(maxB + maxBdiff + 1, height)] + + image_padded_ = pad_2d(cropped, plane, 0, im_x, im_y, im_z) + mask_padded_ = pad_2d(cropped_mask, plane, 0, im_x, im_y, im_z) + + image_padded_prep = preprocess_front(preprocess(image_padded_)) + + out_ori = (model.predict(image_padded_prep) > 0.5).astype(np.uint8) + + out = out_ori[:,0:cropped.shape[0], 0:cropped.shape[1],:].reshape(cropped.shape) + pred[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), max(minB - minBdiff, 0): min(maxB + maxBdiff+ 1, height)] = out + pred_vis = pred[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), max(minB - minBdiff, 0): min(maxB + maxBdiff+ 1, height)] + + if vis == "true": + fig = plt.figure() + ax = fig.add_subplot(1, 3, 1) + ax.set_title("input test image") + ax.imshow(cropped, cmap=plt.cm.gray) + + ax = fig.add_subplot(1, 3, 2) + ax.set_title("prediction") + ax.imshow(pred_vis, cmap=plt.cm.gray) + + ax = fig.add_subplot(1, 3, 3) + ax.set_title("ground truth") + ax.imshow(cropped_mask, cmap=plt.cm.gray) + + # plt.suptitle("slice %s"%sli) + fig.canvas.set_window_title("slice %s"%sli) + plt.axis('off') + plt.show() + + except KeyboardInterrupt: + print 'KeyboardInterrupt caught' + raise ValueError("terminate because of keyboard interruption") + + # ------------ write out for visualization --------------- + np.save(pred_path + case_num + ".npy", pred) # prediction made by the trained model + + # compute DSC + cur_dsc, _, _, _ = DSC_computation(label_, pred) + print cur_dsc + + dsc[i][0] = case_num + dsc[i][1] = cur_dsc + + dsc_mean = np.mean(dsc[:,1]) + dsc_std = np.std(dsc[:,1]) + + # record test dsc mean and standard deviation for each fold in the one file + fd = open(rst_path + 'test_stats.csv','a+') + fd.write("%s,%s,%s,%s\n"%(cur_fold, model_to_test, dsc_mean, dsc_std)) + fd.close() + + print "---------------------------------" + print "mean: ", dsc_mean + print "std: ", dsc_std + + # record test result case by case + np.savetxt(rst_path + model_to_test + ".csv", dsc, fmt = "%i, %.5f", delimiter=",", header="case_num,DSC") + + +if __name__ == "__main__": + + start_time = time.time() + + test(model_to_test, cur_fold, plane, rst_path, vis) + + print "-----------test done, total time used: %s ------------"% (time.time() - start_time)