|
a |
|
b/fetal_net/prediction.py |
|
|
1 |
import itertools |
|
|
2 |
import os |
|
|
3 |
|
|
|
4 |
import nibabel as nib |
|
|
5 |
import numpy as np |
|
|
6 |
import tables |
|
|
7 |
from keras import Model |
|
|
8 |
from scipy import ndimage |
|
|
9 |
from tqdm import tqdm |
|
|
10 |
|
|
|
11 |
from fetal.utils import get_last_model_path |
|
|
12 |
from fetal_net.utils.threaded_generator import ThreadedGenerator |
|
|
13 |
from fetal_net.utils.utils import get_image, list_load, pickle_load |
|
|
14 |
from .augment import permute_data, generate_permutation_keys, reverse_permute_data, contrast_augment |
|
|
15 |
from .training import load_old_model |
|
|
16 |
from .utils.patches import get_patch_from_3d_data |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
def flip_it(data_, axes): |
|
|
20 |
for ax in axes: |
|
|
21 |
data_ = np.flip(data_, ax) |
|
|
22 |
return data_ |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
def predict_augment(data, model, overlap_factor, patch_shape, num_augments=32): |
|
|
26 |
data_max = data.max() |
|
|
27 |
data_min = data.min() |
|
|
28 |
data = data.squeeze() |
|
|
29 |
|
|
|
30 |
order = 2 |
|
|
31 |
predictions = [] |
|
|
32 |
for _ in range(num_augments): |
|
|
33 |
# pixel-wise augmentations |
|
|
34 |
val_range = data_max - data_min |
|
|
35 |
contrast_min_val = data_min + 0.10 * np.random.uniform(-1, 1) * val_range |
|
|
36 |
contrast_max_val = data_max + 0.10 * np.random.uniform(-1, 1) * val_range |
|
|
37 |
curr_data = contrast_augment(data, contrast_min_val, contrast_max_val) |
|
|
38 |
|
|
|
39 |
# spatial augmentations |
|
|
40 |
rotate_factor = np.random.uniform(-30, 30) |
|
|
41 |
to_flip = np.arange(0, 3)[np.random.choice([True, False], size=3)] |
|
|
42 |
to_transpose = np.random.choice([True, False]) |
|
|
43 |
|
|
|
44 |
curr_data = flip_it(curr_data, to_flip) |
|
|
45 |
|
|
|
46 |
if to_transpose: |
|
|
47 |
curr_data = curr_data.transpose([1, 0, 2]) |
|
|
48 |
|
|
|
49 |
curr_data = ndimage.rotate(curr_data, rotate_factor, order=order, reshape=False) |
|
|
50 |
|
|
|
51 |
curr_prediction = patch_wise_prediction(model=model, data=curr_data[np.newaxis, ...], overlap_factor=overlap_factor, patch_shape=patch_shape).squeeze() |
|
|
52 |
|
|
|
53 |
curr_prediction = ndimage.rotate(curr_prediction, -rotate_factor) |
|
|
54 |
|
|
|
55 |
if to_transpose: |
|
|
56 |
curr_prediction = curr_prediction.transpose([1, 0, 2]) |
|
|
57 |
|
|
|
58 |
curr_prediction = flip_it(curr_prediction, to_flip) |
|
|
59 |
predictions += [curr_prediction.squeeze()] |
|
|
60 |
|
|
|
61 |
res = np.stack(predictions, axis=0) |
|
|
62 |
return res |
|
|
63 |
|
|
|
64 |
|
|
|
65 |
def predict_flips(data, model, overlap_factor, config): |
|
|
66 |
def powerset(iterable): |
|
|
67 |
"powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)" |
|
|
68 |
s = list(iterable) |
|
|
69 |
return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(0, len(s) + 1)) |
|
|
70 |
|
|
|
71 |
def predict_it(data_, axes=()): |
|
|
72 |
data_ = flip_it(data_, axes) |
|
|
73 |
curr_pred = \ |
|
|
74 |
patch_wise_prediction(model=model, |
|
|
75 |
data=np.expand_dims(data_.squeeze(), 0), |
|
|
76 |
overlap_factor=overlap_factor, |
|
|
77 |
patch_shape=config["patch_shape"] + [config["patch_depth"]]).squeeze() |
|
|
78 |
curr_pred = flip_it(curr_pred, axes) |
|
|
79 |
return curr_pred |
|
|
80 |
|
|
|
81 |
predictions = [] |
|
|
82 |
for axes in powerset([0, 1, 2]): |
|
|
83 |
predictions += [predict_it(data, axes).squeeze()] |
|
|
84 |
|
|
|
85 |
return predictions |
|
|
86 |
|
|
|
87 |
|
|
|
88 |
def get_set_of_patch_indices_full(start, stop, step): |
|
|
89 |
indices = [] |
|
|
90 |
for start_i, stop_i, step_i in zip(start, stop, step): |
|
|
91 |
indices_i = list(range(start_i, stop_i + 1, step_i)) |
|
|
92 |
if stop_i % step_i > 0: |
|
|
93 |
indices_i += [stop_i] |
|
|
94 |
indices += [indices_i] |
|
|
95 |
return np.array(list(itertools.product(*indices))) |
|
|
96 |
|
|
|
97 |
|
|
|
98 |
def batch_iterator(indices, batch_size, data_0, patch_shape, truth_0, prev_truth_index, truth_patch_shape): |
|
|
99 |
i = 0 |
|
|
100 |
while i < len(indices): |
|
|
101 |
batch = [] |
|
|
102 |
curr_indices = [] |
|
|
103 |
while len(batch) < batch_size and i < len(indices): |
|
|
104 |
curr_index = indices[i] |
|
|
105 |
patch = get_patch_from_3d_data(data_0, patch_shape=patch_shape, patch_index=curr_index) |
|
|
106 |
if truth_0 is not None: |
|
|
107 |
truth_index = list(curr_index[:2]) + [curr_index[2] + prev_truth_index] |
|
|
108 |
truth_patch = get_patch_from_3d_data(truth_0, patch_shape=truth_patch_shape, |
|
|
109 |
patch_index=truth_index) |
|
|
110 |
patch = np.concatenate([patch, truth_patch], axis=-1) |
|
|
111 |
batch.append(patch) |
|
|
112 |
curr_indices.append(curr_index) |
|
|
113 |
i += 1 |
|
|
114 |
yield [batch, curr_indices] |
|
|
115 |
# print('Finished! {}-{}'.format(i, len(indices))) |
|
|
116 |
|
|
|
117 |
|
|
|
118 |
def patch_wise_prediction(model: Model, data, patch_shape, overlap_factor=0, batch_size=5, |
|
|
119 |
permute=False, truth_data=None, prev_truth_index=None, prev_truth_size=None): |
|
|
120 |
""" |
|
|
121 |
:param truth_data: |
|
|
122 |
:param permute: |
|
|
123 |
:param overlap_factor: |
|
|
124 |
:param batch_size: |
|
|
125 |
:param model: |
|
|
126 |
:param data: |
|
|
127 |
:return: |
|
|
128 |
""" |
|
|
129 |
is3d = np.sum(np.array(model.output_shape[1:]) > 1) > 2 |
|
|
130 |
|
|
|
131 |
if is3d: |
|
|
132 |
prediction_shape = model.output_shape[-3:] |
|
|
133 |
else: |
|
|
134 |
prediction_shape = model.output_shape[-3:-1] + (1,) # patch_shape[-3:-1] #[64,64]# |
|
|
135 |
min_overlap = np.subtract(patch_shape, prediction_shape) |
|
|
136 |
max_overlap = np.subtract(patch_shape, (1, 1, 1)) |
|
|
137 |
overlap = min_overlap + (overlap_factor * (max_overlap - min_overlap)).astype(np.int) |
|
|
138 |
data_0 = np.pad(data[0], |
|
|
139 |
[(np.ceil(_ / 2).astype(int), np.floor(_ / 2).astype(int)) for _ in |
|
|
140 |
np.subtract(patch_shape, prediction_shape)], |
|
|
141 |
mode='constant', constant_values=np.percentile(data[0], q=1)) |
|
|
142 |
pad_for_fit = [(np.ceil(_ / 2).astype(int), np.floor(_ / 2).astype(int)) for _ in |
|
|
143 |
np.maximum(np.subtract(patch_shape, data_0.shape), 0)] |
|
|
144 |
data_0 = np.pad(data_0, |
|
|
145 |
[_ for _ in pad_for_fit], |
|
|
146 |
'constant', constant_values=np.percentile(data_0, q=1)) |
|
|
147 |
|
|
|
148 |
if truth_data is not None: |
|
|
149 |
truth_0 = np.pad(truth_data[0], |
|
|
150 |
[(np.ceil(_ / 2).astype(int), np.floor(_ / 2).astype(int)) for _ in |
|
|
151 |
np.subtract(patch_shape, prediction_shape)], |
|
|
152 |
mode='constant', constant_values=0) |
|
|
153 |
truth_0 = np.pad(truth_0, [_ for _ in pad_for_fit], |
|
|
154 |
'constant', constant_values=0) |
|
|
155 |
|
|
|
156 |
truth_patch_shape = list(patch_shape[:2]) + [prev_truth_size] |
|
|
157 |
else: |
|
|
158 |
truth_0 = None |
|
|
159 |
truth_patch_shape = None |
|
|
160 |
|
|
|
161 |
indices = get_set_of_patch_indices_full((0, 0, 0), |
|
|
162 |
np.subtract(data_0.shape, patch_shape), |
|
|
163 |
np.subtract(patch_shape, overlap)) |
|
|
164 |
|
|
|
165 |
b_iter = batch_iterator(indices, batch_size, data_0, patch_shape, |
|
|
166 |
truth_0, prev_truth_index, truth_patch_shape) |
|
|
167 |
tb_iter = iter(ThreadedGenerator(b_iter, queue_maxsize=50)) |
|
|
168 |
|
|
|
169 |
data_shape = list(data.shape[-3:] + np.sum(pad_for_fit, -1)) |
|
|
170 |
if is3d: |
|
|
171 |
data_shape += [model.output_shape[1]] |
|
|
172 |
else: |
|
|
173 |
data_shape += [model.output_shape[-1]] |
|
|
174 |
predicted_output = np.zeros(data_shape) |
|
|
175 |
predicted_count = np.zeros(data_shape, dtype=np.int16) |
|
|
176 |
with tqdm(total=len(indices)) as pbar: |
|
|
177 |
for [curr_batch, batch_indices] in tb_iter: |
|
|
178 |
curr_batch = np.asarray(curr_batch) |
|
|
179 |
if is3d: |
|
|
180 |
curr_batch = np.expand_dims(curr_batch, 1) |
|
|
181 |
prediction = predict(model, curr_batch, permute=permute) |
|
|
182 |
|
|
|
183 |
if is3d: |
|
|
184 |
prediction = prediction.transpose([0, 2, 3, 4, 1]) |
|
|
185 |
else: |
|
|
186 |
prediction = np.expand_dims(prediction, -2) |
|
|
187 |
|
|
|
188 |
for predicted_patch, predicted_index in zip(prediction, batch_indices): |
|
|
189 |
# predictions.append(predicted_patch) |
|
|
190 |
x, y, z = predicted_index |
|
|
191 |
x_len, y_len, z_len = predicted_patch.shape[:-1] |
|
|
192 |
predicted_output[x:x + x_len, y:y + y_len, z:z + z_len, :] += predicted_patch |
|
|
193 |
predicted_count[x:x + x_len, y:y + y_len, z:z + z_len] += 1 |
|
|
194 |
pbar.update(batch_size) |
|
|
195 |
|
|
|
196 |
assert np.all(predicted_count > 0), 'Found zeros in count' |
|
|
197 |
|
|
|
198 |
if np.sum(pad_for_fit) > 0: |
|
|
199 |
# must be a better way :\ |
|
|
200 |
x_pad, y_pad, z_pad = [[None if p2[0] == 0 else p2[0], |
|
|
201 |
None if p2[1] == 0 else -p2[1]] for p2 in pad_for_fit] |
|
|
202 |
predicted_count = predicted_count[x_pad[0]: x_pad[1], |
|
|
203 |
y_pad[0]: y_pad[1], |
|
|
204 |
z_pad[0]: z_pad[1]] |
|
|
205 |
predicted_output = predicted_output[x_pad[0]: x_pad[1], |
|
|
206 |
y_pad[0]: y_pad[1], |
|
|
207 |
z_pad[0]: z_pad[1]] |
|
|
208 |
|
|
|
209 |
assert np.array_equal(predicted_count.shape[:-1], data[0].shape), 'prediction shape wrong' |
|
|
210 |
return predicted_output / predicted_count |
|
|
211 |
# return reconstruct_from_patches(predictions, patch_indices=indices, data_shape=data_shape) |
|
|
212 |
|
|
|
213 |
|
|
|
214 |
def get_prediction_labels(prediction, threshold=0.5, labels=None): |
|
|
215 |
n_samples = prediction.shape[0] |
|
|
216 |
label_arrays = [] |
|
|
217 |
for sample_number in range(n_samples): |
|
|
218 |
label_data = np.argmax(prediction[sample_number], axis=1) |
|
|
219 |
label_data[np.max(prediction[sample_number], axis=0) < threshold] = 0 |
|
|
220 |
if labels: |
|
|
221 |
for value in np.unique(label_data).tolist()[1:]: |
|
|
222 |
label_data[label_data == value] = labels[value - 1] |
|
|
223 |
label_arrays.append(np.array(label_data, dtype=np.uint8)) |
|
|
224 |
return label_arrays |
|
|
225 |
|
|
|
226 |
|
|
|
227 |
def get_test_indices(testing_file): |
|
|
228 |
return pickle_load(testing_file) |
|
|
229 |
|
|
|
230 |
|
|
|
231 |
def predict_from_data_file(model, open_data_file, index): |
|
|
232 |
return model.predict(open_data_file.root.data[index]) |
|
|
233 |
|
|
|
234 |
|
|
|
235 |
def predict_and_get_image(model, data, affine): |
|
|
236 |
return nib.Nifti1Image(model.predict(data)[0, 0], affine) |
|
|
237 |
|
|
|
238 |
|
|
|
239 |
def predict_from_data_file_and_get_image(model, open_data_file, index): |
|
|
240 |
return predict_and_get_image(model, open_data_file.root.data[index], open_data_file.root.affine) |
|
|
241 |
|
|
|
242 |
|
|
|
243 |
def predict_from_data_file_and_write_image(model, open_data_file, index, out_file): |
|
|
244 |
image = predict_from_data_file_and_get_image(model, open_data_file, index) |
|
|
245 |
image.to_filename(out_file) |
|
|
246 |
|
|
|
247 |
|
|
|
248 |
def prediction_to_image(prediction, label_map=False, threshold=0.5, labels=None): |
|
|
249 |
if prediction.shape[0] == 1: |
|
|
250 |
data = prediction[0] |
|
|
251 |
if label_map: |
|
|
252 |
label_map_data = np.zeros(prediction[0, 0].shape, np.int8) |
|
|
253 |
if labels: |
|
|
254 |
label = labels[0] |
|
|
255 |
else: |
|
|
256 |
label = 1 |
|
|
257 |
label_map_data[data > threshold] = label |
|
|
258 |
data = label_map_data |
|
|
259 |
elif prediction.shape[1] > 1: |
|
|
260 |
if label_map: |
|
|
261 |
label_map_data = get_prediction_labels(prediction, threshold=threshold, labels=labels) |
|
|
262 |
data = label_map_data[0] |
|
|
263 |
else: |
|
|
264 |
return multi_class_prediction(prediction) |
|
|
265 |
else: |
|
|
266 |
raise RuntimeError("Invalid prediction array shape: {0}".format(prediction.shape)) |
|
|
267 |
return get_image(data) |
|
|
268 |
|
|
|
269 |
|
|
|
270 |
def multi_class_prediction(prediction, affine): |
|
|
271 |
prediction_images = [] |
|
|
272 |
for i in range(prediction.shape[1]): |
|
|
273 |
prediction_images.append(get_image(prediction[0, i])) |
|
|
274 |
return prediction_images |
|
|
275 |
|
|
|
276 |
|
|
|
277 |
def run_validation_case(data_index, output_dir, model, data_file, training_modalities, patch_shape, |
|
|
278 |
overlap_factor=0, permute=False, prev_truth_index=None, prev_truth_size=None, |
|
|
279 |
use_augmentations=False): |
|
|
280 |
""" |
|
|
281 |
Runs a test case and writes predicted images to file. |
|
|
282 |
:param data_index: Index from of the list of test cases to get an image prediction from. |
|
|
283 |
:param output_dir: Where to write prediction images. |
|
|
284 |
:param output_label_map: If True, will write out a single image with one or more labels. Otherwise outputs |
|
|
285 |
the (sigmoid) prediction values from the model. |
|
|
286 |
:param threshold: If output_label_map is set to True, this threshold defines the value above which is |
|
|
287 |
considered a positive result and will be assigned a label. |
|
|
288 |
:param labels: |
|
|
289 |
:param training_modalities: |
|
|
290 |
:param data_file: |
|
|
291 |
:param model: |
|
|
292 |
""" |
|
|
293 |
if not os.path.exists(output_dir): |
|
|
294 |
os.makedirs(output_dir) |
|
|
295 |
|
|
|
296 |
test_data = np.asarray([data_file.root.data[data_index]]) |
|
|
297 |
if prev_truth_index is not None: |
|
|
298 |
test_truth_data = np.asarray([data_file.root.truth[data_index]]) |
|
|
299 |
else: |
|
|
300 |
test_truth_data = None |
|
|
301 |
|
|
|
302 |
for i, modality in enumerate(training_modalities): |
|
|
303 |
image = get_image(test_data[i]) |
|
|
304 |
image.to_filename(os.path.join(output_dir, "data_{0}.nii.gz".format(modality))) |
|
|
305 |
|
|
|
306 |
test_truth = get_image(data_file.root.truth[data_index]) |
|
|
307 |
test_truth.to_filename(os.path.join(output_dir, "truth.nii.gz")) |
|
|
308 |
|
|
|
309 |
if patch_shape == test_data.shape[-3:]: |
|
|
310 |
prediction = predict(model, test_data, permute=permute) |
|
|
311 |
else: |
|
|
312 |
if use_augmentations: |
|
|
313 |
prediction = predict_augment(data=test_data, model=model, overlap_factor=overlap_factor, |
|
|
314 |
patch_shape=patch_shape) |
|
|
315 |
else: |
|
|
316 |
prediction = \ |
|
|
317 |
patch_wise_prediction(model=model, data=test_data, overlap_factor=overlap_factor, |
|
|
318 |
patch_shape=patch_shape, |
|
|
319 |
truth_data=test_truth_data, prev_truth_index=prev_truth_index, |
|
|
320 |
prev_truth_size=prev_truth_size)[np.newaxis] |
|
|
321 |
|
|
|
322 |
prediction = prediction.squeeze() |
|
|
323 |
prediction_image = get_image(prediction) |
|
|
324 |
if isinstance(prediction_image, list): |
|
|
325 |
for i, image in enumerate(prediction_image): |
|
|
326 |
image.to_filename(os.path.join(output_dir, "prediction_{0}.nii.gz".format(i + 1))) |
|
|
327 |
else: |
|
|
328 |
filename = os.path.join(output_dir, "prediction.nii.gz") |
|
|
329 |
prediction_image.to_filename(filename) |
|
|
330 |
return filename |
|
|
331 |
|
|
|
332 |
|
|
|
333 |
def run_validation_cases(validation_keys_file, model_file, training_modalities, hdf5_file, patch_shape, |
|
|
334 |
output_dir=".", overlap_factor=0, permute=False, |
|
|
335 |
prev_truth_index=None, prev_truth_size=None, use_augmentations=False): |
|
|
336 |
file_names = [] |
|
|
337 |
validation_indices = pickle_load(validation_keys_file) |
|
|
338 |
model = load_old_model(get_last_model_path(model_file)) |
|
|
339 |
data_file = tables.open_file(hdf5_file, "r") |
|
|
340 |
for index in validation_indices: |
|
|
341 |
if 'subject_ids' in data_file.root: |
|
|
342 |
case_directory = os.path.join(output_dir, data_file.root.subject_ids[index].decode('utf-8')) |
|
|
343 |
else: |
|
|
344 |
case_directory = os.path.join(output_dir, "validation_case_{}".format(index)) |
|
|
345 |
file_names.append( |
|
|
346 |
run_validation_case(data_index=index, output_dir=case_directory, model=model, data_file=data_file, |
|
|
347 |
training_modalities=training_modalities, overlap_factor=overlap_factor, |
|
|
348 |
permute=permute, patch_shape=patch_shape, prev_truth_index=prev_truth_index, |
|
|
349 |
prev_truth_size=prev_truth_size, use_augmentations=use_augmentations)) |
|
|
350 |
data_file.close() |
|
|
351 |
return file_names |
|
|
352 |
|
|
|
353 |
|
|
|
354 |
def predict(model, data, permute=False): |
|
|
355 |
if permute: |
|
|
356 |
predictions = list() |
|
|
357 |
for batch_index in range(data.shape[0]): |
|
|
358 |
predictions.append(predict_with_permutations(model, data[batch_index])) |
|
|
359 |
return np.asarray(predictions) |
|
|
360 |
else: |
|
|
361 |
return model.predict(data) |
|
|
362 |
|
|
|
363 |
|
|
|
364 |
def predict_with_permutations(model, data): |
|
|
365 |
predictions = list() |
|
|
366 |
for permutation_key in generate_permutation_keys(): |
|
|
367 |
temp_data = permute_data(data, permutation_key)[np.newaxis] |
|
|
368 |
predictions.append(reverse_permute_data(model.predict(temp_data)[0], permutation_key)) |
|
|
369 |
return np.mean(predictions, axis=0) |