a b/load_dataset/utils.py
1
#utils.py
2
#Copyright (c) 2020 Rachel Lea Ballantyne Draelos
3
4
#MIT License
5
6
#Permission is hereby granted, free of charge, to any person obtaining a copy
7
#of this software and associated documentation files (the "Software"), to deal
8
#in the Software without restriction, including without limitation the rights
9
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
#copies of the Software, and to permit persons to whom the Software is
11
#furnished to do so, subject to the following conditions:
12
13
#The above copyright notice and this permission notice shall be included in all
14
#copies or substantial portions of the Software.
15
16
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
#SOFTWARE
23
24
import torch
25
import numpy as np
26
27
"""CT volume preprocessing functions"""
28
29
#############################################
30
# Pixel Values (on torch Tensors for speed) #-----------------------------------
31
#############################################
32
def normalize(ctvol, lower_bound, upper_bound): #Done testing
33
    """Clip images and normalize"""
34
    #formula https://stats.stackexchange.com/questions/70801/how-to-normalize-data-to-0-1-range
35
    ctvol = torch.clamp(ctvol, lower_bound, upper_bound)
36
    ctvol = (ctvol - lower_bound) / (upper_bound - lower_bound)
37
    return ctvol
38
39
def torchify_pixelnorm_pixelcenter(ctvol, pixel_bounds):
40
    """Normalize using specified pixel_bounds and then center on the ImageNet
41
    mean. Used in 2019_10 dataset preparation"""
42
    #Cast to torch Tensor
43
    #use torch Tensor instead of numpy array because addition, subtraction,
44
    #multiplication, and division are faster in torch Tensors than np arrays
45
    ctvol = torch.from_numpy(ctvol).type(torch.float)
46
    
47
    #Clip Hounsfield units and normalize pixel values
48
    ctvol = normalize(ctvol, pixel_bounds[0], pixel_bounds[1])
49
    
50
    #Center on the ImageNet mean since you are using an ImageNet pretrained
51
    #feature extractor:
52
    ctvol = ctvol - 0.449
53
    return ctvol
54
55
###########
56
# Padding #---------------------------------------------------------------------
57
###########
58
def pad_slices(ctvol, max_slices): #Done testing
59
    """For <ctvol> of shape (slices, side, side) pad the slices to shape
60
    max_slices for output of shape (max_slices, side, side)"""
61
    padding_needed = max_slices - ctvol.shape[0]
62
    assert (padding_needed >= 0), 'Image slices exceed max_slices by'+str(-1*padding_needed)
63
    if padding_needed > 0:
64
        before_padding = int(padding_needed/2.0)
65
        after_padding = padding_needed - before_padding
66
        ctvol = np.pad(ctvol, pad_width = ((before_padding, after_padding), (0,0), (0,0)),
67
                     mode = 'constant', constant_values = np.amin(ctvol))
68
        assert ctvol.shape[0]==max_slices
69
    return ctvol
70
71
def pad_sides(ctvol, max_side_length): #Done testing
72
    """For <ctvol> of shape (slices, side, side) pad the sides to shape
73
    max_side_length for output of shape (slices, max_side_length,
74
    max_side_length)"""
75
    needed_padding = 0
76
    for side in [1,2]:
77
        padding_needed = max_side_length - ctvol.shape[side]
78
        if padding_needed > 0:
79
            before_padding = int(padding_needed/2.0)
80
            after_padding = padding_needed - before_padding
81
            if side == 1:
82
                ctvol = np.pad(ctvol, pad_width = ((0,0), (before_padding, after_padding), (0,0)),
83
                         mode = 'constant', constant_values = np.amin(ctvol))
84
                needed_padding += 1
85
            elif side == 2:
86
                ctvol = np.pad(ctvol, pad_width = ((0,0), (0,0), (before_padding, after_padding)),
87
                         mode = 'constant', constant_values = np.amin(ctvol))
88
                needed_padding += 1
89
    if needed_padding == 2: #if both sides needed to be padded, then they
90
        #should be equal (but it's possible one side or both were too large
91
        #in which case we wouldn't expect them to be equal)
92
        assert ctvol.shape[1]==ctvol.shape[2]==max_side_length
93
    return ctvol
94
95
def pad_volume(ctvol, max_slices, max_side_length):
96
    """Pad <ctvol> to a minimum size of
97
    [max_slices, max_side_length, max_side_length], e.g. [402, 308, 308]
98
    Used in 2019_10 dataset preparation"""
99
    if ctvol.shape[0] < max_slices:
100
        ctvol = pad_slices(ctvol, max_slices)
101
    if ctvol.shape[1] < max_side_length:
102
        ctvol = pad_sides(ctvol, max_side_length)
103
    return ctvol
104
105
###########################
106
# Reshaping to 3 Channels #-----------------------------------------------------
107
###########################
108
def sliceify(ctvol): #Done testing
109
    """Given a numpy array <ctvol> with shape [slices, square, square]
110
    reshape to 'RGB' [max_slices/3, 3, square, square]"""
111
    return np.reshape(ctvol, newshape=[int(ctvol.shape[0]/3), 3, ctvol.shape[1], ctvol.shape[2]])
112
113
def reshape_3_channels(ctvol):
114
    """Reshape grayscale <ctvol> to a 3-channel image
115
    Used in 2019_10 dataset preparation"""
116
    if ctvol.shape[0]%3 == 0:
117
        ctvol = sliceify(ctvol)
118
    else:
119
        if (ctvol.shape[0]-1)%3 == 0:
120
            ctvol = sliceify(ctvol[:-1,:,:])
121
        elif (ctvol.shape[0]-2)%3 == 0:
122
            ctvol = sliceify(ctvol[:-2,:,:])
123
    return ctvol
124
125
##################################
126
# Cropping and Data Augmentation #----------------------------------------------
127
##################################
128
def crop_specified_axis(ctvol, max_dim, axis): #Done testing
129
    """Crop 3D volume <ctvol> to <max_dim> along <axis>"""
130
    dim = ctvol.shape[axis]
131
    if dim > max_dim:
132
        amount_to_crop = dim - max_dim
133
        part_one = int(amount_to_crop/2.0)
134
        part_two = dim - (amount_to_crop - part_one)
135
        if axis == 0:
136
            return ctvol[part_one:part_two, :, :]
137
        elif axis == 1:
138
            return ctvol[:, part_one:part_two, :]
139
        elif axis == 2:
140
            return ctvol[:, :, part_one:part_two]
141
    else:
142
        return ctvol
143
144
def single_crop_3d_fixed(ctvol, max_slices, max_side_length):
145
    """Crop a single 3D volume to shape [max_slices, max_side_length,
146
    max_side_length]"""
147
    ctvol = crop_specified_axis(ctvol, max_slices, 0)
148
    ctvol = crop_specified_axis(ctvol, max_side_length, 1)
149
    ctvol = crop_specified_axis(ctvol, max_side_length, 2)
150
    return ctvol
151
152
def single_crop_3d_augment(ctvol, max_slices, max_side_length):
153
    """Crop a single 3D volume to shape [max_slices, max_side_length,
154
    max_side_length] with randomness in the centering and random
155
    flips or rotations"""
156
    #Introduce random padding so that the centered crop will be slightly random
157
    ctvol = rand_pad(ctvol)
158
    
159
    #Obtain the center crop
160
    ctvol = single_crop_3d_fixed(ctvol, max_slices, max_side_length)
161
    
162
    #Flip and rotate
163
    ctvol = rand_flip(ctvol)
164
    ctvol = rand_rotate(ctvol)
165
    
166
    #Make contiguous array to avoid Pytorch error
167
    return np.ascontiguousarray(ctvol)
168
169
def rand_pad(ctvol):
170
    """Introduce random padding between 0 and 15 pixels on each of the 6 sides
171
    of the <ctvol>"""
172
    randpad = np.random.randint(low=0,high=15,size=(6))
173
    ctvol = np.pad(ctvol, pad_width = ((randpad[0],randpad[1]), (randpad[2],randpad[3]), (randpad[4], randpad[5])),
174
                         mode = 'constant', constant_values = np.amin(ctvol))
175
    return ctvol
176
    
177
def rand_flip(ctvol):
178
    """Flip <ctvol> along a random axis with 50% probability"""
179
    if np.random.randint(low=0,high=100) < 50:
180
        chosen_axis = np.random.randint(low=0,high=3) #0, 1, and 2 are axis options
181
        ctvol =  np.flip(ctvol, axis=chosen_axis)
182
    return ctvol
183
184
def rand_rotate(ctvol):
185
    """Rotate <ctvol> some random amount axially with 50% probability"""
186
    if np.random.randint(low=0,high=100) < 50:
187
        chosen_k = np.random.randint(low=0,high=4)
188
        ctvol = np.rot90(ctvol, k=chosen_k, axes=(1,2))
189
    return ctvol
190
191
###########################################
192
# 2019_10 Dataset Preprocessing Sequences #-------------------------------------
193
###########################################
194
def prepare_ctvol_2019_10_dataset(ctvol, pixel_bounds, data_augment, num_channels,
195
                                  crop_type):
196
    """Pad, crop, possibly augment, reshape to correct
197
    number of channels, cast to torch tensor (to speed up subsequent operations),
198
    Clip Hounsfield units, normalize pixel values, center on the
199
    ImageNet mean, and return as a torch tensor (for crop_type='single')
200
    
201
    <pixel_bounds> is a list of ints e.g. [-1000,200] Hounsfield units. Used for
202
        pixel value clipping and normalization.
203
    <data_augment> is True to employ data augmentation, and False otherwise
204
    <num_channels> is an int, e.g. 3 to reshape the grayscale volume into
205
        a volume of 3-channel images
206
    <crop_type>: if 'single' then return the volume as one 3D numpy array."""
207
    max_slices = 402
208
    max_side_length = 420
209
    assert num_channels == 3 or num_channels == 1
210
    assert crop_type == 'single'
211
    
212
    #Padding to minimum size [max_slices, max_side_length, max_side_length]
213
    ctvol = pad_volume(ctvol, max_slices, max_side_length)
214
    
215
    #Cropping, and data augmentation if indicated
216
    if crop_type == 'single':
217
        if data_augment is True:
218
            ctvol = single_crop_3d_augment(ctvol, max_slices, max_side_length)
219
        else:
220
            ctvol = single_crop_3d_fixed(ctvol, max_slices, max_side_length)
221
        #Reshape to 3 channels if indicated
222
        if num_channels == 3:
223
            ctvol = reshape_3_channels(ctvol)
224
        #Cast to torch tensor and deal with pixel values
225
        output = torchify_pixelnorm_pixelcenter(ctvol, pixel_bounds)
226
    
227
    return output