[805160]: / cardiac_motion / test.py

Download this file

314 lines (272 with data), 12.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
from tqdm import tqdm
import os
import argparse
import logging
import numpy as np
import pandas as pd
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from model.networks import BaseNet
from model.submodules import resample_transform
from model.dataset_utils import CenterCrop, Normalise, ToTensor
from model.datasets import CardiacMR_2D_Eval_UKBB
from utils.metrics import categorical_dice_stack, contour_distances_stack, detJac_stack, bending_energy_stack
from utils import xutils
from utils.visualise import visualise_result
STRUCTURES = ["lv", "myo", "rv"]
SEG_METRICS = ["dice", "mcd", "hd"]
DVF_METRICS = ["mean_mag_grad_detJ", "negative_detJ", "bending_energy"]
METRICS = [f"{metric}_{struct}" for metric in SEG_METRICS for struct in STRUCTURES] + DVF_METRICS
def test(
model,
dataloader,
model_dir,
pixel_size=1.0,
all_slices=False,
run_inference=True,
run_eval=True,
save_output=False,
save_metric_results=False,
log_visual_tb=False,
summary_writer=None,
device=torch.device("cpu"),
):
"""Run model inference on test dataset"""
model.eval()
# initialise metric result dictionary
metric_results_lists = {metric: [] for metric in METRICS}
# set up saved output dir
test_output_dir = f"{model_dir}/test_output"
if save_output:
logging.info(f"Inference output will be saved at: {test_output_dir}")
if not os.path.exists(test_output_dir):
os.makedirs(test_output_dir)
with tqdm(total=len(dataloader)) as t:
for idx, (
image_ed_batch,
image_es_batch,
label_ed_batch,
label_es_batch,
) in enumerate(dataloader):
# (c, N, H, W) to (N, c, H, W)
image_ed_batch = image_ed_batch.permute(1, 0, 2, 3).to(device=device)
image_es_batch = image_es_batch.permute(1, 0, 2, 3).to(device=device)
label_es_batch = label_es_batch.permute(1, 0, 2, 3).to(device=device)
if run_inference:
# run model inference
with torch.no_grad():
dvf_pred = model(image_ed_batch, image_es_batch)
if save_output:
test_output_dir_subj = f"{test_output_dir}/{dataloader.dataset.dir_list[idx]}"
if not os.path.exists(test_output_dir_subj):
os.makedirs(test_output_dir_subj)
dvf_save = dvf_pred.detach().cpu().numpy()
np.save(f"{test_output_dir_subj}/dvf.npy", dvf_save)
else:
# load saved output from disk
assert os.path.exists(
test_output_dir
), f"Test output dir {test_output_dir} doesn't exist, have you run inference? "
test_output_dir_subj = f"{test_output_dir}/{dataloader.dataset.dir_list[idx]}"
dvf_loaded = np.load(f"{test_output_dir_subj}/dvf.npy")
dvf_pred = torch.from_numpy(dvf_loaded).to(device)
# transform label mask of ES frame
warped_label_es_batch = resample_transform(label_es_batch.float(), dvf_pred, interp="nearest")
# Move data to cpu and numpy
warped_label_es_batch = warped_label_es_batch.squeeze(1).cpu().numpy().transpose(1, 2, 0) # (H, W, N)
label_ed_batch = label_ed_batch.squeeze(0).numpy().transpose(1, 2, 0) # (H, W, N)
dvf = dvf_pred.data.cpu().numpy().transpose(0, 2, 3, 1) # (N, H, W, 2)
# visualise in Tensorboard
if log_visual_tb:
warped_source = resample_transform(image_es_batch, dvf_pred)
vis_data_dict = {
"target": image_ed_batch.cpu().numpy(),
"source": image_es_batch.cpu().numpy(),
"target_original": image_es_batch.cpu().numpy(),
"target_pred": warped_source.cpu().numpy(),
"warped_source": warped_source.cpu().numpy(),
"disp_pred": dvf.transpose(0, 3, 1, 2) * image_ed_batch.shape[-1] / 2,
}
fig = visualise_result(vis_data_dict)
summary_writer.add_figure("val_test_fig", fig, close=True)
if run_eval:
if not all_slices:
# extract 3 slices (apical, mid-ventricle and basal)
num_slices = label_ed_batch.shape[-1]
apical_idx = int(round((num_slices - 1) * 0.75)) # 75% from basal
mid_ven_idx = int(round((num_slices - 1) * 0.5)) # 50% from basal
basal_idx = int(round((num_slices - 1) * 0.25)) # 25% from basal
slices_idx = [apical_idx, mid_ven_idx, basal_idx]
warped_label_es_batch = warped_label_es_batch[:, :, slices_idx]
label_ed_batch = label_ed_batch[:, :, slices_idx]
dvf = dvf[slices_idx, :, :, :] # needed for detJac
# accumulate metric results
metrics_result_per_batch = evaluate_per_batch(
warped_label_es_batch, label_ed_batch, dvf, pixel_size=pixel_size
)
for metric in metric_results_lists.keys():
metric_results_lists[metric].append(metrics_result_per_batch[metric])
t.update()
if run_eval:
logging.info("Metrics evaluated...")
# reduce metrics results to mean and std
metric_results_mean_std = {}
for metric, result_list in metric_results_lists.items():
metric_results_mean_std[f"{metric}_mean"] = np.mean(result_list)
metric_results_mean_std[f"{metric}_std"] = np.std(result_list)
# calculate the segmentation metric average over structures
for metric in SEG_METRICS:
metric_results_mean_std[f"{metric}_mean"] = np.nanmean(
[
metric_results_mean_std[k] if metric in k and "_mean" in k else np.NAN
for k in metric_results_mean_std.keys()
]
)
if save_metric_results:
# save all metrics evaluated for all test subjects in pandas dataframe
test_result_dir = os.path.join(model_dir, "test_results")
if not os.path.exists(test_result_dir):
os.makedirs(test_result_dir)
logging.info(f"Saving metric results at: {test_result_dir}")
# save metrics results mean & std
xutils.save_dict_to_json(
metric_results_mean_std,
f"{test_result_dir}/test_results_3slices_{not all_slices}.json",
)
# save accuracy metrics of every subject
subj_id_buffer = dataloader.dataset.dir_list
df_buffer = []
column_method = ["DL"] * len(subj_id_buffer)
for struct in STRUCTURES:
ls_struct = [struct] * len(subj_id_buffer)
seg_metric_data = {
"Method": column_method,
"ID": subj_id_buffer,
"Structure": ls_struct,
}
for metric in SEG_METRICS:
seg_metric_data[metric] = metric_results_lists[f"{metric}_{struct}"]
df_buffer += [pd.DataFrame(data=seg_metric_data)]
# concatenate df and save
metrics_df = pd.concat(df_buffer, axis=0)
metrics_df.to_pickle(f"{test_result_dir}/test_accuracy_results_3slices_{not all_slices}.pkl")
# save detJac metrics for every subject
jac_metric_data = {
"Method": column_method,
"ID": subj_id_buffer,
}
for metric in DVF_METRICS:
jac_metric_data[metric] = metric_results_lists[metric]
jac_df = pd.DataFrame(data=jac_metric_data)
jac_df.to_pickle(f"{test_result_dir}/test_Jacobian_results_3slices{not all_slices}.pkl")
return metric_results_mean_std
def evaluate_per_batch(warped_label_es_batch, label_ed_batch, dvf, pixel_size=1.0):
metric_results = {metric: 0.0 for metric in METRICS}
# dice
for cls, struct in enumerate(STRUCTURES):
metric_results[f"dice_{struct}"] = categorical_dice_stack(
warped_label_es_batch, label_ed_batch, label_class=cls + 1
)
# contour distances
for cls, struct in enumerate(STRUCTURES):
(metric_results[f"mcd_{struct}"], metric_results[f"hd_{struct}"],) = contour_distances_stack(
warped_label_es_batch,
label_ed_batch,
label_class=cls + 1,
dx=pixel_size,
)
# dvf regularity and smoothness metrics
(
metric_results["mean_mag_grad_detJ"],
metric_results["negative_detJ"],
) = detJac_stack(dvf)
metric_results["bending_energy"] = bending_energy_stack(dvf, rescaleFlow=True)
return metric_results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_dir",
default=None,
help="Main directory for the model (with params.json)",
)
parser.add_argument(
"--restore_file",
default="best.pth.tar",
help="Name of the file in --model_dir storing model to load before training",
)
parser.add_argument(
"--all_slices",
action="store_true",
help="Evaluate metrics on all slices instead of 3 (75%/50%/30%) tran-axial slices.",
)
parser.add_argument("--no_inference", action="store_true")
parser.add_argument("--no_eval", action="store_true")
parser.add_argument("--no_save_output", action="store_true")
parser.add_argument("--no_save_metrics", action="store_true")
parser.add_argument("--no_cuda", action="store_true")
parser.add_argument("--gpu", default=0, help="Choose GPU")
parser.add_argument(
"--num_workers",
default=8,
type=int,
help="Number of dataloader workers, 0 for main process only",
)
args = parser.parse_args()
# set device
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
args.device = torch.device("cuda")
else:
args.device = torch.device("cpu")
# set up logger
xutils.set_logger(os.path.join(args.model_dir, "eval.log"))
logging.info(f"Running evaluation of model: \n\t{args.model_dir}")
# check whether the trained model exists
assert os.path.exists(args.model_dir), f"No model dir found at {args.model_dir}"
# load setting parameters from a JSON file
json_path = os.path.join(args.model_dir, "params.json")
assert os.path.isfile(json_path), f"No json configuration file found at {json_path}"
params = xutils.Params(json_path)
# set dataset and DataLoader
logging.info(f"Eval data path: \n\t{params.eval_data_path}")
eval_dataset = CardiacMR_2D_Eval_UKBB(
params.eval_data_path,
seq=params.seq,
label_prefix=params.label_prefix,
transform=transforms.Compose([CenterCrop(params.crop_size), Normalise(), ToTensor()]),
label_transform=transforms.Compose([CenterCrop(params.crop_size), ToTensor()]),
)
eval_dataloader = DataLoader(
eval_dataset,
batch_size=params.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=args.cuda,
)
# set up model and loss function
model = BaseNet()
model = model.to(device=args.device)
# load network parameters from saved checkpoint
if not args.no_inference:
logging.info(f"Loading model from saved file: \n\t{os.path.join(args.model_dir, args.restore_file)}")
xutils.load_checkpoint(os.path.join(args.model_dir, args.restore_file), model)
logging.info("Start running testing...")
if args.no_inference:
logging.info("Loading outputs from disk instead of running inference...")
else:
logging.info("Running model inference...")
test(
model,
eval_dataloader,
args.model_dir,
pixel_size=params.pixel_size,
all_slices=args.all_slices,
run_inference=(not args.no_inference),
run_eval=(not args.no_eval),
save_output=(not args.no_save_output),
save_metric_results=(not args.no_save_metrics),
device=args.device,
)
logging.info(f"Testing complete.")