[6fe801]: / utils / landmarks.py

Download this file

114 lines (87 with data), 4.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
import os
from matplotlib.colors import LinearSegmentedColormap
def write_landmarks_to_list(landmarks, file_path):
'''
Write the landmarks to a text file.
Args:
landmarks ('list'): List of landmarks.
file_path ('str'): Path to the text file. It includes the file name and extension.
Returns:
None
'''
# Write the landmarks to a text file
with open(file_path, 'w') as file:
for row in landmarks:
file.write('\t'.join(map(str, row)) + '\n')
def get_landmarks_from_txt(transformed_file_path, search_key='OutputIndexFixed'):
'''
Get the transformed landmarks from the text file using a given search_key column index.
The searcu_key column index is the column name in the text file, where the landmarks (input or transformed)
are stored. The default value is 'OutputIndexFixed', which is the transformed landmarks. Only two values are
possible: 'OutputIndexFixed' or 'InputIndex'.
Args:
transformed_file_path ('str'): Path to the transformed text file.
search_key ('str'): Column name in the text file where the landmarks are stored.
Returns:
landmarks_list ('list'): List of transformed landmarks.
'''
# validate the search key
assert search_key in ['OutputIndexFixed', 'InputIndex'], "The search_key must be either 'OutputIndexFixed' or 'InputIndex'."
# Define the column names based on the data in the text file
columns = [
'Point', 'InputIndex', 'InputPoint',
'OutputIndexFixed', 'OutputPoint', 'Deformation'
]
# Read the text file into a pandas DataFrame
df = pd.read_csv(transformed_file_path, sep='\t;', comment=';', header=None, names=columns, engine='python')
# select the required column
df_col = df[search_key]
# convert the column values to a list of lists
landmarks_list = [list(map(int, df_col[idx].split(' ')[-4:-1])) for idx in range(len(df_col))]
return landmarks_list
def visualize_landmarks(slice_index=70, subject='copd1', split='train'):
'''
Visualize the landmarks on the reference image or a mask.
'''
# Define the paths
landmarks_path = os.path.join(os.getcwd(),f'../dataset/{split}/{subject}/{subject}_300_iBH_xyz_r1.txt')
reference_image_path = os.path.join(os.getcwd(),f'../dataset/{split}/{subject}/{subject}_iBHCT.nii.gz') # _lung
reference_mask_path = os.path.join(os.getcwd(),f'../dataset/{split}/{subject}/{subject}_iBHCT_lung.nii.gz') # _lung
# reference_mask_path = os.path.join(os.getcwd(),f'../dataset/segmentation_trails/mask1/train/{subject}_eBHCT_lung.nii.gz') # _lung
# -1 to match the MATLAB visualizer indexing result
slice_index = slice_index - 1
# Load the reference image
nii_image = nib.load(reference_image_path)
reference_image = nii_image.get_fdata()
# Load the reference mask
nii_mask = nib.load(reference_mask_path)
reference_mask = nii_mask.get_fdata()
# transpose the axis to rotate for visualization
reference_image = reference_image.transpose(2, 1, 0)
reference_mask = reference_mask.transpose(2, 1, 0)
# Load 3D landmarks from the file
landmarks_data = np.loadtxt(landmarks_path, skiprows=2)
slice_landmarsk = np.array([inner_list for inner_list in landmarks_data if inner_list[2] == slice_index+1])
# Create a red-green colormap with opacity
cmap = LinearSegmentedColormap.from_list('red_green', ['red', 'green'], N=256)
# Visualize a specific slice
plt.figure(figsize=(8, 8))
plt.imshow(reference_image[slice_index, :, :], cmap='gray')
# Overlay the mask with color on the landmarks
mask_overlay = np.ma.masked_where(reference_mask[slice_index, :, :] == 0, reference_mask[slice_index, :, :])
plt.imshow(mask_overlay, cmap=cmap, alpha=0.5)
if len(slice_landmarsk) > 0:
# Extract x, y, z coordinates
x_coords = slice_landmarsk[:, 0].astype(int)
y_coords = slice_landmarsk[:, 1].astype(int)
z_coords = slice_landmarsk[:, 2].astype(int)
# Plot landmarks on the current slice
plt.scatter(x_coords, y_coords, c='b', marker='+', label=f'{landmarks_path.split("/")[-1].split(".txt")[0]} landmarks')
plt.legend()
plt.axis('off')
# plt.title(f"Slice {slice_index+1}")
plt.show()