|
a |
|
b/monai 0.5.0/predict_single_image.py |
|
|
1 |
#!/usr/bin/env python2 |
|
|
2 |
# -*- coding: utf-8 -*- |
|
|
3 |
|
|
|
4 |
from utils import * |
|
|
5 |
import argparse |
|
|
6 |
from networks import * |
|
|
7 |
from monai.inferers import sliding_window_inference |
|
|
8 |
from monai.metrics import DiceMetric |
|
|
9 |
from monai.data import NiftiSaver, create_test_image_3d, list_data_collate |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
def segment(image, label, result, weights, resolution, patch_size, gpu_ids): |
|
|
13 |
|
|
|
14 |
logging.basicConfig(stream=sys.stdout, level=logging.INFO) |
|
|
15 |
|
|
|
16 |
if label is not None: |
|
|
17 |
uniform_img_dimensions_internal(image, label, True) |
|
|
18 |
files = [{"image": image, "label": label}] |
|
|
19 |
else: |
|
|
20 |
files = [{"image": image}] |
|
|
21 |
|
|
|
22 |
# original size, size after crop_background, cropped roi coordinates, cropped resampled roi size |
|
|
23 |
original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution = statistics_crop(image, resolution) |
|
|
24 |
|
|
|
25 |
# ------------------------------- |
|
|
26 |
|
|
|
27 |
if label is not None: |
|
|
28 |
if resolution is not None: |
|
|
29 |
|
|
|
30 |
val_transforms = Compose([ |
|
|
31 |
LoadImaged(keys=['image', 'label']), |
|
|
32 |
AddChanneld(keys=['image', 'label']), |
|
|
33 |
ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT |
|
|
34 |
ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), |
|
|
35 |
CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground |
|
|
36 |
|
|
|
37 |
NormalizeIntensityd(keys=['image']), # intensity |
|
|
38 |
ScaleIntensityd(keys=['image']), |
|
|
39 |
Spacingd(keys=['image', 'label'], pixdim=resolution, mode=('bilinear', 'nearest')), # resolution |
|
|
40 |
|
|
|
41 |
SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method= 'end'), |
|
|
42 |
ToTensord(keys=['image', 'label'])]) |
|
|
43 |
else: |
|
|
44 |
|
|
|
45 |
val_transforms = Compose([ |
|
|
46 |
LoadImaged(keys=['image', 'label']), |
|
|
47 |
AddChanneld(keys=['image', 'label']), |
|
|
48 |
ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT |
|
|
49 |
ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), |
|
|
50 |
CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground |
|
|
51 |
|
|
|
52 |
NormalizeIntensityd(keys=['image']), # intensity |
|
|
53 |
ScaleIntensityd(keys=['image']), |
|
|
54 |
|
|
|
55 |
SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch |
|
|
56 |
ToTensord(keys=['image', 'label'])]) |
|
|
57 |
|
|
|
58 |
else: |
|
|
59 |
if resolution is not None: |
|
|
60 |
|
|
|
61 |
val_transforms = Compose([ |
|
|
62 |
LoadImaged(keys=['image']), |
|
|
63 |
AddChanneld(keys=['image']), |
|
|
64 |
ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT |
|
|
65 |
ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), |
|
|
66 |
CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground |
|
|
67 |
|
|
|
68 |
NormalizeIntensityd(keys=['image']), # intensity |
|
|
69 |
ScaleIntensityd(keys=['image']), |
|
|
70 |
Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')), # resolution |
|
|
71 |
|
|
|
72 |
SpatialPadd(keys=['image'], spatial_size=patch_size, method= 'end'), # pad if the image is smaller than patch |
|
|
73 |
ToTensord(keys=['image'])]) |
|
|
74 |
else: |
|
|
75 |
|
|
|
76 |
val_transforms = Compose([ |
|
|
77 |
LoadImaged(keys=['image']), |
|
|
78 |
AddChanneld(keys=['image']), |
|
|
79 |
ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT |
|
|
80 |
ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), |
|
|
81 |
CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground |
|
|
82 |
|
|
|
83 |
NormalizeIntensityd(keys=['image']), # intensity |
|
|
84 |
ScaleIntensityd(keys=['image']), |
|
|
85 |
|
|
|
86 |
SpatialPadd(keys=['image'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch |
|
|
87 |
ToTensord(keys=['image'])]) |
|
|
88 |
|
|
|
89 |
val_ds = monai.data.Dataset(data=files, transform=val_transforms) |
|
|
90 |
val_loader = DataLoader(val_ds, batch_size=1, num_workers=0, collate_fn=list_data_collate, pin_memory=False) |
|
|
91 |
|
|
|
92 |
dice_metric = DiceMetric(include_background=True, reduction="mean") |
|
|
93 |
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) |
|
|
94 |
|
|
|
95 |
if gpu_ids != '-1': |
|
|
96 |
|
|
|
97 |
# try to use all the available GPUs |
|
|
98 |
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ids |
|
|
99 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
100 |
|
|
|
101 |
else: |
|
|
102 |
device = torch.device("cpu") |
|
|
103 |
|
|
|
104 |
net = build_net() |
|
|
105 |
net = net.to(device) |
|
|
106 |
|
|
|
107 |
if gpu_ids == '-1': |
|
|
108 |
|
|
|
109 |
net.load_state_dict(new_state_dict_cpu(weights)) |
|
|
110 |
|
|
|
111 |
else: |
|
|
112 |
|
|
|
113 |
net.load_state_dict(new_state_dict(weights)) |
|
|
114 |
|
|
|
115 |
# define sliding window size and batch size for windows inference |
|
|
116 |
roi_size = patch_size |
|
|
117 |
sw_batch_size = 4 |
|
|
118 |
|
|
|
119 |
net.eval() |
|
|
120 |
with torch.no_grad(): |
|
|
121 |
|
|
|
122 |
if label is None: |
|
|
123 |
for val_data in val_loader: |
|
|
124 |
val_images = val_data["image"].cuda() |
|
|
125 |
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) |
|
|
126 |
val_outputs = post_trans(val_outputs) |
|
|
127 |
# val_outputs = (val_outputs.sigmoid() >= 0.5).float() |
|
|
128 |
|
|
|
129 |
else: |
|
|
130 |
metric_sum = 0.0 |
|
|
131 |
metric_count = 0 |
|
|
132 |
for val_data in val_loader: |
|
|
133 |
val_images, val_labels = val_data["image"].cuda(), val_data["label"].cuda() |
|
|
134 |
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) |
|
|
135 |
val_outputs = post_trans(val_outputs) |
|
|
136 |
value, _ = dice_metric(y_pred=val_outputs, y=val_labels) |
|
|
137 |
metric_count += len(value) |
|
|
138 |
metric_sum += value.item() * len(value) |
|
|
139 |
# val_outputs = (val_outputs.sigmoid() >= 0.5).float() |
|
|
140 |
|
|
|
141 |
metric = metric_sum / metric_count |
|
|
142 |
print("Evaluation Metric (Dice):", metric) |
|
|
143 |
|
|
|
144 |
result_array = val_outputs.squeeze().data.cpu().numpy() |
|
|
145 |
# Remove the pad if the image was smaller than the patch in some directions |
|
|
146 |
result_array = result_array[0:resampled_size[0],0:resampled_size[1],0:resampled_size[2]] |
|
|
147 |
|
|
|
148 |
# resample back to the original resolution |
|
|
149 |
if resolution is not None: |
|
|
150 |
|
|
|
151 |
result_array_np = np.transpose(result_array, (2, 1, 0)) |
|
|
152 |
result_array_temp = sitk.GetImageFromArray(result_array_np) |
|
|
153 |
result_array_temp.SetSpacing(resolution) |
|
|
154 |
|
|
|
155 |
# save temporary label |
|
|
156 |
writer = sitk.ImageFileWriter() |
|
|
157 |
writer.SetFileName('temp_seg.nii') |
|
|
158 |
writer.Execute(result_array_temp) |
|
|
159 |
|
|
|
160 |
files = [{"image": 'temp_seg.nii'}] |
|
|
161 |
|
|
|
162 |
files_transforms = Compose([ |
|
|
163 |
LoadImaged(keys=['image']), |
|
|
164 |
AddChanneld(keys=['image']), |
|
|
165 |
Spacingd(keys=['image'], pixdim=original_resolution, mode=('nearest')), |
|
|
166 |
Resized(keys=['image'], spatial_size=crop_shape, mode=('nearest')), |
|
|
167 |
]) |
|
|
168 |
|
|
|
169 |
files_ds = Dataset(data=files, transform=files_transforms) |
|
|
170 |
files_loader = DataLoader(files_ds, batch_size=1, num_workers=0) |
|
|
171 |
|
|
|
172 |
for files_data in files_loader: |
|
|
173 |
files_images = files_data["image"] |
|
|
174 |
|
|
|
175 |
res = files_images.squeeze().data.numpy() |
|
|
176 |
|
|
|
177 |
result_array = np.rint(res) |
|
|
178 |
|
|
|
179 |
os.remove('./temp_seg.nii') |
|
|
180 |
|
|
|
181 |
# recover the cropped background before saving the image |
|
|
182 |
empty_array = np.zeros(original_shape) |
|
|
183 |
empty_array[coord1[0]:coord2[0],coord1[1]:coord2[1],coord1[2]:coord2[2]] = result_array |
|
|
184 |
|
|
|
185 |
result_seg = from_numpy_to_itk(empty_array, image) |
|
|
186 |
|
|
|
187 |
# save label |
|
|
188 |
writer = sitk.ImageFileWriter() |
|
|
189 |
writer.SetFileName(result) |
|
|
190 |
writer.Execute(result_seg) |
|
|
191 |
print("Saved Result at:", str(result)) |
|
|
192 |
|
|
|
193 |
|
|
|
194 |
if __name__ == "__main__": |
|
|
195 |
|
|
|
196 |
parser = argparse.ArgumentParser() |
|
|
197 |
parser.add_argument("--image", type=str, default='./Data_folder/CT/0.nii', help='source image' ) |
|
|
198 |
parser.add_argument("--label", type=str, default=None, help='source label, if you want to compute dice. None for new case') |
|
|
199 |
parser.add_argument("--result", type=str, default='./Data_folder/test_0.nii', help='path to the .nii result to save') |
|
|
200 |
parser.add_argument("--weights", type=str, default='./best_metric_model.pth', help='network weights to load') |
|
|
201 |
parser.add_argument("--resolution", default=[2.25, 2.25, 3], help='Resolution used in training phase') |
|
|
202 |
parser.add_argument("--patch_size", type=int, nargs=3, default=(160, 160, 32), help="Input dimension for the generator, same of training") |
|
|
203 |
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') |
|
|
204 |
args = parser.parse_args() |
|
|
205 |
|
|
|
206 |
segment(args.image, args.label, args.result, args.weights, args.resolution, args.patch_size, args.gpu_ids) |
|
|
207 |
|
|
|
208 |
|
|
|
209 |
|
|
|
210 |
|
|
|
211 |
|
|
|
212 |
|
|
|
213 |
|
|
|
214 |
|
|
|
215 |
|
|
|
216 |
|
|
|
217 |
|
|
|
218 |
|
|
|
219 |
|