[405042]: / cardiac_motion / utils / visualise.py

Download this file

171 lines (139 with data), 6.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import numpy as np
import torch
import os
import random
from matplotlib import pyplot as plt
def plot_warped_grid(ax, disp, bg_img=None, interval=3, title="$\mathcal{T}_\phi$", fontsize=30, color="c"):
"""disp shape (2, H, W)"""
if bg_img is not None:
background = bg_img
else:
background = np.zeros(disp.shape[1:])
id_grid_H, id_grid_W = np.meshgrid(
range(0, background.shape[0] - 1, interval), range(0, background.shape[1] - 1, interval), indexing="ij"
)
new_grid_H = id_grid_H + disp[0, id_grid_H, id_grid_W]
new_grid_W = id_grid_W + disp[1, id_grid_H, id_grid_W]
kwargs = {"linewidth": 1.0, "color": color}
# matplotlib.plot() uses CV x-y indexing
for i in range(new_grid_H.shape[0]):
ax.plot(new_grid_W[i, :], new_grid_H[i, :], **kwargs) # each draws a horizontal line
for i in range(new_grid_H.shape[1]):
ax.plot(new_grid_W[:, i], new_grid_H[:, i], **kwargs) # each draws a vertical line
ax.set_title(title, fontsize=fontsize)
ax.imshow(background, cmap="gray")
# ax.axis('off')
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
ax.set_frame_on(False)
def plot_result_fig(vis_data_dict, save_path=None, title_font_size=20, dpi=100, show=False, close=False):
"""Plot visual results in a single figure/subplots.
Images should be shaped (*sizes)
Disp should be shaped (ndim, *sizes)
vis_data_dict.keys() = ['target', 'source', 'source_ref',
'target_pred', 'warped_source',
'disp_gt', 'disp_pred']
"""
# "source_ref" is source image for monomodal registration,
# or the image in target modality that's originally aligned with the source image if available
if "source_ref" not in vis_data_dict.keys():
vis_data_dict["source_ref"] = vis_data_dict["source"]
if "disp_gt" not in vis_data_dict.keys():
vis_data_dict["disp_gt"] = np.zeros_like(vis_data_dict["disp_pred"])
fig = plt.figure(figsize=(30, 18))
title_pad = 10
ax = plt.subplot(2, 4, 1)
plt.imshow(vis_data_dict["target"], cmap="gray")
plt.axis("off")
ax.set_title("Target", fontsize=title_font_size, pad=title_pad)
ax = plt.subplot(2, 4, 2)
plt.imshow(vis_data_dict["source_ref"], cmap="gray")
plt.axis("off")
ax.set_title("Source reference", fontsize=title_font_size, pad=title_pad)
# calculate the error before and after registration
error_before = vis_data_dict["target"] - vis_data_dict["source_ref"]
error_after = vis_data_dict["target"] - vis_data_dict["target_pred"]
# error before
ax = plt.subplot(2, 4, 3)
plt.imshow(error_before, vmin=-2, vmax=2, cmap="seismic") # assuming images were normalised to [0, 1]
plt.axis("off")
ax.set_title("Error before", fontsize=title_font_size, pad=title_pad)
# error after
ax = plt.subplot(2, 4, 4)
plt.imshow(error_after, vmin=-2, vmax=2, cmap="seismic") # assuming images were normalised to [0, 1]
plt.axis("off")
ax.set_title("Error after", fontsize=title_font_size, pad=title_pad)
# predicted target image
ax = plt.subplot(2, 4, 5)
plt.imshow(vis_data_dict["target_pred"], cmap="gray")
plt.axis("off")
ax.set_title("Target predict", fontsize=title_font_size, pad=title_pad)
# warped source image
ax = plt.subplot(2, 4, 6)
plt.imshow(vis_data_dict["warped_source"], cmap="gray")
plt.axis("off")
ax.set_title("Warped source", fontsize=title_font_size, pad=title_pad)
# warped grid: ground truth
ax = plt.subplot(2, 4, 7)
bg_img = vis_data_dict["source"]
plot_warped_grid(ax, vis_data_dict["disp_gt"], bg_img, interval=3, title="$\phi_{GT}$", fontsize=title_font_size)
# warped grid: prediction
ax = plt.subplot(2, 4, 8)
plot_warped_grid(
ax, vis_data_dict["disp_pred"], bg_img, interval=3, title="$\phi_{pred}$", fontsize=title_font_size
)
# adjust subplot placements and spacing
plt.subplots_adjust(left=0.0001, right=0.99, top=0.9, bottom=0.1, wspace=0.001, hspace=0.1)
# saving
if save_path is not None:
fig.savefig(save_path, bbox_inches="tight", dpi=dpi)
if show:
plt.show()
if close:
plt.close()
return fig
def visualise_result(data_dict, axis=0, save_result_dir=None, epoch=None, dpi=50):
"""
Save one validation visualisation figure for each epoch.
- 2D: 1 random slice from N-slice stack (not a sequence)
- 3D: the middle slice on the chosen axis
Args:
data_dict: (dict) images shape (N, 1, *sizes), disp shape (N, ndim, *sizes)
save_result_dir: (string) Path to visualisation result directory
epoch: (int) Epoch number (for naming when saving)
axis: (int) For 3D only, choose the 2D plane orthogonal to this axis in 3D volume
dpi: (int) Image resolution of saved figure
"""
# check cast to Numpy array
for n, d in data_dict.items():
if isinstance(d, torch.Tensor):
data_dict[n] = d.cpu().numpy()
ndim = data_dict["target"].ndim - 2
sizes = data_dict["target"].shape[2:]
# put 2D slices into visualisation data dict
vis_data_dict = {}
if ndim == 2:
# randomly choose a slice for 2D
z = random.randint(0, data_dict["target"].shape[0] - 1)
for name, d in data_dict.items():
vis_data_dict[name] = data_dict[name].squeeze()[z, ...] # (H, W) or (2, H, W)
else: # 3D
# visualise the middle slice of the chosen axis
z = int(sizes[axis] // 2)
for name, d in data_dict.items():
if name in ["disp_pred", "disp_gt"]:
# dvf: choose the two axes/directions to visualise
axes = [0, 1, 2]
axes.remove(axis)
vis_data_dict[name] = d[0, axes, ...].take(z, axis=axis + 1) # (2, X, X)
else:
# images
vis_data_dict[name] = d[0, 0, ...].take(z, axis=axis) # (X, X)
# set up figure saving path
if save_result_dir is not None:
fig_save_path = os.path.join(save_result_dir, f"epoch{epoch}_axis_{axis}_slice_{z}.png")
else:
fig_save_path = None
fig = plot_result_fig(vis_data_dict, save_path=fig_save_path, dpi=dpi)
return fig