[6fe801]: / utils / dataset.py

Download this file

399 lines (318 with data), 14.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
import SimpleITK as sitk
import os
import tempfile
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import binary_closing
from skimage import measure
def read_raw(
binary_file_name,
image_size,
sitk_pixel_type,
image_spacing=None,
image_origin=None,
big_endian=False,
):
"""
Read a raw binary scalar image.
Source: https://simpleitk.readthedocs.io/en/master/link_RawImageReading_docs.html
Parameters
----------
binary_file_name (str): Raw, binary image file content.
image_size (tuple like): Size of image (e.g. [2048,2048])
sitk_pixel_type (SimpleITK pixel type: Pixel type of data (e.g.
sitk.sitkUInt16).
image_spacing (tuple like): Optional image spacing, if none given assumed
to be [1]*dim.
image_origin (tuple like): Optional image origin, if none given assumed to
be [0]*dim.
big_endian (bool): Optional byte order indicator, if True big endian, else
little endian.
Returns
-------
SimpleITK image or None if fails.
"""
pixel_dict = {
sitk.sitkUInt8: "MET_UCHAR",
sitk.sitkInt8: "MET_CHAR",
sitk.sitkUInt16: "MET_USHORT",
sitk.sitkInt16: "MET_SHORT",
sitk.sitkUInt32: "MET_UINT",
sitk.sitkInt32: "MET_INT",
sitk.sitkUInt64: "MET_ULONG_LONG",
sitk.sitkInt64: "MET_LONG_LONG",
sitk.sitkFloat32: "MET_FLOAT",
sitk.sitkFloat64: "MET_DOUBLE",
}
direction_cosine = [
"1 0 0 1",
"1 0 0 0 1 0 0 0 1",
"1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1",
]
dim = len(image_size)
header = [
"ObjectType = Image\n".encode(),
(f"NDims = {dim}\n").encode(),
(
"DimSize = " + " ".join([str(v) for v in image_size]) + "\n"
).encode(),
(
"ElementSpacing = "
+ (
" ".join([str(v) for v in image_spacing])
if image_spacing
else " ".join(["1"] * dim)
)
+ "\n"
).encode(),
(
"Offset = "
+ (
" ".join([str(v) for v in image_origin])
if image_origin
else " ".join(["0"] * dim) + "\n"
)
).encode(),
("TransformMatrix = " + direction_cosine[dim - 2] + "\n").encode(),
("ElementType = " + pixel_dict[sitk_pixel_type] + "\n").encode(),
"BinaryData = True\n".encode(),
("BinaryDataByteOrderMSB = " + str(big_endian) + "\n").encode(),
# ElementDataFile must be the last entry in the header
(
"ElementDataFile = " + os.path.abspath(binary_file_name) + "\n"
).encode(),
]
fp = tempfile.NamedTemporaryFile(suffix=".mhd", delete=False)
# print(header)
# Not using the tempfile with a context manager and auto-delete
# because on windows we can't open the file a second time for ReadImage.
fp.writelines(header)
fp.close()
img = sitk.ReadImage(fp.name)
os.remove(fp.name)
return img
def convert_raw_to_nifti(
dataset_dir: list,
output_dir: str,
metadata: dict):
'''
Convert the dataset images from raw (.raw) to nifti (nii.gz) format and export it to a
given output directory.
Args:
dataset_dir (list): list of paths to the dataset directories
output_dir (str): path to the output directory
metadata (dict): dictionary containing the metadata of the dataset
Returns:
None
'''
pass
def create_mask(volume, threshold = 700):
'''
Create a mask from a given volume using a given threshold.
Args:
volume (numpy array): volume to be masked
threshold (int): threshold to be used for the masking
Returns:
numpy array: masked volume
'''
return np.where(volume <= threshold, 1, 0)
def label_regions(mask):
'''
Label connected components in a binary mask.
Args:
mask (numpy array): Binary mask.
Returns:
tuple: A tuple containing labeled mask and the number of labels.
'''
labeled_mask, num_labels = measure.label(mask, connectivity=2, return_num=True, background=0)
return labeled_mask, num_labels
def get_largest_regions(labeled_mask, num_regions=2):
'''
Get the largest connected regions in a labeled mask.
Args:
labeled_mask (numpy array): Labeled mask.
num_regions (int): Number of largest regions to retrieve.
Returns:
list: List of region properties for the largest regions.
'''
regions = measure.regionprops(labeled_mask)
regions.sort(key=lambda x: x.area, reverse=True)
regions = regions[:min(num_regions, len(regions))]
# print([regions[i].axis_major_length for i in range(len(regions))])
# print([regions[i].axis_minor_length for i in range(len(regions))])
return regions
def create_masks(labeled_mask, regions):
'''
Create masks for specific regions in a labeled mask.
Args:
labeled_mask (numpy array): Labeled mask.
regions (list): List of region properties for which masks need to be created.
Returns:
list: List of masks corresponding to the specified regions.
'''
masks = [labeled_mask == region.label for region in regions]
return masks
def fill_holes_and_erode(mask, structure=(7, 7, 5)):
'''
Fill holes in a binary mask and perform erosion.
Args:
mask (numpy array): Binary mask.
dilation_structure (tuple): Dilation structure for binary dilation.
erosion_structure (tuple): Erosion structure for binary erosion.
Returns:
numpy array: Processed mask after filling holes and erosion.
'''
processed_mask = binary_closing(mask, structure=np.ones(structure))
return processed_mask
def remove_trachea(largest_masks, get_largest_regions, create_masks):
'''
Remove the trachea from a set of largest masks with a shape (Slice, H, W).
Args:
largest_masks (numpy array): 3D array of largest masks.
get_largest_regions (function): Function to get largest regions.
create_masks (function): Function to create masks.
Returns:
numpy array: 3D array of masks with trachea removed.
'''
# Find bounding boxes for each region in the 3D mask
labeled_mask_slices = np.array([label_regions(largest_masks[idx, :, :])[0] for idx in range(largest_masks.shape[0])])
# labeled_mask_slices = np.transpose(labeled_mask_slices, (1, 2, 0))
largest_regions_slices = [
get_largest_regions(labeled_mask_slices[idx, :, :], num_regions=3)
for idx in range(labeled_mask_slices.shape[0])
]
largest_regions_masks = [
# we filter the trachea by checking the difference between the major and minor axis length when there is only 1 region
create_masks(labeled_mask_slices[idx, :, :], region)[0] if (len(region) == 1 and (abs(region[0].axis_major_length - region[0].axis_minor_length) > 30))
# this handles the very first few slices with trachea that has a very small difference between the major and minor axis length
else np.zeros_like(labeled_mask_slices[idx, :, :]) if (len(region) == 1 and (abs(region[0].axis_major_length - region[0].axis_minor_length) < 30))
# remove the trachea if there are 3 regions, it will be the 3rd region as we sort by area (highest to lowest)
else create_masks(labeled_mask_slices[idx, :, :], region)[0] + create_masks(labeled_mask_slices[idx, :, :], region)[1] if len(region) == 3
# when there are only 2 regions, we check the difference in the area (area of the first region has to be atleast 50 more than the second region) to indicate that it is a lung not a trachea
# also check if the minor axis of the second region (trachea) is less than 100
# this condition happens when both lungs are touching each other as a region, and trachea as another region
else create_masks(labeled_mask_slices[idx, :, :], region)[0] if len(region) == 2 and (getattr(region[0], 'area') - getattr(region[1], 'area') > 50) and (region[1].axis_minor_length < 100)
# when there are only 2 regions, we combine them. This is after the previous condition is met (when only 2 lungs are detected)
else create_masks(labeled_mask_slices[idx, :, :], region)[0] + create_masks(labeled_mask_slices[idx, :, :], region)[1] if len(region) == 2
else np.zeros_like(labeled_mask_slices[idx, :, :])
for idx, region in enumerate(largest_regions_slices)
]
# largest_regions_masks = np.transpose(largest_regions_masks, (1, 2, 0))
return largest_regions_masks
def segment_lungs_and_remove_trachea(volume, threshold=700, structure=(7, 7, 5), fill_holes_before_trachea_removal=False):
'''
Segment lungs and remove trachea from a given 3D volume with shape (Slice, H, W). Note that this shape is a must for
the internal functions to compute as expected.
Args:
volume (numpy array): 3D volume shape (slice, H, W).
threshold (int): Threshold for creating the initial mask.
dilation_structure (tuple): Dilation structure for binary dilation.
erosion_structure (tuple): Erosion structure for binary erosion.
Returns:
initial_mask (numpy array): Initial mask created from the volume.
labeled_mask (numpy array): Labeled mask.
largest_masks (numpy array): 3D array of largest masks.
processed_mask_without_trachea (numpy array): 3D binary array of masks with trachea removed.
'''
# create a mask
initial_mask = create_mask(volume, threshold=threshold)
# Label connected components
labeled_mask, _ = label_regions(initial_mask)
# Get the largest three regions (two lungs and trachea)
largest_regions = get_largest_regions(labeled_mask, num_regions=3)
# Create masks for the largest three regions
largest_masks = create_masks(labeled_mask, largest_regions)[1]
# fill holes of the largest mask
if fill_holes_before_trachea_removal:
largest_masks = fill_holes_and_erode(largest_masks, structure=tuple([2*x for x in structure]))
# remove the trachea
largest_masks_without_trachea = remove_trachea(largest_masks, get_largest_regions, create_masks)
# Exclude the trachea by subtracting it from the processed mask
processed_mask_without_trachea = fill_holes_and_erode(largest_masks_without_trachea, structure=structure)
return initial_mask, labeled_mask, largest_masks, processed_mask_without_trachea.astype(np.uint8)
def segment_body(image, threshold=700):
'''
Segment the body from a given 3D volume with shape (Slice, H, W). Note that this shape is a must for
the internal functions to compute as expected.
Args:
image (numpy array): 3D volume shape (slice, H, W).
threshold (int): Threshold for creating the initial mask.
Returns:
mask (numpy array): Initial mask created from the volume.
labeled_mask (numpy array): Labeled mask.
largest_masks (numpy array): 3D array of largest masks.
body_segmented (numpy array): 3D binary array of masks with body segmented.
'''
mask = create_mask(image, threshold=threshold)
labeled_mask, _ = label_regions(mask)
largest_regions = get_largest_regions(labeled_mask, num_regions=3)
largest_masks = create_masks(labeled_mask, largest_regions)[0]
# to have zeros and ones instead of binary false and true
largest_masks = largest_masks.astype(np.int8)
body_segmented = np.zeros_like(image)
body_segmented[largest_masks == 0] = image[largest_masks == 0]
return mask, labeled_mask, largest_masks, body_segmented
def display_two_volumes(volume1, volume2, title1, title2, slice=70):
'''
Display two volumes side by side.
Args:
volume1 (numpy array): first volume to be displayed
volume2 (numpy array): second volume to be displayed
title1 (str): title of the first volume
title2 (str): title of the second volume
slice (int): slice to be displayed
Returns:
None
'''
plt.figure(figsize=(9, 6))
plt.subplot(1, 2, 1)
plt.imshow(volume1[slice, :, :], cmap='gray')
plt.title(title1)
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(volume2[slice, :, :], cmap='gray')
plt.title(title2)
plt.axis('off')
plt.show()
def display_volumes(*volumes, **titles_and_slices):
'''
Display multiple volumes side by side.
Args:
volumes (tuple of numpy arrays): volumes to be displayed
titles_and_slices (dict): titles and slices for each volume
Returns:
None
'''
num_volumes = len(volumes)
plt.figure(figsize=(6 * num_volumes, 6))
for i, volume in enumerate(volumes, start=1):
title = titles_and_slices.get(f'title{i}', f'Title {i}')
slice_val = titles_and_slices.get(f'slice{i}', 70)
plt.subplot(1, num_volumes, i)
plt.imshow(volume[slice_val, :, :], cmap='gray') #gray
plt.title(title)
plt.axis('off')
plt.show()
def min_max_normalization(image, mask = None, max_value=None):
'''
Perform min-max normalization on a given image.
Args:
image ('np.array'): Input image to normalize.
mask ('np.array'): Mask to be applied to the image.
max_value ('float'): Maximum value for normalization.
Returns:
normalized_image ('np.array'): Min-max normalized image.
'''
if max_value is None:
max_value = np.iinfo(image.dtype).max
print(f"The maximum value for this volume {image.dtype} is: {max_value}")
print("Using mask for normalization" if mask is not None else "Not using mask for normalization")
# Ensure the image is a NumPy array for efficient calculations
image = np.array(image)
# Calculate the minimum and maximum pixel values
min_value = np.min(image[mask == 1]) if mask is not None else np.min(image)
max_actual = np.max(image[mask == 1]) if mask is not None else np.max(image)
# Perform min-max normalization
normalized_image = (image - min_value) / (max_actual - min_value) * max_value
normalized_image = np.clip(normalized_image, 0, max_value)
return normalized_image.astype(image.dtype)