--- a
+++ b/testvis.py
@@ -0,0 +1,183 @@
+"""
+This code is to test NN model and visualize output
+"""
+import numpy as np
+import sys
+import time
+import matplotlib.pyplot as plt
+
+from keras.models import Model, load_model
+from keras.layers import Input, Activation, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, UpSampling2D, ZeroPadding2D, BatchNormalization
+from keras.optimizers import Adam, SGD
+from keras.callbacks import ModelCheckpoint
+from keras import backend as K
+import tensorflow as tf
+
+from data import load_train_data, load_test_data
+from utils import *
+
+K.set_image_data_format('channels_last')  # Tensorflow dimension ordering
+
+data_path  = sys.argv[1] + "/"
+model_path = data_path + "models/"
+
+# dir for storing results that contains
+rst_path = data_path + "test-records/"
+if not os.path.exists(rst_path):
+    os.makedirs(rst_path)
+
+model_to_test = sys.argv[2]
+cur_fold = sys.argv[3]
+plane = sys.argv[4]
+im_z = int(sys.argv[5])
+im_y = int(sys.argv[6])
+im_x = int(sys.argv[7])
+high_range = float(sys.argv[8])
+low_range = float(sys.argv[9])
+margin = int(sys.argv[10])
+vis = sys.argv[11]
+
+# prediction of trained model
+pred_path = os.path.join(rst_path, "pred-%s/"%cur_fold)
+if not os.path.exists(pred_path):
+    os.makedirs(pred_path)
+
+"""
+Dice Ceofficient and Cost functions for training
+"""
+smooth = 1.
+
+def dice_coef(y_true, y_pred):
+    y_true_f = K.flatten(y_true)
+    y_pred_f = K.flatten(y_pred)
+    intersection = K.sum(y_true_f * y_pred_f)
+    return (2.0 * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
+
+def dice_coef_loss(y_true, y_pred):
+    return  -dice_coef(y_true, y_pred)
+
+
+def test(model_to_test, current_fold, plane, rst_dir, vis):
+    print "-"*50
+    print "loading model ", model_to_test
+    print "-"*50
+
+    model = load_model(model_path + model_to_test + '.h5', custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef':dice_coef})
+    volume_list = open(testing_set_filename(current_fold), 'r').read().splitlines()
+    total = len(volume_list)
+
+    dsc = np.zeros((total, 2))
+
+    # iterate all test cases
+    for i in range(total):
+        s = volume_list[i].split(' ')
+        image = np.load(s[1])
+        label = np.load(s[2])
+
+        case_num = s[1].split("00")[1].split(".")[0]
+        print "testing case: ", case_num
+
+        image_ = np.transpose(image, (2, 0, 1))
+        label_ = np.transpose(label, (2, 0, 1))
+
+	    # standardize test data
+        image_[image_ < low_range] = low_range
+        image_[image_ > high_range] = high_range
+        image_ = (image_ - low_range) / float(high_range - low_range)
+
+        # for creating final prediction visualization
+        pred = np.zeros_like(image_)
+
+        for sli in range(label_.shape[0]):
+            try:
+                # crop each slice according to smallest bounding box of each slice
+                width = label_[sli].shape[0]
+                height = label_[sli].shape[1]
+
+                arr = np.nonzero(label_[sli])
+
+                if len(arr[0]) == 0:
+        	        continue
+
+                minA = min(arr[0])
+                maxA = max(arr[0])
+                minB = min(arr[1])
+                maxB = max(arr[1])
+
+                minAdiff = margin
+                maxAdiff = margin
+                minBdiff = margin
+                maxBdiff = margin
+
+                cropped = image_[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), \
+                        max(minB - minBdiff, 0): min(maxB + maxBdiff + 1, height)]
+                cropped_mask = label_[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), \
+                        max(minB - minBdiff, 0): min(maxB + maxBdiff + 1, height)]
+
+                image_padded_ = pad_2d(cropped, plane, 0, im_x, im_y, im_z)
+                mask_padded_ = pad_2d(cropped_mask, plane, 0, im_x, im_y, im_z)
+
+                image_padded_prep = preprocess_front(preprocess(image_padded_))
+
+                out_ori = (model.predict(image_padded_prep) > 0.5).astype(np.uint8)
+
+                out = out_ori[:,0:cropped.shape[0], 0:cropped.shape[1],:].reshape(cropped.shape)
+                pred[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), max(minB - minBdiff, 0): min(maxB + maxBdiff+ 1, height)] = out
+                pred_vis = pred[sli, max(minA - minAdiff, 0): min(maxA + maxAdiff + 1, width), max(minB - minBdiff, 0): min(maxB + maxBdiff+ 1, height)]
+
+                if vis == "true":
+                    fig = plt.figure()
+                    ax = fig.add_subplot(1, 3, 1)
+                    ax.set_title("input test image")
+                    ax.imshow(cropped, cmap=plt.cm.gray)
+
+                    ax = fig.add_subplot(1, 3, 2)
+                    ax.set_title("prediction")
+                    ax.imshow(pred_vis, cmap=plt.cm.gray)
+
+                    ax = fig.add_subplot(1, 3, 3)
+                    ax.set_title("ground truth")
+                    ax.imshow(cropped_mask, cmap=plt.cm.gray)
+
+                    # plt.suptitle("slice %s"%sli)
+                    fig.canvas.set_window_title("slice %s"%sli)
+                    plt.axis('off')
+                    plt.show()
+
+            except KeyboardInterrupt:
+                print 'KeyboardInterrupt caught'
+                raise ValueError("terminate because of keyboard interruption")
+
+        # ------------ write out for visualization ---------------
+        np.save(pred_path + case_num + ".npy", pred) # prediction made by the trained model
+
+	    # compute DSC
+        cur_dsc, _, _, _ = DSC_computation(label_, pred)
+        print cur_dsc
+
+        dsc[i][0] = case_num
+        dsc[i][1] = cur_dsc
+
+    dsc_mean = np.mean(dsc[:,1])
+    dsc_std = np.std(dsc[:,1])
+
+    # record test dsc mean and standard deviation for each fold in the one file
+    fd = open(rst_path + 'test_stats.csv','a+')
+    fd.write("%s,%s,%s,%s\n"%(cur_fold, model_to_test, dsc_mean, dsc_std))
+    fd.close()
+
+    print "---------------------------------"
+    print "mean: ", dsc_mean
+    print "std: ", dsc_std
+
+    # record test result case by case
+    np.savetxt(rst_path + model_to_test + ".csv", dsc, fmt = "%i, %.5f", delimiter=",", header="case_num,DSC")
+
+
+if __name__ == "__main__":
+
+    start_time = time.time()
+
+    test(model_to_test, cur_fold, plane, rst_path, vis)
+
+    print "-----------test done, total time used: %s ------------"% (time.time() - start_time)