Diff of /utils/dataset.py [000000] .. [6fe801]

Switch to side-by-side view

--- a
+++ b/utils/dataset.py
@@ -0,0 +1,398 @@
+import SimpleITK as sitk
+import os
+import tempfile
+import numpy as np
+import matplotlib.pyplot as plt
+from scipy.ndimage import binary_closing
+from skimage import measure
+
+def read_raw(
+    binary_file_name,
+    image_size,
+    sitk_pixel_type,
+    image_spacing=None,
+    image_origin=None,
+    big_endian=False,
+):
+    """
+    Read a raw binary scalar image.
+
+    Source: https://simpleitk.readthedocs.io/en/master/link_RawImageReading_docs.html
+
+    Parameters
+    ----------
+    binary_file_name (str): Raw, binary image file content.
+    image_size (tuple like): Size of image (e.g. [2048,2048])
+    sitk_pixel_type (SimpleITK pixel type: Pixel type of data (e.g.
+        sitk.sitkUInt16).
+    image_spacing (tuple like): Optional image spacing, if none given assumed
+        to be [1]*dim.
+    image_origin (tuple like): Optional image origin, if none given assumed to
+        be [0]*dim.
+    big_endian (bool): Optional byte order indicator, if True big endian, else
+        little endian.
+
+    Returns
+    -------
+    SimpleITK image or None if fails.
+    """
+
+    pixel_dict = {
+        sitk.sitkUInt8: "MET_UCHAR",
+        sitk.sitkInt8: "MET_CHAR",
+        sitk.sitkUInt16: "MET_USHORT",
+        sitk.sitkInt16: "MET_SHORT",
+        sitk.sitkUInt32: "MET_UINT",
+        sitk.sitkInt32: "MET_INT",
+        sitk.sitkUInt64: "MET_ULONG_LONG",
+        sitk.sitkInt64: "MET_LONG_LONG",
+        sitk.sitkFloat32: "MET_FLOAT",
+        sitk.sitkFloat64: "MET_DOUBLE",
+    }
+    direction_cosine = [
+        "1 0 0 1",
+        "1 0 0 0 1 0 0 0 1",
+        "1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1",
+    ]
+    dim = len(image_size)
+    header = [
+        "ObjectType = Image\n".encode(),
+        (f"NDims = {dim}\n").encode(),
+        (
+            "DimSize = " + " ".join([str(v) for v in image_size]) + "\n"
+        ).encode(),
+        (
+            "ElementSpacing = "
+            + (
+                " ".join([str(v) for v in image_spacing])
+                if image_spacing
+                else " ".join(["1"] * dim)
+            )
+            + "\n"
+        ).encode(),
+        (
+            "Offset = "
+            + (
+                " ".join([str(v) for v in image_origin])
+                if image_origin
+                else " ".join(["0"] * dim) + "\n"
+            )
+        ).encode(),
+        ("TransformMatrix = " + direction_cosine[dim - 2] + "\n").encode(),
+        ("ElementType = " + pixel_dict[sitk_pixel_type] + "\n").encode(),
+        "BinaryData = True\n".encode(),
+        ("BinaryDataByteOrderMSB = " + str(big_endian) + "\n").encode(),
+        # ElementDataFile must be the last entry in the header
+        (
+            "ElementDataFile = " + os.path.abspath(binary_file_name) + "\n"
+        ).encode(),
+    ]
+    fp = tempfile.NamedTemporaryFile(suffix=".mhd", delete=False)
+
+    # print(header)
+
+    # Not using the tempfile with a context manager and auto-delete
+    # because on windows we can't open the file a second time for ReadImage.
+    fp.writelines(header)
+    fp.close()
+    img = sitk.ReadImage(fp.name)
+    os.remove(fp.name)
+    return img
+
+def convert_raw_to_nifti(
+        dataset_dir: list, 
+        output_dir: str,
+        metadata: dict):
+    '''
+    Convert the dataset images from raw (.raw) to nifti (nii.gz) format and export it to a 
+    given output directory. 
+
+    Args:
+        dataset_dir (list): list of paths to the dataset directories
+        output_dir (str): path to the output directory
+        metadata (dict): dictionary containing the metadata of the dataset
+
+    Returns:
+        None
+    '''
+
+    pass
+
+def create_mask(volume, threshold = 700):
+    '''
+    Create a mask from a given volume using a given threshold.
+
+    Args:
+        volume (numpy array): volume to be masked
+        threshold (int): threshold to be used for the masking
+
+    Returns:
+        numpy array: masked volume
+    '''
+    return np.where(volume <= threshold, 1, 0)
+
+def label_regions(mask):
+    '''
+    Label connected components in a binary mask.
+
+    Args:
+        mask (numpy array): Binary mask.
+
+    Returns:
+        tuple: A tuple containing labeled mask and the number of labels.
+    '''
+    labeled_mask, num_labels = measure.label(mask, connectivity=2, return_num=True, background=0)
+    return labeled_mask, num_labels
+
+def get_largest_regions(labeled_mask, num_regions=2):
+    '''
+    Get the largest connected regions in a labeled mask.
+
+    Args:
+        labeled_mask (numpy array): Labeled mask.
+        num_regions (int): Number of largest regions to retrieve.
+
+    Returns:
+        list: List of region properties for the largest regions.
+    '''
+    regions = measure.regionprops(labeled_mask)
+    regions.sort(key=lambda x: x.area, reverse=True)
+    regions = regions[:min(num_regions, len(regions))]
+    # print([regions[i].axis_major_length for i in range(len(regions))])
+    # print([regions[i].axis_minor_length for i in range(len(regions))])
+    return regions
+
+def create_masks(labeled_mask, regions):
+    '''
+    Create masks for specific regions in a labeled mask.
+
+    Args:
+        labeled_mask (numpy array): Labeled mask.
+        regions (list): List of region properties for which masks need to be created.
+
+    Returns:
+        list: List of masks corresponding to the specified regions.
+    '''
+    masks = [labeled_mask == region.label for region in regions]
+    return masks
+
+def fill_holes_and_erode(mask, structure=(7, 7, 5)):
+    '''
+    Fill holes in a binary mask and perform erosion.
+
+    Args:
+        mask (numpy array): Binary mask.
+        dilation_structure (tuple): Dilation structure for binary dilation.
+        erosion_structure (tuple): Erosion structure for binary erosion.
+
+    Returns:
+        numpy array: Processed mask after filling holes and erosion.
+    '''
+    processed_mask = binary_closing(mask, structure=np.ones(structure))
+
+    return processed_mask
+
+def remove_trachea(largest_masks, get_largest_regions, create_masks):
+    '''
+    Remove the trachea from a set of largest masks with a shape (Slice, H, W).
+
+    Args:
+        largest_masks (numpy array): 3D array of largest masks.
+        get_largest_regions (function): Function to get largest regions.
+        create_masks (function): Function to create masks.
+
+    Returns:
+        numpy array: 3D array of masks with trachea removed.
+    '''
+    # Find bounding boxes for each region in the 3D mask
+    labeled_mask_slices = np.array([label_regions(largest_masks[idx, :, :])[0] for idx in range(largest_masks.shape[0])])
+    # labeled_mask_slices = np.transpose(labeled_mask_slices, (1, 2, 0))
+
+    largest_regions_slices = [
+        get_largest_regions(labeled_mask_slices[idx, :, :], num_regions=3)
+        for idx in range(labeled_mask_slices.shape[0])
+    ]
+
+    largest_regions_masks = [
+        # we filter the trachea by checking the difference between the major and minor axis length when there is only 1 region
+        create_masks(labeled_mask_slices[idx, :, :], region)[0] if (len(region) == 1 and (abs(region[0].axis_major_length - region[0].axis_minor_length) > 30))
+
+        # this handles the very first few slices with trachea that has a very small difference between the major and minor axis length
+        else np.zeros_like(labeled_mask_slices[idx, :, :]) if (len(region) == 1 and (abs(region[0].axis_major_length - region[0].axis_minor_length) < 30)) 
+        
+        # remove the trachea if there are 3 regions, it will be the 3rd region as we sort by area (highest to lowest)
+        else create_masks(labeled_mask_slices[idx, :, :], region)[0] + create_masks(labeled_mask_slices[idx, :, :], region)[1] if len(region) == 3 
+
+        # when there are only 2 regions, we check the difference in the area (area of the first region has to be atleast 50 more than the second region) to indicate that it is a lung not a trachea
+        # also check if the minor axis of the second region (trachea) is less than 100
+        # this condition happens when both lungs are touching each other as a region, and trachea as another region
+        else create_masks(labeled_mask_slices[idx, :, :], region)[0] if len(region) == 2 and (getattr(region[0], 'area') - getattr(region[1], 'area') > 50) and (region[1].axis_minor_length < 100) 
+
+        # when there are only 2 regions, we combine them. This is after the previous condition is met (when only 2 lungs are detected)
+        else create_masks(labeled_mask_slices[idx, :, :], region)[0] + create_masks(labeled_mask_slices[idx, :, :], region)[1] if len(region) == 2 
+
+        else np.zeros_like(labeled_mask_slices[idx, :, :])
+        for idx, region in enumerate(largest_regions_slices)
+    ]
+    # largest_regions_masks = np.transpose(largest_regions_masks, (1, 2, 0))
+
+    return largest_regions_masks
+
+def segment_lungs_and_remove_trachea(volume, threshold=700, structure=(7, 7, 5), fill_holes_before_trachea_removal=False):
+    '''
+    Segment lungs and remove trachea from a given 3D volume with shape (Slice, H, W). Note that this shape is a must for 
+    the internal functions to compute as expected.
+
+    Args:
+        volume (numpy array): 3D volume shape (slice, H, W).
+        threshold (int): Threshold for creating the initial mask.
+        dilation_structure (tuple): Dilation structure for binary dilation.
+        erosion_structure (tuple): Erosion structure for binary erosion.
+
+    Returns:
+        initial_mask (numpy array): Initial mask created from the volume.
+        labeled_mask (numpy array): Labeled mask.
+        largest_masks (numpy array): 3D array of largest masks.
+        processed_mask_without_trachea (numpy array): 3D binary array of masks with trachea removed.
+    '''
+    # create a mask
+    initial_mask = create_mask(volume, threshold=threshold)
+
+    # Label connected components
+    labeled_mask, _ = label_regions(initial_mask)
+
+    # Get the largest three regions (two lungs and trachea)
+    largest_regions = get_largest_regions(labeled_mask, num_regions=3)
+
+    # Create masks for the largest three regions
+    largest_masks = create_masks(labeled_mask, largest_regions)[1]
+
+    # fill holes of the largest mask
+    if fill_holes_before_trachea_removal:
+        largest_masks = fill_holes_and_erode(largest_masks, structure=tuple([2*x for x in structure]))
+
+    # remove the trachea
+    largest_masks_without_trachea = remove_trachea(largest_masks, get_largest_regions, create_masks)
+
+    # Exclude the trachea by subtracting it from the processed mask
+    processed_mask_without_trachea = fill_holes_and_erode(largest_masks_without_trachea, structure=structure)
+
+    return initial_mask, labeled_mask, largest_masks, processed_mask_without_trachea.astype(np.uint8)
+
+def segment_body(image, threshold=700):
+    '''
+    Segment the body from a given 3D volume with shape (Slice, H, W). Note that this shape is a must for
+    the internal functions to compute as expected.
+
+    Args:
+        image (numpy array): 3D volume shape (slice, H, W).
+        threshold (int): Threshold for creating the initial mask.
+
+    Returns:
+        mask (numpy array): Initial mask created from the volume.
+        labeled_mask (numpy array): Labeled mask.
+        largest_masks (numpy array): 3D array of largest masks.
+        body_segmented (numpy array): 3D binary array of masks with body segmented.
+
+    '''
+    mask = create_mask(image, threshold=threshold)
+    labeled_mask, _ = label_regions(mask)
+    largest_regions = get_largest_regions(labeled_mask, num_regions=3)
+    largest_masks = create_masks(labeled_mask, largest_regions)[0]
+
+    # to have zeros and ones instead of binary false and true
+    largest_masks = largest_masks.astype(np.int8)
+
+    body_segmented = np.zeros_like(image)
+    body_segmented[largest_masks == 0] = image[largest_masks == 0]
+
+    return mask, labeled_mask, largest_masks, body_segmented
+
+
+def display_two_volumes(volume1, volume2, title1, title2, slice=70):
+    '''
+    Display two volumes side by side.
+
+    Args:
+        volume1 (numpy array): first volume to be displayed
+        volume2 (numpy array): second volume to be displayed
+        title1 (str): title of the first volume
+        title2 (str): title of the second volume
+        slice (int): slice to be displayed
+
+    Returns:
+        None
+    '''
+    plt.figure(figsize=(9, 6))
+
+    plt.subplot(1, 2, 1)
+    plt.imshow(volume1[slice, :, :], cmap='gray') 
+    plt.title(title1)
+    plt.axis('off')
+
+
+    plt.subplot(1, 2, 2)
+    plt.imshow(volume2[slice, :, :], cmap='gray') 
+    plt.title(title2)
+    plt.axis('off')
+
+    plt.show()
+
+def display_volumes(*volumes, **titles_and_slices):
+    '''
+    Display multiple volumes side by side.
+
+    Args:
+        volumes (tuple of numpy arrays): volumes to be displayed
+        titles_and_slices (dict): titles and slices for each volume
+        
+    Returns:
+        None
+    '''
+    num_volumes = len(volumes)
+    
+    plt.figure(figsize=(6 * num_volumes, 6))
+
+    for i, volume in enumerate(volumes, start=1):
+        title = titles_and_slices.get(f'title{i}', f'Title {i}')
+        slice_val = titles_and_slices.get(f'slice{i}', 70)
+
+        plt.subplot(1, num_volumes, i)
+        plt.imshow(volume[slice_val, :, :], cmap='gray') #gray
+        plt.title(title)
+        plt.axis('off')
+
+    plt.show()
+
+
+def min_max_normalization(image, mask = None, max_value=None):
+    '''
+    Perform min-max normalization on a given image.
+
+    Args:
+        image ('np.array'): Input image to normalize.
+        mask ('np.array'): Mask to be applied to the image.
+        max_value ('float'): Maximum value for normalization.
+
+    Returns:
+        normalized_image ('np.array'): Min-max normalized image.
+    '''
+
+    if max_value is None:
+        max_value = np.iinfo(image.dtype).max
+        print(f"The maximum value for this volume {image.dtype} is: {max_value}")
+    
+    print("Using mask for normalization" if mask is not None else "Not using mask for normalization")
+
+    # Ensure the image is a NumPy array for efficient calculations
+    image = np.array(image)
+
+    # Calculate the minimum and maximum pixel values
+    min_value = np.min(image[mask == 1]) if mask is not None else np.min(image)
+    max_actual = np.max(image[mask == 1]) if mask is not None else np.max(image)
+    
+    # Perform min-max normalization
+    normalized_image = (image - min_value) / (max_actual - min_value) * max_value
+    normalized_image = np.clip(normalized_image, 0, max_value)
+    
+    return normalized_image.astype(image.dtype)