a b/fetal_net/utils/sitk_utils.py
1
import SimpleITK as sitk
2
import numpy as np
3
4
5
def calculate_origin_offset(new_spacing, old_spacing):
6
    return np.subtract(new_spacing, old_spacing)/2
7
8
9
def sitk_resample_to_spacing(image, new_spacing=(1.0, 1.0, 1.0), interpolator=sitk.sitkLinear, default_value=0.):
10
    zoom_factor = np.divide(image.GetSpacing(), new_spacing)
11
    new_size = np.asarray(np.ceil(np.round(np.multiply(zoom_factor, image.GetSize()), decimals=5)), dtype=np.int16)
12
    offset = calculate_origin_offset(new_spacing, image.GetSpacing())
13
    reference_image = sitk_new_blank_image(size=new_size, spacing=new_spacing, direction=image.GetDirection(),
14
                                           origin=image.GetOrigin() + offset, default_value=default_value)
15
    return sitk_resample_to_image(image, reference_image, interpolator=interpolator, default_value=default_value)
16
17
18
def sitk_resample_to_image(image, reference_image, default_value=0., interpolator=sitk.sitkLinear, transform=None,
19
                           output_pixel_type=None):
20
    if transform is None:
21
        transform = sitk.Transform()
22
        transform.SetIdentity()
23
    if output_pixel_type is None:
24
        output_pixel_type = image.GetPixelID()
25
    resample_filter = sitk.ResampleImageFilter()
26
    resample_filter.SetInterpolator(interpolator)
27
    resample_filter.SetTransform(transform)
28
    resample_filter.SetOutputPixelType(output_pixel_type)
29
    resample_filter.SetDefaultPixelValue(default_value)
30
    resample_filter.SetReferenceImage(reference_image)
31
    return resample_filter.Execute(image)
32
33
34
def sitk_new_blank_image(size, spacing, direction, origin, default_value=0.):
35
    image = sitk.GetImageFromArray(np.ones(size, dtype=np.float).T * default_value)
36
    image.SetSpacing(spacing)
37
    image.SetDirection(direction)
38
    image.SetOrigin(origin)
39
    return image
40
41
42
def resample_to_spacing(data, spacing, target_spacing, interpolation="linear", default_value=0.):
43
    image = data_to_sitk_image(data, spacing=spacing)
44
    if interpolation is "linear":
45
        interpolator = sitk.sitkLinear
46
    elif interpolation is "nearest":
47
        interpolator = sitk.sitkNearestNeighbor
48
    else:
49
        raise ValueError("'interpolation' must be either 'linear' or 'nearest'. '{}' is not recognized".format(
50
            interpolation))
51
    resampled_image = sitk_resample_to_spacing(image, new_spacing=target_spacing, interpolator=interpolator,
52
                                               default_value=default_value)
53
    return sitk_image_to_data(resampled_image)
54
55
56
def data_to_sitk_image(data, spacing=(1., 1., 1.)):
57
    if len(data.shape) == 3:
58
        data = np.rot90(data, 1, axes=(0, 2))
59
    image = sitk.GetImageFromArray(data)
60
    image.SetSpacing(np.asarray(spacing, dtype=np.float))
61
    return image
62
63
64
def sitk_image_to_data(image):
65
    data = sitk.GetArrayFromImage(image)
66
    if len(data.shape) == 3:
67
        data = np.rot90(data, -1, axes=(0, 2))
68
    return data