Diff of /utils/dataset.py [000000] .. [6fe801]

Switch to unified view

a b/utils/dataset.py
1
import SimpleITK as sitk
2
import os
3
import tempfile
4
import numpy as np
5
import matplotlib.pyplot as plt
6
from scipy.ndimage import binary_closing
7
from skimage import measure
8
9
def read_raw(
10
    binary_file_name,
11
    image_size,
12
    sitk_pixel_type,
13
    image_spacing=None,
14
    image_origin=None,
15
    big_endian=False,
16
):
17
    """
18
    Read a raw binary scalar image.
19
20
    Source: https://simpleitk.readthedocs.io/en/master/link_RawImageReading_docs.html
21
22
    Parameters
23
    ----------
24
    binary_file_name (str): Raw, binary image file content.
25
    image_size (tuple like): Size of image (e.g. [2048,2048])
26
    sitk_pixel_type (SimpleITK pixel type: Pixel type of data (e.g.
27
        sitk.sitkUInt16).
28
    image_spacing (tuple like): Optional image spacing, if none given assumed
29
        to be [1]*dim.
30
    image_origin (tuple like): Optional image origin, if none given assumed to
31
        be [0]*dim.
32
    big_endian (bool): Optional byte order indicator, if True big endian, else
33
        little endian.
34
35
    Returns
36
    -------
37
    SimpleITK image or None if fails.
38
    """
39
40
    pixel_dict = {
41
        sitk.sitkUInt8: "MET_UCHAR",
42
        sitk.sitkInt8: "MET_CHAR",
43
        sitk.sitkUInt16: "MET_USHORT",
44
        sitk.sitkInt16: "MET_SHORT",
45
        sitk.sitkUInt32: "MET_UINT",
46
        sitk.sitkInt32: "MET_INT",
47
        sitk.sitkUInt64: "MET_ULONG_LONG",
48
        sitk.sitkInt64: "MET_LONG_LONG",
49
        sitk.sitkFloat32: "MET_FLOAT",
50
        sitk.sitkFloat64: "MET_DOUBLE",
51
    }
52
    direction_cosine = [
53
        "1 0 0 1",
54
        "1 0 0 0 1 0 0 0 1",
55
        "1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1",
56
    ]
57
    dim = len(image_size)
58
    header = [
59
        "ObjectType = Image\n".encode(),
60
        (f"NDims = {dim}\n").encode(),
61
        (
62
            "DimSize = " + " ".join([str(v) for v in image_size]) + "\n"
63
        ).encode(),
64
        (
65
            "ElementSpacing = "
66
            + (
67
                " ".join([str(v) for v in image_spacing])
68
                if image_spacing
69
                else " ".join(["1"] * dim)
70
            )
71
            + "\n"
72
        ).encode(),
73
        (
74
            "Offset = "
75
            + (
76
                " ".join([str(v) for v in image_origin])
77
                if image_origin
78
                else " ".join(["0"] * dim) + "\n"
79
            )
80
        ).encode(),
81
        ("TransformMatrix = " + direction_cosine[dim - 2] + "\n").encode(),
82
        ("ElementType = " + pixel_dict[sitk_pixel_type] + "\n").encode(),
83
        "BinaryData = True\n".encode(),
84
        ("BinaryDataByteOrderMSB = " + str(big_endian) + "\n").encode(),
85
        # ElementDataFile must be the last entry in the header
86
        (
87
            "ElementDataFile = " + os.path.abspath(binary_file_name) + "\n"
88
        ).encode(),
89
    ]
90
    fp = tempfile.NamedTemporaryFile(suffix=".mhd", delete=False)
91
92
    # print(header)
93
94
    # Not using the tempfile with a context manager and auto-delete
95
    # because on windows we can't open the file a second time for ReadImage.
96
    fp.writelines(header)
97
    fp.close()
98
    img = sitk.ReadImage(fp.name)
99
    os.remove(fp.name)
100
    return img
101
102
def convert_raw_to_nifti(
103
        dataset_dir: list, 
104
        output_dir: str,
105
        metadata: dict):
106
    '''
107
    Convert the dataset images from raw (.raw) to nifti (nii.gz) format and export it to a 
108
    given output directory. 
109
110
    Args:
111
        dataset_dir (list): list of paths to the dataset directories
112
        output_dir (str): path to the output directory
113
        metadata (dict): dictionary containing the metadata of the dataset
114
115
    Returns:
116
        None
117
    '''
118
119
    pass
120
121
def create_mask(volume, threshold = 700):
122
    '''
123
    Create a mask from a given volume using a given threshold.
124
125
    Args:
126
        volume (numpy array): volume to be masked
127
        threshold (int): threshold to be used for the masking
128
129
    Returns:
130
        numpy array: masked volume
131
    '''
132
    return np.where(volume <= threshold, 1, 0)
133
134
def label_regions(mask):
135
    '''
136
    Label connected components in a binary mask.
137
138
    Args:
139
        mask (numpy array): Binary mask.
140
141
    Returns:
142
        tuple: A tuple containing labeled mask and the number of labels.
143
    '''
144
    labeled_mask, num_labels = measure.label(mask, connectivity=2, return_num=True, background=0)
145
    return labeled_mask, num_labels
146
147
def get_largest_regions(labeled_mask, num_regions=2):
148
    '''
149
    Get the largest connected regions in a labeled mask.
150
151
    Args:
152
        labeled_mask (numpy array): Labeled mask.
153
        num_regions (int): Number of largest regions to retrieve.
154
155
    Returns:
156
        list: List of region properties for the largest regions.
157
    '''
158
    regions = measure.regionprops(labeled_mask)
159
    regions.sort(key=lambda x: x.area, reverse=True)
160
    regions = regions[:min(num_regions, len(regions))]
161
    # print([regions[i].axis_major_length for i in range(len(regions))])
162
    # print([regions[i].axis_minor_length for i in range(len(regions))])
163
    return regions
164
165
def create_masks(labeled_mask, regions):
166
    '''
167
    Create masks for specific regions in a labeled mask.
168
169
    Args:
170
        labeled_mask (numpy array): Labeled mask.
171
        regions (list): List of region properties for which masks need to be created.
172
173
    Returns:
174
        list: List of masks corresponding to the specified regions.
175
    '''
176
    masks = [labeled_mask == region.label for region in regions]
177
    return masks
178
179
def fill_holes_and_erode(mask, structure=(7, 7, 5)):
180
    '''
181
    Fill holes in a binary mask and perform erosion.
182
183
    Args:
184
        mask (numpy array): Binary mask.
185
        dilation_structure (tuple): Dilation structure for binary dilation.
186
        erosion_structure (tuple): Erosion structure for binary erosion.
187
188
    Returns:
189
        numpy array: Processed mask after filling holes and erosion.
190
    '''
191
    processed_mask = binary_closing(mask, structure=np.ones(structure))
192
193
    return processed_mask
194
195
def remove_trachea(largest_masks, get_largest_regions, create_masks):
196
    '''
197
    Remove the trachea from a set of largest masks with a shape (Slice, H, W).
198
199
    Args:
200
        largest_masks (numpy array): 3D array of largest masks.
201
        get_largest_regions (function): Function to get largest regions.
202
        create_masks (function): Function to create masks.
203
204
    Returns:
205
        numpy array: 3D array of masks with trachea removed.
206
    '''
207
    # Find bounding boxes for each region in the 3D mask
208
    labeled_mask_slices = np.array([label_regions(largest_masks[idx, :, :])[0] for idx in range(largest_masks.shape[0])])
209
    # labeled_mask_slices = np.transpose(labeled_mask_slices, (1, 2, 0))
210
211
    largest_regions_slices = [
212
        get_largest_regions(labeled_mask_slices[idx, :, :], num_regions=3)
213
        for idx in range(labeled_mask_slices.shape[0])
214
    ]
215
216
    largest_regions_masks = [
217
        # we filter the trachea by checking the difference between the major and minor axis length when there is only 1 region
218
        create_masks(labeled_mask_slices[idx, :, :], region)[0] if (len(region) == 1 and (abs(region[0].axis_major_length - region[0].axis_minor_length) > 30))
219
220
        # this handles the very first few slices with trachea that has a very small difference between the major and minor axis length
221
        else np.zeros_like(labeled_mask_slices[idx, :, :]) if (len(region) == 1 and (abs(region[0].axis_major_length - region[0].axis_minor_length) < 30)) 
222
        
223
        # remove the trachea if there are 3 regions, it will be the 3rd region as we sort by area (highest to lowest)
224
        else create_masks(labeled_mask_slices[idx, :, :], region)[0] + create_masks(labeled_mask_slices[idx, :, :], region)[1] if len(region) == 3 
225
226
        # when there are only 2 regions, we check the difference in the area (area of the first region has to be atleast 50 more than the second region) to indicate that it is a lung not a trachea
227
        # also check if the minor axis of the second region (trachea) is less than 100
228
        # this condition happens when both lungs are touching each other as a region, and trachea as another region
229
        else create_masks(labeled_mask_slices[idx, :, :], region)[0] if len(region) == 2 and (getattr(region[0], 'area') - getattr(region[1], 'area') > 50) and (region[1].axis_minor_length < 100) 
230
231
        # when there are only 2 regions, we combine them. This is after the previous condition is met (when only 2 lungs are detected)
232
        else create_masks(labeled_mask_slices[idx, :, :], region)[0] + create_masks(labeled_mask_slices[idx, :, :], region)[1] if len(region) == 2 
233
234
        else np.zeros_like(labeled_mask_slices[idx, :, :])
235
        for idx, region in enumerate(largest_regions_slices)
236
    ]
237
    # largest_regions_masks = np.transpose(largest_regions_masks, (1, 2, 0))
238
239
    return largest_regions_masks
240
241
def segment_lungs_and_remove_trachea(volume, threshold=700, structure=(7, 7, 5), fill_holes_before_trachea_removal=False):
242
    '''
243
    Segment lungs and remove trachea from a given 3D volume with shape (Slice, H, W). Note that this shape is a must for 
244
    the internal functions to compute as expected.
245
246
    Args:
247
        volume (numpy array): 3D volume shape (slice, H, W).
248
        threshold (int): Threshold for creating the initial mask.
249
        dilation_structure (tuple): Dilation structure for binary dilation.
250
        erosion_structure (tuple): Erosion structure for binary erosion.
251
252
    Returns:
253
        initial_mask (numpy array): Initial mask created from the volume.
254
        labeled_mask (numpy array): Labeled mask.
255
        largest_masks (numpy array): 3D array of largest masks.
256
        processed_mask_without_trachea (numpy array): 3D binary array of masks with trachea removed.
257
    '''
258
    # create a mask
259
    initial_mask = create_mask(volume, threshold=threshold)
260
261
    # Label connected components
262
    labeled_mask, _ = label_regions(initial_mask)
263
264
    # Get the largest three regions (two lungs and trachea)
265
    largest_regions = get_largest_regions(labeled_mask, num_regions=3)
266
267
    # Create masks for the largest three regions
268
    largest_masks = create_masks(labeled_mask, largest_regions)[1]
269
270
    # fill holes of the largest mask
271
    if fill_holes_before_trachea_removal:
272
        largest_masks = fill_holes_and_erode(largest_masks, structure=tuple([2*x for x in structure]))
273
274
    # remove the trachea
275
    largest_masks_without_trachea = remove_trachea(largest_masks, get_largest_regions, create_masks)
276
277
    # Exclude the trachea by subtracting it from the processed mask
278
    processed_mask_without_trachea = fill_holes_and_erode(largest_masks_without_trachea, structure=structure)
279
280
    return initial_mask, labeled_mask, largest_masks, processed_mask_without_trachea.astype(np.uint8)
281
282
def segment_body(image, threshold=700):
283
    '''
284
    Segment the body from a given 3D volume with shape (Slice, H, W). Note that this shape is a must for
285
    the internal functions to compute as expected.
286
287
    Args:
288
        image (numpy array): 3D volume shape (slice, H, W).
289
        threshold (int): Threshold for creating the initial mask.
290
291
    Returns:
292
        mask (numpy array): Initial mask created from the volume.
293
        labeled_mask (numpy array): Labeled mask.
294
        largest_masks (numpy array): 3D array of largest masks.
295
        body_segmented (numpy array): 3D binary array of masks with body segmented.
296
297
    '''
298
    mask = create_mask(image, threshold=threshold)
299
    labeled_mask, _ = label_regions(mask)
300
    largest_regions = get_largest_regions(labeled_mask, num_regions=3)
301
    largest_masks = create_masks(labeled_mask, largest_regions)[0]
302
303
    # to have zeros and ones instead of binary false and true
304
    largest_masks = largest_masks.astype(np.int8)
305
306
    body_segmented = np.zeros_like(image)
307
    body_segmented[largest_masks == 0] = image[largest_masks == 0]
308
309
    return mask, labeled_mask, largest_masks, body_segmented
310
311
312
def display_two_volumes(volume1, volume2, title1, title2, slice=70):
313
    '''
314
    Display two volumes side by side.
315
316
    Args:
317
        volume1 (numpy array): first volume to be displayed
318
        volume2 (numpy array): second volume to be displayed
319
        title1 (str): title of the first volume
320
        title2 (str): title of the second volume
321
        slice (int): slice to be displayed
322
323
    Returns:
324
        None
325
    '''
326
    plt.figure(figsize=(9, 6))
327
328
    plt.subplot(1, 2, 1)
329
    plt.imshow(volume1[slice, :, :], cmap='gray') 
330
    plt.title(title1)
331
    plt.axis('off')
332
333
334
    plt.subplot(1, 2, 2)
335
    plt.imshow(volume2[slice, :, :], cmap='gray') 
336
    plt.title(title2)
337
    plt.axis('off')
338
339
    plt.show()
340
341
def display_volumes(*volumes, **titles_and_slices):
342
    '''
343
    Display multiple volumes side by side.
344
345
    Args:
346
        volumes (tuple of numpy arrays): volumes to be displayed
347
        titles_and_slices (dict): titles and slices for each volume
348
        
349
    Returns:
350
        None
351
    '''
352
    num_volumes = len(volumes)
353
    
354
    plt.figure(figsize=(6 * num_volumes, 6))
355
356
    for i, volume in enumerate(volumes, start=1):
357
        title = titles_and_slices.get(f'title{i}', f'Title {i}')
358
        slice_val = titles_and_slices.get(f'slice{i}', 70)
359
360
        plt.subplot(1, num_volumes, i)
361
        plt.imshow(volume[slice_val, :, :], cmap='gray') #gray
362
        plt.title(title)
363
        plt.axis('off')
364
365
    plt.show()
366
367
368
def min_max_normalization(image, mask = None, max_value=None):
369
    '''
370
    Perform min-max normalization on a given image.
371
372
    Args:
373
        image ('np.array'): Input image to normalize.
374
        mask ('np.array'): Mask to be applied to the image.
375
        max_value ('float'): Maximum value for normalization.
376
377
    Returns:
378
        normalized_image ('np.array'): Min-max normalized image.
379
    '''
380
381
    if max_value is None:
382
        max_value = np.iinfo(image.dtype).max
383
        print(f"The maximum value for this volume {image.dtype} is: {max_value}")
384
    
385
    print("Using mask for normalization" if mask is not None else "Not using mask for normalization")
386
387
    # Ensure the image is a NumPy array for efficient calculations
388
    image = np.array(image)
389
390
    # Calculate the minimum and maximum pixel values
391
    min_value = np.min(image[mask == 1]) if mask is not None else np.min(image)
392
    max_actual = np.max(image[mask == 1]) if mask is not None else np.max(image)
393
    
394
    # Perform min-max normalization
395
    normalized_image = (image - min_value) / (max_actual - min_value) * max_value
396
    normalized_image = np.clip(normalized_image, 0, max_value)
397
    
398
    return normalized_image.astype(image.dtype)