--- a
+++ b/tools/test_df_vis.py
@@ -0,0 +1,243 @@
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+
+import os
+import argparse
+import numpy as np
+import cv2
+import logging
+import math
+
+import _init_paths
+from libs.network import U_Net, U_NetDF 
+from libs.datasets import AcdcDataset
+import libs.datasets.joint_augment as joint_augment
+import libs.datasets.augment as standard_augment
+from libs.datasets.collate_batch import BatchCollator
+
+from libs.configs.config_acdc import cfg
+from train_utils.train_utils import load_checkpoint
+from utils.metrics import dice
+from utils.vis_utils import mask2png, masks_to_contours, apply_mask, img_mask_png
+
+parser = argparse.ArgumentParser(description="arg parser")
+parser.add_argument('--used_df', type=str, default=False, help='whether to use df')
+parser.add_argument('--selfeat', action='store_true', default=False, help='whether to use feature select')
+parser.add_argument('--mgpus', type=str, default=None, required=True, help='whether to use multiple gpu')
+parser.add_argument('--model_path1', type=str, default=None, help='whether to train with evaluation')
+parser.add_argument('--model_path2', type=str, default=None, help='whether to train with evaluation')
+parser.add_argument('--batch_size', type=int, default=1, help='batch size for evaluation')
+parser.add_argument('--workers', type=int, default=4, help='number of workers for dataloader')
+parser.add_argument('--output_dir', type=str, default=None, required=True, help='specify an output directory if needed')
+parser.add_argument('--log_file', type=str, default="../log_evalation.txt", help="the file to write logging")
+parser.add_argument('--vis', action='store_true', default=False, help="weather to save test result images")
+args = parser.parse_args()
+
+if args.mgpus is not None:
+    os.environ["CUDA_VISIBLE_DEVICES"] = args.mgpus
+
+def create_logger(log_file):
+    log_format = '%(asctime)s  %(levelname)5s  %(message)s'
+    logging.basicConfig(level=logging.DEBUG, format=log_format, filename=log_file)
+    console = logging.StreamHandler()
+    console.setLevel(logging.DEBUG)
+    console.setFormatter(logging.Formatter(log_format))
+    logging.getLogger(__name__).addHandler(console)
+    return logging.getLogger(__name__)
+
+def create_dataloader():
+    eval_transform = joint_augment.Compose([
+                    joint_augment.To_PIL_Image(),
+                    joint_augment.FixResize(256),
+                    joint_augment.To_Tensor()])
+    evalImg_transform = standard_augment.Compose([
+                        standard_augment.normalize([cfg.DATASET.MEAN], [cfg.DATASET.STD])])
+
+    if cfg.DATASET.NAME == "acdc":
+        test_set = AcdcDataset(cfg.DATASET.TEST_LIST, df_used=True, joint_augment=eval_transform,
+                            augment=evalImg_transform)
+
+    test_loader = DataLoader(test_set, batch_size=1, pin_memory=True,
+                             num_workers=args.workers, shuffle=False,
+                             collate_fn=BatchCollator(size_divisible=32, df_used=True))
+    return test_loader, test_set
+
+def cal_perfer(preds, masks, tb_dict):
+    LV_dice = []  # 1
+    MYO_dice = []  # 2
+    RV_dice = []  # 3
+
+    for i in range(preds.shape[0]):
+        LV_dice.append(dice(preds[i,1,:,:],masks[i,1,:,:]))
+        RV_dice.append(dice(preds[i, 3, :, :], masks[i, 3, :, :]))
+        MYO_dice.append(dice(preds[i, 2, :, :], masks[i, 2, :, :]))
+        # LV_dice.append(dice(preds[i, 3,:,:],masks[i,1,:,:]))
+        # RV_dice.append(dice(preds[i, 1, :, :], masks[i, 3, :, :]))
+        # MYO_dice.append(dice(preds[i, 2, :, :], masks[i, 2, :, :]))
+    
+    tb_dict["LV_dice"].append(np.mean(LV_dice))
+    tb_dict["RV_dice"].append(np.mean(RV_dice))
+    tb_dict["MYO_dice"].append(np.mean(MYO_dice))
+    return np.mean(LV_dice), np.mean(RV_dice), np.mean(MYO_dice)
+
+def make_one_hot(input, num_classes):
+    """Convert class index tensor to one hot encoding tensor.
+    Args:
+         input: A tensor of shape [N, 1, *]
+         num_classes: An int of number of class
+    Returns:
+        A tensor of shape [N, num_classes, *]
+    """
+    shape = np.array(input.shape)
+    shape[1] = num_classes
+    shape = tuple(shape)
+    result = torch.zeros(shape).scatter_(1, input.cpu().long(), 1)
+    # result = result.scatter_(1, input.cpu(), 1)
+
+    return result
+
+def test_it(model, data, device="cuda"):
+    model.eval()
+    imgs, gts = data[:2]
+    gts_df = data[2]
+
+    imgs = imgs.to(device)
+    gts = gts.to(device)
+
+    net_out = model(imgs)
+    if len(net_out) > 1:
+        preds_out = net_out[0]
+        preds_df = net_out[1]
+    else:
+        preds_out = net_out[0]
+        preds_df = None
+    preds_out = nn.functional.softmax(preds_out, dim=1)
+    _, preds = torch.max(preds_out, 1)
+    preds = preds.unsqueeze(1)  # (N, 1, *)
+
+    return preds, preds_df
+
+def vis_it(pred, gt, img=None, filename=None, infos=None):
+    h, w = pred.shape
+    # gt_contours = masks_to_contours(gt)
+    # mask = np.hstack([pred, np.zeros((h, 1)), gt])
+    # gt_contours = np.hstack([gt_contours, np.zeros((h, 1)), np.zeros_like(gt)])
+    # im_rgb = mask2png(mask).astype(np.int16)
+    # im_rgb[:, w, :] = [255, 255, 255]
+    # im_rgb = apply_mask(im_rgb, gt_contours, [255, 255, 255], 0.8)
+    pred_im = mask2png(pred).astype(np.int16)
+    gt_im = mask2png(gt).astype(np.int16)
+
+    img = (img - img.min()) / (img.max() - img.min()) * 255
+    img = np.stack([img, img, img], axis=2)
+    # img = img_mask_png(img, gt, alpha=0.1)
+
+    # im_rgb = np.hstack([im_rgb, 255*np.ones((h, 1, 3)), img])
+
+    # cv2.putText(im_rgb, "prediction", (2,h-4),
+    #             fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.4, color=(255, 255, 255), thickness=1)
+    # cv2.putText(im_rgb, "ground truth", (w, h-4),
+    #             fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.4, color=(255, 255, 255), thickness=1)
+
+    # st_pos = 15
+    # if infos is not None:
+    #     for info in infos:
+    #         cv2.putText(im_rgb, info+": {}".format(infos[info]), (2, st_pos),
+    #                     fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.4, color=(255, 255, 255), thickness=1)
+    #         st_pos += 10
+
+    cv2.imwrite(filename+"_img.png", img[:,:,::-1])
+    cv2.imwrite(filename+"_pred.png", pred_im[:,:,::-1])
+    cv2.imwrite(filename+"_gt.png", gt_im[:,:,::-1])
+
+def vis_df(pred_df, gt_df, filename, infos=None):
+    _, h, w = pred_df.shape
+
+    # save .npy files
+    np.save(filename+'.npy', [pred_df, gt_df])
+    
+    theta = np.arctan2(gt_df[1,...], gt_df[0,...])
+    degree_gt = (theta - theta.min()) / (theta.max() - theta.min()) * 255
+    # degree_gt = theta * 360
+    mag_gt = np.sum(gt_df ** 2, axis=0, keepdims=False)
+    mag_gt = mag_gt / mag_gt.max() * 255
+
+    theta = np.arctan2(pred_df[1,...], pred_df[0,...])
+    degree_df = (theta - theta.min()) / (theta.max() - theta.min()) * 255
+    # degree_df = theta * 360
+    magnitude = np.sum(pred_df ** 2, axis=0, keepdims=False)
+    magnitude = magnitude / magnitude.max() * 255
+
+    im = np.hstack([magnitude, np.zeros((h, 1)), mag_gt, np.zeros((h, 1)), degree_df, np.zeros((h, 1)), degree_gt]).astype(np.uint8)
+    im = cv2.applyColorMap(im, cv2.COLORMAP_JET)
+    cv2.imwrite(filename+"_df_pred_mag.png", im[:h, :w, ...])
+    cv2.imwrite(filename+"_df_gt_mag.png", im[:h, w+1:2*w+1, ...])
+    cv2.imwrite(filename+"_df_pred_degree.png", im[:h, 2*w+2:3*w+2, ...])
+    cv2.imwrite(filename+"_df_gt_degree.png", im[:h, 3*w+3:, ...])
+
+
+def test():
+    root_result_dir = args.output_dir
+    os.makedirs(root_result_dir, exist_ok=True)
+
+    log_file = os.path.join(root_result_dir, args.log_file)
+    logger = create_logger(log_file)
+
+    for key, val in vars(args).items():
+        logger.info("{:16} {}".format(key, val))
+
+    # create dataset & dataloader & network
+    if args.used_df == 'U_NetDF':
+        model = U_NetDF(selfeat=args.selfeat, num_class=4, auxseg=True)
+    elif args.used_df == 'U_Net':
+        model = U_Net(num_class=4)
+
+    if args.mgpus is not None and len(args.mgpus) > 2:
+        model = nn.DataParallel(model)
+    model.cuda()
+
+    test_loader, test_set = create_dataloader()
+    
+    checkpoint = torch.load(args.model_path1, map_location='cpu')
+    model.load_state_dict(checkpoint['model_state'])
+
+    dice_dict = {"LV_dice": [],
+                    "RV_dice": [],
+                    "MYO_dice": []}
+    for i, data in enumerate(test_loader):
+        if i != 23: continue
+    # i = 5405
+    # data = test_set[5405]
+        # data = [data[0][None], data[1][None], data[2][None]]
+
+        pred, pred_df = test_it(model, data[:3])
+
+        _, gt, gt_df = data[:3]
+        gt = gt.to("cuda")
+
+        L, R, MYO = cal_perfer(make_one_hot(pred, 4), make_one_hot(gt, 4), dice_dict)
+
+        data_info = test_set.data_infos[i]
+        if args.vis:
+        # if 0.7 <= (L + R + MYO) / 3 < 0.8:
+            vis_it(pred.cpu().numpy()[0, 0], gt.cpu().numpy()[0, 0], data[0].cpu().numpy()[0, 0],
+                    filename=os.path.join(root_result_dir, str(i)))
+            if pred_df is not None:
+                vis_df(pred_df.detach().cpu().numpy()[0], gt_df.cpu().numpy()[0], 
+                        filename=os.path.join(root_result_dir, str(i)))
+
+        print("\r{}/{} {:.0%}   {}".format(i, len(test_set), i/len(test_set), 
+                                np.mean(list(dice_dict.values()))), end="")
+    print()
+
+    logger.info("2D Dice Metirc:")
+    logger.info("Total {}".format(len(test_set)))
+    logger.info("LV_dice: {}".format(np.mean(dice_dict["LV_dice"])))
+    logger.info("RV_dice: {}".format(np.mean(dice_dict["RV_dice"])))
+    logger.info("MYO_dice: {}".format(np.mean(dice_dict["MYO_dice"])))
+    logger.info("Mean_dice: {}".format(np.mean(list(dice_dict.values()))))
+
+
+if __name__ == "__main__":
+    test()
\ No newline at end of file