a b/monai 0.5.0/predict_single_image.py
1
#!/usr/bin/env python2
2
# -*- coding: utf-8 -*-
3
4
from utils import *
5
import argparse
6
from networks import *
7
from monai.inferers import sliding_window_inference
8
from monai.metrics import DiceMetric
9
from monai.data import NiftiSaver, create_test_image_3d, list_data_collate
10
11
12
def segment(image, label, result, weights, resolution, patch_size, gpu_ids):
13
14
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
15
16
    if label is not None:
17
        uniform_img_dimensions_internal(image, label, True)
18
        files = [{"image": image, "label": label}]
19
    else:
20
        files = [{"image": image}]
21
22
    # original size, size after crop_background, cropped roi coordinates, cropped resampled roi size
23
    original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution = statistics_crop(image, resolution)
24
25
    # -------------------------------
26
27
    if label is not None:
28
        if resolution is not None:
29
30
            val_transforms = Compose([
31
                LoadImaged(keys=['image', 'label']),
32
                AddChanneld(keys=['image', 'label']),
33
                ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # Threshold CT
34
                ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
35
                CropForegroundd(keys=['image', 'label'], source_key='image'),  # crop CropForeground
36
37
                NormalizeIntensityd(keys=['image']),  # intensity
38
                ScaleIntensityd(keys=['image']),
39
                Spacingd(keys=['image', 'label'], pixdim=resolution, mode=('bilinear', 'nearest')),  # resolution
40
41
                SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method= 'end'),
42
                ToTensord(keys=['image', 'label'])])
43
        else:
44
45
            val_transforms = Compose([
46
                LoadImaged(keys=['image', 'label']),
47
                AddChanneld(keys=['image', 'label']),
48
                ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # Threshold CT
49
                ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
50
                CropForegroundd(keys=['image', 'label'], source_key='image'),  # crop CropForeground
51
52
                NormalizeIntensityd(keys=['image']),  # intensity
53
                ScaleIntensityd(keys=['image']),
54
55
                SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method='end'),  # pad if the image is smaller than patch
56
                ToTensord(keys=['image', 'label'])])
57
58
    else:
59
        if resolution is not None:
60
61
            val_transforms = Compose([
62
                LoadImaged(keys=['image']),
63
                AddChanneld(keys=['image']),
64
                ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # Threshold CT
65
                ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
66
                CropForegroundd(keys=['image'], source_key='image'),  # crop CropForeground
67
68
                NormalizeIntensityd(keys=['image']),  # intensity
69
                ScaleIntensityd(keys=['image']),
70
                Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')),  # resolution
71
72
                SpatialPadd(keys=['image'], spatial_size=patch_size, method= 'end'),  # pad if the image is smaller than patch
73
                ToTensord(keys=['image'])])
74
        else:
75
76
            val_transforms = Compose([
77
                LoadImaged(keys=['image']),
78
                AddChanneld(keys=['image']),
79
                ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # Threshold CT
80
                ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
81
                CropForegroundd(keys=['image'], source_key='image'),  # crop CropForeground
82
83
                NormalizeIntensityd(keys=['image']),  # intensity
84
                ScaleIntensityd(keys=['image']),
85
86
                SpatialPadd(keys=['image'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch
87
                ToTensord(keys=['image'])])
88
89
    val_ds = monai.data.Dataset(data=files, transform=val_transforms)
90
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=0, collate_fn=list_data_collate, pin_memory=False)
91
92
    dice_metric = DiceMetric(include_background=True, reduction="mean")
93
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
94
95
    if gpu_ids != '-1':
96
97
        # try to use all the available GPUs
98
        os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ids
99
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
100
101
    else:
102
        device = torch.device("cpu")
103
104
    net = build_net()
105
    net = net.to(device)
106
107
    if gpu_ids == '-1':
108
109
        net.load_state_dict(new_state_dict_cpu(weights))
110
111
    else:
112
113
        net.load_state_dict(new_state_dict(weights))
114
115
    # define sliding window size and batch size for windows inference
116
    roi_size = patch_size
117
    sw_batch_size = 4
118
119
    net.eval()
120
    with torch.no_grad():
121
122
        if label is None:
123
            for val_data in val_loader:
124
                val_images = val_data["image"].cuda()
125
                val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
126
                val_outputs = post_trans(val_outputs)
127
                # val_outputs = (val_outputs.sigmoid() >= 0.5).float()
128
129
        else:
130
            metric_sum = 0.0
131
            metric_count = 0
132
            for val_data in val_loader:
133
                val_images, val_labels = val_data["image"].cuda(), val_data["label"].cuda()
134
                val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
135
                val_outputs = post_trans(val_outputs)
136
                value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
137
                metric_count += len(value)
138
                metric_sum += value.item() * len(value)
139
                # val_outputs = (val_outputs.sigmoid() >= 0.5).float()
140
141
            metric = metric_sum / metric_count
142
            print("Evaluation Metric (Dice):", metric)
143
144
        result_array = val_outputs.squeeze().data.cpu().numpy()
145
        # Remove the pad if the image was smaller than the patch in some directions
146
        result_array = result_array[0:resampled_size[0],0:resampled_size[1],0:resampled_size[2]]
147
148
        # resample back to the original resolution
149
        if resolution is not None:
150
151
            result_array_np = np.transpose(result_array, (2, 1, 0))
152
            result_array_temp = sitk.GetImageFromArray(result_array_np)
153
            result_array_temp.SetSpacing(resolution)
154
155
            # save temporary label
156
            writer = sitk.ImageFileWriter()
157
            writer.SetFileName('temp_seg.nii')
158
            writer.Execute(result_array_temp)
159
160
            files = [{"image": 'temp_seg.nii'}]
161
162
            files_transforms = Compose([
163
                LoadImaged(keys=['image']),
164
                AddChanneld(keys=['image']),
165
                Spacingd(keys=['image'], pixdim=original_resolution, mode=('nearest')),
166
                Resized(keys=['image'], spatial_size=crop_shape, mode=('nearest')),
167
            ])
168
169
            files_ds = Dataset(data=files, transform=files_transforms)
170
            files_loader = DataLoader(files_ds, batch_size=1, num_workers=0)
171
172
            for files_data in files_loader:
173
                files_images = files_data["image"]
174
175
                res = files_images.squeeze().data.numpy()
176
177
            result_array = np.rint(res)
178
179
            os.remove('./temp_seg.nii')
180
181
        # recover the cropped background before saving the image
182
        empty_array = np.zeros(original_shape)
183
        empty_array[coord1[0]:coord2[0],coord1[1]:coord2[1],coord1[2]:coord2[2]] = result_array
184
185
        result_seg = from_numpy_to_itk(empty_array, image)
186
187
        # save label
188
        writer = sitk.ImageFileWriter()
189
        writer.SetFileName(result)
190
        writer.Execute(result_seg)
191
        print("Saved Result at:", str(result))
192
193
194
if __name__ == "__main__":
195
196
    parser = argparse.ArgumentParser()
197
    parser.add_argument("--image", type=str, default='./Data_folder/CT/0.nii', help='source image' )
198
    parser.add_argument("--label", type=str, default=None, help='source label, if you want to compute dice. None for new case')
199
    parser.add_argument("--result", type=str, default='./Data_folder/test_0.nii', help='path to the .nii result to save')
200
    parser.add_argument("--weights", type=str, default='./best_metric_model.pth', help='network weights to load')
201
    parser.add_argument("--resolution", default=[2.25, 2.25, 3], help='Resolution used in training phase')
202
    parser.add_argument("--patch_size", type=int, nargs=3, default=(160, 160, 32), help="Input dimension for the generator, same of training")
203
    parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')
204
    args = parser.parse_args()
205
206
    segment(args.image, args.label, args.result, args.weights, args.resolution, args.patch_size, args.gpu_ids)
207
208
209
210
211
212
213
214
215
216
217
218
219