[405042]: / cardiac_motion / model / datasets.py

Download this file

363 lines (293 with data), 13.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
import os
import os.path as path
import random
import datetime
import numpy as np
import nibabel as nib
import torch
import torch.nn.functional as F
import torch.utils.data as data
class CardiacMR_2D_UKBB(data.Dataset):
"""
Training class for UKBB. Loads the specific ED file as target.
"""
def __init__(self, data_path, seq="sa", seq_length=30, augment=False, transform=None):
# super(TrainDataset, self).__init__()
super().__init__() # this syntax is allowed in Python3
self.data_path = data_path
self.seq = seq
self.seq_length = seq_length
self.augment = augment
self.transform = transform
self.dir_list = []
for subj_dir in sorted(os.listdir(self.data_path)):
if path.exists(path.join(data_path, subj_dir, seq + ".nii.gz")) and path.exists(
path.join(data_path, subj_dir, seq + "_ED.nii.gz")
):
self.dir_list += [subj_dir]
else:
raise RuntimeError(f"Data path does not exist: {self.data_path}")
def __getitem__(self, index):
"""
Load and pre-process the input image.
Args:
index: index into the dir list
Returns:
target: target image, Tensor of size (1, H, W)
source: source image sequence, Tensor of size (seq_length, H, W)
"""
# update the seed to avoid workers sample the same augmentation parameters
# if self.augment:
# np.random.seed(datetime.datetime.now().second + datetime.datetime.now().microsecond)
# load nifti into array
subj_dir = os.path.join(self.data_path, self.dir_list[index])
image_seq_path = os.path.join(subj_dir, self.seq + ".nii.gz")
image_raw = nib.load(image_seq_path).get_data()
image_ed_path = os.path.join(subj_dir, self.seq + "_ED.nii.gz")
image_ed = nib.load(image_ed_path).get_data()
if self.seq == "sa":
# random select a z-axis slice and transpose into (seq_length, H, W)
slice_num = random.randint(0, image_raw.shape[-2] - 1)
else:
slice_num = 0
image = image_raw[:, :, slice_num, :].transpose(2, 0, 1).astype(np.float32)
image_ed = image_ed[:, :, slice_num]
# target images are copies of the ED frame (extended later in training code to make use of Pytorch view)
target = image_ed[np.newaxis, :, :] # extend dim to (1, H, W)
# source images are a sequence of params.seq_length frames
if image.shape[0] > self.seq_length:
start_frame_idx = random.randint(0, image.shape[0] - self.seq_length)
end_frame_idx = start_frame_idx + self.seq_length
source = image[start_frame_idx:end_frame_idx, :, :] # (seq_length, H, W)
else:
# if the sequence is shorter than seq_length, use the whole sequence
source = image[1:, :, :] # (T-1, H, W)
# transformation functions expect input shape (N, H, W)
if self.transform:
target = self.transform(target)
source = self.transform(source)
# dev: expand target dimension (1, H, W) -> (N, H, W)
target = target.expand(source.size()[0], -1, -1)
return {"target": target, "source": source}
def __len__(self):
return len(self.dir_list)
class CardiacMR_2D_Eval_UKBB(data.Dataset):
"""Validation and evaluation for UKBB
Fetches ED and ES frame images and segmentation labels"""
def __init__(
self,
data_path,
seq="sa",
label_prefix="label",
augment=False,
transform=None,
label_transform=None,
):
super().__init__() # this syntax is allowed in Python3
self.data_path = data_path
self.seq = seq
self.label_prefix = label_prefix
self.augment = augment
self.transform = transform
self.label_transform = label_transform
# check required data files
self.dir_list = []
for subj_dir in sorted(os.listdir(self.data_path)):
if (
path.exists(path.join(data_path, subj_dir, seq + "_ES.nii.gz"))
and path.exists(path.join(data_path, subj_dir, seq + "_ED.nii.gz"))
and path.exists(
path.join(
data_path,
subj_dir,
"{}_".format(label_prefix) + seq + "_ED.nii.gz",
)
)
and path.exists(
path.join(
data_path,
subj_dir,
"{}_".format(label_prefix) + seq + "_ES.nii.gz",
)
)
):
self.dir_list += [subj_dir]
def __getitem__(self, index):
"""
Load and pre-process input image and label maps
For now batch size is expected to be 1 and each batch contains
images and labels for each subject at ED and ES (stacks)
Args:
index:
Returns:
image_ed, image_es, label_ed, label_es: Tensors of size (N, H, W)
"""
# update the seed to avoid workers sample the same augmentation parameters
if self.augment:
np.random.seed(datetime.datetime.now().second + datetime.datetime.now().microsecond)
# load nifti into array
image_path_ed = os.path.join(self.data_path, self.dir_list[index], self.seq + "_ED.nii.gz")
image_path_es = os.path.join(self.data_path, self.dir_list[index], self.seq + "_ES.nii.gz")
label_path_ed = os.path.join(
self.data_path,
self.dir_list[index],
"{}_".format(self.label_prefix) + self.seq + "_ED.nii.gz",
)
label_path_es = os.path.join(
self.data_path,
self.dir_list[index],
"{}_".format(self.label_prefix) + self.seq + "_ES.nii.gz",
)
# images and labels are in shape (H, W, N)
image_ed = nib.load(image_path_ed).get_data()
image_es = nib.load(image_path_es).get_data()
label_ed = nib.load(label_path_ed).get_data()
label_es = nib.load(label_path_es).get_data()
# transpose into (N, H, W)
image_ed = image_ed.transpose(2, 0, 1)
image_es = image_es.transpose(2, 0, 1)
label_ed = label_ed.transpose(2, 0, 1)
label_es = label_es.transpose(2, 0, 1)
# transformation functions expect input shaped (N, H, W)
N = image_ed.shape[0]
images = np.concatenate((image_ed, image_es), axis=0)
if self.transform:
images = self.transform(images)
image_ed = images[:N]
image_es = images[N:]
if self.label_transform:
label_ed = self.label_transform(label_ed)
label_es = self.label_transform(label_es)
return image_ed, image_es, label_ed, label_es
def __len__(self):
return len(self.dir_list)
class CardiacMR_2D_Inference_UKBB(data.Dataset):
"""Inference dataset, works with UKBB data or data with segmentation,
loop over frames of one subject"""
def __init__(self, data_path, seq="sa", transform=None):
"""data_path is the path to the direcotry containing the nifti files"""
super().__init__() # this syntax is allowed in Python3
self.data_path = data_path
self.seq = seq
self.transform = transform
self.seq_length = None
# load sequence image nifti
file_path = os.path.join(self.data_path, self.seq + ".nii.gz")
nim = nib.load(file_path)
self.image_seq = nim.get_data()
# pass sequence length to object handle
self.seq_length = self.image_seq.shape[-1]
def __getitem__(self, idx):
"""Returns volume pairs of two consecutive frames in a sequence"""
target = self.image_seq[:, :, :, 0].transpose(2, 0, 1)
source = self.image_seq[:, :, :, idx].transpose(2, 0, 1)
if self.transform:
target = self.transform(target)
source = self.transform(source)
return target, source
def __len__(self):
return self.seq_length
class CardiacMR_2D_UKBB_SynthDeform(CardiacMR_2D_UKBB):
def __init__(self, *args, scales=(8, 16, 32), min_std=0.0, max_std=2.0, seed=None, norm_dvf=True, **kwargs):
super().__init__(*args, **kwargs)
self.scales = scales
self.min_std = min_std
self.max_std = max_std
self.norm_dvf = norm_dvf
self.np_rand = np.random.default_rng(seed) # set by worker_init_fn() passed to dataloader
@staticmethod
def normalise_disp(disp):
"""
Spatially normalise DVF to [-1, 1] coordinate system used by Pytorch `grid_sample()`
Assumes disp size is the same as the corresponding image.
Args:
disp: (numpy.ndarray or torch.Tensor, shape (N, ndim, *size)) Displacement field
Returns:
disp: (normalised disp)
"""
ndim = disp.ndim - 2
if type(disp) is np.ndarray:
norm_factors = 2.0 / np.array(disp.shape[2:])
norm_factors = norm_factors.reshape(1, ndim, *(1,) * ndim)
elif type(disp) is torch.Tensor:
norm_factors = torch.tensor(2.0) / torch.tensor(disp.size()[2:], dtype=disp.dtype, device=disp.device)
norm_factors = norm_factors.view(1, ndim, *(1,) * ndim)
else:
raise RuntimeError("Input data type not recognised, expect numpy.ndarray or torch.Tensor")
return disp * norm_factors
@classmethod
def warp(cls, x, disp, interp_mode="bilinear", norm_disp=True, align_corners=True):
"""
Spatially transform an image by sampling at transformed locations (2D and 3D)
Args:
x: (Tensor float, shape (N, ndim, *sizes)) input image
disp: (Tensor float, shape (N, ndim, *sizes)) dense displacement field in i-j-k order
interp_mode: (string) mode of interpolation in grid_sample()
norm_disp: (bool) if True, normalise the disp to [-1, 1] from voxel number space before applying
Returns:
deformed x, Tensor of the same shape as input
"""
ndim = x.ndim - 2
size = x.size()[2:]
disp = disp.type_as(x)
if norm_disp:
disp = cls.normalise_disp(disp)
# generate standard mesh grid
grid = torch.meshgrid([torch.linspace(-1, 1, size[i]).type_as(disp) for i in range(ndim)])
grid = [grid[i].requires_grad_(False) for i in range(ndim)]
# apply displacements to each direction (N, *size)
warped_grid = [grid[i] + disp[:, i, ...] for i in range(ndim)]
# swapping i-j-k order to x-y-z (k-j-i) order for grid_sample()
warped_grid = [warped_grid[ndim - 1 - i] for i in range(ndim)]
warped_grid = torch.stack(warped_grid, -1) # (N, *size, dim)
return F.grid_sample(x, warped_grid, mode=interp_mode, align_corners=align_corners)
@classmethod
def svf_exp(cls, flow, scale=1, steps=5, sampling="bilinear"):
"""Exponential of velocity field by Scaling and Squaring"""
disp = flow * (scale / (2**steps))
for i in range(steps):
disp = disp + cls.warp(x=disp, disp=disp, interp_mode=sampling, norm_disp=True)
return disp
def _synthesis_deformation(
self, out_shape, scales=16, min_std=0.0, max_std=0.5, num_batch=1, zoom_mode="bilinear"
) -> torch.Tensor:
"""out_shape is spatial shape, keeping batch dimension, drop if needed in datasets
returns torch.Tensor
"""
ndims = len(out_shape)
# vmx generates at half resolution then upsample
out_shape = np.asarray(out_shape, dtype=np.int32)
gen_shape = out_shape // 2
if np.isscalar(scales):
scales = [scales]
scales = [s // 2 for s in scales]
svf = torch.zeros(num_batch, ndims, *gen_shape)
for scale in scales:
sample_shape = np.int32(np.ceil(gen_shape / scale))
sample_shape = (num_batch, ndims, *sample_shape)
std = self.np_rand.uniform(min_std, max_std, size=sample_shape)
gauss = self.np_rand.normal(0, std, size=sample_shape)
zoom = [o / s for o, s in zip(gen_shape, sample_shape[2:])]
if scale > 1:
gauss = torch.from_numpy(gauss)
gauss = F.interpolate(gauss, scale_factor=zoom, mode=zoom_mode, align_corners=True)
svf += gauss
# integrate, upsample
dvf = self.svf_exp(svf)
dvf = F.interpolate(dvf, scale_factor=2, mode=zoom_mode, align_corners=True) * 2
if self.norm_dvf:
dvf = self.normalise_disp(dvf)
return dvf
def __getitem__(self, index):
x_data = super().__getitem__(index)
source = x_data["source"]
out_shape = source.shape[1:]
dvf = self._synthesis_deformation(
out_shape,
scales=self.scales,
min_std=self.min_std,
max_std=self.max_std,
num_batch=source.shape[0],
)
target = self.warp(source.unsqueeze(1), dvf, interp_mode="bilinear", norm_disp=False).squeeze(1)
return {"target": target, "source": source, "dvf": dvf}