|
a |
|
b/utils/dataset.py |
|
|
1 |
import SimpleITK as sitk |
|
|
2 |
import os |
|
|
3 |
import tempfile |
|
|
4 |
import numpy as np |
|
|
5 |
import matplotlib.pyplot as plt |
|
|
6 |
from scipy.ndimage import binary_closing |
|
|
7 |
from skimage import measure |
|
|
8 |
|
|
|
9 |
def read_raw( |
|
|
10 |
binary_file_name, |
|
|
11 |
image_size, |
|
|
12 |
sitk_pixel_type, |
|
|
13 |
image_spacing=None, |
|
|
14 |
image_origin=None, |
|
|
15 |
big_endian=False, |
|
|
16 |
): |
|
|
17 |
""" |
|
|
18 |
Read a raw binary scalar image. |
|
|
19 |
|
|
|
20 |
Source: https://simpleitk.readthedocs.io/en/master/link_RawImageReading_docs.html |
|
|
21 |
|
|
|
22 |
Parameters |
|
|
23 |
---------- |
|
|
24 |
binary_file_name (str): Raw, binary image file content. |
|
|
25 |
image_size (tuple like): Size of image (e.g. [2048,2048]) |
|
|
26 |
sitk_pixel_type (SimpleITK pixel type: Pixel type of data (e.g. |
|
|
27 |
sitk.sitkUInt16). |
|
|
28 |
image_spacing (tuple like): Optional image spacing, if none given assumed |
|
|
29 |
to be [1]*dim. |
|
|
30 |
image_origin (tuple like): Optional image origin, if none given assumed to |
|
|
31 |
be [0]*dim. |
|
|
32 |
big_endian (bool): Optional byte order indicator, if True big endian, else |
|
|
33 |
little endian. |
|
|
34 |
|
|
|
35 |
Returns |
|
|
36 |
------- |
|
|
37 |
SimpleITK image or None if fails. |
|
|
38 |
""" |
|
|
39 |
|
|
|
40 |
pixel_dict = { |
|
|
41 |
sitk.sitkUInt8: "MET_UCHAR", |
|
|
42 |
sitk.sitkInt8: "MET_CHAR", |
|
|
43 |
sitk.sitkUInt16: "MET_USHORT", |
|
|
44 |
sitk.sitkInt16: "MET_SHORT", |
|
|
45 |
sitk.sitkUInt32: "MET_UINT", |
|
|
46 |
sitk.sitkInt32: "MET_INT", |
|
|
47 |
sitk.sitkUInt64: "MET_ULONG_LONG", |
|
|
48 |
sitk.sitkInt64: "MET_LONG_LONG", |
|
|
49 |
sitk.sitkFloat32: "MET_FLOAT", |
|
|
50 |
sitk.sitkFloat64: "MET_DOUBLE", |
|
|
51 |
} |
|
|
52 |
direction_cosine = [ |
|
|
53 |
"1 0 0 1", |
|
|
54 |
"1 0 0 0 1 0 0 0 1", |
|
|
55 |
"1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1", |
|
|
56 |
] |
|
|
57 |
dim = len(image_size) |
|
|
58 |
header = [ |
|
|
59 |
"ObjectType = Image\n".encode(), |
|
|
60 |
(f"NDims = {dim}\n").encode(), |
|
|
61 |
( |
|
|
62 |
"DimSize = " + " ".join([str(v) for v in image_size]) + "\n" |
|
|
63 |
).encode(), |
|
|
64 |
( |
|
|
65 |
"ElementSpacing = " |
|
|
66 |
+ ( |
|
|
67 |
" ".join([str(v) for v in image_spacing]) |
|
|
68 |
if image_spacing |
|
|
69 |
else " ".join(["1"] * dim) |
|
|
70 |
) |
|
|
71 |
+ "\n" |
|
|
72 |
).encode(), |
|
|
73 |
( |
|
|
74 |
"Offset = " |
|
|
75 |
+ ( |
|
|
76 |
" ".join([str(v) for v in image_origin]) |
|
|
77 |
if image_origin |
|
|
78 |
else " ".join(["0"] * dim) + "\n" |
|
|
79 |
) |
|
|
80 |
).encode(), |
|
|
81 |
("TransformMatrix = " + direction_cosine[dim - 2] + "\n").encode(), |
|
|
82 |
("ElementType = " + pixel_dict[sitk_pixel_type] + "\n").encode(), |
|
|
83 |
"BinaryData = True\n".encode(), |
|
|
84 |
("BinaryDataByteOrderMSB = " + str(big_endian) + "\n").encode(), |
|
|
85 |
# ElementDataFile must be the last entry in the header |
|
|
86 |
( |
|
|
87 |
"ElementDataFile = " + os.path.abspath(binary_file_name) + "\n" |
|
|
88 |
).encode(), |
|
|
89 |
] |
|
|
90 |
fp = tempfile.NamedTemporaryFile(suffix=".mhd", delete=False) |
|
|
91 |
|
|
|
92 |
# print(header) |
|
|
93 |
|
|
|
94 |
# Not using the tempfile with a context manager and auto-delete |
|
|
95 |
# because on windows we can't open the file a second time for ReadImage. |
|
|
96 |
fp.writelines(header) |
|
|
97 |
fp.close() |
|
|
98 |
img = sitk.ReadImage(fp.name) |
|
|
99 |
os.remove(fp.name) |
|
|
100 |
return img |
|
|
101 |
|
|
|
102 |
def convert_raw_to_nifti( |
|
|
103 |
dataset_dir: list, |
|
|
104 |
output_dir: str, |
|
|
105 |
metadata: dict): |
|
|
106 |
''' |
|
|
107 |
Convert the dataset images from raw (.raw) to nifti (nii.gz) format and export it to a |
|
|
108 |
given output directory. |
|
|
109 |
|
|
|
110 |
Args: |
|
|
111 |
dataset_dir (list): list of paths to the dataset directories |
|
|
112 |
output_dir (str): path to the output directory |
|
|
113 |
metadata (dict): dictionary containing the metadata of the dataset |
|
|
114 |
|
|
|
115 |
Returns: |
|
|
116 |
None |
|
|
117 |
''' |
|
|
118 |
|
|
|
119 |
pass |
|
|
120 |
|
|
|
121 |
def create_mask(volume, threshold = 700): |
|
|
122 |
''' |
|
|
123 |
Create a mask from a given volume using a given threshold. |
|
|
124 |
|
|
|
125 |
Args: |
|
|
126 |
volume (numpy array): volume to be masked |
|
|
127 |
threshold (int): threshold to be used for the masking |
|
|
128 |
|
|
|
129 |
Returns: |
|
|
130 |
numpy array: masked volume |
|
|
131 |
''' |
|
|
132 |
return np.where(volume <= threshold, 1, 0) |
|
|
133 |
|
|
|
134 |
def label_regions(mask): |
|
|
135 |
''' |
|
|
136 |
Label connected components in a binary mask. |
|
|
137 |
|
|
|
138 |
Args: |
|
|
139 |
mask (numpy array): Binary mask. |
|
|
140 |
|
|
|
141 |
Returns: |
|
|
142 |
tuple: A tuple containing labeled mask and the number of labels. |
|
|
143 |
''' |
|
|
144 |
labeled_mask, num_labels = measure.label(mask, connectivity=2, return_num=True, background=0) |
|
|
145 |
return labeled_mask, num_labels |
|
|
146 |
|
|
|
147 |
def get_largest_regions(labeled_mask, num_regions=2): |
|
|
148 |
''' |
|
|
149 |
Get the largest connected regions in a labeled mask. |
|
|
150 |
|
|
|
151 |
Args: |
|
|
152 |
labeled_mask (numpy array): Labeled mask. |
|
|
153 |
num_regions (int): Number of largest regions to retrieve. |
|
|
154 |
|
|
|
155 |
Returns: |
|
|
156 |
list: List of region properties for the largest regions. |
|
|
157 |
''' |
|
|
158 |
regions = measure.regionprops(labeled_mask) |
|
|
159 |
regions.sort(key=lambda x: x.area, reverse=True) |
|
|
160 |
regions = regions[:min(num_regions, len(regions))] |
|
|
161 |
# print([regions[i].axis_major_length for i in range(len(regions))]) |
|
|
162 |
# print([regions[i].axis_minor_length for i in range(len(regions))]) |
|
|
163 |
return regions |
|
|
164 |
|
|
|
165 |
def create_masks(labeled_mask, regions): |
|
|
166 |
''' |
|
|
167 |
Create masks for specific regions in a labeled mask. |
|
|
168 |
|
|
|
169 |
Args: |
|
|
170 |
labeled_mask (numpy array): Labeled mask. |
|
|
171 |
regions (list): List of region properties for which masks need to be created. |
|
|
172 |
|
|
|
173 |
Returns: |
|
|
174 |
list: List of masks corresponding to the specified regions. |
|
|
175 |
''' |
|
|
176 |
masks = [labeled_mask == region.label for region in regions] |
|
|
177 |
return masks |
|
|
178 |
|
|
|
179 |
def fill_holes_and_erode(mask, structure=(7, 7, 5)): |
|
|
180 |
''' |
|
|
181 |
Fill holes in a binary mask and perform erosion. |
|
|
182 |
|
|
|
183 |
Args: |
|
|
184 |
mask (numpy array): Binary mask. |
|
|
185 |
dilation_structure (tuple): Dilation structure for binary dilation. |
|
|
186 |
erosion_structure (tuple): Erosion structure for binary erosion. |
|
|
187 |
|
|
|
188 |
Returns: |
|
|
189 |
numpy array: Processed mask after filling holes and erosion. |
|
|
190 |
''' |
|
|
191 |
processed_mask = binary_closing(mask, structure=np.ones(structure)) |
|
|
192 |
|
|
|
193 |
return processed_mask |
|
|
194 |
|
|
|
195 |
def remove_trachea(largest_masks, get_largest_regions, create_masks): |
|
|
196 |
''' |
|
|
197 |
Remove the trachea from a set of largest masks with a shape (Slice, H, W). |
|
|
198 |
|
|
|
199 |
Args: |
|
|
200 |
largest_masks (numpy array): 3D array of largest masks. |
|
|
201 |
get_largest_regions (function): Function to get largest regions. |
|
|
202 |
create_masks (function): Function to create masks. |
|
|
203 |
|
|
|
204 |
Returns: |
|
|
205 |
numpy array: 3D array of masks with trachea removed. |
|
|
206 |
''' |
|
|
207 |
# Find bounding boxes for each region in the 3D mask |
|
|
208 |
labeled_mask_slices = np.array([label_regions(largest_masks[idx, :, :])[0] for idx in range(largest_masks.shape[0])]) |
|
|
209 |
# labeled_mask_slices = np.transpose(labeled_mask_slices, (1, 2, 0)) |
|
|
210 |
|
|
|
211 |
largest_regions_slices = [ |
|
|
212 |
get_largest_regions(labeled_mask_slices[idx, :, :], num_regions=3) |
|
|
213 |
for idx in range(labeled_mask_slices.shape[0]) |
|
|
214 |
] |
|
|
215 |
|
|
|
216 |
largest_regions_masks = [ |
|
|
217 |
# we filter the trachea by checking the difference between the major and minor axis length when there is only 1 region |
|
|
218 |
create_masks(labeled_mask_slices[idx, :, :], region)[0] if (len(region) == 1 and (abs(region[0].axis_major_length - region[0].axis_minor_length) > 30)) |
|
|
219 |
|
|
|
220 |
# this handles the very first few slices with trachea that has a very small difference between the major and minor axis length |
|
|
221 |
else np.zeros_like(labeled_mask_slices[idx, :, :]) if (len(region) == 1 and (abs(region[0].axis_major_length - region[0].axis_minor_length) < 30)) |
|
|
222 |
|
|
|
223 |
# remove the trachea if there are 3 regions, it will be the 3rd region as we sort by area (highest to lowest) |
|
|
224 |
else create_masks(labeled_mask_slices[idx, :, :], region)[0] + create_masks(labeled_mask_slices[idx, :, :], region)[1] if len(region) == 3 |
|
|
225 |
|
|
|
226 |
# when there are only 2 regions, we check the difference in the area (area of the first region has to be atleast 50 more than the second region) to indicate that it is a lung not a trachea |
|
|
227 |
# also check if the minor axis of the second region (trachea) is less than 100 |
|
|
228 |
# this condition happens when both lungs are touching each other as a region, and trachea as another region |
|
|
229 |
else create_masks(labeled_mask_slices[idx, :, :], region)[0] if len(region) == 2 and (getattr(region[0], 'area') - getattr(region[1], 'area') > 50) and (region[1].axis_minor_length < 100) |
|
|
230 |
|
|
|
231 |
# when there are only 2 regions, we combine them. This is after the previous condition is met (when only 2 lungs are detected) |
|
|
232 |
else create_masks(labeled_mask_slices[idx, :, :], region)[0] + create_masks(labeled_mask_slices[idx, :, :], region)[1] if len(region) == 2 |
|
|
233 |
|
|
|
234 |
else np.zeros_like(labeled_mask_slices[idx, :, :]) |
|
|
235 |
for idx, region in enumerate(largest_regions_slices) |
|
|
236 |
] |
|
|
237 |
# largest_regions_masks = np.transpose(largest_regions_masks, (1, 2, 0)) |
|
|
238 |
|
|
|
239 |
return largest_regions_masks |
|
|
240 |
|
|
|
241 |
def segment_lungs_and_remove_trachea(volume, threshold=700, structure=(7, 7, 5), fill_holes_before_trachea_removal=False): |
|
|
242 |
''' |
|
|
243 |
Segment lungs and remove trachea from a given 3D volume with shape (Slice, H, W). Note that this shape is a must for |
|
|
244 |
the internal functions to compute as expected. |
|
|
245 |
|
|
|
246 |
Args: |
|
|
247 |
volume (numpy array): 3D volume shape (slice, H, W). |
|
|
248 |
threshold (int): Threshold for creating the initial mask. |
|
|
249 |
dilation_structure (tuple): Dilation structure for binary dilation. |
|
|
250 |
erosion_structure (tuple): Erosion structure for binary erosion. |
|
|
251 |
|
|
|
252 |
Returns: |
|
|
253 |
initial_mask (numpy array): Initial mask created from the volume. |
|
|
254 |
labeled_mask (numpy array): Labeled mask. |
|
|
255 |
largest_masks (numpy array): 3D array of largest masks. |
|
|
256 |
processed_mask_without_trachea (numpy array): 3D binary array of masks with trachea removed. |
|
|
257 |
''' |
|
|
258 |
# create a mask |
|
|
259 |
initial_mask = create_mask(volume, threshold=threshold) |
|
|
260 |
|
|
|
261 |
# Label connected components |
|
|
262 |
labeled_mask, _ = label_regions(initial_mask) |
|
|
263 |
|
|
|
264 |
# Get the largest three regions (two lungs and trachea) |
|
|
265 |
largest_regions = get_largest_regions(labeled_mask, num_regions=3) |
|
|
266 |
|
|
|
267 |
# Create masks for the largest three regions |
|
|
268 |
largest_masks = create_masks(labeled_mask, largest_regions)[1] |
|
|
269 |
|
|
|
270 |
# fill holes of the largest mask |
|
|
271 |
if fill_holes_before_trachea_removal: |
|
|
272 |
largest_masks = fill_holes_and_erode(largest_masks, structure=tuple([2*x for x in structure])) |
|
|
273 |
|
|
|
274 |
# remove the trachea |
|
|
275 |
largest_masks_without_trachea = remove_trachea(largest_masks, get_largest_regions, create_masks) |
|
|
276 |
|
|
|
277 |
# Exclude the trachea by subtracting it from the processed mask |
|
|
278 |
processed_mask_without_trachea = fill_holes_and_erode(largest_masks_without_trachea, structure=structure) |
|
|
279 |
|
|
|
280 |
return initial_mask, labeled_mask, largest_masks, processed_mask_without_trachea.astype(np.uint8) |
|
|
281 |
|
|
|
282 |
def segment_body(image, threshold=700): |
|
|
283 |
''' |
|
|
284 |
Segment the body from a given 3D volume with shape (Slice, H, W). Note that this shape is a must for |
|
|
285 |
the internal functions to compute as expected. |
|
|
286 |
|
|
|
287 |
Args: |
|
|
288 |
image (numpy array): 3D volume shape (slice, H, W). |
|
|
289 |
threshold (int): Threshold for creating the initial mask. |
|
|
290 |
|
|
|
291 |
Returns: |
|
|
292 |
mask (numpy array): Initial mask created from the volume. |
|
|
293 |
labeled_mask (numpy array): Labeled mask. |
|
|
294 |
largest_masks (numpy array): 3D array of largest masks. |
|
|
295 |
body_segmented (numpy array): 3D binary array of masks with body segmented. |
|
|
296 |
|
|
|
297 |
''' |
|
|
298 |
mask = create_mask(image, threshold=threshold) |
|
|
299 |
labeled_mask, _ = label_regions(mask) |
|
|
300 |
largest_regions = get_largest_regions(labeled_mask, num_regions=3) |
|
|
301 |
largest_masks = create_masks(labeled_mask, largest_regions)[0] |
|
|
302 |
|
|
|
303 |
# to have zeros and ones instead of binary false and true |
|
|
304 |
largest_masks = largest_masks.astype(np.int8) |
|
|
305 |
|
|
|
306 |
body_segmented = np.zeros_like(image) |
|
|
307 |
body_segmented[largest_masks == 0] = image[largest_masks == 0] |
|
|
308 |
|
|
|
309 |
return mask, labeled_mask, largest_masks, body_segmented |
|
|
310 |
|
|
|
311 |
|
|
|
312 |
def display_two_volumes(volume1, volume2, title1, title2, slice=70): |
|
|
313 |
''' |
|
|
314 |
Display two volumes side by side. |
|
|
315 |
|
|
|
316 |
Args: |
|
|
317 |
volume1 (numpy array): first volume to be displayed |
|
|
318 |
volume2 (numpy array): second volume to be displayed |
|
|
319 |
title1 (str): title of the first volume |
|
|
320 |
title2 (str): title of the second volume |
|
|
321 |
slice (int): slice to be displayed |
|
|
322 |
|
|
|
323 |
Returns: |
|
|
324 |
None |
|
|
325 |
''' |
|
|
326 |
plt.figure(figsize=(9, 6)) |
|
|
327 |
|
|
|
328 |
plt.subplot(1, 2, 1) |
|
|
329 |
plt.imshow(volume1[slice, :, :], cmap='gray') |
|
|
330 |
plt.title(title1) |
|
|
331 |
plt.axis('off') |
|
|
332 |
|
|
|
333 |
|
|
|
334 |
plt.subplot(1, 2, 2) |
|
|
335 |
plt.imshow(volume2[slice, :, :], cmap='gray') |
|
|
336 |
plt.title(title2) |
|
|
337 |
plt.axis('off') |
|
|
338 |
|
|
|
339 |
plt.show() |
|
|
340 |
|
|
|
341 |
def display_volumes(*volumes, **titles_and_slices): |
|
|
342 |
''' |
|
|
343 |
Display multiple volumes side by side. |
|
|
344 |
|
|
|
345 |
Args: |
|
|
346 |
volumes (tuple of numpy arrays): volumes to be displayed |
|
|
347 |
titles_and_slices (dict): titles and slices for each volume |
|
|
348 |
|
|
|
349 |
Returns: |
|
|
350 |
None |
|
|
351 |
''' |
|
|
352 |
num_volumes = len(volumes) |
|
|
353 |
|
|
|
354 |
plt.figure(figsize=(6 * num_volumes, 6)) |
|
|
355 |
|
|
|
356 |
for i, volume in enumerate(volumes, start=1): |
|
|
357 |
title = titles_and_slices.get(f'title{i}', f'Title {i}') |
|
|
358 |
slice_val = titles_and_slices.get(f'slice{i}', 70) |
|
|
359 |
|
|
|
360 |
plt.subplot(1, num_volumes, i) |
|
|
361 |
plt.imshow(volume[slice_val, :, :], cmap='gray') #gray |
|
|
362 |
plt.title(title) |
|
|
363 |
plt.axis('off') |
|
|
364 |
|
|
|
365 |
plt.show() |
|
|
366 |
|
|
|
367 |
|
|
|
368 |
def min_max_normalization(image, mask = None, max_value=None): |
|
|
369 |
''' |
|
|
370 |
Perform min-max normalization on a given image. |
|
|
371 |
|
|
|
372 |
Args: |
|
|
373 |
image ('np.array'): Input image to normalize. |
|
|
374 |
mask ('np.array'): Mask to be applied to the image. |
|
|
375 |
max_value ('float'): Maximum value for normalization. |
|
|
376 |
|
|
|
377 |
Returns: |
|
|
378 |
normalized_image ('np.array'): Min-max normalized image. |
|
|
379 |
''' |
|
|
380 |
|
|
|
381 |
if max_value is None: |
|
|
382 |
max_value = np.iinfo(image.dtype).max |
|
|
383 |
print(f"The maximum value for this volume {image.dtype} is: {max_value}") |
|
|
384 |
|
|
|
385 |
print("Using mask for normalization" if mask is not None else "Not using mask for normalization") |
|
|
386 |
|
|
|
387 |
# Ensure the image is a NumPy array for efficient calculations |
|
|
388 |
image = np.array(image) |
|
|
389 |
|
|
|
390 |
# Calculate the minimum and maximum pixel values |
|
|
391 |
min_value = np.min(image[mask == 1]) if mask is not None else np.min(image) |
|
|
392 |
max_actual = np.max(image[mask == 1]) if mask is not None else np.max(image) |
|
|
393 |
|
|
|
394 |
# Perform min-max normalization |
|
|
395 |
normalized_image = (image - min_value) / (max_actual - min_value) * max_value |
|
|
396 |
normalized_image = np.clip(normalized_image, 0, max_value) |
|
|
397 |
|
|
|
398 |
return normalized_image.astype(image.dtype) |