#!/usr/bin/env python2
# -*- coding: utf-8 -*-
from utils import *
import argparse
from networks import *
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.data import NiftiSaver, create_test_image_3d, list_data_collate
def segment(image, label, result, weights, resolution, patch_size, gpu_ids):
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
if label is not None:
uniform_img_dimensions_internal(image, label, True)
files = [{"image": image, "label": label}]
else:
files = [{"image": image}]
# original size, size after crop_background, cropped roi coordinates, cropped resampled roi size
original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution = statistics_crop(image, resolution)
# -------------------------------
if label is not None:
if resolution is not None:
val_transforms = Compose([
LoadImaged(keys=['image', 'label']),
AddChanneld(keys=['image', 'label']),
ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT
ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground
NormalizeIntensityd(keys=['image']), # intensity
ScaleIntensityd(keys=['image']),
Spacingd(keys=['image', 'label'], pixdim=resolution, mode=('bilinear', 'nearest')), # resolution
SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method= 'end'),
ToTensord(keys=['image', 'label'])])
else:
val_transforms = Compose([
LoadImaged(keys=['image', 'label']),
AddChanneld(keys=['image', 'label']),
ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT
ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground
NormalizeIntensityd(keys=['image']), # intensity
ScaleIntensityd(keys=['image']),
SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch
ToTensord(keys=['image', 'label'])])
else:
if resolution is not None:
val_transforms = Compose([
LoadImaged(keys=['image']),
AddChanneld(keys=['image']),
ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT
ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground
NormalizeIntensityd(keys=['image']), # intensity
ScaleIntensityd(keys=['image']),
Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')), # resolution
SpatialPadd(keys=['image'], spatial_size=patch_size, method= 'end'), # pad if the image is smaller than patch
ToTensord(keys=['image'])])
else:
val_transforms = Compose([
LoadImaged(keys=['image']),
AddChanneld(keys=['image']),
ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT
ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground
NormalizeIntensityd(keys=['image']), # intensity
ScaleIntensityd(keys=['image']),
SpatialPadd(keys=['image'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch
ToTensord(keys=['image'])])
val_ds = monai.data.Dataset(data=files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=0, collate_fn=list_data_collate, pin_memory=False)
dice_metric = DiceMetric(include_background=True, reduction="mean")
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
if gpu_ids != '-1':
# try to use all the available GPUs
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ids
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device("cpu")
net = build_net()
net = net.to(device)
if gpu_ids == '-1':
net.load_state_dict(new_state_dict_cpu(weights))
else:
net.load_state_dict(new_state_dict(weights))
# define sliding window size and batch size for windows inference
roi_size = patch_size
sw_batch_size = 4
net.eval()
with torch.no_grad():
if label is None:
for val_data in val_loader:
val_images = val_data["image"].cuda()
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
val_outputs = post_trans(val_outputs)
# val_outputs = (val_outputs.sigmoid() >= 0.5).float()
else:
metric_sum = 0.0
metric_count = 0
for val_data in val_loader:
val_images, val_labels = val_data["image"].cuda(), val_data["label"].cuda()
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
val_outputs = post_trans(val_outputs)
value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
metric_count += len(value)
metric_sum += value.item() * len(value)
# val_outputs = (val_outputs.sigmoid() >= 0.5).float()
metric = metric_sum / metric_count
print("Evaluation Metric (Dice):", metric)
result_array = val_outputs.squeeze().data.cpu().numpy()
# Remove the pad if the image was smaller than the patch in some directions
result_array = result_array[0:resampled_size[0],0:resampled_size[1],0:resampled_size[2]]
# resample back to the original resolution
if resolution is not None:
result_array_np = np.transpose(result_array, (2, 1, 0))
result_array_temp = sitk.GetImageFromArray(result_array_np)
result_array_temp.SetSpacing(resolution)
# save temporary label
writer = sitk.ImageFileWriter()
writer.SetFileName('temp_seg.nii')
writer.Execute(result_array_temp)
files = [{"image": 'temp_seg.nii'}]
files_transforms = Compose([
LoadImaged(keys=['image']),
AddChanneld(keys=['image']),
Spacingd(keys=['image'], pixdim=original_resolution, mode=('nearest')),
Resized(keys=['image'], spatial_size=crop_shape, mode=('nearest')),
])
files_ds = Dataset(data=files, transform=files_transforms)
files_loader = DataLoader(files_ds, batch_size=1, num_workers=0)
for files_data in files_loader:
files_images = files_data["image"]
res = files_images.squeeze().data.numpy()
result_array = np.rint(res)
os.remove('./temp_seg.nii')
# recover the cropped background before saving the image
empty_array = np.zeros(original_shape)
empty_array[coord1[0]:coord2[0],coord1[1]:coord2[1],coord1[2]:coord2[2]] = result_array
result_seg = from_numpy_to_itk(empty_array, image)
# save label
writer = sitk.ImageFileWriter()
writer.SetFileName(result)
writer.Execute(result_seg)
print("Saved Result at:", str(result))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--image", type=str, default='./Data_folder/CT/0.nii', help='source image' )
parser.add_argument("--label", type=str, default=None, help='source label, if you want to compute dice. None for new case')
parser.add_argument("--result", type=str, default='./Data_folder/test_0.nii', help='path to the .nii result to save')
parser.add_argument("--weights", type=str, default='./best_metric_model.pth', help='network weights to load')
parser.add_argument("--resolution", default=[2.25, 2.25, 3], help='Resolution used in training phase')
parser.add_argument("--patch_size", type=int, nargs=3, default=(160, 160, 32), help="Input dimension for the generator, same of training")
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
args = parser.parse_args()
segment(args.image, args.label, args.result, args.weights, args.resolution, args.patch_size, args.gpu_ids)