Diff of /monai 0.5.0/utils.py [000000] .. [83198a]

Switch to side-by-side view

--- a
+++ b/monai 0.5.0/utils.py
@@ -0,0 +1,541 @@
+import os
+import re
+import argparse
+import numpy as np
+import random
+import monai
+import time
+# from networks import build_net
+import logging
+import os
+import sys
+import tempfile
+from glob import glob
+from ignite.metrics import Accuracy
+import nibabel as nib
+import torch
+import argparse
+from monai.data import CacheDataset, DataLoader, Dataset
+import SimpleITK as sitk
+from monai.inferers import sliding_window_inference
+from monai.metrics import DiceMetric
+from monai.data import NiftiSaver, create_test_image_3d, list_data_collate
+from collections import OrderedDict
+from monai.handlers import (MeanDice, StatsHandler, ValidationHandler, CheckpointSaver, LrScheduleHandler, CheckpointLoader,
+                         SegmentationSaver, TensorBoardImageHandler, TensorBoardStatsHandler)
+from monai.inferers import SimpleInferer, SlidingWindowInferer
+from monai.utils import set_determinism
+import re
+from monai.data import create_test_image_3d, list_data_collate
+from monai.inferers import sliding_window_inference
+from monai.transforms import (Activationsd,MeanEnsembled, GaussianSmoothd, CropForegroundd, ThresholdIntensityd, Activations,AsDiscrete, LoadImaged, AsChannelFirstd, VoteEnsembled, AsDiscreted, Compose, AddChanneld, Transpose, ConcatItemsd,
+                              ScaleIntensityd, Resized,ToTensord, RandSpatialCropd, Rand3DElasticd, RandAffined, RandGaussianSmoothd, SpatialPadd,
+    Spacingd, Orientationd, RandShiftIntensityd, BorderPadd, RandGaussianNoised, RandAdjustContrastd,NormalizeIntensityd,RandFlipd, KeepLargestConnectedComponent)
+
+from monai.engines import (
+    EnsembleEvaluator,
+    SupervisedEvaluator,
+    SupervisedTrainer
+)
+
+from skimage.measure import label
+def getLargestCC(segmentation):
+    labels = label(segmentation)
+    unique, counts = np.unique(labels, return_counts=True)
+    list_seg=list(zip(unique, counts))[1:] # the 0 label is by default background so take the rest
+    largest=max(list_seg, key=lambda x:x[1])[0]
+    labels_max=(labels == largest).astype(int)
+    return labels_max
+
+
+def Padding(image, reference):
+
+
+    size_new = reference.GetSize()
+
+    output_size = tuple(size_new)
+
+    resampler = sitk.ResampleImageFilter()
+    resampler.SetOutputSpacing(reference.GetSpacing())
+    resampler.SetSize(output_size)
+
+    # resample on label
+    resampler.SetInterpolator(sitk.sitkNearestNeighbor)
+    resampler.SetOutputOrigin(reference.GetOrigin())
+    resampler.SetOutputDirection(reference.GetDirection())
+
+    image = resampler.Execute(image)
+
+    return image
+
+
+def resize(img, new_size, interpolator):
+    # img = sitk.ReadImage(img)
+    dimension = img.GetDimension()
+
+    # Physical image size corresponds to the largest physical size in the training set, or any other arbitrary size.
+    reference_physical_size = np.zeros(dimension)
+
+    reference_physical_size[:] = [(sz - 1) * spc if sz * spc > mx else mx for sz, spc, mx in
+                                  zip(img.GetSize(), img.GetSpacing(), reference_physical_size)]
+
+    # Create the reference image with a zero origin, identity direction cosine matrix and dimension
+    reference_origin = np.zeros(dimension)
+    reference_direction = np.identity(dimension).flatten()
+    reference_size = new_size
+    reference_spacing = [phys_sz / (sz - 1) for sz, phys_sz in zip(reference_size, reference_physical_size)]
+
+    reference_image = sitk.Image(reference_size, img.GetPixelIDValue())
+    reference_image.SetOrigin(reference_origin)
+    reference_image.SetSpacing(reference_spacing)
+    reference_image.SetDirection(reference_direction)
+
+    # Always use the TransformContinuousIndexToPhysicalPoint to compute an indexed point's physical coordinates as
+    # this takes into account size, spacing and direction cosines. For the vast majority of images the direction
+    # cosines are the identity matrix, but when this isn't the case simply multiplying the central index by the
+    # spacing will not yield the correct coordinates resulting in a long debugging session.
+    reference_center = np.array(
+        reference_image.TransformContinuousIndexToPhysicalPoint(np.array(reference_image.GetSize()) / 2.0))
+
+    # Transform which maps from the reference_image to the current img with the translation mapping the image
+    # origins to each other.
+    transform = sitk.AffineTransform(dimension)
+    transform.SetMatrix(img.GetDirection())
+    transform.SetTranslation(np.array(img.GetOrigin()) - reference_origin)
+    # Modify the transformation to align the centers of the original and reference image instead of their origins.
+    centering_transform = sitk.TranslationTransform(dimension)
+    img_center = np.array(img.TransformContinuousIndexToPhysicalPoint(np.array(img.GetSize()) / 2.0))
+    centering_transform.SetOffset(np.array(transform.GetInverse().TransformPoint(img_center) - reference_center))
+
+    # centered_transform = sitk.Transform(transform)
+    # centered_transform.AddTransform(centering_transform)
+
+    centered_transform = sitk.CompositeTransform([transform, centering_transform])
+
+    # Using the linear interpolator as these are intensity images, if there is a need to resample a ground truth
+    # segmentation then the segmentation image should be resampled using the NearestNeighbor interpolator so that
+    # no new labels are introduced.
+
+    return sitk.Resample(img, reference_image, centered_transform, interpolator, 0.0)
+
+
+def resample_sitk_image(sitk_image, spacing=None, interpolator=None, fill_value=0):
+    # https://github.com/SimpleITK/SlicerSimpleFilters/blob/master/SimpleFilters/SimpleFilters.py
+    _SITK_INTERPOLATOR_DICT = {
+        'nearest': sitk.sitkNearestNeighbor,
+        'linear': sitk.sitkLinear,
+        'gaussian': sitk.sitkGaussian,
+        'label_gaussian': sitk.sitkLabelGaussian,
+        'bspline': sitk.sitkBSpline,
+        'hamming_sinc': sitk.sitkHammingWindowedSinc,
+        'cosine_windowed_sinc': sitk.sitkCosineWindowedSinc,
+        'welch_windowed_sinc': sitk.sitkWelchWindowedSinc,
+        'lanczos_windowed_sinc': sitk.sitkLanczosWindowedSinc
+    }
+
+    if isinstance(sitk_image, str):
+        sitk_image = sitk.ReadImage(sitk_image)
+    num_dim = sitk_image.GetDimension()
+
+    if not interpolator:
+        interpolator = 'linear'
+        pixelid = sitk_image.GetPixelIDValue()
+
+        if pixelid not in [1, 2, 4]:
+            raise NotImplementedError(
+                'Set `interpolator` manually, '
+                'can only infer for 8-bit unsigned or 16, 32-bit signed integers')
+        if pixelid == 1:  # 8-bit unsigned int
+            interpolator = 'nearest'
+
+    orig_pixelid = sitk_image.GetPixelIDValue()
+    orig_origin = sitk_image.GetOrigin()
+    orig_direction = sitk_image.GetDirection()
+    orig_spacing = np.array(sitk_image.GetSpacing())
+    orig_size = np.array(sitk_image.GetSize(), dtype=np.int)
+
+    if not spacing:
+        min_spacing = orig_spacing.min()
+        new_spacing = [min_spacing] * num_dim
+    else:
+        new_spacing = [float(s) for s in spacing]
+
+    assert interpolator in _SITK_INTERPOLATOR_DICT.keys(), \
+        '`interpolator` should be one of {}'.format(_SITK_INTERPOLATOR_DICT.keys())
+
+    sitk_interpolator = _SITK_INTERPOLATOR_DICT[interpolator]
+
+    new_size = orig_size * (orig_spacing / new_spacing)
+    new_size = np.ceil(new_size).astype(np.int)  # Image dimensions are in integers
+    new_size = [int(s) for s in new_size]  # SimpleITK expects lists, not ndarrays
+
+    resample_filter = sitk.ResampleImageFilter()
+
+    resample_filter.SetOutputSpacing(new_spacing)
+    resample_filter.SetSize(new_size)
+    resample_filter.SetOutputDirection(orig_direction)
+    resample_filter.SetOutputOrigin(orig_origin)
+    resample_filter.SetTransform(sitk.Transform())
+    resample_filter.SetDefaultPixelValue(orig_pixelid)
+    resample_filter.SetInterpolator(sitk_interpolator)
+    resample_filter.SetDefaultPixelValue(fill_value)
+
+    resampled_sitk_image = resample_filter.Execute(sitk_image)
+
+    return resampled_sitk_image
+
+
+def numericalSort(value):
+    numbers = re.compile(r'(\d+)')
+    parts = numbers.split(value)
+    parts[1::2] = map(int, parts[1::2])
+    return parts
+
+
+def lstFiles(Path):
+
+    images_list = []  # create an empty list, the raw image data files is stored here
+    for dirName, subdirList, fileList in os.walk(Path):
+        for filename in fileList:
+            if ".nii.gz" in filename.lower():
+                images_list.append(os.path.join(dirName, filename))
+            elif ".nii" in filename.lower():
+                images_list.append(os.path.join(dirName, filename))
+            elif ".mhd" in filename.lower():
+                images_list.append(os.path.join(dirName, filename))
+
+    images_list = sorted(images_list, key=numericalSort)
+
+    return images_list
+
+
+def new_state_dict(file_name):
+    state_dict = torch.load(file_name)
+    new_state_dict = OrderedDict()
+    for k, v in state_dict.items():
+        if k[:6] == 'module':
+            name = k[7:]
+            new_state_dict[name] = v
+        else:
+            new_state_dict[k] = v
+    return new_state_dict
+
+
+def new_state_dict_cpu(file_name):
+    state_dict = torch.load(file_name, map_location='cpu')
+    new_state_dict_cpu = OrderedDict()
+    for k, v in state_dict.items():
+        if k[:6] == 'module':
+            name = k[7:]
+            new_state_dict_cpu[name] = v
+        else:
+            new_state_dict_cpu[k] = v
+    return new_state_dict_cpu
+
+
+def from_numpy_to_itk(image_np, image_itk):
+
+    # read image file
+    reader = sitk.ImageFileReader()
+    reader.SetFileName(image_itk)
+    image_itk = reader.Execute()
+
+    image_np = np.transpose(image_np, (2, 1, 0))
+    image = sitk.GetImageFromArray(image_np)
+    image.SetDirection(image_itk.GetDirection())
+    image.SetSpacing(image_itk.GetSpacing())
+    image.SetOrigin(image_itk.GetOrigin())
+    return image
+
+
+# function to keep track of the cropped area and coordinates
+def statistics_crop(image, resolution):
+
+    files = [{"image": image}]
+
+    reader = sitk.ImageFileReader()
+    reader.SetFileName(image)
+    image_itk = reader.Execute()
+    original_resolution = image_itk.GetSpacing()
+
+    # original size
+    transforms = Compose([
+        LoadImaged(keys=['image']),
+        AddChanneld(keys=['image']),
+        ToTensord(keys=['image'])])
+    data = monai.data.Dataset(data=files, transform=transforms)
+    loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available())
+    loader = monai.utils.misc.first(loader)
+    im, = (loader['image'][0])
+    vol = im.numpy()
+    original_shape = vol.shape
+
+    # cropped foreground size
+    transforms = Compose([
+        LoadImaged(keys=['image']),
+        AddChanneld(keys=['image']),
+        CropForegroundd(keys=['image'], source_key='image', start_coord_key='foreground_start_coord',
+                        end_coord_key='foreground_end_coord', ),  # crop CropForeground
+        ToTensord(keys=['image', 'foreground_start_coord', 'foreground_end_coord'])])
+
+    data = monai.data.Dataset(data=files, transform=transforms)
+    loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available())
+    loader = monai.utils.misc.first(loader)
+    im, coord1, coord2 = (loader['image'][0], loader['foreground_start_coord'][0], loader['foreground_end_coord'][0])
+    vol = im[0].numpy()
+    coord1 = coord1.numpy()
+    coord2 = coord2.numpy()
+    crop_shape = vol.shape
+
+    if resolution is not None:
+
+        transforms = Compose([
+            LoadImaged(keys=['image']),
+            AddChanneld(keys=['image']),
+            CropForegroundd(keys=['image'], source_key='image'),  # crop CropForeground
+            Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')),  # resolution
+            ToTensord(keys=['image'])])
+
+        data = monai.data.Dataset(data=files, transform=transforms)
+        loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available())
+        loader = monai.utils.misc.first(loader)
+        im, = (loader['image'][0])
+        vol = im.numpy()
+        resampled_size = vol.shape
+
+    else:
+
+        resampled_size = original_shape
+
+    return original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution
+
+
+def build_net_CT(patch_size,resolution):
+
+    from monai.networks.layers import Norm
+
+    sizes, spacings = patch_size, resolution
+
+    strides, kernels = [], []
+
+    while True:
+        spacing_ratio = [sp / min(spacings) for sp in spacings]
+        stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
+        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
+        if all(s == 1 for s in stride):
+            break
+        sizes = [i / j for i, j in zip(sizes, stride)]
+        spacings = [i * j for i, j in zip(spacings, stride)]
+        kernels.append(kernel)
+        strides.append(stride)
+    strides.insert(0, len(spacings) * [1])
+    kernels.append(len(spacings) * [3])
+
+    # # create Unet
+
+    nn_Unet = monai.networks.nets.DynUNet(
+        spatial_dims=3,
+        in_channels=1,
+        out_channels=1,
+        kernel_size=kernels,
+        strides=strides,
+        upsample_kernel_size=strides[1:],
+        res_block=True,
+    )
+
+    return nn_Unet
+
+
+def crop_window(prostate_contour):
+    # Cut data, restricted to the prostate contours + a pitch per direction per dimension.
+    """
+    nrrd has the following format, assuming to watch the patient from the front:
+    (x, y, z)
+    x: left to right (ascending)
+    y: front to back (ascending)
+    z: bottom to top (ascending)
+    """
+    pitch = 5
+    pattern = np.where(prostate_contour == 1)
+
+    minx = np.min(pattern[0]) - pitch
+    maxx = np.max(pattern[0]) + pitch
+    miny = np.min(pattern[1]) - pitch
+    maxy = np.max(pattern[1]) + pitch
+    minz = np.min(pattern[2]) - pitch
+    maxz = np.max(pattern[2]) + pitch
+
+    if (maxx - minx) % 2 != 0:
+        maxx += 1
+    if (maxy - miny) % 2 != 0:
+        maxy += 1
+    if (maxz - minz) % 2 != 0:
+        maxz += 1
+
+    """
+    Choose all tensors to have size of 64x64x64
+    """
+    limit = 32
+
+    while maxx - minx < limit:
+        maxx += 1
+        minx -= 1
+
+    while maxy - miny < limit:
+        maxy += 1
+        miny -= 1
+
+    while maxz - minz < limit:
+        maxz += 1
+        minz -= 1
+
+    return minx, maxx, miny, maxy, minz, maxz
+
+
+def uniform_img_dimensions(image, label, nearest):
+
+    image_array = sitk.GetArrayFromImage(image)
+    image_array = np.transpose(image_array, axes=(2, 1, 0))  # reshape array from itk z,y,x  to  x,y,z
+    image_shape = image_array.shape
+
+    if nearest is True:
+        label = resample_sitk_image(label, spacing=image.GetSpacing(), interpolator='nearest')
+        res = resize(label,image_shape,sitk.sitkNearestNeighbor)
+        res = (np.rint(sitk.GetArrayFromImage(res)))
+        res = sitk.GetImageFromArray(res.astype('uint8'))
+        # print(res.GetSize())
+
+    else:
+        label = resample_sitk_image(label, spacing=image.GetSpacing(), interpolator='linear')
+        res = resize(label, image_shape, sitk.sitkLinear)
+        res = (np.rint(sitk.GetArrayFromImage(res)))
+        res = sitk.GetImageFromArray(res.astype('float'))
+
+    res.SetDirection(image.GetDirection())
+    res.SetOrigin(image.GetOrigin())
+    res.SetSpacing(image.GetSpacing())
+
+    return image, res
+
+
+def uniform_img_dimensions_internal(image, label, nearest):
+
+    name_label = label
+
+    image = sitk.ReadImage(image)
+    label = sitk.ReadImage(label)
+    image_array = sitk.GetArrayFromImage(image)
+    image_array = np.transpose(image_array, axes=(2, 1, 0))  # reshape array from itk z,y,x  to  x,y,z
+    image_shape = image_array.shape
+
+    if nearest is True:
+        label = resample_sitk_image(label, spacing=image.GetSpacing(), interpolator='nearest')
+        res = resize(label,image_shape,sitk.sitkNearestNeighbor)
+        res = (np.rint(sitk.GetArrayFromImage(res)))
+        res = sitk.GetImageFromArray(res.astype('uint8'))
+        # print(res.GetSize())
+
+    else:
+        label = resample_sitk_image(label, spacing=image.GetSpacing(), interpolator='linear')
+        res = resize(label, image_shape, sitk.sitkLinear)
+        res = (np.rint(sitk.GetArrayFromImage(res)))
+        res = sitk.GetImageFromArray(res.astype('float'))
+
+    res.SetDirection(image.GetDirection())
+    res.SetOrigin(image.GetOrigin())
+    res.SetSpacing(image.GetSpacing())
+
+    sitk.WriteImage(res, name_label)
+
+
+def normalize_PET(image_itk, value):
+
+    # read image file
+    image_np = sitk.GetArrayFromImage(image_itk)
+    image_np = image_np/value
+    image = sitk.GetImageFromArray(image_np)
+    image.SetDirection(image_itk.GetDirection())
+    image.SetSpacing(image_itk.GetSpacing())
+    image.SetOrigin(image_itk.GetOrigin())
+    return image
+
+
+def processing_itk(label_CT, image_PET, label_PET, gluteus, new_resolution, patch_size):
+
+    gluteus = sitk.ReadImage(gluteus)
+    label_CT = sitk.ReadImage(label_CT)
+    image_PET = sitk.ReadImage(image_PET)
+
+    if label_PET is not None:
+        label_PET = sitk.ReadImage(label_PET)
+
+    if new_resolution is not None:
+        image_PET = resample_sitk_image(image_PET, spacing=new_resolution, interpolator='linear')
+
+    label_CT = Padding(label_CT, image_PET)
+    gluteus = Padding(gluteus, image_PET)
+    image_PET, label_CT = uniform_img_dimensions(image_PET, label_CT, True)
+    image_PET, gluteus = uniform_img_dimensions(image_PET, gluteus, True)
+
+    # new part for Pet tumor_background normalization
+
+    gluteos_ROI_array = sitk.GetArrayFromImage(gluteus)
+    gluteos_ROI_index = np.where(gluteos_ROI_array == 1)
+    PET_array = sitk.GetArrayFromImage(image_PET)
+    avg = np.mean(PET_array[gluteos_ROI_index])
+    image_PET = normalize_PET(image_PET, avg)
+
+    # end normalization
+
+    if label_PET is not None:
+        label_PET = Padding(label_PET, image_PET)
+        image_PET, label_PET = uniform_img_dimensions(image_PET, label_PET, True)
+
+    label_CT_array = sitk.GetArrayFromImage(label_CT)
+
+    minx, maxx, miny, maxy, minz, maxz = crop_window(label_CT_array)
+
+    roiFilter = sitk.RegionOfInterestImageFilter()
+    roiFilter.SetSize(patch_size)
+    roiFilter.SetIndex([int(minz), int(miny), int(minx)])
+
+    label_CT = roiFilter.Execute(label_CT)
+    image_PET = roiFilter.Execute(image_PET)
+
+    if label_PET is not None:
+        label_PET = roiFilter.Execute(label_PET)
+    else:
+        label_PET = None
+
+    sitk.WriteImage(label_CT, 'mask_crop.nii')
+    sitk.WriteImage(image_PET, 'result.nii')
+
+    if label_PET is not None:
+
+        sitk.WriteImage(label_PET, 'label_crop.nii')
+
+
+def gaussian2(image):
+
+    resacleFilter = sitk.RescaleIntensityImageFilter()
+    resacleFilter.SetOutputMaximum(255)
+    resacleFilter.SetOutputMinimum(0)
+    image = resacleFilter.Execute(image)  # set intensity 0-255
+
+    gaussianFilter = sitk.SmoothingRecursiveGaussianImageFilter()
+    gaussianFilter.SetSigma(3)
+    image = gaussianFilter.Execute(image)
+
+    resacleFilter = sitk.RescaleIntensityImageFilter()
+    resacleFilter.SetOutputMaximum(1)
+    resacleFilter.SetOutputMinimum(0)
+    image = resacleFilter.Execute(image)  # set intensity 0-255
+
+    thresholdFilter = sitk.BinaryThresholdImageFilter()
+    thresholdFilter.SetLowerThreshold(0.5)
+    thresholdFilter.SetUpperThreshold(2)
+    thresholdFilter.SetInsideValue(1)
+    thresholdFilter.SetOutsideValue(0)
+    image = thresholdFilter.Execute(image)
+
+    return image
\ No newline at end of file