Diff of /utils/landmarks.py [000000] .. [6fe801]

Switch to side-by-side view

--- a
+++ b/utils/landmarks.py
@@ -0,0 +1,113 @@
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import nibabel as nib
+import os
+from matplotlib.colors import LinearSegmentedColormap
+
+def write_landmarks_to_list(landmarks, file_path):
+    '''
+    Write the landmarks to a text file.
+
+    Args:
+        landmarks ('list'): List of landmarks.
+        file_path ('str'): Path to the text file. It includes the file name and extension.
+
+    Returns:
+        None
+    '''
+    # Write the landmarks to a text file
+    with open(file_path, 'w') as file:
+        for row in landmarks:
+            file.write('\t'.join(map(str, row)) + '\n')
+
+
+def get_landmarks_from_txt(transformed_file_path, search_key='OutputIndexFixed'):
+    '''
+    Get the transformed landmarks from the text file using a given search_key column index.
+    The searcu_key column index is the column name in the text file, where the landmarks (input or transformed)
+    are stored. The default value is 'OutputIndexFixed', which is the transformed landmarks. Only two values are
+    possible: 'OutputIndexFixed' or 'InputIndex'.
+
+    Args:
+        transformed_file_path ('str'): Path to the transformed text file.
+        search_key ('str'): Column name in the text file where the landmarks are stored.
+
+    Returns:
+        landmarks_list ('list'): List of transformed landmarks.
+    '''
+    # validate the search key
+    assert search_key in ['OutputIndexFixed', 'InputIndex'], "The search_key must be either 'OutputIndexFixed' or 'InputIndex'."
+
+    # Define the column names based on the data in the text file
+    columns = [
+        'Point', 'InputIndex', 'InputPoint',
+        'OutputIndexFixed', 'OutputPoint', 'Deformation'
+    ]
+
+    # Read the text file into a pandas DataFrame
+    df = pd.read_csv(transformed_file_path, sep='\t;', comment=';', header=None, names=columns, engine='python')
+
+    # select the required column
+    df_col = df[search_key]
+    
+    # convert the column values to a list of lists
+    landmarks_list = [list(map(int, df_col[idx].split(' ')[-4:-1])) for idx in range(len(df_col))]
+
+    return landmarks_list
+
+def visualize_landmarks(slice_index=70, subject='copd1', split='train'):
+    '''
+    Visualize the landmarks on the reference image or a mask.
+    '''
+    # Define the paths
+    landmarks_path = os.path.join(os.getcwd(),f'../dataset/{split}/{subject}/{subject}_300_iBH_xyz_r1.txt')
+    reference_image_path = os.path.join(os.getcwd(),f'../dataset/{split}/{subject}/{subject}_iBHCT.nii.gz') # _lung
+    reference_mask_path = os.path.join(os.getcwd(),f'../dataset/{split}/{subject}/{subject}_iBHCT_lung.nii.gz') # _lung
+    # reference_mask_path = os.path.join(os.getcwd(),f'../dataset/segmentation_trails/mask1/train/{subject}_eBHCT_lung.nii.gz') # _lung
+
+    # -1 to match the MATLAB visualizer indexing result
+    slice_index = slice_index - 1
+
+    # Load the reference image
+    nii_image = nib.load(reference_image_path)
+    reference_image = nii_image.get_fdata()
+
+    # Load the reference mask
+    nii_mask = nib.load(reference_mask_path)
+    reference_mask = nii_mask.get_fdata()
+
+    # transpose the axis to rotate for visualization
+    reference_image = reference_image.transpose(2, 1, 0)
+    reference_mask = reference_mask.transpose(2, 1, 0)
+    
+    # Load 3D landmarks from the file
+    landmarks_data = np.loadtxt(landmarks_path, skiprows=2)
+    slice_landmarsk = np.array([inner_list for inner_list in landmarks_data if inner_list[2] == slice_index+1])
+    
+    # Create a red-green colormap with opacity
+    cmap = LinearSegmentedColormap.from_list('red_green', ['red', 'green'], N=256)
+    
+    # Visualize a specific slice
+    plt.figure(figsize=(8, 8))
+    plt.imshow(reference_image[slice_index, :, :], cmap='gray')
+
+    # Overlay the mask with color on the landmarks
+    mask_overlay = np.ma.masked_where(reference_mask[slice_index, :, :] == 0, reference_mask[slice_index, :, :])
+    plt.imshow(mask_overlay, cmap=cmap, alpha=0.5)
+
+    if len(slice_landmarsk) > 0:
+            
+        # Extract x, y, z coordinates
+        x_coords = slice_landmarsk[:, 0].astype(int)
+        y_coords = slice_landmarsk[:, 1].astype(int)
+        z_coords = slice_landmarsk[:, 2].astype(int)
+
+        # Plot landmarks on the current slice
+        plt.scatter(x_coords, y_coords, c='b', marker='+', label=f'{landmarks_path.split("/")[-1].split(".txt")[0]} landmarks')
+
+    plt.legend()
+    plt.axis('off')
+    # plt.title(f"Slice {slice_index+1}")
+    plt.show()
+