--- a +++ b/predict_single_image.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python2 +# -*- coding: utf-8 -*- + +from utils import * +import argparse +from networks import build_net, build_UNETR +from monai.inferers import sliding_window_inference +from monai.metrics import DiceMetric +from monai.data import NiftiSaver, create_test_image_3d, list_data_collate, decollate_batch +from monai.transforms import (EnsureType, Compose, LoadImaged, AddChanneld, Transpose,Activations,AsDiscrete, RandGaussianSmoothd, CropForegroundd, SpatialPadd, + ScaleIntensityd, ToTensord, RandSpatialCropd, Rand3DElasticd, RandAffined, RandZoomd, + Spacingd, Orientationd, Resized, ThresholdIntensityd, RandShiftIntensityd, BorderPadd, RandGaussianNoised, RandAdjustContrastd,NormalizeIntensityd,RandFlipd) + + +def segment(image, label, result, weights, resolution, patch_size, network, gpu_ids): + + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + + if label is not None: + uniform_img_dimensions_internal(image, label, True) + files = [{"image": image, "label": label}] + else: + files = [{"image": image}] + + # original size, size after crop_background, cropped roi coordinates, cropped resampled roi size + original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution = statistics_crop(image, resolution) + + # ------------------------------- + + if label is not None: + if resolution is not None: + + val_transforms = Compose([ + LoadImaged(keys=['image', 'label']), + AddChanneld(keys=['image', 'label']), + # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT + # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), + CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground + + NormalizeIntensityd(keys=['image']), # intensity + ScaleIntensityd(keys=['image']), + Spacingd(keys=['image', 'label'], pixdim=resolution, mode=('bilinear', 'nearest')), # resolution + + SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method= 'end'), + ToTensord(keys=['image', 'label'])]) + else: + + val_transforms = Compose([ + LoadImaged(keys=['image', 'label']), + AddChanneld(keys=['image', 'label']), + # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT + # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), + CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground + + NormalizeIntensityd(keys=['image']), # intensity + ScaleIntensityd(keys=['image']), + + SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch + ToTensord(keys=['image', 'label'])]) + + else: + if resolution is not None: + + val_transforms = Compose([ + LoadImaged(keys=['image']), + AddChanneld(keys=['image']), + # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT + # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), + CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground + + NormalizeIntensityd(keys=['image']), # intensity + ScaleIntensityd(keys=['image']), + Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')), # resolution + + SpatialPadd(keys=['image'], spatial_size=patch_size, method= 'end'), # pad if the image is smaller than patch + ToTensord(keys=['image'])]) + else: + + val_transforms = Compose([ + LoadImaged(keys=['image']), + AddChanneld(keys=['image']), + # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT + # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), + CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground + + NormalizeIntensityd(keys=['image']), # intensity + ScaleIntensityd(keys=['image']), + + SpatialPadd(keys=['image'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch + ToTensord(keys=['image'])]) + + val_ds = monai.data.Dataset(data=files, transform=val_transforms) + val_loader = DataLoader(val_ds, batch_size=1, num_workers=0, collate_fn=list_data_collate, pin_memory=False) + + dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) + post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) + + if gpu_ids != '-1': + + # try to use all the available GPUs + os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ids + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + else: + device = torch.device("cpu") + + # build the network + if network == 'nnunet': + net = build_net() # nn build_net + elif network == 'unetr': + net = build_UNETR() # UneTR + + net = net.to(device) + + if gpu_ids == '-1': + + net.load_state_dict(new_state_dict_cpu(weights)) + + else: + + net.load_state_dict(new_state_dict(weights)) + + # define sliding window size and batch size for windows inference + roi_size = patch_size + sw_batch_size = 4 + + net.eval() + with torch.no_grad(): + + if label is None: + for val_data in val_loader: + val_images = val_data["image"].to(device) + val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) + val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] + + else: + for val_data in val_loader: + val_images, val_labels = val_data["image"].to(device), val_data["label"].to(device) + val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) + val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] + dice_metric(y_pred=val_outputs, y=val_labels) + + metric = dice_metric.aggregate().item() + print("Evaluation Metric (Dice):", metric) + + result_array = val_outputs[0].squeeze().data.cpu().numpy() + # Remove the pad if the image was smaller than the patch in some directions + result_array = result_array[0:resampled_size[0],0:resampled_size[1],0:resampled_size[2]] + + # resample back to the original resolution + if resolution is not None: + + result_array_np = np.transpose(result_array, (2, 1, 0)) + result_array_temp = sitk.GetImageFromArray(result_array_np) + result_array_temp.SetSpacing(resolution) + + # save temporary label + writer = sitk.ImageFileWriter() + writer.SetFileName('temp_seg.nii') + writer.Execute(result_array_temp) + + files = [{"image": 'temp_seg.nii'}] + + files_transforms = Compose([ + LoadImaged(keys=['image']), + AddChanneld(keys=['image']), + Spacingd(keys=['image'], pixdim=original_resolution, mode=('nearest')), + Resized(keys=['image'], spatial_size=crop_shape, mode=('nearest')), + ]) + + files_ds = Dataset(data=files, transform=files_transforms) + files_loader = DataLoader(files_ds, batch_size=1, num_workers=0) + + for files_data in files_loader: + files_images = files_data["image"] + + res = files_images.squeeze().data.numpy() + + result_array = np.rint(res) + + os.remove('./temp_seg.nii') + + # recover the cropped background before saving the image + empty_array = np.zeros(original_shape) + empty_array[coord1[0]:coord2[0],coord1[1]:coord2[1],coord1[2]:coord2[2]] = result_array + + result_seg = from_numpy_to_itk(empty_array, image) + + # save label + writer = sitk.ImageFileWriter() + writer.SetFileName(result) + writer.Execute(result_seg) + print("Saved Result at:", str(result)) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--image", type=str, default='./Data_folder/T2/3.nii', help='source image' ) + parser.add_argument("--label", type=str, default='./Data_folder/T2_labels/3.nii', help='source label, if you want to compute dice. None for new case') + parser.add_argument("--result", type=str, default='./Data_folder/test_0.nii', help='path to the .nii result to save') + parser.add_argument("--weights", type=str, default='./best_metric_model.pth', help='network weights to load') + parser.add_argument("--resolution", default=[0.7, 0.7, 3], help='Resolution used in training phase') + parser.add_argument("--patch_size", type=int, nargs=3, default=(256, 256, 16), help="Input dimension for the generator, same of training") + parser.add_argument('--network', default='unetr', help='nnunet, unetr') + parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + args = parser.parse_args() + + segment(args.image, args.label, args.result, args.weights, args.resolution, args.patch_size, args.network, args.gpu_ids) + + + + + + + + + + + + +