a b/utils/landmarks.py
1
import pandas as pd
2
import numpy as np
3
import matplotlib.pyplot as plt
4
import nibabel as nib
5
import os
6
from matplotlib.colors import LinearSegmentedColormap
7
8
def write_landmarks_to_list(landmarks, file_path):
9
    '''
10
    Write the landmarks to a text file.
11
12
    Args:
13
        landmarks ('list'): List of landmarks.
14
        file_path ('str'): Path to the text file. It includes the file name and extension.
15
16
    Returns:
17
        None
18
    '''
19
    # Write the landmarks to a text file
20
    with open(file_path, 'w') as file:
21
        for row in landmarks:
22
            file.write('\t'.join(map(str, row)) + '\n')
23
24
25
def get_landmarks_from_txt(transformed_file_path, search_key='OutputIndexFixed'):
26
    '''
27
    Get the transformed landmarks from the text file using a given search_key column index.
28
    The searcu_key column index is the column name in the text file, where the landmarks (input or transformed)
29
    are stored. The default value is 'OutputIndexFixed', which is the transformed landmarks. Only two values are
30
    possible: 'OutputIndexFixed' or 'InputIndex'.
31
32
    Args:
33
        transformed_file_path ('str'): Path to the transformed text file.
34
        search_key ('str'): Column name in the text file where the landmarks are stored.
35
36
    Returns:
37
        landmarks_list ('list'): List of transformed landmarks.
38
    '''
39
    # validate the search key
40
    assert search_key in ['OutputIndexFixed', 'InputIndex'], "The search_key must be either 'OutputIndexFixed' or 'InputIndex'."
41
42
    # Define the column names based on the data in the text file
43
    columns = [
44
        'Point', 'InputIndex', 'InputPoint',
45
        'OutputIndexFixed', 'OutputPoint', 'Deformation'
46
    ]
47
48
    # Read the text file into a pandas DataFrame
49
    df = pd.read_csv(transformed_file_path, sep='\t;', comment=';', header=None, names=columns, engine='python')
50
51
    # select the required column
52
    df_col = df[search_key]
53
    
54
    # convert the column values to a list of lists
55
    landmarks_list = [list(map(int, df_col[idx].split(' ')[-4:-1])) for idx in range(len(df_col))]
56
57
    return landmarks_list
58
59
def visualize_landmarks(slice_index=70, subject='copd1', split='train'):
60
    '''
61
    Visualize the landmarks on the reference image or a mask.
62
    '''
63
    # Define the paths
64
    landmarks_path = os.path.join(os.getcwd(),f'../dataset/{split}/{subject}/{subject}_300_iBH_xyz_r1.txt')
65
    reference_image_path = os.path.join(os.getcwd(),f'../dataset/{split}/{subject}/{subject}_iBHCT.nii.gz') # _lung
66
    reference_mask_path = os.path.join(os.getcwd(),f'../dataset/{split}/{subject}/{subject}_iBHCT_lung.nii.gz') # _lung
67
    # reference_mask_path = os.path.join(os.getcwd(),f'../dataset/segmentation_trails/mask1/train/{subject}_eBHCT_lung.nii.gz') # _lung
68
69
    # -1 to match the MATLAB visualizer indexing result
70
    slice_index = slice_index - 1
71
72
    # Load the reference image
73
    nii_image = nib.load(reference_image_path)
74
    reference_image = nii_image.get_fdata()
75
76
    # Load the reference mask
77
    nii_mask = nib.load(reference_mask_path)
78
    reference_mask = nii_mask.get_fdata()
79
80
    # transpose the axis to rotate for visualization
81
    reference_image = reference_image.transpose(2, 1, 0)
82
    reference_mask = reference_mask.transpose(2, 1, 0)
83
    
84
    # Load 3D landmarks from the file
85
    landmarks_data = np.loadtxt(landmarks_path, skiprows=2)
86
    slice_landmarsk = np.array([inner_list for inner_list in landmarks_data if inner_list[2] == slice_index+1])
87
    
88
    # Create a red-green colormap with opacity
89
    cmap = LinearSegmentedColormap.from_list('red_green', ['red', 'green'], N=256)
90
    
91
    # Visualize a specific slice
92
    plt.figure(figsize=(8, 8))
93
    plt.imshow(reference_image[slice_index, :, :], cmap='gray')
94
95
    # Overlay the mask with color on the landmarks
96
    mask_overlay = np.ma.masked_where(reference_mask[slice_index, :, :] == 0, reference_mask[slice_index, :, :])
97
    plt.imshow(mask_overlay, cmap=cmap, alpha=0.5)
98
99
    if len(slice_landmarsk) > 0:
100
            
101
        # Extract x, y, z coordinates
102
        x_coords = slice_landmarsk[:, 0].astype(int)
103
        y_coords = slice_landmarsk[:, 1].astype(int)
104
        z_coords = slice_landmarsk[:, 2].astype(int)
105
106
        # Plot landmarks on the current slice
107
        plt.scatter(x_coords, y_coords, c='b', marker='+', label=f'{landmarks_path.split("/")[-1].split(".txt")[0]} landmarks')
108
109
    plt.legend()
110
    plt.axis('off')
111
    # plt.title(f"Slice {slice_index+1}")
112
    plt.show()
113