|
a |
|
b/tools/test_df_vis.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
from torch.utils.data import DataLoader |
|
|
4 |
|
|
|
5 |
import os |
|
|
6 |
import argparse |
|
|
7 |
import numpy as np |
|
|
8 |
import cv2 |
|
|
9 |
import logging |
|
|
10 |
import math |
|
|
11 |
|
|
|
12 |
import _init_paths |
|
|
13 |
from libs.network import U_Net, U_NetDF |
|
|
14 |
from libs.datasets import AcdcDataset |
|
|
15 |
import libs.datasets.joint_augment as joint_augment |
|
|
16 |
import libs.datasets.augment as standard_augment |
|
|
17 |
from libs.datasets.collate_batch import BatchCollator |
|
|
18 |
|
|
|
19 |
from libs.configs.config_acdc import cfg |
|
|
20 |
from train_utils.train_utils import load_checkpoint |
|
|
21 |
from utils.metrics import dice |
|
|
22 |
from utils.vis_utils import mask2png, masks_to_contours, apply_mask, img_mask_png |
|
|
23 |
|
|
|
24 |
parser = argparse.ArgumentParser(description="arg parser") |
|
|
25 |
parser.add_argument('--used_df', type=str, default=False, help='whether to use df') |
|
|
26 |
parser.add_argument('--selfeat', action='store_true', default=False, help='whether to use feature select') |
|
|
27 |
parser.add_argument('--mgpus', type=str, default=None, required=True, help='whether to use multiple gpu') |
|
|
28 |
parser.add_argument('--model_path1', type=str, default=None, help='whether to train with evaluation') |
|
|
29 |
parser.add_argument('--model_path2', type=str, default=None, help='whether to train with evaluation') |
|
|
30 |
parser.add_argument('--batch_size', type=int, default=1, help='batch size for evaluation') |
|
|
31 |
parser.add_argument('--workers', type=int, default=4, help='number of workers for dataloader') |
|
|
32 |
parser.add_argument('--output_dir', type=str, default=None, required=True, help='specify an output directory if needed') |
|
|
33 |
parser.add_argument('--log_file', type=str, default="../log_evalation.txt", help="the file to write logging") |
|
|
34 |
parser.add_argument('--vis', action='store_true', default=False, help="weather to save test result images") |
|
|
35 |
args = parser.parse_args() |
|
|
36 |
|
|
|
37 |
if args.mgpus is not None: |
|
|
38 |
os.environ["CUDA_VISIBLE_DEVICES"] = args.mgpus |
|
|
39 |
|
|
|
40 |
def create_logger(log_file): |
|
|
41 |
log_format = '%(asctime)s %(levelname)5s %(message)s' |
|
|
42 |
logging.basicConfig(level=logging.DEBUG, format=log_format, filename=log_file) |
|
|
43 |
console = logging.StreamHandler() |
|
|
44 |
console.setLevel(logging.DEBUG) |
|
|
45 |
console.setFormatter(logging.Formatter(log_format)) |
|
|
46 |
logging.getLogger(__name__).addHandler(console) |
|
|
47 |
return logging.getLogger(__name__) |
|
|
48 |
|
|
|
49 |
def create_dataloader(): |
|
|
50 |
eval_transform = joint_augment.Compose([ |
|
|
51 |
joint_augment.To_PIL_Image(), |
|
|
52 |
joint_augment.FixResize(256), |
|
|
53 |
joint_augment.To_Tensor()]) |
|
|
54 |
evalImg_transform = standard_augment.Compose([ |
|
|
55 |
standard_augment.normalize([cfg.DATASET.MEAN], [cfg.DATASET.STD])]) |
|
|
56 |
|
|
|
57 |
if cfg.DATASET.NAME == "acdc": |
|
|
58 |
test_set = AcdcDataset(cfg.DATASET.TEST_LIST, df_used=True, joint_augment=eval_transform, |
|
|
59 |
augment=evalImg_transform) |
|
|
60 |
|
|
|
61 |
test_loader = DataLoader(test_set, batch_size=1, pin_memory=True, |
|
|
62 |
num_workers=args.workers, shuffle=False, |
|
|
63 |
collate_fn=BatchCollator(size_divisible=32, df_used=True)) |
|
|
64 |
return test_loader, test_set |
|
|
65 |
|
|
|
66 |
def cal_perfer(preds, masks, tb_dict): |
|
|
67 |
LV_dice = [] # 1 |
|
|
68 |
MYO_dice = [] # 2 |
|
|
69 |
RV_dice = [] # 3 |
|
|
70 |
|
|
|
71 |
for i in range(preds.shape[0]): |
|
|
72 |
LV_dice.append(dice(preds[i,1,:,:],masks[i,1,:,:])) |
|
|
73 |
RV_dice.append(dice(preds[i, 3, :, :], masks[i, 3, :, :])) |
|
|
74 |
MYO_dice.append(dice(preds[i, 2, :, :], masks[i, 2, :, :])) |
|
|
75 |
# LV_dice.append(dice(preds[i, 3,:,:],masks[i,1,:,:])) |
|
|
76 |
# RV_dice.append(dice(preds[i, 1, :, :], masks[i, 3, :, :])) |
|
|
77 |
# MYO_dice.append(dice(preds[i, 2, :, :], masks[i, 2, :, :])) |
|
|
78 |
|
|
|
79 |
tb_dict["LV_dice"].append(np.mean(LV_dice)) |
|
|
80 |
tb_dict["RV_dice"].append(np.mean(RV_dice)) |
|
|
81 |
tb_dict["MYO_dice"].append(np.mean(MYO_dice)) |
|
|
82 |
return np.mean(LV_dice), np.mean(RV_dice), np.mean(MYO_dice) |
|
|
83 |
|
|
|
84 |
def make_one_hot(input, num_classes): |
|
|
85 |
"""Convert class index tensor to one hot encoding tensor. |
|
|
86 |
Args: |
|
|
87 |
input: A tensor of shape [N, 1, *] |
|
|
88 |
num_classes: An int of number of class |
|
|
89 |
Returns: |
|
|
90 |
A tensor of shape [N, num_classes, *] |
|
|
91 |
""" |
|
|
92 |
shape = np.array(input.shape) |
|
|
93 |
shape[1] = num_classes |
|
|
94 |
shape = tuple(shape) |
|
|
95 |
result = torch.zeros(shape).scatter_(1, input.cpu().long(), 1) |
|
|
96 |
# result = result.scatter_(1, input.cpu(), 1) |
|
|
97 |
|
|
|
98 |
return result |
|
|
99 |
|
|
|
100 |
def test_it(model, data, device="cuda"): |
|
|
101 |
model.eval() |
|
|
102 |
imgs, gts = data[:2] |
|
|
103 |
gts_df = data[2] |
|
|
104 |
|
|
|
105 |
imgs = imgs.to(device) |
|
|
106 |
gts = gts.to(device) |
|
|
107 |
|
|
|
108 |
net_out = model(imgs) |
|
|
109 |
if len(net_out) > 1: |
|
|
110 |
preds_out = net_out[0] |
|
|
111 |
preds_df = net_out[1] |
|
|
112 |
else: |
|
|
113 |
preds_out = net_out[0] |
|
|
114 |
preds_df = None |
|
|
115 |
preds_out = nn.functional.softmax(preds_out, dim=1) |
|
|
116 |
_, preds = torch.max(preds_out, 1) |
|
|
117 |
preds = preds.unsqueeze(1) # (N, 1, *) |
|
|
118 |
|
|
|
119 |
return preds, preds_df |
|
|
120 |
|
|
|
121 |
def vis_it(pred, gt, img=None, filename=None, infos=None): |
|
|
122 |
h, w = pred.shape |
|
|
123 |
# gt_contours = masks_to_contours(gt) |
|
|
124 |
# mask = np.hstack([pred, np.zeros((h, 1)), gt]) |
|
|
125 |
# gt_contours = np.hstack([gt_contours, np.zeros((h, 1)), np.zeros_like(gt)]) |
|
|
126 |
# im_rgb = mask2png(mask).astype(np.int16) |
|
|
127 |
# im_rgb[:, w, :] = [255, 255, 255] |
|
|
128 |
# im_rgb = apply_mask(im_rgb, gt_contours, [255, 255, 255], 0.8) |
|
|
129 |
pred_im = mask2png(pred).astype(np.int16) |
|
|
130 |
gt_im = mask2png(gt).astype(np.int16) |
|
|
131 |
|
|
|
132 |
img = (img - img.min()) / (img.max() - img.min()) * 255 |
|
|
133 |
img = np.stack([img, img, img], axis=2) |
|
|
134 |
# img = img_mask_png(img, gt, alpha=0.1) |
|
|
135 |
|
|
|
136 |
# im_rgb = np.hstack([im_rgb, 255*np.ones((h, 1, 3)), img]) |
|
|
137 |
|
|
|
138 |
# cv2.putText(im_rgb, "prediction", (2,h-4), |
|
|
139 |
# fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.4, color=(255, 255, 255), thickness=1) |
|
|
140 |
# cv2.putText(im_rgb, "ground truth", (w, h-4), |
|
|
141 |
# fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.4, color=(255, 255, 255), thickness=1) |
|
|
142 |
|
|
|
143 |
# st_pos = 15 |
|
|
144 |
# if infos is not None: |
|
|
145 |
# for info in infos: |
|
|
146 |
# cv2.putText(im_rgb, info+": {}".format(infos[info]), (2, st_pos), |
|
|
147 |
# fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.4, color=(255, 255, 255), thickness=1) |
|
|
148 |
# st_pos += 10 |
|
|
149 |
|
|
|
150 |
cv2.imwrite(filename+"_img.png", img[:,:,::-1]) |
|
|
151 |
cv2.imwrite(filename+"_pred.png", pred_im[:,:,::-1]) |
|
|
152 |
cv2.imwrite(filename+"_gt.png", gt_im[:,:,::-1]) |
|
|
153 |
|
|
|
154 |
def vis_df(pred_df, gt_df, filename, infos=None): |
|
|
155 |
_, h, w = pred_df.shape |
|
|
156 |
|
|
|
157 |
# save .npy files |
|
|
158 |
np.save(filename+'.npy', [pred_df, gt_df]) |
|
|
159 |
|
|
|
160 |
theta = np.arctan2(gt_df[1,...], gt_df[0,...]) |
|
|
161 |
degree_gt = (theta - theta.min()) / (theta.max() - theta.min()) * 255 |
|
|
162 |
# degree_gt = theta * 360 |
|
|
163 |
mag_gt = np.sum(gt_df ** 2, axis=0, keepdims=False) |
|
|
164 |
mag_gt = mag_gt / mag_gt.max() * 255 |
|
|
165 |
|
|
|
166 |
theta = np.arctan2(pred_df[1,...], pred_df[0,...]) |
|
|
167 |
degree_df = (theta - theta.min()) / (theta.max() - theta.min()) * 255 |
|
|
168 |
# degree_df = theta * 360 |
|
|
169 |
magnitude = np.sum(pred_df ** 2, axis=0, keepdims=False) |
|
|
170 |
magnitude = magnitude / magnitude.max() * 255 |
|
|
171 |
|
|
|
172 |
im = np.hstack([magnitude, np.zeros((h, 1)), mag_gt, np.zeros((h, 1)), degree_df, np.zeros((h, 1)), degree_gt]).astype(np.uint8) |
|
|
173 |
im = cv2.applyColorMap(im, cv2.COLORMAP_JET) |
|
|
174 |
cv2.imwrite(filename+"_df_pred_mag.png", im[:h, :w, ...]) |
|
|
175 |
cv2.imwrite(filename+"_df_gt_mag.png", im[:h, w+1:2*w+1, ...]) |
|
|
176 |
cv2.imwrite(filename+"_df_pred_degree.png", im[:h, 2*w+2:3*w+2, ...]) |
|
|
177 |
cv2.imwrite(filename+"_df_gt_degree.png", im[:h, 3*w+3:, ...]) |
|
|
178 |
|
|
|
179 |
|
|
|
180 |
def test(): |
|
|
181 |
root_result_dir = args.output_dir |
|
|
182 |
os.makedirs(root_result_dir, exist_ok=True) |
|
|
183 |
|
|
|
184 |
log_file = os.path.join(root_result_dir, args.log_file) |
|
|
185 |
logger = create_logger(log_file) |
|
|
186 |
|
|
|
187 |
for key, val in vars(args).items(): |
|
|
188 |
logger.info("{:16} {}".format(key, val)) |
|
|
189 |
|
|
|
190 |
# create dataset & dataloader & network |
|
|
191 |
if args.used_df == 'U_NetDF': |
|
|
192 |
model = U_NetDF(selfeat=args.selfeat, num_class=4, auxseg=True) |
|
|
193 |
elif args.used_df == 'U_Net': |
|
|
194 |
model = U_Net(num_class=4) |
|
|
195 |
|
|
|
196 |
if args.mgpus is not None and len(args.mgpus) > 2: |
|
|
197 |
model = nn.DataParallel(model) |
|
|
198 |
model.cuda() |
|
|
199 |
|
|
|
200 |
test_loader, test_set = create_dataloader() |
|
|
201 |
|
|
|
202 |
checkpoint = torch.load(args.model_path1, map_location='cpu') |
|
|
203 |
model.load_state_dict(checkpoint['model_state']) |
|
|
204 |
|
|
|
205 |
dice_dict = {"LV_dice": [], |
|
|
206 |
"RV_dice": [], |
|
|
207 |
"MYO_dice": []} |
|
|
208 |
for i, data in enumerate(test_loader): |
|
|
209 |
if i != 23: continue |
|
|
210 |
# i = 5405 |
|
|
211 |
# data = test_set[5405] |
|
|
212 |
# data = [data[0][None], data[1][None], data[2][None]] |
|
|
213 |
|
|
|
214 |
pred, pred_df = test_it(model, data[:3]) |
|
|
215 |
|
|
|
216 |
_, gt, gt_df = data[:3] |
|
|
217 |
gt = gt.to("cuda") |
|
|
218 |
|
|
|
219 |
L, R, MYO = cal_perfer(make_one_hot(pred, 4), make_one_hot(gt, 4), dice_dict) |
|
|
220 |
|
|
|
221 |
data_info = test_set.data_infos[i] |
|
|
222 |
if args.vis: |
|
|
223 |
# if 0.7 <= (L + R + MYO) / 3 < 0.8: |
|
|
224 |
vis_it(pred.cpu().numpy()[0, 0], gt.cpu().numpy()[0, 0], data[0].cpu().numpy()[0, 0], |
|
|
225 |
filename=os.path.join(root_result_dir, str(i))) |
|
|
226 |
if pred_df is not None: |
|
|
227 |
vis_df(pred_df.detach().cpu().numpy()[0], gt_df.cpu().numpy()[0], |
|
|
228 |
filename=os.path.join(root_result_dir, str(i))) |
|
|
229 |
|
|
|
230 |
print("\r{}/{} {:.0%} {}".format(i, len(test_set), i/len(test_set), |
|
|
231 |
np.mean(list(dice_dict.values()))), end="") |
|
|
232 |
print() |
|
|
233 |
|
|
|
234 |
logger.info("2D Dice Metirc:") |
|
|
235 |
logger.info("Total {}".format(len(test_set))) |
|
|
236 |
logger.info("LV_dice: {}".format(np.mean(dice_dict["LV_dice"]))) |
|
|
237 |
logger.info("RV_dice: {}".format(np.mean(dice_dict["RV_dice"]))) |
|
|
238 |
logger.info("MYO_dice: {}".format(np.mean(dice_dict["MYO_dice"]))) |
|
|
239 |
logger.info("Mean_dice: {}".format(np.mean(list(dice_dict.values())))) |
|
|
240 |
|
|
|
241 |
|
|
|
242 |
if __name__ == "__main__": |
|
|
243 |
test() |