|
a |
|
b/utils.py |
|
|
1 |
import os |
|
|
2 |
import re |
|
|
3 |
import argparse |
|
|
4 |
import numpy as np |
|
|
5 |
import random |
|
|
6 |
import monai |
|
|
7 |
import time |
|
|
8 |
# from networks import build_net |
|
|
9 |
import logging |
|
|
10 |
import os |
|
|
11 |
import sys |
|
|
12 |
import tempfile |
|
|
13 |
from glob import glob |
|
|
14 |
from ignite.metrics import Accuracy |
|
|
15 |
import nibabel as nib |
|
|
16 |
import torch |
|
|
17 |
import argparse |
|
|
18 |
from monai.data import CacheDataset, DataLoader, Dataset |
|
|
19 |
import SimpleITK as sitk |
|
|
20 |
from monai.inferers import sliding_window_inference |
|
|
21 |
from monai.metrics import DiceMetric |
|
|
22 |
from monai.data import NiftiSaver, create_test_image_3d, list_data_collate |
|
|
23 |
from collections import OrderedDict |
|
|
24 |
from monai.handlers import (MeanDice, StatsHandler, ValidationHandler, CheckpointSaver, LrScheduleHandler, CheckpointLoader, |
|
|
25 |
SegmentationSaver, TensorBoardImageHandler, TensorBoardStatsHandler) |
|
|
26 |
from monai.inferers import SimpleInferer, SlidingWindowInferer |
|
|
27 |
from monai.utils import set_determinism |
|
|
28 |
import re |
|
|
29 |
from monai.data import create_test_image_3d, list_data_collate |
|
|
30 |
from monai.inferers import sliding_window_inference |
|
|
31 |
from monai.transforms import (Activationsd,MeanEnsembled, GaussianSmoothd, CropForegroundd, ThresholdIntensityd, Activations,AsDiscrete, LoadImaged, AsChannelFirstd, VoteEnsembled, AsDiscreted, Compose, AddChanneld, Transpose, ConcatItemsd, |
|
|
32 |
ScaleIntensityd, Resized,ToTensord, RandSpatialCropd, Rand3DElasticd, RandAffined, RandGaussianSmoothd, SpatialPadd, |
|
|
33 |
Spacingd, Orientationd, RandShiftIntensityd, BorderPadd, RandGaussianNoised, RandAdjustContrastd,NormalizeIntensityd,RandFlipd, KeepLargestConnectedComponent) |
|
|
34 |
|
|
|
35 |
from monai.engines import ( |
|
|
36 |
EnsembleEvaluator, |
|
|
37 |
SupervisedEvaluator, |
|
|
38 |
SupervisedTrainer |
|
|
39 |
) |
|
|
40 |
|
|
|
41 |
from skimage.measure import label |
|
|
42 |
def getLargestCC(segmentation): |
|
|
43 |
labels = label(segmentation) |
|
|
44 |
unique, counts = np.unique(labels, return_counts=True) |
|
|
45 |
list_seg=list(zip(unique, counts))[1:] # the 0 label is by default background so take the rest |
|
|
46 |
largest=max(list_seg, key=lambda x:x[1])[0] |
|
|
47 |
labels_max=(labels == largest).astype(int) |
|
|
48 |
return labels_max |
|
|
49 |
|
|
|
50 |
|
|
|
51 |
def Padding(image, reference): |
|
|
52 |
|
|
|
53 |
|
|
|
54 |
size_new = reference.GetSize() |
|
|
55 |
|
|
|
56 |
output_size = tuple(size_new) |
|
|
57 |
|
|
|
58 |
resampler = sitk.ResampleImageFilter() |
|
|
59 |
resampler.SetOutputSpacing(reference.GetSpacing()) |
|
|
60 |
resampler.SetSize(output_size) |
|
|
61 |
|
|
|
62 |
# resample on label |
|
|
63 |
resampler.SetInterpolator(sitk.sitkNearestNeighbor) |
|
|
64 |
resampler.SetOutputOrigin(reference.GetOrigin()) |
|
|
65 |
resampler.SetOutputDirection(reference.GetDirection()) |
|
|
66 |
|
|
|
67 |
image = resampler.Execute(image) |
|
|
68 |
|
|
|
69 |
return image |
|
|
70 |
|
|
|
71 |
|
|
|
72 |
def resize(img, new_size, interpolator): |
|
|
73 |
# img = sitk.ReadImage(img) |
|
|
74 |
dimension = img.GetDimension() |
|
|
75 |
|
|
|
76 |
# Physical image size corresponds to the largest physical size in the training set, or any other arbitrary size. |
|
|
77 |
reference_physical_size = np.zeros(dimension) |
|
|
78 |
|
|
|
79 |
reference_physical_size[:] = [(sz - 1) * spc if sz * spc > mx else mx for sz, spc, mx in |
|
|
80 |
zip(img.GetSize(), img.GetSpacing(), reference_physical_size)] |
|
|
81 |
|
|
|
82 |
# Create the reference image with a zero origin, identity direction cosine matrix and dimension |
|
|
83 |
reference_origin = np.zeros(dimension) |
|
|
84 |
reference_direction = np.identity(dimension).flatten() |
|
|
85 |
reference_size = new_size |
|
|
86 |
reference_spacing = [phys_sz / (sz - 1) for sz, phys_sz in zip(reference_size, reference_physical_size)] |
|
|
87 |
|
|
|
88 |
reference_image = sitk.Image(reference_size, img.GetPixelIDValue()) |
|
|
89 |
reference_image.SetOrigin(reference_origin) |
|
|
90 |
reference_image.SetSpacing(reference_spacing) |
|
|
91 |
reference_image.SetDirection(reference_direction) |
|
|
92 |
|
|
|
93 |
# Always use the TransformContinuousIndexToPhysicalPoint to compute an indexed point's physical coordinates as |
|
|
94 |
# this takes into account size, spacing and direction cosines. For the vast majority of images the direction |
|
|
95 |
# cosines are the identity matrix, but when this isn't the case simply multiplying the central index by the |
|
|
96 |
# spacing will not yield the correct coordinates resulting in a long debugging session. |
|
|
97 |
reference_center = np.array( |
|
|
98 |
reference_image.TransformContinuousIndexToPhysicalPoint(np.array(reference_image.GetSize()) / 2.0)) |
|
|
99 |
|
|
|
100 |
# Transform which maps from the reference_image to the current img with the translation mapping the image |
|
|
101 |
# origins to each other. |
|
|
102 |
transform = sitk.AffineTransform(dimension) |
|
|
103 |
transform.SetMatrix(img.GetDirection()) |
|
|
104 |
transform.SetTranslation(np.array(img.GetOrigin()) - reference_origin) |
|
|
105 |
# Modify the transformation to align the centers of the original and reference image instead of their origins. |
|
|
106 |
centering_transform = sitk.TranslationTransform(dimension) |
|
|
107 |
img_center = np.array(img.TransformContinuousIndexToPhysicalPoint(np.array(img.GetSize()) / 2.0)) |
|
|
108 |
centering_transform.SetOffset(np.array(transform.GetInverse().TransformPoint(img_center) - reference_center)) |
|
|
109 |
|
|
|
110 |
# centered_transform = sitk.Transform(transform) |
|
|
111 |
# centered_transform.AddTransform(centering_transform) |
|
|
112 |
|
|
|
113 |
centered_transform = sitk.CompositeTransform([transform, centering_transform]) |
|
|
114 |
|
|
|
115 |
# Using the linear interpolator as these are intensity images, if there is a need to resample a ground truth |
|
|
116 |
# segmentation then the segmentation image should be resampled using the NearestNeighbor interpolator so that |
|
|
117 |
# no new labels are introduced. |
|
|
118 |
|
|
|
119 |
return sitk.Resample(img, reference_image, centered_transform, interpolator, 0.0) |
|
|
120 |
|
|
|
121 |
|
|
|
122 |
def resample_sitk_image(sitk_image, spacing=None, interpolator=None, fill_value=0): |
|
|
123 |
# https://github.com/SimpleITK/SlicerSimpleFilters/blob/master/SimpleFilters/SimpleFilters.py |
|
|
124 |
_SITK_INTERPOLATOR_DICT = { |
|
|
125 |
'nearest': sitk.sitkNearestNeighbor, |
|
|
126 |
'linear': sitk.sitkLinear, |
|
|
127 |
'gaussian': sitk.sitkGaussian, |
|
|
128 |
'label_gaussian': sitk.sitkLabelGaussian, |
|
|
129 |
'bspline': sitk.sitkBSpline, |
|
|
130 |
'hamming_sinc': sitk.sitkHammingWindowedSinc, |
|
|
131 |
'cosine_windowed_sinc': sitk.sitkCosineWindowedSinc, |
|
|
132 |
'welch_windowed_sinc': sitk.sitkWelchWindowedSinc, |
|
|
133 |
'lanczos_windowed_sinc': sitk.sitkLanczosWindowedSinc |
|
|
134 |
} |
|
|
135 |
|
|
|
136 |
if isinstance(sitk_image, str): |
|
|
137 |
sitk_image = sitk.ReadImage(sitk_image) |
|
|
138 |
num_dim = sitk_image.GetDimension() |
|
|
139 |
|
|
|
140 |
if not interpolator: |
|
|
141 |
interpolator = 'linear' |
|
|
142 |
pixelid = sitk_image.GetPixelIDValue() |
|
|
143 |
|
|
|
144 |
if pixelid not in [1, 2, 4]: |
|
|
145 |
raise NotImplementedError( |
|
|
146 |
'Set `interpolator` manually, ' |
|
|
147 |
'can only infer for 8-bit unsigned or 16, 32-bit signed integers') |
|
|
148 |
if pixelid == 1: # 8-bit unsigned int |
|
|
149 |
interpolator = 'nearest' |
|
|
150 |
|
|
|
151 |
orig_pixelid = sitk_image.GetPixelIDValue() |
|
|
152 |
orig_origin = sitk_image.GetOrigin() |
|
|
153 |
orig_direction = sitk_image.GetDirection() |
|
|
154 |
orig_spacing = np.array(sitk_image.GetSpacing()) |
|
|
155 |
orig_size = np.array(sitk_image.GetSize(), dtype=np.int) |
|
|
156 |
|
|
|
157 |
if not spacing: |
|
|
158 |
min_spacing = orig_spacing.min() |
|
|
159 |
new_spacing = [min_spacing] * num_dim |
|
|
160 |
else: |
|
|
161 |
new_spacing = [float(s) for s in spacing] |
|
|
162 |
|
|
|
163 |
assert interpolator in _SITK_INTERPOLATOR_DICT.keys(), \ |
|
|
164 |
'`interpolator` should be one of {}'.format(_SITK_INTERPOLATOR_DICT.keys()) |
|
|
165 |
|
|
|
166 |
sitk_interpolator = _SITK_INTERPOLATOR_DICT[interpolator] |
|
|
167 |
|
|
|
168 |
new_size = orig_size * (orig_spacing / new_spacing) |
|
|
169 |
new_size = np.ceil(new_size).astype(np.int) # Image dimensions are in integers |
|
|
170 |
new_size = [int(s) for s in new_size] # SimpleITK expects lists, not ndarrays |
|
|
171 |
|
|
|
172 |
resample_filter = sitk.ResampleImageFilter() |
|
|
173 |
|
|
|
174 |
resample_filter.SetOutputSpacing(new_spacing) |
|
|
175 |
resample_filter.SetSize(new_size) |
|
|
176 |
resample_filter.SetOutputDirection(orig_direction) |
|
|
177 |
resample_filter.SetOutputOrigin(orig_origin) |
|
|
178 |
resample_filter.SetTransform(sitk.Transform()) |
|
|
179 |
resample_filter.SetDefaultPixelValue(orig_pixelid) |
|
|
180 |
resample_filter.SetInterpolator(sitk_interpolator) |
|
|
181 |
resample_filter.SetDefaultPixelValue(fill_value) |
|
|
182 |
|
|
|
183 |
resampled_sitk_image = resample_filter.Execute(sitk_image) |
|
|
184 |
|
|
|
185 |
return resampled_sitk_image |
|
|
186 |
|
|
|
187 |
|
|
|
188 |
def numericalSort(value): |
|
|
189 |
numbers = re.compile(r'(\d+)') |
|
|
190 |
parts = numbers.split(value) |
|
|
191 |
parts[1::2] = map(int, parts[1::2]) |
|
|
192 |
return parts |
|
|
193 |
|
|
|
194 |
|
|
|
195 |
def lstFiles(Path): |
|
|
196 |
|
|
|
197 |
images_list = [] # create an empty list, the raw image data files is stored here |
|
|
198 |
for dirName, subdirList, fileList in os.walk(Path): |
|
|
199 |
for filename in fileList: |
|
|
200 |
if ".nii.gz" in filename.lower(): |
|
|
201 |
images_list.append(os.path.join(dirName, filename)) |
|
|
202 |
elif ".nii" in filename.lower(): |
|
|
203 |
images_list.append(os.path.join(dirName, filename)) |
|
|
204 |
elif ".mhd" in filename.lower(): |
|
|
205 |
images_list.append(os.path.join(dirName, filename)) |
|
|
206 |
|
|
|
207 |
images_list = sorted(images_list, key=numericalSort) |
|
|
208 |
|
|
|
209 |
return images_list |
|
|
210 |
|
|
|
211 |
|
|
|
212 |
def new_state_dict(file_name): |
|
|
213 |
state_dict = torch.load(file_name) |
|
|
214 |
new_state_dict = OrderedDict() |
|
|
215 |
for k, v in state_dict.items(): |
|
|
216 |
if k[:6] == 'module': |
|
|
217 |
name = k[7:] |
|
|
218 |
new_state_dict[name] = v |
|
|
219 |
else: |
|
|
220 |
new_state_dict[k] = v |
|
|
221 |
return new_state_dict |
|
|
222 |
|
|
|
223 |
|
|
|
224 |
def new_state_dict_cpu(file_name): |
|
|
225 |
state_dict = torch.load(file_name, map_location='cpu') |
|
|
226 |
new_state_dict_cpu = OrderedDict() |
|
|
227 |
for k, v in state_dict.items(): |
|
|
228 |
if k[:6] == 'module': |
|
|
229 |
name = k[7:] |
|
|
230 |
new_state_dict_cpu[name] = v |
|
|
231 |
else: |
|
|
232 |
new_state_dict_cpu[k] = v |
|
|
233 |
return new_state_dict_cpu |
|
|
234 |
|
|
|
235 |
|
|
|
236 |
def from_numpy_to_itk(image_np, image_itk): |
|
|
237 |
|
|
|
238 |
# read image file |
|
|
239 |
reader = sitk.ImageFileReader() |
|
|
240 |
reader.SetFileName(image_itk) |
|
|
241 |
image_itk = reader.Execute() |
|
|
242 |
|
|
|
243 |
image_np = np.transpose(image_np, (2, 1, 0)) |
|
|
244 |
image = sitk.GetImageFromArray(image_np) |
|
|
245 |
image.SetDirection(image_itk.GetDirection()) |
|
|
246 |
image.SetSpacing(image_itk.GetSpacing()) |
|
|
247 |
image.SetOrigin(image_itk.GetOrigin()) |
|
|
248 |
return image |
|
|
249 |
|
|
|
250 |
|
|
|
251 |
# function to keep track of the cropped area and coordinates |
|
|
252 |
def statistics_crop(image, resolution): |
|
|
253 |
|
|
|
254 |
files = [{"image": image}] |
|
|
255 |
|
|
|
256 |
reader = sitk.ImageFileReader() |
|
|
257 |
reader.SetFileName(image) |
|
|
258 |
image_itk = reader.Execute() |
|
|
259 |
original_resolution = image_itk.GetSpacing() |
|
|
260 |
|
|
|
261 |
# original size |
|
|
262 |
transforms = Compose([ |
|
|
263 |
LoadImaged(keys=['image']), |
|
|
264 |
AddChanneld(keys=['image']), |
|
|
265 |
ToTensord(keys=['image'])]) |
|
|
266 |
data = monai.data.Dataset(data=files, transform=transforms) |
|
|
267 |
loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available()) |
|
|
268 |
loader = monai.utils.misc.first(loader) |
|
|
269 |
im, = (loader['image'][0]) |
|
|
270 |
vol = im.numpy() |
|
|
271 |
original_shape = vol.shape |
|
|
272 |
|
|
|
273 |
# cropped foreground size |
|
|
274 |
transforms = Compose([ |
|
|
275 |
LoadImaged(keys=['image']), |
|
|
276 |
AddChanneld(keys=['image']), |
|
|
277 |
CropForegroundd(keys=['image'], source_key='image', start_coord_key='foreground_start_coord', |
|
|
278 |
end_coord_key='foreground_end_coord', ), # crop CropForeground |
|
|
279 |
ToTensord(keys=['image', 'foreground_start_coord', 'foreground_end_coord'])]) |
|
|
280 |
|
|
|
281 |
data = monai.data.Dataset(data=files, transform=transforms) |
|
|
282 |
loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available()) |
|
|
283 |
loader = monai.utils.misc.first(loader) |
|
|
284 |
im, coord1, coord2 = (loader['image'][0], loader['foreground_start_coord'][0], loader['foreground_end_coord'][0]) |
|
|
285 |
vol = im[0].numpy() |
|
|
286 |
coord1 = coord1.numpy() |
|
|
287 |
coord2 = coord2.numpy() |
|
|
288 |
crop_shape = vol.shape |
|
|
289 |
|
|
|
290 |
if resolution is not None: |
|
|
291 |
|
|
|
292 |
transforms = Compose([ |
|
|
293 |
LoadImaged(keys=['image']), |
|
|
294 |
AddChanneld(keys=['image']), |
|
|
295 |
CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground |
|
|
296 |
Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')), # resolution |
|
|
297 |
ToTensord(keys=['image'])]) |
|
|
298 |
|
|
|
299 |
data = monai.data.Dataset(data=files, transform=transforms) |
|
|
300 |
loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available()) |
|
|
301 |
loader = monai.utils.misc.first(loader) |
|
|
302 |
im, = (loader['image'][0]) |
|
|
303 |
vol = im.numpy() |
|
|
304 |
resampled_size = vol.shape |
|
|
305 |
|
|
|
306 |
else: |
|
|
307 |
|
|
|
308 |
resampled_size = original_shape |
|
|
309 |
|
|
|
310 |
return original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution |
|
|
311 |
|
|
|
312 |
|
|
|
313 |
def build_net_CT(patch_size,resolution): |
|
|
314 |
|
|
|
315 |
from monai.networks.layers import Norm |
|
|
316 |
|
|
|
317 |
sizes, spacings = patch_size, resolution |
|
|
318 |
|
|
|
319 |
strides, kernels = [], [] |
|
|
320 |
|
|
|
321 |
while True: |
|
|
322 |
spacing_ratio = [sp / min(spacings) for sp in spacings] |
|
|
323 |
stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)] |
|
|
324 |
kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] |
|
|
325 |
if all(s == 1 for s in stride): |
|
|
326 |
break |
|
|
327 |
sizes = [i / j for i, j in zip(sizes, stride)] |
|
|
328 |
spacings = [i * j for i, j in zip(spacings, stride)] |
|
|
329 |
kernels.append(kernel) |
|
|
330 |
strides.append(stride) |
|
|
331 |
strides.insert(0, len(spacings) * [1]) |
|
|
332 |
kernels.append(len(spacings) * [3]) |
|
|
333 |
|
|
|
334 |
# # create Unet |
|
|
335 |
|
|
|
336 |
nn_Unet = monai.networks.nets.DynUNet( |
|
|
337 |
spatial_dims=3, |
|
|
338 |
in_channels=1, |
|
|
339 |
out_channels=1, |
|
|
340 |
kernel_size=kernels, |
|
|
341 |
strides=strides, |
|
|
342 |
upsample_kernel_size=strides[1:], |
|
|
343 |
res_block=True, |
|
|
344 |
) |
|
|
345 |
|
|
|
346 |
return nn_Unet |
|
|
347 |
|
|
|
348 |
|
|
|
349 |
def crop_window(prostate_contour): |
|
|
350 |
# Cut data, restricted to the prostate contours + a pitch per direction per dimension. |
|
|
351 |
""" |
|
|
352 |
nrrd has the following format, assuming to watch the patient from the front: |
|
|
353 |
(x, y, z) |
|
|
354 |
x: left to right (ascending) |
|
|
355 |
y: front to back (ascending) |
|
|
356 |
z: bottom to top (ascending) |
|
|
357 |
""" |
|
|
358 |
pitch = 5 |
|
|
359 |
pattern = np.where(prostate_contour == 1) |
|
|
360 |
|
|
|
361 |
minx = np.min(pattern[0]) - pitch |
|
|
362 |
maxx = np.max(pattern[0]) + pitch |
|
|
363 |
miny = np.min(pattern[1]) - pitch |
|
|
364 |
maxy = np.max(pattern[1]) + pitch |
|
|
365 |
minz = np.min(pattern[2]) - pitch |
|
|
366 |
maxz = np.max(pattern[2]) + pitch |
|
|
367 |
|
|
|
368 |
if (maxx - minx) % 2 != 0: |
|
|
369 |
maxx += 1 |
|
|
370 |
if (maxy - miny) % 2 != 0: |
|
|
371 |
maxy += 1 |
|
|
372 |
if (maxz - minz) % 2 != 0: |
|
|
373 |
maxz += 1 |
|
|
374 |
|
|
|
375 |
""" |
|
|
376 |
Choose all tensors to have size of 64x64x64 |
|
|
377 |
""" |
|
|
378 |
limit = 32 |
|
|
379 |
|
|
|
380 |
while maxx - minx < limit: |
|
|
381 |
maxx += 1 |
|
|
382 |
minx -= 1 |
|
|
383 |
|
|
|
384 |
while maxy - miny < limit: |
|
|
385 |
maxy += 1 |
|
|
386 |
miny -= 1 |
|
|
387 |
|
|
|
388 |
while maxz - minz < limit: |
|
|
389 |
maxz += 1 |
|
|
390 |
minz -= 1 |
|
|
391 |
|
|
|
392 |
return minx, maxx, miny, maxy, minz, maxz |
|
|
393 |
|
|
|
394 |
|
|
|
395 |
def uniform_img_dimensions(image, label, nearest): |
|
|
396 |
|
|
|
397 |
image_array = sitk.GetArrayFromImage(image) |
|
|
398 |
image_array = np.transpose(image_array, axes=(2, 1, 0)) # reshape array from itk z,y,x to x,y,z |
|
|
399 |
image_shape = image_array.shape |
|
|
400 |
|
|
|
401 |
if nearest is True: |
|
|
402 |
label = resample_sitk_image(label, spacing=image.GetSpacing(), interpolator='nearest') |
|
|
403 |
res = resize(label,image_shape,sitk.sitkNearestNeighbor) |
|
|
404 |
res = (np.rint(sitk.GetArrayFromImage(res))) |
|
|
405 |
res = sitk.GetImageFromArray(res.astype('uint8')) |
|
|
406 |
# print(res.GetSize()) |
|
|
407 |
|
|
|
408 |
else: |
|
|
409 |
label = resample_sitk_image(label, spacing=image.GetSpacing(), interpolator='linear') |
|
|
410 |
res = resize(label, image_shape, sitk.sitkLinear) |
|
|
411 |
res = (np.rint(sitk.GetArrayFromImage(res))) |
|
|
412 |
res = sitk.GetImageFromArray(res.astype('float')) |
|
|
413 |
|
|
|
414 |
res.SetDirection(image.GetDirection()) |
|
|
415 |
res.SetOrigin(image.GetOrigin()) |
|
|
416 |
res.SetSpacing(image.GetSpacing()) |
|
|
417 |
|
|
|
418 |
return image, res |
|
|
419 |
|
|
|
420 |
|
|
|
421 |
def uniform_img_dimensions_internal(image, label, nearest): |
|
|
422 |
|
|
|
423 |
name_label = label |
|
|
424 |
|
|
|
425 |
image = sitk.ReadImage(image) |
|
|
426 |
label = sitk.ReadImage(label) |
|
|
427 |
image_array = sitk.GetArrayFromImage(image) |
|
|
428 |
image_array = np.transpose(image_array, axes=(2, 1, 0)) # reshape array from itk z,y,x to x,y,z |
|
|
429 |
image_shape = image_array.shape |
|
|
430 |
|
|
|
431 |
if nearest is True: |
|
|
432 |
label = resample_sitk_image(label, spacing=image.GetSpacing(), interpolator='nearest') |
|
|
433 |
res = resize(label,image_shape,sitk.sitkNearestNeighbor) |
|
|
434 |
res = (np.rint(sitk.GetArrayFromImage(res))) |
|
|
435 |
res = sitk.GetImageFromArray(res.astype('uint8')) |
|
|
436 |
# print(res.GetSize()) |
|
|
437 |
|
|
|
438 |
else: |
|
|
439 |
label = resample_sitk_image(label, spacing=image.GetSpacing(), interpolator='linear') |
|
|
440 |
res = resize(label, image_shape, sitk.sitkLinear) |
|
|
441 |
res = (np.rint(sitk.GetArrayFromImage(res))) |
|
|
442 |
res = sitk.GetImageFromArray(res.astype('float')) |
|
|
443 |
|
|
|
444 |
res.SetDirection(image.GetDirection()) |
|
|
445 |
res.SetOrigin(image.GetOrigin()) |
|
|
446 |
res.SetSpacing(image.GetSpacing()) |
|
|
447 |
|
|
|
448 |
sitk.WriteImage(res, name_label) |
|
|
449 |
|
|
|
450 |
|
|
|
451 |
def normalize_PET(image_itk, value): |
|
|
452 |
|
|
|
453 |
# read image file |
|
|
454 |
image_np = sitk.GetArrayFromImage(image_itk) |
|
|
455 |
image_np = image_np/value |
|
|
456 |
image = sitk.GetImageFromArray(image_np) |
|
|
457 |
image.SetDirection(image_itk.GetDirection()) |
|
|
458 |
image.SetSpacing(image_itk.GetSpacing()) |
|
|
459 |
image.SetOrigin(image_itk.GetOrigin()) |
|
|
460 |
return image |
|
|
461 |
|
|
|
462 |
|
|
|
463 |
def processing_itk(label_CT, image_PET, label_PET, gluteus, new_resolution, patch_size): |
|
|
464 |
|
|
|
465 |
gluteus = sitk.ReadImage(gluteus) |
|
|
466 |
label_CT = sitk.ReadImage(label_CT) |
|
|
467 |
image_PET = sitk.ReadImage(image_PET) |
|
|
468 |
|
|
|
469 |
if label_PET is not None: |
|
|
470 |
label_PET = sitk.ReadImage(label_PET) |
|
|
471 |
|
|
|
472 |
if new_resolution is not None: |
|
|
473 |
image_PET = resample_sitk_image(image_PET, spacing=new_resolution, interpolator='linear') |
|
|
474 |
|
|
|
475 |
label_CT = Padding(label_CT, image_PET) |
|
|
476 |
gluteus = Padding(gluteus, image_PET) |
|
|
477 |
image_PET, label_CT = uniform_img_dimensions(image_PET, label_CT, True) |
|
|
478 |
image_PET, gluteus = uniform_img_dimensions(image_PET, gluteus, True) |
|
|
479 |
|
|
|
480 |
# new part for Pet tumor_background normalization |
|
|
481 |
|
|
|
482 |
gluteos_ROI_array = sitk.GetArrayFromImage(gluteus) |
|
|
483 |
gluteos_ROI_index = np.where(gluteos_ROI_array == 1) |
|
|
484 |
PET_array = sitk.GetArrayFromImage(image_PET) |
|
|
485 |
avg = np.mean(PET_array[gluteos_ROI_index]) |
|
|
486 |
image_PET = normalize_PET(image_PET, avg) |
|
|
487 |
|
|
|
488 |
# end normalization |
|
|
489 |
|
|
|
490 |
if label_PET is not None: |
|
|
491 |
label_PET = Padding(label_PET, image_PET) |
|
|
492 |
image_PET, label_PET = uniform_img_dimensions(image_PET, label_PET, True) |
|
|
493 |
|
|
|
494 |
label_CT_array = sitk.GetArrayFromImage(label_CT) |
|
|
495 |
|
|
|
496 |
minx, maxx, miny, maxy, minz, maxz = crop_window(label_CT_array) |
|
|
497 |
|
|
|
498 |
roiFilter = sitk.RegionOfInterestImageFilter() |
|
|
499 |
roiFilter.SetSize(patch_size) |
|
|
500 |
roiFilter.SetIndex([int(minz), int(miny), int(minx)]) |
|
|
501 |
|
|
|
502 |
label_CT = roiFilter.Execute(label_CT) |
|
|
503 |
image_PET = roiFilter.Execute(image_PET) |
|
|
504 |
|
|
|
505 |
if label_PET is not None: |
|
|
506 |
label_PET = roiFilter.Execute(label_PET) |
|
|
507 |
else: |
|
|
508 |
label_PET = None |
|
|
509 |
|
|
|
510 |
sitk.WriteImage(label_CT, 'mask_crop.nii') |
|
|
511 |
sitk.WriteImage(image_PET, 'result.nii') |
|
|
512 |
|
|
|
513 |
if label_PET is not None: |
|
|
514 |
|
|
|
515 |
sitk.WriteImage(label_PET, 'label_crop.nii') |
|
|
516 |
|
|
|
517 |
|
|
|
518 |
def gaussian2(image): |
|
|
519 |
|
|
|
520 |
resacleFilter = sitk.RescaleIntensityImageFilter() |
|
|
521 |
resacleFilter.SetOutputMaximum(255) |
|
|
522 |
resacleFilter.SetOutputMinimum(0) |
|
|
523 |
image = resacleFilter.Execute(image) # set intensity 0-255 |
|
|
524 |
|
|
|
525 |
gaussianFilter = sitk.SmoothingRecursiveGaussianImageFilter() |
|
|
526 |
gaussianFilter.SetSigma(3) |
|
|
527 |
image = gaussianFilter.Execute(image) |
|
|
528 |
|
|
|
529 |
resacleFilter = sitk.RescaleIntensityImageFilter() |
|
|
530 |
resacleFilter.SetOutputMaximum(1) |
|
|
531 |
resacleFilter.SetOutputMinimum(0) |
|
|
532 |
image = resacleFilter.Execute(image) # set intensity 0-255 |
|
|
533 |
|
|
|
534 |
thresholdFilter = sitk.BinaryThresholdImageFilter() |
|
|
535 |
thresholdFilter.SetLowerThreshold(0.5) |
|
|
536 |
thresholdFilter.SetUpperThreshold(2) |
|
|
537 |
thresholdFilter.SetInsideValue(1) |
|
|
538 |
thresholdFilter.SetOutsideValue(0) |
|
|
539 |
image = thresholdFilter.Execute(image) |
|
|
540 |
|
|
|
541 |
return image |