[405042]: / cardiac_motion / utils / flow_utils.py

Download this file

312 lines (241 with data), 11.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import torch
import torch.nn.functional as F
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import nibabel as nib
import imageio
from utils.imageio_utils import save_gif, save_png, save_nifti
from model.submodules import resample_transform_cpu
def flow_line_integral(op_flow):
"""
Perform approximated line integral of frame-to-frame optical flow
using Pytorch and GPU
Args:
op_flow: optical flow, Tensor, shape (N, 2, W, H)
Returns:
accum_flow: line integrated optical flow, same shape as input
"""
# generate a standard grid
h, w = op_flow.size()[-2:]
std_grid_h, std_grid_w = torch.meshgrid([torch.linspace(-1, 1, h), torch.linspace(-1, 1, w)])
std_grid = torch.stack([std_grid_h, std_grid_w]) # (2, H, W)
std_grid = std_grid.cuda().float()
# loop over frames
accum_flow = []
for fr in range(op_flow.size()[-1]):
if fr == 0:
new_grid = std_grid + op_flow[fr, ...] # (2, H, W) + (2, H, W)
else:
# sample optical flow at current new_grid, then update new_grid by sampled offset
# this line is unnecessarily complicated thanks to pytorch
new_grid += F.grid_sample(op_flow[fr, ...].unsqueeze(0), new_grid.unsqueeze(0).permute(0, 2, 3, 1)).squeeze(
0
) # (2, H, W)
accum_flow += [new_grid - std_grid]
accum_flow = torch.stack(accum_flow) # (N, 2, H, W)
return accum_flow
def flow_to_hsv(opt_flow, max_mag=0.1, white_bg=False):
"""
Encode optical flow to HSV.
Args:
opt_flow: 2D optical flow in (dx, dy) encoding, shape (H, W, 2)
max_mag: flow magnitude will be normalised to [0, max_mag]
Returns:
hsv_flow_rgb: HSV encoded flow converted to RGB (for visualisation), same shape as input
"""
# convert to polar coordinates
mag, ang = cv2.cartToPolar(opt_flow[..., 0], opt_flow[..., 1])
# hsv encoding
hsv_flow = np.zeros((opt_flow.shape[0], opt_flow.shape[1], 3))
hsv_flow[..., 0] = ang * 180 / np.pi / 2 # hue = angle
hsv_flow[..., 1] = 255.0 # saturation = 255
hsv_flow[..., 2] = 255.0 * mag / max_mag
# (wrong) hsv_flow[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
# convert hsv encoding to rgb for visualisation
# ([..., ::-1] converts from BGR to RGB)
hsv_flow_rgb = cv2.cvtColor(hsv_flow.astype(np.uint8), cv2.COLOR_HSV2BGR)[..., ::-1]
hsv_flow_rgb = hsv_flow_rgb.astype(np.uint8)
if white_bg:
hsv_flow_rgb = 255 - hsv_flow_rgb
return hsv_flow_rgb
def blend_image_seq(images1, images2, alpha=0.7):
"""
Blend two sequences of images.
(used in this project to blend HSV-encoded flow with image)
Repeat to fill RGB channels if needed.
Args:
images1: numpy array, shape (H, W, Ch, Frames) or (H, W, Frames)
images2: numpy array, shape (H, W, Ch, Frames) or (H, W, Frames)
alpha: mixing weighting, higher alpha increase image 2. (1 - alpha) * images1 + alpha * images2
Returns:
blended_images: numpy array, shape (H, W, Ch, Frames)
"""
if images1.ndim < images2.ndim:
images1 = np.repeat(images1[:, :, np.newaxis, :], images2.shape[2], axis=2)
elif images1.ndim > images2.ndim:
images2 = np.repeat(images2[:, :, np.newaxis, :], images1.shape[2], axis=2)
assert images1.shape == images2.shape, "Blending: images being blended have different shapes, {} vs {}".format(
images1.shape, images2.shape
)
blended_images = (1 - alpha) * images1 + alpha * images2
return blended_images.astype(np.uint8)
def save_flow_hsv(op_flow, background, save_result_dir, fps=20, max_mag=0.1):
"""
Save HSV encoded optical flow overlayed on background image.
GIF and PNG images
Args:
op_flow: numpy array of shape (N, H, W, 2)
background: numpy array of shape (N, H, W)
save_result_dir: path to save result dir
fps: frames per second, for gif
max_mag: maximum flow magnitude used in normalisation
Returns:
"""
# encode flow in hsv
op_flow_hsv = []
for fr in range(op_flow.shape[0]):
op_flow_hsv += [flow_to_hsv(op_flow[fr, :, :, :], max_mag=max_mag)] # a list of N items each shaped (H, W, ch)
# save flow sequence into a gif file and a sequence of png files
op_flow_hsv = np.array(op_flow_hsv).transpose(1, 2, 3, 0) # (H, W, 3, N)
# overlay on background images
op_flow_hsv_blend = blend_image_seq(background.transpose(1, 2, 0), op_flow_hsv)
# save gif and png
save_result_dir = os.path.join(save_result_dir, "hsv_flow")
if not os.path.exists(save_result_dir):
os.makedirs(save_result_dir)
save_gif(op_flow_hsv, os.path.join(save_result_dir, "flow.gif"), fps=fps)
save_gif(op_flow_hsv_blend, os.path.join(save_result_dir, "flow_blend.gif"), fps=fps)
save_png(op_flow_hsv_blend, save_result_dir)
print("HSV flow saved to: {}".format(save_result_dir))
def save_flow_quiver(op_flow, background, save_result_dir, scale=1, interval=3, fps=20):
"""
Plot quiver plot and save.
Args:
op_flow: numpy array of shape (N, H, W, 2)
background: numpy array of shape (N, H, W)
save_result_dir: path to save result dir
Returns:
"""
# set up saving directory
save_result_dir = os.path.join(save_result_dir, "quiver")
if not os.path.exists(save_result_dir):
os.makedirs(save_result_dir)
# create mesh grid of vector origins
# note: numpy uses x-y order in generating mesh grid, i.e. (x, y) = (w, h)
mesh_x, mesh_y = np.meshgrid(
range(0, background.shape[1] - 1, interval), range(0, background.shape[2] - 1, interval)
)
png_list = []
for fr in range(background.shape[0]):
fig, ax = plt.subplots(figsize=(5, 5))
ax = plt.imshow(background[fr, :, :], cmap="gray")
ax = plt.quiver(
mesh_x,
mesh_y,
op_flow[fr, mesh_y, mesh_x, 1],
op_flow[fr, mesh_y, mesh_x, 0],
angles="xy",
scale_units="xy",
scale=scale,
color="g",
)
save_path = os.path.join(save_result_dir, "frame_{}.png".format(fr))
plt.axis("off")
fig.savefig(save_path, bbox_inches="tight")
plt.close(fig)
# read it back to make gif
png_list += [imageio.imread(save_path)]
# save gif
imageio.mimwrite(os.path.join(save_result_dir, "quiver.gif"), png_list, fps=fps)
print("Flow quiver plots saved to: {}".format(save_result_dir))
def save_warp_n_error(warped_source, target, source, save_result_dir, fps=20):
"""
Calculate warping and save results
Args:
warped_source: source images warped to target images, numpy array shaped (N, H, W)
target: target image, numpy array shaped (N, H, W)
source: numpy array shaped (N, H, W)
save_result_dir: numpy array shaped (N, H, W)
fps:
Returns:
"""
# transpose all to (H, W, N)
warped_source = warped_source.transpose(1, 2, 0)
target = target.transpose(1, 2, 0)
source = source.transpose(1, 2, 0)
# calculate error normalised to (0, 255)
error = np.abs(warped_source - target)
error_before = np.abs(source - target)
save_gif(error, os.path.join(save_result_dir, "error.gif"), fps=fps)
save_gif(error_before, os.path.join(save_result_dir, "error_before.gif"), fps=fps)
save_gif(target, os.path.join(save_result_dir, "target.gif"), fps=fps)
save_gif(warped_source, os.path.join(save_result_dir, "wapred_source.gif"), fps=fps)
save_gif(source, os.path.join(save_result_dir, "source.gif"), fps=fps)
print("Warping and error saved to: {}".format(save_result_dir))
def dof_to_dvf(target_img, dofin, dvfout, output_dir):
"""
Convert MIRTK format DOF file to dense Deformation Vector Field (DVF)
by warping meshgrid using MIRTK transform-image
Note: this function only handles 2D deformation a.t.m.
Args:
target_img: (str) full path of target image of the registration
dofin: (str) full path to input DOF file
dofdir: (str) full path of the dir of input DOF file
dvfout: (str) output DVF file name under output_dir (without ".nii.gz")
output_dir: (str) full path of output dir
Returns:
dvf: (ndarray of shape HxWx2) dense DVF, numpy array coordinate system
"""
# read a target image to get the affine transformation
# this is to ensure the mirtk transform-image apply the DOF under the same transformation
nim = nib.load(target_img)
# generate the meshgrids
mesh_x, mesh_y = np.meshgrid(range(0, nim.shape[0], 1), range(0, nim.shape[1], 1))
mesh_x = mesh_x.astype("float")
mesh_y = mesh_y.astype("float")
# save both (x and y) into nifti files, using the target image header
nim_meshx = nib.Nifti1Image(mesh_x, nim.affine)
meshx_path = "{0}/{1}_meshx.nii.gz".format(output_dir, dvfout)
nib.save(nim_meshx, meshx_path)
nim_meshy = nib.Nifti1Image(mesh_y, nim.affine)
meshy_path = "{0}/{1}_meshy.nii.gz".format(output_dir, dvfout)
nib.save(nim_meshy, meshy_path)
# use mirtk to transform it with DOF file as input
warped_meshx_path = "{0}/{1}_meshx_warped.nii.gz".format(output_dir, dvfout)
warped_meshy_path = "{0}/{1}_meshy_warped.nii.gz".format(output_dir, dvfout)
os.system(
"mirtk transform-image {0} {1} -dofin {2} -target {3}".format(meshx_path, warped_meshx_path, dofin, target_img)
)
os.system(
"mirtk transform-image {0} {1} -dofin {2} -target {3}".format(meshy_path, warped_meshy_path, dofin, target_img)
)
# read in the generated mesh grid x and grid y
warp_meshx = nib.load(warped_meshx_path).get_data()[:, :, 0]
warp_meshy = nib.load(warped_meshy_path).get_data()[:, :, 0]
# calculate the DVF by substracting the initial mesh grid from it
dvf_x = warp_meshx - mesh_x
dvf_y = warp_meshy - mesh_y
dvf = np.array([dvf_y, dvf_x]).transpose(1, 2, 0) # (H, W, 2), notice the x-y swap
# save flow to nifti
ndvf = nib.Nifti1Image(dvf, nim.affine)
nib.save(ndvf, "{0}/{1}.nii.gz".format(output_dir, dvfout))
# clean up: remove all the mesh files
os.system("rm {0}/*mesh*".format(output_dir))
return dvf
def warp_numpy_cpu(source_img, dvf):
"""
Warp numpy array image using Pytorch's resample based function on CPU
Args:
source_img: (ndarry shape: H,W,N) source image, assume square image
dvf: (ndarray shape: H,W,N,1,2) dense displacement vector field, not normalised to [-1,1]
Returns:
warped_source_img: (ndarray shape: H,W,N) resample deformed source image
"""
dvf_norm = 2 * dvf / source_img.shape[0] # normalise to Pytorch coordinate system
dvf_tensor = torch.from_numpy(dvf_norm[:, :, :, 0, :].transpose(2, 3, 0, 1)).float() # tensor (N, 2, H, W)
source_img_tensor = torch.from_numpy(source_img.transpose(2, 0, 1)).unsqueeze(1) # tensor (N, 1, H, W)
warped_source_img_tensor = resample_transform_cpu(source_img_tensor, dvf_tensor)
warped_source_img = warped_source_img_tensor.numpy().transpose(2, 3, 0, 1)[..., 0] # (H, W, N)
return warped_source_img