Diff of /utils.py [000000] .. [83198a]

Switch to unified view

a b/utils.py
1
import os
2
import re
3
import argparse
4
import numpy as np
5
import random
6
import monai
7
import time
8
# from networks import build_net
9
import logging
10
import os
11
import sys
12
import tempfile
13
from glob import glob
14
from ignite.metrics import Accuracy
15
import nibabel as nib
16
import torch
17
import argparse
18
from monai.data import CacheDataset, DataLoader, Dataset
19
import SimpleITK as sitk
20
from monai.inferers import sliding_window_inference
21
from monai.metrics import DiceMetric
22
from monai.data import NiftiSaver, create_test_image_3d, list_data_collate
23
from collections import OrderedDict
24
from monai.handlers import (MeanDice, StatsHandler, ValidationHandler, CheckpointSaver, LrScheduleHandler, CheckpointLoader,
25
                         SegmentationSaver, TensorBoardImageHandler, TensorBoardStatsHandler)
26
from monai.inferers import SimpleInferer, SlidingWindowInferer
27
from monai.utils import set_determinism
28
import re
29
from monai.data import create_test_image_3d, list_data_collate
30
from monai.inferers import sliding_window_inference
31
from monai.transforms import (Activationsd,MeanEnsembled, GaussianSmoothd, CropForegroundd, ThresholdIntensityd, Activations,AsDiscrete, LoadImaged, AsChannelFirstd, VoteEnsembled, AsDiscreted, Compose, AddChanneld, Transpose, ConcatItemsd,
32
                              ScaleIntensityd, Resized,ToTensord, RandSpatialCropd, Rand3DElasticd, RandAffined, RandGaussianSmoothd, SpatialPadd,
33
    Spacingd, Orientationd, RandShiftIntensityd, BorderPadd, RandGaussianNoised, RandAdjustContrastd,NormalizeIntensityd,RandFlipd, KeepLargestConnectedComponent)
34
35
from monai.engines import (
36
    EnsembleEvaluator,
37
    SupervisedEvaluator,
38
    SupervisedTrainer
39
)
40
41
from skimage.measure import label
42
def getLargestCC(segmentation):
43
    labels = label(segmentation)
44
    unique, counts = np.unique(labels, return_counts=True)
45
    list_seg=list(zip(unique, counts))[1:] # the 0 label is by default background so take the rest
46
    largest=max(list_seg, key=lambda x:x[1])[0]
47
    labels_max=(labels == largest).astype(int)
48
    return labels_max
49
50
51
def Padding(image, reference):
52
53
54
    size_new = reference.GetSize()
55
56
    output_size = tuple(size_new)
57
58
    resampler = sitk.ResampleImageFilter()
59
    resampler.SetOutputSpacing(reference.GetSpacing())
60
    resampler.SetSize(output_size)
61
62
    # resample on label
63
    resampler.SetInterpolator(sitk.sitkNearestNeighbor)
64
    resampler.SetOutputOrigin(reference.GetOrigin())
65
    resampler.SetOutputDirection(reference.GetDirection())
66
67
    image = resampler.Execute(image)
68
69
    return image
70
71
72
def resize(img, new_size, interpolator):
73
    # img = sitk.ReadImage(img)
74
    dimension = img.GetDimension()
75
76
    # Physical image size corresponds to the largest physical size in the training set, or any other arbitrary size.
77
    reference_physical_size = np.zeros(dimension)
78
79
    reference_physical_size[:] = [(sz - 1) * spc if sz * spc > mx else mx for sz, spc, mx in
80
                                  zip(img.GetSize(), img.GetSpacing(), reference_physical_size)]
81
82
    # Create the reference image with a zero origin, identity direction cosine matrix and dimension
83
    reference_origin = np.zeros(dimension)
84
    reference_direction = np.identity(dimension).flatten()
85
    reference_size = new_size
86
    reference_spacing = [phys_sz / (sz - 1) for sz, phys_sz in zip(reference_size, reference_physical_size)]
87
88
    reference_image = sitk.Image(reference_size, img.GetPixelIDValue())
89
    reference_image.SetOrigin(reference_origin)
90
    reference_image.SetSpacing(reference_spacing)
91
    reference_image.SetDirection(reference_direction)
92
93
    # Always use the TransformContinuousIndexToPhysicalPoint to compute an indexed point's physical coordinates as
94
    # this takes into account size, spacing and direction cosines. For the vast majority of images the direction
95
    # cosines are the identity matrix, but when this isn't the case simply multiplying the central index by the
96
    # spacing will not yield the correct coordinates resulting in a long debugging session.
97
    reference_center = np.array(
98
        reference_image.TransformContinuousIndexToPhysicalPoint(np.array(reference_image.GetSize()) / 2.0))
99
100
    # Transform which maps from the reference_image to the current img with the translation mapping the image
101
    # origins to each other.
102
    transform = sitk.AffineTransform(dimension)
103
    transform.SetMatrix(img.GetDirection())
104
    transform.SetTranslation(np.array(img.GetOrigin()) - reference_origin)
105
    # Modify the transformation to align the centers of the original and reference image instead of their origins.
106
    centering_transform = sitk.TranslationTransform(dimension)
107
    img_center = np.array(img.TransformContinuousIndexToPhysicalPoint(np.array(img.GetSize()) / 2.0))
108
    centering_transform.SetOffset(np.array(transform.GetInverse().TransformPoint(img_center) - reference_center))
109
110
    # centered_transform = sitk.Transform(transform)
111
    # centered_transform.AddTransform(centering_transform)
112
113
    centered_transform = sitk.CompositeTransform([transform, centering_transform])
114
115
    # Using the linear interpolator as these are intensity images, if there is a need to resample a ground truth
116
    # segmentation then the segmentation image should be resampled using the NearestNeighbor interpolator so that
117
    # no new labels are introduced.
118
119
    return sitk.Resample(img, reference_image, centered_transform, interpolator, 0.0)
120
121
122
def resample_sitk_image(sitk_image, spacing=None, interpolator=None, fill_value=0):
123
    # https://github.com/SimpleITK/SlicerSimpleFilters/blob/master/SimpleFilters/SimpleFilters.py
124
    _SITK_INTERPOLATOR_DICT = {
125
        'nearest': sitk.sitkNearestNeighbor,
126
        'linear': sitk.sitkLinear,
127
        'gaussian': sitk.sitkGaussian,
128
        'label_gaussian': sitk.sitkLabelGaussian,
129
        'bspline': sitk.sitkBSpline,
130
        'hamming_sinc': sitk.sitkHammingWindowedSinc,
131
        'cosine_windowed_sinc': sitk.sitkCosineWindowedSinc,
132
        'welch_windowed_sinc': sitk.sitkWelchWindowedSinc,
133
        'lanczos_windowed_sinc': sitk.sitkLanczosWindowedSinc
134
    }
135
136
    if isinstance(sitk_image, str):
137
        sitk_image = sitk.ReadImage(sitk_image)
138
    num_dim = sitk_image.GetDimension()
139
140
    if not interpolator:
141
        interpolator = 'linear'
142
        pixelid = sitk_image.GetPixelIDValue()
143
144
        if pixelid not in [1, 2, 4]:
145
            raise NotImplementedError(
146
                'Set `interpolator` manually, '
147
                'can only infer for 8-bit unsigned or 16, 32-bit signed integers')
148
        if pixelid == 1:  # 8-bit unsigned int
149
            interpolator = 'nearest'
150
151
    orig_pixelid = sitk_image.GetPixelIDValue()
152
    orig_origin = sitk_image.GetOrigin()
153
    orig_direction = sitk_image.GetDirection()
154
    orig_spacing = np.array(sitk_image.GetSpacing())
155
    orig_size = np.array(sitk_image.GetSize(), dtype=np.int)
156
157
    if not spacing:
158
        min_spacing = orig_spacing.min()
159
        new_spacing = [min_spacing] * num_dim
160
    else:
161
        new_spacing = [float(s) for s in spacing]
162
163
    assert interpolator in _SITK_INTERPOLATOR_DICT.keys(), \
164
        '`interpolator` should be one of {}'.format(_SITK_INTERPOLATOR_DICT.keys())
165
166
    sitk_interpolator = _SITK_INTERPOLATOR_DICT[interpolator]
167
168
    new_size = orig_size * (orig_spacing / new_spacing)
169
    new_size = np.ceil(new_size).astype(np.int)  # Image dimensions are in integers
170
    new_size = [int(s) for s in new_size]  # SimpleITK expects lists, not ndarrays
171
172
    resample_filter = sitk.ResampleImageFilter()
173
174
    resample_filter.SetOutputSpacing(new_spacing)
175
    resample_filter.SetSize(new_size)
176
    resample_filter.SetOutputDirection(orig_direction)
177
    resample_filter.SetOutputOrigin(orig_origin)
178
    resample_filter.SetTransform(sitk.Transform())
179
    resample_filter.SetDefaultPixelValue(orig_pixelid)
180
    resample_filter.SetInterpolator(sitk_interpolator)
181
    resample_filter.SetDefaultPixelValue(fill_value)
182
183
    resampled_sitk_image = resample_filter.Execute(sitk_image)
184
185
    return resampled_sitk_image
186
187
188
def numericalSort(value):
189
    numbers = re.compile(r'(\d+)')
190
    parts = numbers.split(value)
191
    parts[1::2] = map(int, parts[1::2])
192
    return parts
193
194
195
def lstFiles(Path):
196
197
    images_list = []  # create an empty list, the raw image data files is stored here
198
    for dirName, subdirList, fileList in os.walk(Path):
199
        for filename in fileList:
200
            if ".nii.gz" in filename.lower():
201
                images_list.append(os.path.join(dirName, filename))
202
            elif ".nii" in filename.lower():
203
                images_list.append(os.path.join(dirName, filename))
204
            elif ".mhd" in filename.lower():
205
                images_list.append(os.path.join(dirName, filename))
206
207
    images_list = sorted(images_list, key=numericalSort)
208
209
    return images_list
210
211
212
def new_state_dict(file_name):
213
    state_dict = torch.load(file_name)
214
    new_state_dict = OrderedDict()
215
    for k, v in state_dict.items():
216
        if k[:6] == 'module':
217
            name = k[7:]
218
            new_state_dict[name] = v
219
        else:
220
            new_state_dict[k] = v
221
    return new_state_dict
222
223
224
def new_state_dict_cpu(file_name):
225
    state_dict = torch.load(file_name, map_location='cpu')
226
    new_state_dict_cpu = OrderedDict()
227
    for k, v in state_dict.items():
228
        if k[:6] == 'module':
229
            name = k[7:]
230
            new_state_dict_cpu[name] = v
231
        else:
232
            new_state_dict_cpu[k] = v
233
    return new_state_dict_cpu
234
235
236
def from_numpy_to_itk(image_np, image_itk):
237
238
    # read image file
239
    reader = sitk.ImageFileReader()
240
    reader.SetFileName(image_itk)
241
    image_itk = reader.Execute()
242
243
    image_np = np.transpose(image_np, (2, 1, 0))
244
    image = sitk.GetImageFromArray(image_np)
245
    image.SetDirection(image_itk.GetDirection())
246
    image.SetSpacing(image_itk.GetSpacing())
247
    image.SetOrigin(image_itk.GetOrigin())
248
    return image
249
250
251
# function to keep track of the cropped area and coordinates
252
def statistics_crop(image, resolution):
253
254
    files = [{"image": image}]
255
256
    reader = sitk.ImageFileReader()
257
    reader.SetFileName(image)
258
    image_itk = reader.Execute()
259
    original_resolution = image_itk.GetSpacing()
260
261
    # original size
262
    transforms = Compose([
263
        LoadImaged(keys=['image']),
264
        AddChanneld(keys=['image']),
265
        ToTensord(keys=['image'])])
266
    data = monai.data.Dataset(data=files, transform=transforms)
267
    loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available())
268
    loader = monai.utils.misc.first(loader)
269
    im, = (loader['image'][0])
270
    vol = im.numpy()
271
    original_shape = vol.shape
272
273
    # cropped foreground size
274
    transforms = Compose([
275
        LoadImaged(keys=['image']),
276
        AddChanneld(keys=['image']),
277
        CropForegroundd(keys=['image'], source_key='image', start_coord_key='foreground_start_coord',
278
                        end_coord_key='foreground_end_coord', ),  # crop CropForeground
279
        ToTensord(keys=['image', 'foreground_start_coord', 'foreground_end_coord'])])
280
281
    data = monai.data.Dataset(data=files, transform=transforms)
282
    loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available())
283
    loader = monai.utils.misc.first(loader)
284
    im, coord1, coord2 = (loader['image'][0], loader['foreground_start_coord'][0], loader['foreground_end_coord'][0])
285
    vol = im[0].numpy()
286
    coord1 = coord1.numpy()
287
    coord2 = coord2.numpy()
288
    crop_shape = vol.shape
289
290
    if resolution is not None:
291
292
        transforms = Compose([
293
            LoadImaged(keys=['image']),
294
            AddChanneld(keys=['image']),
295
            CropForegroundd(keys=['image'], source_key='image'),  # crop CropForeground
296
            Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')),  # resolution
297
            ToTensord(keys=['image'])])
298
299
        data = monai.data.Dataset(data=files, transform=transforms)
300
        loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available())
301
        loader = monai.utils.misc.first(loader)
302
        im, = (loader['image'][0])
303
        vol = im.numpy()
304
        resampled_size = vol.shape
305
306
    else:
307
308
        resampled_size = original_shape
309
310
    return original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution
311
312
313
def build_net_CT(patch_size,resolution):
314
315
    from monai.networks.layers import Norm
316
317
    sizes, spacings = patch_size, resolution
318
319
    strides, kernels = [], []
320
321
    while True:
322
        spacing_ratio = [sp / min(spacings) for sp in spacings]
323
        stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
324
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
325
        if all(s == 1 for s in stride):
326
            break
327
        sizes = [i / j for i, j in zip(sizes, stride)]
328
        spacings = [i * j for i, j in zip(spacings, stride)]
329
        kernels.append(kernel)
330
        strides.append(stride)
331
    strides.insert(0, len(spacings) * [1])
332
    kernels.append(len(spacings) * [3])
333
334
    # # create Unet
335
336
    nn_Unet = monai.networks.nets.DynUNet(
337
        spatial_dims=3,
338
        in_channels=1,
339
        out_channels=1,
340
        kernel_size=kernels,
341
        strides=strides,
342
        upsample_kernel_size=strides[1:],
343
        res_block=True,
344
    )
345
346
    return nn_Unet
347
348
349
def crop_window(prostate_contour):
350
    # Cut data, restricted to the prostate contours + a pitch per direction per dimension.
351
    """
352
    nrrd has the following format, assuming to watch the patient from the front:
353
    (x, y, z)
354
    x: left to right (ascending)
355
    y: front to back (ascending)
356
    z: bottom to top (ascending)
357
    """
358
    pitch = 5
359
    pattern = np.where(prostate_contour == 1)
360
361
    minx = np.min(pattern[0]) - pitch
362
    maxx = np.max(pattern[0]) + pitch
363
    miny = np.min(pattern[1]) - pitch
364
    maxy = np.max(pattern[1]) + pitch
365
    minz = np.min(pattern[2]) - pitch
366
    maxz = np.max(pattern[2]) + pitch
367
368
    if (maxx - minx) % 2 != 0:
369
        maxx += 1
370
    if (maxy - miny) % 2 != 0:
371
        maxy += 1
372
    if (maxz - minz) % 2 != 0:
373
        maxz += 1
374
375
    """
376
    Choose all tensors to have size of 64x64x64
377
    """
378
    limit = 32
379
380
    while maxx - minx < limit:
381
        maxx += 1
382
        minx -= 1
383
384
    while maxy - miny < limit:
385
        maxy += 1
386
        miny -= 1
387
388
    while maxz - minz < limit:
389
        maxz += 1
390
        minz -= 1
391
392
    return minx, maxx, miny, maxy, minz, maxz
393
394
395
def uniform_img_dimensions(image, label, nearest):
396
397
    image_array = sitk.GetArrayFromImage(image)
398
    image_array = np.transpose(image_array, axes=(2, 1, 0))  # reshape array from itk z,y,x  to  x,y,z
399
    image_shape = image_array.shape
400
401
    if nearest is True:
402
        label = resample_sitk_image(label, spacing=image.GetSpacing(), interpolator='nearest')
403
        res = resize(label,image_shape,sitk.sitkNearestNeighbor)
404
        res = (np.rint(sitk.GetArrayFromImage(res)))
405
        res = sitk.GetImageFromArray(res.astype('uint8'))
406
        # print(res.GetSize())
407
408
    else:
409
        label = resample_sitk_image(label, spacing=image.GetSpacing(), interpolator='linear')
410
        res = resize(label, image_shape, sitk.sitkLinear)
411
        res = (np.rint(sitk.GetArrayFromImage(res)))
412
        res = sitk.GetImageFromArray(res.astype('float'))
413
414
    res.SetDirection(image.GetDirection())
415
    res.SetOrigin(image.GetOrigin())
416
    res.SetSpacing(image.GetSpacing())
417
418
    return image, res
419
420
421
def uniform_img_dimensions_internal(image, label, nearest):
422
423
    name_label = label
424
425
    image = sitk.ReadImage(image)
426
    label = sitk.ReadImage(label)
427
    image_array = sitk.GetArrayFromImage(image)
428
    image_array = np.transpose(image_array, axes=(2, 1, 0))  # reshape array from itk z,y,x  to  x,y,z
429
    image_shape = image_array.shape
430
431
    if nearest is True:
432
        label = resample_sitk_image(label, spacing=image.GetSpacing(), interpolator='nearest')
433
        res = resize(label,image_shape,sitk.sitkNearestNeighbor)
434
        res = (np.rint(sitk.GetArrayFromImage(res)))
435
        res = sitk.GetImageFromArray(res.astype('uint8'))
436
        # print(res.GetSize())
437
438
    else:
439
        label = resample_sitk_image(label, spacing=image.GetSpacing(), interpolator='linear')
440
        res = resize(label, image_shape, sitk.sitkLinear)
441
        res = (np.rint(sitk.GetArrayFromImage(res)))
442
        res = sitk.GetImageFromArray(res.astype('float'))
443
444
    res.SetDirection(image.GetDirection())
445
    res.SetOrigin(image.GetOrigin())
446
    res.SetSpacing(image.GetSpacing())
447
448
    sitk.WriteImage(res, name_label)
449
450
451
def normalize_PET(image_itk, value):
452
453
    # read image file
454
    image_np = sitk.GetArrayFromImage(image_itk)
455
    image_np = image_np/value
456
    image = sitk.GetImageFromArray(image_np)
457
    image.SetDirection(image_itk.GetDirection())
458
    image.SetSpacing(image_itk.GetSpacing())
459
    image.SetOrigin(image_itk.GetOrigin())
460
    return image
461
462
463
def processing_itk(label_CT, image_PET, label_PET, gluteus, new_resolution, patch_size):
464
465
    gluteus = sitk.ReadImage(gluteus)
466
    label_CT = sitk.ReadImage(label_CT)
467
    image_PET = sitk.ReadImage(image_PET)
468
469
    if label_PET is not None:
470
        label_PET = sitk.ReadImage(label_PET)
471
472
    if new_resolution is not None:
473
        image_PET = resample_sitk_image(image_PET, spacing=new_resolution, interpolator='linear')
474
475
    label_CT = Padding(label_CT, image_PET)
476
    gluteus = Padding(gluteus, image_PET)
477
    image_PET, label_CT = uniform_img_dimensions(image_PET, label_CT, True)
478
    image_PET, gluteus = uniform_img_dimensions(image_PET, gluteus, True)
479
480
    # new part for Pet tumor_background normalization
481
482
    gluteos_ROI_array = sitk.GetArrayFromImage(gluteus)
483
    gluteos_ROI_index = np.where(gluteos_ROI_array == 1)
484
    PET_array = sitk.GetArrayFromImage(image_PET)
485
    avg = np.mean(PET_array[gluteos_ROI_index])
486
    image_PET = normalize_PET(image_PET, avg)
487
488
    # end normalization
489
490
    if label_PET is not None:
491
        label_PET = Padding(label_PET, image_PET)
492
        image_PET, label_PET = uniform_img_dimensions(image_PET, label_PET, True)
493
494
    label_CT_array = sitk.GetArrayFromImage(label_CT)
495
496
    minx, maxx, miny, maxy, minz, maxz = crop_window(label_CT_array)
497
498
    roiFilter = sitk.RegionOfInterestImageFilter()
499
    roiFilter.SetSize(patch_size)
500
    roiFilter.SetIndex([int(minz), int(miny), int(minx)])
501
502
    label_CT = roiFilter.Execute(label_CT)
503
    image_PET = roiFilter.Execute(image_PET)
504
505
    if label_PET is not None:
506
        label_PET = roiFilter.Execute(label_PET)
507
    else:
508
        label_PET = None
509
510
    sitk.WriteImage(label_CT, 'mask_crop.nii')
511
    sitk.WriteImage(image_PET, 'result.nii')
512
513
    if label_PET is not None:
514
515
        sitk.WriteImage(label_PET, 'label_crop.nii')
516
517
518
def gaussian2(image):
519
520
    resacleFilter = sitk.RescaleIntensityImageFilter()
521
    resacleFilter.SetOutputMaximum(255)
522
    resacleFilter.SetOutputMinimum(0)
523
    image = resacleFilter.Execute(image)  # set intensity 0-255
524
525
    gaussianFilter = sitk.SmoothingRecursiveGaussianImageFilter()
526
    gaussianFilter.SetSigma(3)
527
    image = gaussianFilter.Execute(image)
528
529
    resacleFilter = sitk.RescaleIntensityImageFilter()
530
    resacleFilter.SetOutputMaximum(1)
531
    resacleFilter.SetOutputMinimum(0)
532
    image = resacleFilter.Execute(image)  # set intensity 0-255
533
534
    thresholdFilter = sitk.BinaryThresholdImageFilter()
535
    thresholdFilter.SetLowerThreshold(0.5)
536
    thresholdFilter.SetUpperThreshold(2)
537
    thresholdFilter.SetInsideValue(1)
538
    thresholdFilter.SetOutsideValue(0)
539
    image = thresholdFilter.Execute(image)
540
541
    return image