Diff of /utils/test_patch.py [000000] .. [903821]

Switch to side-by-side view

--- a
+++ b/utils/test_patch.py
@@ -0,0 +1,240 @@
+import h5py
+import math
+import nibabel as nib
+import numpy as np
+from medpy import metric
+import torch
+import torch.nn.functional as F
+from tqdm import tqdm
+from skimage.measure import label
+
+def getLargestCC(segmentation):
+    labels = label(segmentation)
+    assert( labels.max() != 0 ) # assume at least 1 CC
+    largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
+    return largestCC
+
+def var_all_case(model, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, dataset_name="LA"):
+    if dataset_name == "LA":
+        p = '/data/omnisky/postgraduate/Yb/data_set/LASet'
+        with open(p+'/test.list', 'r') as f:
+            image_list = f.readlines()
+        image_list = [p+"/2018LA_Seg_Training Set/" + item.replace('\n', '') + "/mri_norm2.h5" for item in image_list]
+    elif dataset_name == "Pancreas_CT":
+        with open('./data/Pancreas/test.list', 'r') as f:
+            image_list = f.readlines()
+        image_list = ["./data/Pancreas/Pancreas_h5/" + item.replace('\n', '') + "_norm.h5" for item in image_list]
+    loader = tqdm(image_list)
+    total_dice = 0.0
+    for image_path in loader:
+        h5f = h5py.File(image_path, 'r')
+        image = h5f['image'][:]
+        label = h5f['label'][:]
+        prediction, score_map = test_single_case_first_output(model, image, stride_xy, stride_z, patch_size, num_classes=num_classes)
+        if np.sum(prediction)==0:
+            dice = 0
+        else:
+            dice = metric.binary.dc(prediction, label)
+        total_dice += dice
+    avg_dice = total_dice / len(image_list)
+    print('average metric is {}'.format(avg_dice))
+    return avg_dice
+
+def test_all_case(model_name, num_outputs, model, image_list, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, preproc_fn=None, metric_detail=1, nms=0):
+
+    loader = tqdm(image_list) if not metric_detail else image_list
+    ith = 0
+    total_metric = 0.0
+    total_metric_average = 0.0
+    for image_path in loader:
+        h5f = h5py.File(image_path, 'r')
+        image = h5f['image'][:]
+        label = h5f['label'][:]
+        if preproc_fn is not None:
+            image = preproc_fn(image)
+        prediction, score_map = test_single_case_first_output(model, image, stride_xy, stride_z, patch_size, num_classes=num_classes)
+        if num_outputs > 1:
+            prediction_average, score_map_average = test_single_case_average_output(model, image, stride_xy, stride_z, patch_size, num_classes=num_classes)
+        if nms:
+            prediction = getLargestCC(prediction)
+            if num_outputs > 1:
+                 prediction_average = getLargestCC(prediction_average)
+            
+        if np.sum(prediction)==0:
+            single_metric = (0,0,0,0)
+            if num_outputs > 1:
+                single_metric_average = (0,0,0,0)
+        else:
+            single_metric = calculate_metric_percase(prediction, label[:])
+            if num_outputs > 1:
+                single_metric_average  = calculate_metric_percase(prediction_average, label[:])
+            
+        if metric_detail:
+            print('%02d,\t%.5f, %.5f, %.5f, %.5f' % (ith, single_metric[0], single_metric[1], single_metric[2], single_metric[3]))
+            if num_outputs > 1:
+                print('%02d,\t%.5f, %.5f, %.5f, %.5f' % (ith, single_metric_average[0], single_metric_average[1], single_metric_average[2], single_metric_average[3]))
+        
+        total_metric += np.asarray(single_metric)
+        if num_outputs > 1:
+            total_metric_average += np.asarray(single_metric_average) 
+        
+        if save_result:
+            nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path +  "%02d_pred.nii.gz" % ith)
+            nib.save(nib.Nifti1Image(score_map[0].astype(np.float32), np.eye(4)), test_save_path +  "%02d_scores.nii.gz" % ith)
+            if num_outputs > 1:
+                nib.save(nib.Nifti1Image(prediction_average.astype(np.float32), np.eye(4)), test_save_path +  "%02d_pred_average.nii.gz" % ith)
+                nib.save(nib.Nifti1Image(score_map_average[0].astype(np.float32), np.eye(4)), test_save_path +  "%02d_scores_average.nii.gz" % ith)
+            nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path +  "%02d_img.nii.gz" % ith)
+            nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path +  "%02d_gt.nii.gz" % ith)
+        
+        ith += 1
+
+    avg_metric = total_metric / len(image_list)
+    print('average metric is decoder 1 {}'.format(avg_metric))
+    if num_outputs > 1:
+        avg_metric_average = total_metric_average / len(image_list)
+        print('average metric of all decoders is {}'.format(avg_metric_average))
+    
+    with open(test_save_path+'../{}_performance.txt'.format(model_name), 'w') as f:
+        f.writelines('average metric of decoder 1 is {} \n'.format(avg_metric))
+        if num_outputs > 1:
+            f.writelines('average metric of all decoders is {} \n'.format(avg_metric_average))
+    return avg_metric
+
+
+def test_single_case_first_output(model, image, stride_xy, stride_z, patch_size, num_classes=1):
+    w, h, d = image.shape
+
+    # if the size of image is less than patch_size, then padding it
+    add_pad = False
+    if w < patch_size[0]:
+        w_pad = patch_size[0]-w
+        add_pad = True
+    else:
+        w_pad = 0
+    if h < patch_size[1]:
+        h_pad = patch_size[1]-h
+        add_pad = True
+    else:
+        h_pad = 0
+    if d < patch_size[2]:
+        d_pad = patch_size[2]-d
+        add_pad = True
+    else:
+        d_pad = 0
+    wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2
+    hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2
+    dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2
+    if add_pad:
+        image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0)
+    ww,hh,dd = image.shape
+
+    sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
+    sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
+    sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
+    # print("{}, {}, {}".format(sx, sy, sz))
+    score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32)
+    cnt = np.zeros(image.shape).astype(np.float32)
+
+    for x in range(0, sx):
+        xs = min(stride_xy*x, ww-patch_size[0])
+        for y in range(0, sy):
+            ys = min(stride_xy * y,hh-patch_size[1])
+            for z in range(0, sz):
+                zs = min(stride_z * z, dd-patch_size[2])
+                test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]
+                test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32)
+                test_patch = torch.from_numpy(test_patch).cuda()
+
+                with torch.no_grad():
+                    y = model(test_patch)
+                    if len(y) > 1:
+                        y = y[0]
+                    y = F.softmax(y, dim=1)
+                y = y.cpu().data.numpy()
+                y = y[0,1,:,:,:]
+                score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
+                  = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y
+                cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
+                  = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
+
+    score_map = score_map/np.expand_dims(cnt,axis=0)
+    label_map = (score_map[0]>0.5).astype(np.int)
+    if add_pad:
+        label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
+        score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
+    return label_map, score_map
+
+def test_single_case_average_output(net, image, stride_xy, stride_z, patch_size, num_classes=1):
+    w, h, d = image.shape
+
+    # if the size of image is less than patch_size, then padding it
+    add_pad = False
+    if w < patch_size[0]:
+        w_pad = patch_size[0]-w
+        add_pad = True
+    else:
+        w_pad = 0
+    if h < patch_size[1]:
+        h_pad = patch_size[1]-h
+        add_pad = True
+    else:
+        h_pad = 0
+    if d < patch_size[2]:
+        d_pad = patch_size[2]-d
+        add_pad = True
+    else:
+        d_pad = 0
+    wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2
+    hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2
+    dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2
+    if add_pad:
+        image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0)
+    ww,hh,dd = image.shape
+
+    sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
+    sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
+    sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
+    # print("{}, {}, {}".format(sx, sy, sz))
+    score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32)
+    cnt = np.zeros(image.shape).astype(np.float32)
+
+    for x in range(0, sx):
+        xs = min(stride_xy*x, ww-patch_size[0])
+        for y in range(0, sy):
+            ys = min(stride_xy * y,hh-patch_size[1])
+            for z in range(0, sz):
+                zs = min(stride_z * z, dd-patch_size[2])
+                test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]
+                test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32)
+                test_patch = torch.from_numpy(test_patch).cuda()
+
+                with torch.no_grad():
+                    y_logit = net(test_patch)
+                    num_outputs = len(y_logit)
+                    y=torch.zeros(y_logit[0].shape).cuda()
+                    for idx in range(num_outputs):
+                        y += y_logit[idx]
+                    y/=num_outputs
+                    
+                y = y.cpu().data.numpy()
+                y = y[0,1,:,:,:]
+                score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
+                  = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y
+                cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
+                  = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
+
+    score_map = score_map/np.expand_dims(cnt,axis=0)
+    label_map = (score_map[0]>0.5).astype(np.int)
+    if add_pad:
+        label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
+        score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
+    return label_map, score_map
+
+def calculate_metric_percase(pred, gt):
+    dice = metric.binary.dc(pred, gt)
+    jc = metric.binary.jc(pred, gt)
+    hd = metric.binary.hd95(pred, gt)
+    asd = metric.binary.asd(pred, gt)
+
+    return dice, jc, hd, asd