Diff of /tools/test_df_vis.py [000000] .. [98e649]

Switch to unified view

a b/tools/test_df_vis.py
1
import torch
2
import torch.nn as nn
3
from torch.utils.data import DataLoader
4
5
import os
6
import argparse
7
import numpy as np
8
import cv2
9
import logging
10
import math
11
12
import _init_paths
13
from libs.network import U_Net, U_NetDF 
14
from libs.datasets import AcdcDataset
15
import libs.datasets.joint_augment as joint_augment
16
import libs.datasets.augment as standard_augment
17
from libs.datasets.collate_batch import BatchCollator
18
19
from libs.configs.config_acdc import cfg
20
from train_utils.train_utils import load_checkpoint
21
from utils.metrics import dice
22
from utils.vis_utils import mask2png, masks_to_contours, apply_mask, img_mask_png
23
24
parser = argparse.ArgumentParser(description="arg parser")
25
parser.add_argument('--used_df', type=str, default=False, help='whether to use df')
26
parser.add_argument('--selfeat', action='store_true', default=False, help='whether to use feature select')
27
parser.add_argument('--mgpus', type=str, default=None, required=True, help='whether to use multiple gpu')
28
parser.add_argument('--model_path1', type=str, default=None, help='whether to train with evaluation')
29
parser.add_argument('--model_path2', type=str, default=None, help='whether to train with evaluation')
30
parser.add_argument('--batch_size', type=int, default=1, help='batch size for evaluation')
31
parser.add_argument('--workers', type=int, default=4, help='number of workers for dataloader')
32
parser.add_argument('--output_dir', type=str, default=None, required=True, help='specify an output directory if needed')
33
parser.add_argument('--log_file', type=str, default="../log_evalation.txt", help="the file to write logging")
34
parser.add_argument('--vis', action='store_true', default=False, help="weather to save test result images")
35
args = parser.parse_args()
36
37
if args.mgpus is not None:
38
    os.environ["CUDA_VISIBLE_DEVICES"] = args.mgpus
39
40
def create_logger(log_file):
41
    log_format = '%(asctime)s  %(levelname)5s  %(message)s'
42
    logging.basicConfig(level=logging.DEBUG, format=log_format, filename=log_file)
43
    console = logging.StreamHandler()
44
    console.setLevel(logging.DEBUG)
45
    console.setFormatter(logging.Formatter(log_format))
46
    logging.getLogger(__name__).addHandler(console)
47
    return logging.getLogger(__name__)
48
49
def create_dataloader():
50
    eval_transform = joint_augment.Compose([
51
                    joint_augment.To_PIL_Image(),
52
                    joint_augment.FixResize(256),
53
                    joint_augment.To_Tensor()])
54
    evalImg_transform = standard_augment.Compose([
55
                        standard_augment.normalize([cfg.DATASET.MEAN], [cfg.DATASET.STD])])
56
57
    if cfg.DATASET.NAME == "acdc":
58
        test_set = AcdcDataset(cfg.DATASET.TEST_LIST, df_used=True, joint_augment=eval_transform,
59
                            augment=evalImg_transform)
60
61
    test_loader = DataLoader(test_set, batch_size=1, pin_memory=True,
62
                             num_workers=args.workers, shuffle=False,
63
                             collate_fn=BatchCollator(size_divisible=32, df_used=True))
64
    return test_loader, test_set
65
66
def cal_perfer(preds, masks, tb_dict):
67
    LV_dice = []  # 1
68
    MYO_dice = []  # 2
69
    RV_dice = []  # 3
70
71
    for i in range(preds.shape[0]):
72
        LV_dice.append(dice(preds[i,1,:,:],masks[i,1,:,:]))
73
        RV_dice.append(dice(preds[i, 3, :, :], masks[i, 3, :, :]))
74
        MYO_dice.append(dice(preds[i, 2, :, :], masks[i, 2, :, :]))
75
        # LV_dice.append(dice(preds[i, 3,:,:],masks[i,1,:,:]))
76
        # RV_dice.append(dice(preds[i, 1, :, :], masks[i, 3, :, :]))
77
        # MYO_dice.append(dice(preds[i, 2, :, :], masks[i, 2, :, :]))
78
    
79
    tb_dict["LV_dice"].append(np.mean(LV_dice))
80
    tb_dict["RV_dice"].append(np.mean(RV_dice))
81
    tb_dict["MYO_dice"].append(np.mean(MYO_dice))
82
    return np.mean(LV_dice), np.mean(RV_dice), np.mean(MYO_dice)
83
84
def make_one_hot(input, num_classes):
85
    """Convert class index tensor to one hot encoding tensor.
86
    Args:
87
         input: A tensor of shape [N, 1, *]
88
         num_classes: An int of number of class
89
    Returns:
90
        A tensor of shape [N, num_classes, *]
91
    """
92
    shape = np.array(input.shape)
93
    shape[1] = num_classes
94
    shape = tuple(shape)
95
    result = torch.zeros(shape).scatter_(1, input.cpu().long(), 1)
96
    # result = result.scatter_(1, input.cpu(), 1)
97
98
    return result
99
100
def test_it(model, data, device="cuda"):
101
    model.eval()
102
    imgs, gts = data[:2]
103
    gts_df = data[2]
104
105
    imgs = imgs.to(device)
106
    gts = gts.to(device)
107
108
    net_out = model(imgs)
109
    if len(net_out) > 1:
110
        preds_out = net_out[0]
111
        preds_df = net_out[1]
112
    else:
113
        preds_out = net_out[0]
114
        preds_df = None
115
    preds_out = nn.functional.softmax(preds_out, dim=1)
116
    _, preds = torch.max(preds_out, 1)
117
    preds = preds.unsqueeze(1)  # (N, 1, *)
118
119
    return preds, preds_df
120
121
def vis_it(pred, gt, img=None, filename=None, infos=None):
122
    h, w = pred.shape
123
    # gt_contours = masks_to_contours(gt)
124
    # mask = np.hstack([pred, np.zeros((h, 1)), gt])
125
    # gt_contours = np.hstack([gt_contours, np.zeros((h, 1)), np.zeros_like(gt)])
126
    # im_rgb = mask2png(mask).astype(np.int16)
127
    # im_rgb[:, w, :] = [255, 255, 255]
128
    # im_rgb = apply_mask(im_rgb, gt_contours, [255, 255, 255], 0.8)
129
    pred_im = mask2png(pred).astype(np.int16)
130
    gt_im = mask2png(gt).astype(np.int16)
131
132
    img = (img - img.min()) / (img.max() - img.min()) * 255
133
    img = np.stack([img, img, img], axis=2)
134
    # img = img_mask_png(img, gt, alpha=0.1)
135
136
    # im_rgb = np.hstack([im_rgb, 255*np.ones((h, 1, 3)), img])
137
138
    # cv2.putText(im_rgb, "prediction", (2,h-4),
139
    #             fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.4, color=(255, 255, 255), thickness=1)
140
    # cv2.putText(im_rgb, "ground truth", (w, h-4),
141
    #             fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.4, color=(255, 255, 255), thickness=1)
142
143
    # st_pos = 15
144
    # if infos is not None:
145
    #     for info in infos:
146
    #         cv2.putText(im_rgb, info+": {}".format(infos[info]), (2, st_pos),
147
    #                     fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.4, color=(255, 255, 255), thickness=1)
148
    #         st_pos += 10
149
150
    cv2.imwrite(filename+"_img.png", img[:,:,::-1])
151
    cv2.imwrite(filename+"_pred.png", pred_im[:,:,::-1])
152
    cv2.imwrite(filename+"_gt.png", gt_im[:,:,::-1])
153
154
def vis_df(pred_df, gt_df, filename, infos=None):
155
    _, h, w = pred_df.shape
156
157
    # save .npy files
158
    np.save(filename+'.npy', [pred_df, gt_df])
159
    
160
    theta = np.arctan2(gt_df[1,...], gt_df[0,...])
161
    degree_gt = (theta - theta.min()) / (theta.max() - theta.min()) * 255
162
    # degree_gt = theta * 360
163
    mag_gt = np.sum(gt_df ** 2, axis=0, keepdims=False)
164
    mag_gt = mag_gt / mag_gt.max() * 255
165
166
    theta = np.arctan2(pred_df[1,...], pred_df[0,...])
167
    degree_df = (theta - theta.min()) / (theta.max() - theta.min()) * 255
168
    # degree_df = theta * 360
169
    magnitude = np.sum(pred_df ** 2, axis=0, keepdims=False)
170
    magnitude = magnitude / magnitude.max() * 255
171
172
    im = np.hstack([magnitude, np.zeros((h, 1)), mag_gt, np.zeros((h, 1)), degree_df, np.zeros((h, 1)), degree_gt]).astype(np.uint8)
173
    im = cv2.applyColorMap(im, cv2.COLORMAP_JET)
174
    cv2.imwrite(filename+"_df_pred_mag.png", im[:h, :w, ...])
175
    cv2.imwrite(filename+"_df_gt_mag.png", im[:h, w+1:2*w+1, ...])
176
    cv2.imwrite(filename+"_df_pred_degree.png", im[:h, 2*w+2:3*w+2, ...])
177
    cv2.imwrite(filename+"_df_gt_degree.png", im[:h, 3*w+3:, ...])
178
179
180
def test():
181
    root_result_dir = args.output_dir
182
    os.makedirs(root_result_dir, exist_ok=True)
183
184
    log_file = os.path.join(root_result_dir, args.log_file)
185
    logger = create_logger(log_file)
186
187
    for key, val in vars(args).items():
188
        logger.info("{:16} {}".format(key, val))
189
190
    # create dataset & dataloader & network
191
    if args.used_df == 'U_NetDF':
192
        model = U_NetDF(selfeat=args.selfeat, num_class=4, auxseg=True)
193
    elif args.used_df == 'U_Net':
194
        model = U_Net(num_class=4)
195
196
    if args.mgpus is not None and len(args.mgpus) > 2:
197
        model = nn.DataParallel(model)
198
    model.cuda()
199
200
    test_loader, test_set = create_dataloader()
201
    
202
    checkpoint = torch.load(args.model_path1, map_location='cpu')
203
    model.load_state_dict(checkpoint['model_state'])
204
205
    dice_dict = {"LV_dice": [],
206
                    "RV_dice": [],
207
                    "MYO_dice": []}
208
    for i, data in enumerate(test_loader):
209
        if i != 23: continue
210
    # i = 5405
211
    # data = test_set[5405]
212
        # data = [data[0][None], data[1][None], data[2][None]]
213
214
        pred, pred_df = test_it(model, data[:3])
215
216
        _, gt, gt_df = data[:3]
217
        gt = gt.to("cuda")
218
219
        L, R, MYO = cal_perfer(make_one_hot(pred, 4), make_one_hot(gt, 4), dice_dict)
220
221
        data_info = test_set.data_infos[i]
222
        if args.vis:
223
        # if 0.7 <= (L + R + MYO) / 3 < 0.8:
224
            vis_it(pred.cpu().numpy()[0, 0], gt.cpu().numpy()[0, 0], data[0].cpu().numpy()[0, 0],
225
                    filename=os.path.join(root_result_dir, str(i)))
226
            if pred_df is not None:
227
                vis_df(pred_df.detach().cpu().numpy()[0], gt_df.cpu().numpy()[0], 
228
                        filename=os.path.join(root_result_dir, str(i)))
229
230
        print("\r{}/{} {:.0%}   {}".format(i, len(test_set), i/len(test_set), 
231
                                np.mean(list(dice_dict.values()))), end="")
232
    print()
233
234
    logger.info("2D Dice Metirc:")
235
    logger.info("Total {}".format(len(test_set)))
236
    logger.info("LV_dice: {}".format(np.mean(dice_dict["LV_dice"])))
237
    logger.info("RV_dice: {}".format(np.mean(dice_dict["RV_dice"])))
238
    logger.info("MYO_dice: {}".format(np.mean(dice_dict["MYO_dice"])))
239
    logger.info("Mean_dice: {}".format(np.mean(list(dice_dict.values()))))
240
241
242
if __name__ == "__main__":
243
    test()