Diff of /data_transforms.py [000000] .. [70b6b3]

Switch to unified view

a b/data_transforms.py
1
from collections import namedtuple
2
import numpy as np
3
import scipy.ndimage
4
import math
5
import utils_lung
6
7
MAX_HU = 400.
8
MIN_HU = -1000.
9
rng = np.random.RandomState(317070)
10
11
12
13
14
def hu2normHU(x):
15
    """
16
    Modifies input data
17
    :param x:
18
    :return:
19
    """
20
    x = (x - MIN_HU) / (MAX_HU - MIN_HU)
21
    x = np.clip(x, 0., 1., out=x)
22
    return x
23
24
def hu2normHU_low_clip(x):
25
    """
26
    Modifies input data
27
    :param x:
28
    :return:
29
    """
30
    x = (x - MIN_HU) / (MAX_HU - MIN_HU)
31
    x = np.clip(x, 0., 10., out=x)
32
    return x
33
34
def pixelnormHU(x):
35
    x = (x - MIN_HU) / (MAX_HU - MIN_HU)
36
    x = np.clip(x, 0., 1., out=x)
37
    return (x - 0.5) / 0.5
38
39
40
def histogram_equalization(x, hist=None, bins=None):
41
    # hist is a normalized histogram, which means that the sum of the counts has to be one
42
    if hist is None and bins is None:
43
        # For the case no target histogram is given
44
        bins = np.arange(-950,500,100)
45
        n_bins = bins.shape[0] -1
46
        hist = 1. * np.ones(n_bins) / n_bins
47
    elif hist is None or bins is None:
48
        raise
49
        
50
    assert(len(bins) == (len(hist)+1))
51
52
    # init our target array 
53
    z = np.empty(x.shape)
54
55
    # copy the values outside of the bins from the original
56
    z[x<=bins[0]] = x[x<=bins[0]] 
57
    z[x>=bins[-1]] = x[x>=bins[-1]] 
58
59
    inside_bins = np.logical_and(x>bins[0], x<bins[-1])
60
61
    n_bins = bins.shape[0] -1
62
    prev_percentile = 0
63
    for i in range(n_bins):
64
        target_count = hist[i]
65
        lower_bound = bins[i]
66
        upper_bound = bins[i+1]
67
        new_percentile = prev_percentile + target_count*100
68
        low_orig = np.percentile(x[inside_bins], prev_percentile)
69
        if i == n_bins-1:
70
            high_orig = bins[-1]
71
        else:
72
            high_orig = np.percentile(x[inside_bins], new_percentile)
73
74
        prev_percentile = new_percentile
75
76
        elements_to_rescale = np.logical_and(x>=low_orig, x<high_orig)
77
        y = x[elements_to_rescale]
78
        y_r = (y - low_orig)/(high_orig-low_orig)*(upper_bound-lower_bound) + lower_bound
79
        print 'y_r', np.isnan(y_r).any()
80
        z[elements_to_rescale] = y_r
81
82
    return z
83
84
def get_rescale_params_hist_eq(x, hist=None, bins=None):
85
    # hist is a normalized histogram, which means that the sum of the counts has to be one
86
    if hist is None and bins is None:
87
        # For the case no target histogram is given
88
        bins = np.arange(-950,500,100)
89
        n_bins = bins.shape[0] -1
90
        hist = 1. * np.ones(n_bins) / n_bins
91
    elif hist is None or bins is None:
92
        raise
93
        
94
    assert(len(bins) == (len(hist)+1))
95
96
    inside_bins = np.logical_and(x>bins[0], x<bins[-1])
97
98
    n_bins = bins.shape[0] -1
99
    prev_percentile = 0
100
    original_borders = []
101
    for i in range(n_bins):
102
        target_count = hist[i]
103
        lower_bound = bins[i]
104
        upper_bound = bins[i+1]
105
        new_percentile = prev_percentile + target_count*100
106
        low_orig = np.percentile(x[inside_bins], prev_percentile)
107
        original_borders.append(low_orig)
108
        prev_percentile = new_percentile
109
    original_borders.append(bins[-1])
110
        
111
112
    return bins, original_borders
113
114
def apply_hist_eq_patch(x, bins, original_borders):
115
116
    # init our target array 
117
    z = np.empty(x.shape)
118
119
    # if np.isnan(z).any():
120
    #     print '1 np.isnan(z).any()', np.isnan(z).any()
121
122
    # copy the values outside of the bins from the original
123
    z[x<=bins[0]] = x[x<=bins[0]] 
124
    z[x>=bins[-1]] = x[x>=bins[-1]] 
125
    # print 'x.shape', x.shape, x.shape[0] * x.shape[1] * x.shape[2] * x.shape[3]
126
    # print 'np.sum(x<=bins[0])', np.sum(x<=bins[0])
127
    # print 'np.sum(x>=bins[-1])', np.sum(x>=bins[-1])
128
129
    # if np.isnan(z).any():
130
    #     print '2 np.isnan(z).any()', np.isnan(z).any()
131
132
    inside_bins = np.logical_and(x>bins[0], x<bins[-1])
133
    # print 'np.sum(inside_bins)', np.sum(inside_bins)
134
135
    n_total_elements_replaced = 0
136
    n_bins = bins.shape[0] -1
137
    for i in range(n_bins):
138
        lower_bound = bins[i]
139
        upper_bound = bins[i+1]
140
        low_orig = original_borders[i]
141
        high_orig = original_borders[i+1]
142
143
        elements_to_rescale = np.logical_and(x>=low_orig, x<high_orig)
144
        n_total_elements_replaced += np.sum(elements_to_rescale)    
145
        # print 'np.sum(elements_to_rescale)', np.sum(elements_to_rescale)  
146
        y = x[elements_to_rescale]
147
        y_r = (y - low_orig)/(high_orig-low_orig)*(upper_bound-lower_bound) + lower_bound
148
149
        z[elements_to_rescale] = y_r
150
151
    #     if np.isnan(z).any():
152
    #         print 'np.isnan(z).any()', np.isnan(z).any()
153
154
    # print 'n_total_elements_replaced', n_total_elements_replaced
155
        
156
    return z
157
158
159
def sample_augmentation_parameters(transformation):
160
    shift_z = rng.uniform(*transformation.get('translation_range_z', [0., 0.]))
161
    shift_y = rng.uniform(*transformation.get('translation_range_y', [0., 0.]))
162
    shift_x = rng.uniform(*transformation.get('translation_range_x', [0., 0.]))
163
    translation = (shift_z, shift_y, shift_x)
164
165
    rotation_z = rng.uniform(*transformation.get('rotation_range_z', [0., 0.]))
166
    rotation_y = rng.uniform(*transformation.get('rotation_range_y', [0., 0.]))
167
    rotation_x = rng.uniform(*transformation.get('rotation_range_x', [0., 0.]))
168
    rotation = (rotation_z, rotation_y, rotation_x)
169
170
    return namedtuple('Params', ['translation', 'rotation'])(translation, rotation)
171
172
173
def transform_scan3d(data, pixel_spacing, p_transform,
174
                     luna_annotations=None,
175
                     luna_origin=None,
176
                     p_transform_augment=None,
177
                     world_coord_system=True,
178
                     lung_mask=None):
179
    mm_patch_size = np.asarray(p_transform['mm_patch_size'], dtype='float32')
180
    out_pixel_spacing = np.asarray(p_transform['pixel_spacing'])
181
182
    input_shape = np.asarray(data.shape)
183
    mm_shape = input_shape * pixel_spacing / out_pixel_spacing
184
    output_shape = p_transform['patch_size']
185
186
    # here we give parameters to affine transform as if it's T in
187
    # output = T.dot(input)
188
    # https://www.cs.mtu.edu/~shene/COURSES/cs3621/NOTES/geometry/geo-tran.html
189
    # but the affine_transform() makes it reversed for scipy
190
    tf_mm_scale = affine_transform(scale=mm_shape / input_shape)
191
    tf_shift_center = affine_transform(translation=-mm_shape / 2.)
192
193
    tf_shift_uncenter = affine_transform(translation=mm_patch_size / 2.)
194
    tf_output_scale = affine_transform(scale=output_shape / mm_patch_size)
195
196
    if p_transform_augment:
197
        augment_params_sample = sample_augmentation_parameters(p_transform_augment)
198
        tf_augment = affine_transform(translation=augment_params_sample.translation,
199
                                      rotation=augment_params_sample.rotation)
200
        tf_total = tf_mm_scale.dot(tf_shift_center).dot(tf_augment).dot(tf_shift_uncenter).dot(tf_output_scale)
201
    else:
202
        tf_total = tf_mm_scale.dot(tf_shift_center).dot(tf_shift_uncenter).dot(tf_output_scale)
203
204
    data_out = apply_affine_transform(data, tf_total, order=1, output_shape=output_shape)
205
206
    if lung_mask is not None:
207
        lung_mask_out = apply_affine_transform(lung_mask, tf_total, order=1, output_shape=output_shape)
208
        lung_mask_out[lung_mask_out > 0.] = 1.
209
    if luna_annotations is not None:
210
        annotatations_out = []
211
        for zyxd in luna_annotations:
212
            zyx = np.array(zyxd[:3])
213
            voxel_coords = utils_lung.world2voxel(zyx, luna_origin, pixel_spacing) if world_coord_system else zyx
214
            voxel_coords = np.append(voxel_coords, [1])
215
            voxel_coords_out = np.linalg.inv(tf_total).dot(voxel_coords)[:3]
216
            diameter_mm = zyxd[-1]
217
            diameter_out = diameter_mm * output_shape[1] / mm_patch_size[1] / out_pixel_spacing[1]
218
            zyxd_out = np.rint(np.append(voxel_coords_out, diameter_out))
219
            annotatations_out.append(zyxd_out)
220
        annotatations_out = np.asarray(annotatations_out)
221
        if lung_mask is None:
222
            return data_out, annotatations_out, tf_total
223
        else:
224
            return data_out, annotatations_out, tf_total, lung_mask_out
225
226
    if lung_mask is None:
227
        return data_out, tf_total
228
    else:
229
        return data_out, tf_total, lung_mask_out
230
231
232
def transform_patch3d(data, pixel_spacing, p_transform,
233
                      patch_center,
234
                      luna_origin,
235
                      luna_annotations=None,
236
                      p_transform_augment=None,
237
                      world_coord_system=True):
238
    mm_patch_size = np.asarray(p_transform['mm_patch_size'], dtype='float32')
239
    out_pixel_spacing = np.asarray(p_transform['pixel_spacing'])
240
241
    input_shape = np.asarray(data.shape)
242
    mm_shape = input_shape * pixel_spacing / out_pixel_spacing
243
    output_shape = p_transform['patch_size']
244
245
    zyx = np.array(patch_center[:3])
246
    voxel_coords = utils_lung.world2voxel(zyx, luna_origin, pixel_spacing) if world_coord_system else zyx
247
    voxel_coords_mm = voxel_coords * mm_shape / input_shape
248
249
    # here we give parameters to affine transform as if it's T in
250
    # output = T.dot(input)
251
    # https://www.cs.mtu.edu/~shene/COURSES/cs3621/NOTES/geometry/geo-tran.html
252
    # but the affine_transform() makes it reversed for scipy
253
    tf_mm_scale = affine_transform(scale=mm_shape / input_shape)
254
    tf_shift_center = affine_transform(translation=-voxel_coords_mm)
255
256
    tf_shift_uncenter = affine_transform(translation=mm_patch_size / 2.)
257
    tf_output_scale = affine_transform(scale=output_shape / mm_patch_size)
258
259
    if p_transform_augment:
260
        augment_params_sample = sample_augmentation_parameters(p_transform_augment)
261
        # print 'augmentation parameters', augment_params_sample
262
        tf_augment = affine_transform(translation=augment_params_sample.translation,
263
                                      rotation=augment_params_sample.rotation)
264
        tf_total = tf_mm_scale.dot(tf_shift_center).dot(tf_augment).dot(tf_shift_uncenter).dot(tf_output_scale)
265
    else:
266
        tf_total = tf_mm_scale.dot(tf_shift_center).dot(tf_shift_uncenter).dot(tf_output_scale)
267
268
    data_out = apply_affine_transform(data, tf_total, order=1, output_shape=output_shape)
269
270
    # transform patch annotations
271
    diameter_mm = patch_center[-1]
272
    diameter_out = diameter_mm * output_shape[1] / mm_patch_size[1] / out_pixel_spacing[1]
273
    voxel_coords = np.append(voxel_coords, [1])
274
    voxel_coords_out = np.linalg.inv(tf_total).dot(voxel_coords)[:3]
275
    patch_annotation_out = np.rint(np.append(voxel_coords_out, diameter_out))
276
    # print 'pathch_center_after_transform', patch_annotation_out
277
278
    if luna_annotations is not None:
279
        annotatations_out = []
280
        for zyxd in luna_annotations:
281
            zyx = np.array(zyxd[:3])
282
            voxel_coords = utils_lung.world2voxel(zyx, luna_origin, pixel_spacing) if world_coord_system else zyx
283
            voxel_coords = np.append(voxel_coords, [1])
284
            voxel_coords_out = np.linalg.inv(tf_total).dot(voxel_coords)[:3]
285
            diameter_mm = zyxd[-1]
286
            diameter_out = diameter_mm * output_shape[1] / mm_patch_size[1] / out_pixel_spacing[1]
287
            zyxd_out = np.rint(np.append(voxel_coords_out, diameter_out))
288
            annotatations_out.append(zyxd_out)
289
        annotatations_out = np.asarray(annotatations_out)
290
        return data_out, patch_annotation_out, annotatations_out
291
292
    return data_out, patch_annotation_out
293
294
295
def transform_patch3d_ls(data, pixel_spacing, p_transform,
296
                      patch_center,
297
                      luna_origin,
298
                      p_transform_augment=None,
299
                      world_coord_system=True):
300
    mm_patch_size = np.asarray(p_transform['mm_patch_size'], dtype='float32')
301
    out_pixel_spacing = np.asarray(p_transform['pixel_spacing'])
302
303
    input_shape = np.asarray(data.shape)
304
    mm_shape = input_shape * pixel_spacing / out_pixel_spacing
305
    output_shape = p_transform['patch_size']
306
307
    zyx = np.array(patch_center[:3])
308
    # voxel_coords = utils_lung.world2voxel(zyx, luna_origin, pixel_spacing) if world_coord_system else zyx
309
    # voxel_coords_mm = voxel_coords * mm_shape / input_shape
310
    voxel_coords_mm = zyx * mm_shape / input_shape
311
312
    # here we give parameters to affine transform as if it's T in
313
    # output = T.dot(input)
314
    # https://www.cs.mtu.edu/~shene/COURSES/cs3621/NOTES/geometry/geo-tran.html
315
    # but the affine_transform() makes it reversed for scipy
316
    tf_mm_scale = affine_transform(scale=mm_shape / input_shape)
317
    tf_shift_center = affine_transform(translation=-voxel_coords_mm)
318
319
    tf_shift_uncenter = affine_transform(translation=mm_patch_size / 2.)
320
    tf_output_scale = affine_transform(scale=output_shape / mm_patch_size)
321
322
    if p_transform_augment:
323
        augment_params_sample = sample_augmentation_parameters(p_transform_augment)
324
        # print 'augmentation parameters', augment_params_sample
325
        tf_augment = affine_transform(translation=augment_params_sample.translation,
326
                                      rotation=augment_params_sample.rotation)
327
        tf_total = tf_mm_scale.dot(tf_shift_center).dot(tf_augment).dot(tf_shift_uncenter).dot(tf_output_scale)
328
    else:
329
        tf_total = tf_mm_scale.dot(tf_shift_center).dot(tf_shift_uncenter).dot(tf_output_scale)
330
331
332
    print 'data min,max', np.amin(data), np.amax(data)
333
    data_out = apply_affine_transform(data, tf_total, order=1, output_shape=output_shape)
334
    print 'data_out min,max', np.amin(data_out), np.amax(data_out)
335
336
    # transform patch annotations
337
    # voxel_coords = np.append(voxel_coords, [1])
338
    # voxel_coords_out = np.linalg.inv(tf_total).dot(voxel_coords)[:3]
339
    # patch_annotation_out = np.rint(voxel_coords_out)
340
    # print 'pathch_center_after_transform', patch_annotation_out
341
342
    return data_out #, patch_annotation_out
343
344
345
def transform_dsb_candidates(data, patch_centers, pixel_spacing, p_transform,
346
                             p_transform_augment=None):
347
    input_shape = np.asarray(data.shape)
348
    output_shape = np.asarray(p_transform['patch_size'])
349
350
    patches_out = []
351
    for zyxd in patch_centers:
352
        if -1 in zyxd:
353
            patch_out = np.zeros(output_shape)
354
        elif 'affine_tf' in p_transform and not p_transform['affine_tf']:
355
            assert(output_shape[0] == output_shape[1])
356
            assert(output_shape[0] == output_shape[2])
357
358
            zyx = np.round(np.array(zyxd[:3])).astype('int32')
359
360
            z_in = zyx[0] > output_shape[0]/2 and zyx[0] < input_shape[0]-output_shape[0]/2
361
            y_in = zyx[1] > output_shape[1]/2 and zyx[1] < input_shape[1]-output_shape[1]/2
362
            x_in = zyx[2] > output_shape[2]/2 and zyx[2] < input_shape[2]-output_shape[2]/2
363
364
            patch_inside_tensor = z_in and y_in and x_in
365
366
            if patch_inside_tensor:
367
                patch_out = data[zyx[0]-output_shape[0]/2:zyx[0]+output_shape[0]/2,
368
                                 zyx[1]-output_shape[1]/2:zyx[1]+output_shape[1]/2,
369
                                 zyx[2]-output_shape[2]/2:zyx[2]+output_shape[2]/2] 
370
            else:
371
                data_pad = np.empty((input_shape[0]+output_shape[0], 
372
                                     input_shape[1]+output_shape[1], 
373
                                     input_shape[2]+output_shape[2]))
374
375
                data_pad[0:output_shape[0]/2,:,:] = 0
376
                data_pad[output_shape[0]/2+input_shape[0]:,:,:] = 0
377
378
                data_pad[:,0:output_shape[1]/2,:] = 0
379
                data_pad[:,output_shape[1]/2+input_shape[1]:,:] = 0
380
381
                data_pad[:,:,0:output_shape[2]/2] = 0
382
                data_pad[:,:,output_shape[2]/2+input_shape[2]:] = 0
383
384
                data_pad[output_shape[0]/2:output_shape[0]/2+input_shape[0],
385
                         output_shape[1]/2:output_shape[1]/2+input_shape[1],
386
                         output_shape[2]/2:output_shape[2]/2+input_shape[2],] = data
387
388
                #too slow data_pad = np.lib.pad(data, output_shape[0], mode='constant', constant_values = MIN_HU)
389
390
                zyx_pad = zyx + output_shape/2
391
                patch_out = data_pad[zyx_pad[0]-output_shape[0]/2:zyx_pad[0]+output_shape[0]/2,
392
                                     zyx_pad[1]-output_shape[1]/2:zyx_pad[1]+output_shape[1]/2,
393
                                     zyx_pad[2]-output_shape[2]/2:zyx_pad[2]+output_shape[2]/2] 
394
        else:
395
            mm_patch_size = np.asarray(p_transform['mm_patch_size'], dtype='float32')
396
            out_pixel_spacing = np.asarray(p_transform['pixel_spacing'])
397
            mm_shape = input_shape * pixel_spacing / out_pixel_spacing
398
399
            zyx = np.array(zyxd[:3])
400
            zyx_mm = zyx * mm_shape / input_shape
401
402
            tf_mm_scale = affine_transform(scale=mm_shape / input_shape)
403
            tf_shift_center = affine_transform(translation=-zyx_mm)
404
            tf_shift_uncenter = affine_transform(translation=mm_patch_size / 2.)
405
            tf_output_scale = affine_transform(scale=output_shape / mm_patch_size)
406
407
            if p_transform_augment:
408
                augment_params_sample = sample_augmentation_parameters(p_transform_augment)
409
                tf_augment = affine_transform(translation=augment_params_sample.translation,
410
                                              rotation=augment_params_sample.rotation)
411
                tf_total = tf_mm_scale.dot(tf_shift_center).dot(tf_augment).dot(tf_shift_uncenter).dot(tf_output_scale)
412
            else:
413
                tf_total = tf_mm_scale.dot(tf_shift_center).dot(tf_shift_uncenter).dot(tf_output_scale)
414
415
            patch_out = apply_affine_transform(data, tf_total, order=p_transform['order'], output_shape=output_shape)
416
        
417
        patches_out.append(patch_out[None, :, :, :])
418
    return np.concatenate(patches_out, axis=0)
419
420
421
def build_dsb_can_heatmap(data, candidates, pixel_spacing, p_transform,
422
                             p_transform_augment=None):
423
424
    assert(candidates.shape[1]>3)
425
426
    mm_patch_size = np.asarray(p_transform['mm_patch_size'], dtype='float32')
427
    out_pixel_spacing = np.asarray(p_transform['pixel_spacing'])
428
429
    input_shape = np.asarray(data.shape)
430
    mm_shape = input_shape * pixel_spacing / out_pixel_spacing
431
432
    output_shape = p_transform['heatmap_size']
433
    max_shape = p_transform['max_shape']
434
435
    # Constructing heatmap
436
    heatmap = np.zeros(output_shape)
437
    max_dims = np.zeros(3)
438
    min_dims = 99999*np.ones(3)
439
    for can in candidates:
440
        value = can[-1]
441
        zyx = np.array(can[:3])
442
        zyx_mm = zyx * mm_shape / input_shape
443
        #only for analyse purpose
444
        for idx, d in enumerate(zyx_mm):
445
            if d>max_dims[idx]:
446
                max_dims[idx] = d
447
            if d<min_dims[idx]:
448
                min_dims[idx] = d
449
        zyx_hm = zyx_mm / max_shape * output_shape
450
        heatmap[zyx_hm.astype('int')] += value 
451
452
    # print 'max_dims', max_dims
453
    # print 'min_dims', min_dims
454
    # print 'heatmap max', np.amax(heatmap)
455
    # print 'heatmap min', np.amin(heatmap)
456
457
    # augmentation
458
    if p_transform_augment:
459
        augment_params_sample = sample_augmentation_parameters(p_transform_augment)
460
        tf_augment = affine_transform(translation=augment_params_sample.translation, rotation=augment_params_sample.rotation)
461
        heatmap = apply_affine_transform(heatmap, tf_augment, order=p_transform['heatmap_order'], output_shape=output_shape)
462
463
    heatmap = heatmap / p_transform['heatmap_norm']
464
465
    return heatmap
466
467
468
def make_3d_mask(img_shape, center, radius, shape='sphere'):
469
    mask = np.zeros(img_shape)
470
    radius = np.rint(radius)
471
    center = np.rint(center)
472
    sz = np.arange(int(max(center[0] - radius, 0)), int(max(min(center[0] + radius + 1, img_shape[0]), 0)))
473
    sy = np.arange(int(max(center[1] - radius, 0)), int(max(min(center[1] + radius + 1, img_shape[1]), 0)))
474
    sx = np.arange(int(max(center[2] - radius, 0)), int(max(min(center[2] + radius + 1, img_shape[2]), 0)))
475
    sz, sy, sx = np.meshgrid(sz, sy, sx)
476
    if shape == 'cube':
477
        mask[sz, sy, sx] = 1.
478
    elif shape == 'sphere':
479
        distance2 = ((center[0] - sz) ** 2
480
                     + (center[1] - sy) ** 2
481
                     + (center[2] - sx) ** 2)
482
        distance_matrix = np.ones_like(mask) * np.inf
483
        distance_matrix[sz, sy, sx] = distance2
484
        mask[(distance_matrix <= radius ** 2)] = 1
485
    elif shape == 'gauss':
486
        z, y, x = np.ogrid[:mask.shape[0], :mask.shape[1], :mask.shape[2]]
487
        distance = ((z - center[0]) ** 2 + (y - center[1]) ** 2 + (x - center[2]) ** 2)
488
        mask = np.exp(- 1. * distance / (2 * radius ** 2))
489
        mask[(distance > 3 * radius ** 2)] = 0
490
    return mask
491
492
493
def make_3d_mask_from_annotations(img_shape, annotations, shape):
494
    mask = np.zeros(img_shape)
495
    for zyxd in annotations:
496
        mask += make_3d_mask(img_shape, zyxd[:3], zyxd[-1] / 2, shape)
497
    mask = np.clip(mask, 0., 1.)
498
    return mask
499
500
501
def make_gaussian_annotation(patch_annotation_tf, patch_size):
502
    radius = patch_annotation_tf[-1] / 2.
503
    zyx = patch_annotation_tf[:3]
504
    distance_z = (zyx[0] - np.arange(patch_size[0])) ** 2
505
    distance_y = (zyx[1] - np.arange(patch_size[1])) ** 2
506
    distance_x = (zyx[2] - np.arange(patch_size[2])) ** 2
507
    z_label = np.exp(- 1. * distance_z / (2 * radius ** 2))
508
    y_label = np.exp(- 1. * distance_y / (2 * radius ** 2))
509
    x_label = np.exp(- 1. * distance_x / (2 * radius ** 2))
510
    label = np.vstack((z_label, y_label, x_label))
511
    return label
512
513
514
def zmuv(x, mean, std):
515
    if mean is not None and std is not None:
516
        return (x - mean) / std
517
    else:
518
        return x
519
520
521
def affine_transform(scale=None, rotation=None, translation=None):
522
    """
523
    rotation and shear in degrees
524
    """
525
    matrix = np.eye(4)
526
527
    if translation is not None:
528
        matrix[:3, 3] = -np.asarray(translation, np.float)
529
530
    if scale is not None:
531
        matrix[0, 0] = 1. / scale[0]
532
        matrix[1, 1] = 1. / scale[1]
533
        matrix[2, 2] = 1. / scale[2]
534
535
    if rotation is not None:
536
        rotation = np.asarray(rotation, np.float)
537
        rotation = map(math.radians, rotation)
538
        cos = map(math.cos, rotation)
539
        sin = map(math.sin, rotation)
540
541
        mz = np.eye(4)
542
        mz[1, 1] = cos[0]
543
        mz[2, 1] = sin[0]
544
        mz[1, 2] = -sin[0]
545
        mz[2, 2] = cos[0]
546
547
        my = np.eye(4)
548
        my[0, 0] = cos[1]
549
        my[0, 2] = -sin[1]
550
        my[2, 0] = sin[1]
551
        my[2, 2] = cos[1]
552
553
        mx = np.eye(4)
554
        mx[0, 0] = cos[2]
555
        mx[0, 1] = sin[2]
556
        mx[1, 0] = -sin[2]
557
        mx[1, 1] = cos[2]
558
559
        matrix = mx.dot(my).dot(mz).dot(matrix)
560
    return matrix
561
562
563
def apply_affine_transform(_input, matrix, order=1, output_shape=None):
564
    # output.dot(T) + s = input
565
    T = matrix[:3, :3]
566
    s = matrix[:3, 3]
567
    return scipy.ndimage.interpolation.affine_transform(
568
        _input, matrix=T, offset=s, order=order, output_shape=output_shape)