Diff of /fetal_net/prediction.py [000000] .. [ccb1dd]

Switch to unified view

a b/fetal_net/prediction.py
1
import itertools
2
import os
3
4
import nibabel as nib
5
import numpy as np
6
import tables
7
from keras import Model
8
from scipy import ndimage
9
from tqdm import tqdm
10
11
from fetal.utils import get_last_model_path
12
from fetal_net.utils.threaded_generator import ThreadedGenerator
13
from fetal_net.utils.utils import get_image, list_load, pickle_load
14
from .augment import permute_data, generate_permutation_keys, reverse_permute_data, contrast_augment
15
from .training import load_old_model
16
from .utils.patches import get_patch_from_3d_data
17
18
19
def flip_it(data_, axes):
20
    for ax in axes:
21
        data_ = np.flip(data_, ax)
22
    return data_
23
24
25
def predict_augment(data, model, overlap_factor, patch_shape, num_augments=32):
26
    data_max = data.max()
27
    data_min = data.min()
28
    data = data.squeeze()
29
30
    order = 2
31
    predictions = []
32
    for _ in range(num_augments):
33
        # pixel-wise augmentations
34
        val_range = data_max - data_min
35
        contrast_min_val = data_min + 0.10 * np.random.uniform(-1, 1) * val_range
36
        contrast_max_val = data_max + 0.10 * np.random.uniform(-1, 1) * val_range
37
        curr_data = contrast_augment(data, contrast_min_val, contrast_max_val)
38
39
        # spatial augmentations
40
        rotate_factor = np.random.uniform(-30, 30)
41
        to_flip = np.arange(0, 3)[np.random.choice([True, False], size=3)]
42
        to_transpose = np.random.choice([True, False])
43
44
        curr_data = flip_it(curr_data, to_flip)
45
46
        if to_transpose:
47
            curr_data = curr_data.transpose([1, 0, 2])
48
49
        curr_data = ndimage.rotate(curr_data, rotate_factor, order=order, reshape=False)
50
51
        curr_prediction = patch_wise_prediction(model=model, data=curr_data[np.newaxis, ...], overlap_factor=overlap_factor, patch_shape=patch_shape).squeeze()
52
53
        curr_prediction = ndimage.rotate(curr_prediction, -rotate_factor)
54
55
        if to_transpose:
56
            curr_prediction = curr_prediction.transpose([1, 0, 2])
57
58
        curr_prediction = flip_it(curr_prediction, to_flip)
59
        predictions += [curr_prediction.squeeze()]
60
61
    res = np.stack(predictions, axis=0)
62
    return res
63
64
65
def predict_flips(data, model, overlap_factor, config):
66
    def powerset(iterable):
67
        "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
68
        s = list(iterable)
69
        return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(0, len(s) + 1))
70
71
    def predict_it(data_, axes=()):
72
        data_ = flip_it(data_, axes)
73
        curr_pred = \
74
            patch_wise_prediction(model=model,
75
                                  data=np.expand_dims(data_.squeeze(), 0),
76
                                  overlap_factor=overlap_factor,
77
                                  patch_shape=config["patch_shape"] + [config["patch_depth"]]).squeeze()
78
        curr_pred = flip_it(curr_pred, axes)
79
        return curr_pred
80
81
    predictions = []
82
    for axes in powerset([0, 1, 2]):
83
        predictions += [predict_it(data, axes).squeeze()]
84
85
    return predictions
86
87
88
def get_set_of_patch_indices_full(start, stop, step):
89
    indices = []
90
    for start_i, stop_i, step_i in zip(start, stop, step):
91
        indices_i = list(range(start_i, stop_i + 1, step_i))
92
        if stop_i % step_i > 0:
93
            indices_i += [stop_i]
94
        indices += [indices_i]
95
    return np.array(list(itertools.product(*indices)))
96
97
98
def batch_iterator(indices, batch_size, data_0, patch_shape, truth_0, prev_truth_index, truth_patch_shape):
99
    i = 0
100
    while i < len(indices):
101
        batch = []
102
        curr_indices = []
103
        while len(batch) < batch_size and i < len(indices):
104
            curr_index = indices[i]
105
            patch = get_patch_from_3d_data(data_0, patch_shape=patch_shape, patch_index=curr_index)
106
            if truth_0 is not None:
107
                truth_index = list(curr_index[:2]) + [curr_index[2] + prev_truth_index]
108
                truth_patch = get_patch_from_3d_data(truth_0, patch_shape=truth_patch_shape,
109
                                                     patch_index=truth_index)
110
                patch = np.concatenate([patch, truth_patch], axis=-1)
111
            batch.append(patch)
112
            curr_indices.append(curr_index)
113
            i += 1
114
        yield [batch, curr_indices]
115
    # print('Finished! {}-{}'.format(i, len(indices)))
116
117
118
def patch_wise_prediction(model: Model, data, patch_shape, overlap_factor=0, batch_size=5,
119
                          permute=False, truth_data=None, prev_truth_index=None, prev_truth_size=None):
120
    """
121
    :param truth_data:
122
    :param permute:
123
    :param overlap_factor:
124
    :param batch_size:
125
    :param model:
126
    :param data:
127
    :return:
128
    """
129
    is3d = np.sum(np.array(model.output_shape[1:]) > 1) > 2
130
131
    if is3d:
132
        prediction_shape = model.output_shape[-3:]
133
    else:
134
        prediction_shape = model.output_shape[-3:-1] + (1,)  # patch_shape[-3:-1] #[64,64]#
135
    min_overlap = np.subtract(patch_shape, prediction_shape)
136
    max_overlap = np.subtract(patch_shape, (1, 1, 1))
137
    overlap = min_overlap + (overlap_factor * (max_overlap - min_overlap)).astype(np.int)
138
    data_0 = np.pad(data[0],
139
                    [(np.ceil(_ / 2).astype(int), np.floor(_ / 2).astype(int)) for _ in
140
                     np.subtract(patch_shape, prediction_shape)],
141
                    mode='constant', constant_values=np.percentile(data[0], q=1))
142
    pad_for_fit = [(np.ceil(_ / 2).astype(int), np.floor(_ / 2).astype(int)) for _ in
143
                   np.maximum(np.subtract(patch_shape, data_0.shape), 0)]
144
    data_0 = np.pad(data_0,
145
                    [_ for _ in pad_for_fit],
146
                    'constant', constant_values=np.percentile(data_0, q=1))
147
148
    if truth_data is not None:
149
        truth_0 = np.pad(truth_data[0],
150
                         [(np.ceil(_ / 2).astype(int), np.floor(_ / 2).astype(int)) for _ in
151
                          np.subtract(patch_shape, prediction_shape)],
152
                         mode='constant', constant_values=0)
153
        truth_0 = np.pad(truth_0, [_ for _ in pad_for_fit],
154
                         'constant', constant_values=0)
155
156
        truth_patch_shape = list(patch_shape[:2]) + [prev_truth_size]
157
    else:
158
        truth_0 = None
159
        truth_patch_shape = None
160
161
    indices = get_set_of_patch_indices_full((0, 0, 0),
162
                                            np.subtract(data_0.shape, patch_shape),
163
                                            np.subtract(patch_shape, overlap))
164
165
    b_iter = batch_iterator(indices, batch_size, data_0, patch_shape,
166
                            truth_0, prev_truth_index, truth_patch_shape)
167
    tb_iter = iter(ThreadedGenerator(b_iter, queue_maxsize=50))
168
169
    data_shape = list(data.shape[-3:] + np.sum(pad_for_fit, -1))
170
    if is3d:
171
        data_shape += [model.output_shape[1]]
172
    else:
173
        data_shape += [model.output_shape[-1]]
174
    predicted_output = np.zeros(data_shape)
175
    predicted_count = np.zeros(data_shape, dtype=np.int16)
176
    with tqdm(total=len(indices)) as pbar:
177
        for [curr_batch, batch_indices] in tb_iter:
178
            curr_batch = np.asarray(curr_batch)
179
            if is3d:
180
                curr_batch = np.expand_dims(curr_batch, 1)
181
            prediction = predict(model, curr_batch, permute=permute)
182
183
            if is3d:
184
                prediction = prediction.transpose([0, 2, 3, 4, 1])
185
            else:
186
                prediction = np.expand_dims(prediction, -2)
187
188
            for predicted_patch, predicted_index in zip(prediction, batch_indices):
189
                # predictions.append(predicted_patch)
190
                x, y, z = predicted_index
191
                x_len, y_len, z_len = predicted_patch.shape[:-1]
192
                predicted_output[x:x + x_len, y:y + y_len, z:z + z_len, :] += predicted_patch
193
                predicted_count[x:x + x_len, y:y + y_len, z:z + z_len] += 1
194
            pbar.update(batch_size)
195
196
    assert np.all(predicted_count > 0), 'Found zeros in count'
197
198
    if np.sum(pad_for_fit) > 0:
199
        # must be a better way :\
200
        x_pad, y_pad, z_pad = [[None if p2[0] == 0 else p2[0],
201
                                None if p2[1] == 0 else -p2[1]] for p2 in pad_for_fit]
202
        predicted_count = predicted_count[x_pad[0]: x_pad[1],
203
                          y_pad[0]: y_pad[1],
204
                          z_pad[0]: z_pad[1]]
205
        predicted_output = predicted_output[x_pad[0]: x_pad[1],
206
                           y_pad[0]: y_pad[1],
207
                           z_pad[0]: z_pad[1]]
208
209
    assert np.array_equal(predicted_count.shape[:-1], data[0].shape), 'prediction shape wrong'
210
    return predicted_output / predicted_count
211
    # return reconstruct_from_patches(predictions, patch_indices=indices, data_shape=data_shape)
212
213
214
def get_prediction_labels(prediction, threshold=0.5, labels=None):
215
    n_samples = prediction.shape[0]
216
    label_arrays = []
217
    for sample_number in range(n_samples):
218
        label_data = np.argmax(prediction[sample_number], axis=1)
219
        label_data[np.max(prediction[sample_number], axis=0) < threshold] = 0
220
        if labels:
221
            for value in np.unique(label_data).tolist()[1:]:
222
                label_data[label_data == value] = labels[value - 1]
223
        label_arrays.append(np.array(label_data, dtype=np.uint8))
224
    return label_arrays
225
226
227
def get_test_indices(testing_file):
228
    return pickle_load(testing_file)
229
230
231
def predict_from_data_file(model, open_data_file, index):
232
    return model.predict(open_data_file.root.data[index])
233
234
235
def predict_and_get_image(model, data, affine):
236
    return nib.Nifti1Image(model.predict(data)[0, 0], affine)
237
238
239
def predict_from_data_file_and_get_image(model, open_data_file, index):
240
    return predict_and_get_image(model, open_data_file.root.data[index], open_data_file.root.affine)
241
242
243
def predict_from_data_file_and_write_image(model, open_data_file, index, out_file):
244
    image = predict_from_data_file_and_get_image(model, open_data_file, index)
245
    image.to_filename(out_file)
246
247
248
def prediction_to_image(prediction, label_map=False, threshold=0.5, labels=None):
249
    if prediction.shape[0] == 1:
250
        data = prediction[0]
251
        if label_map:
252
            label_map_data = np.zeros(prediction[0, 0].shape, np.int8)
253
            if labels:
254
                label = labels[0]
255
            else:
256
                label = 1
257
            label_map_data[data > threshold] = label
258
            data = label_map_data
259
    elif prediction.shape[1] > 1:
260
        if label_map:
261
            label_map_data = get_prediction_labels(prediction, threshold=threshold, labels=labels)
262
            data = label_map_data[0]
263
        else:
264
            return multi_class_prediction(prediction)
265
    else:
266
        raise RuntimeError("Invalid prediction array shape: {0}".format(prediction.shape))
267
    return get_image(data)
268
269
270
def multi_class_prediction(prediction, affine):
271
    prediction_images = []
272
    for i in range(prediction.shape[1]):
273
        prediction_images.append(get_image(prediction[0, i]))
274
    return prediction_images
275
276
277
def run_validation_case(data_index, output_dir, model, data_file, training_modalities, patch_shape,
278
                        overlap_factor=0, permute=False, prev_truth_index=None, prev_truth_size=None,
279
                        use_augmentations=False):
280
    """
281
    Runs a test case and writes predicted images to file.
282
    :param data_index: Index from of the list of test cases to get an image prediction from.
283
    :param output_dir: Where to write prediction images.
284
    :param output_label_map: If True, will write out a single image with one or more labels. Otherwise outputs
285
    the (sigmoid) prediction values from the model.
286
    :param threshold: If output_label_map is set to True, this threshold defines the value above which is 
287
    considered a positive result and will be assigned a label.  
288
    :param labels:
289
    :param training_modalities:
290
    :param data_file:
291
    :param model:
292
    """
293
    if not os.path.exists(output_dir):
294
        os.makedirs(output_dir)
295
296
    test_data = np.asarray([data_file.root.data[data_index]])
297
    if prev_truth_index is not None:
298
        test_truth_data = np.asarray([data_file.root.truth[data_index]])
299
    else:
300
        test_truth_data = None
301
302
    for i, modality in enumerate(training_modalities):
303
        image = get_image(test_data[i])
304
        image.to_filename(os.path.join(output_dir, "data_{0}.nii.gz".format(modality)))
305
306
    test_truth = get_image(data_file.root.truth[data_index])
307
    test_truth.to_filename(os.path.join(output_dir, "truth.nii.gz"))
308
309
    if patch_shape == test_data.shape[-3:]:
310
        prediction = predict(model, test_data, permute=permute)
311
    else:
312
        if use_augmentations:
313
            prediction = predict_augment(data=test_data, model=model, overlap_factor=overlap_factor,
314
                                         patch_shape=patch_shape)
315
        else:
316
            prediction = \
317
                patch_wise_prediction(model=model, data=test_data, overlap_factor=overlap_factor,
318
                                      patch_shape=patch_shape,
319
                                      truth_data=test_truth_data, prev_truth_index=prev_truth_index,
320
                                      prev_truth_size=prev_truth_size)[np.newaxis]
321
322
    prediction = prediction.squeeze()
323
    prediction_image = get_image(prediction)
324
    if isinstance(prediction_image, list):
325
        for i, image in enumerate(prediction_image):
326
            image.to_filename(os.path.join(output_dir, "prediction_{0}.nii.gz".format(i + 1)))
327
    else:
328
        filename = os.path.join(output_dir, "prediction.nii.gz")
329
        prediction_image.to_filename(filename)
330
    return filename
331
332
333
def run_validation_cases(validation_keys_file, model_file, training_modalities, hdf5_file, patch_shape,
334
                         output_dir=".", overlap_factor=0, permute=False,
335
                         prev_truth_index=None, prev_truth_size=None, use_augmentations=False):
336
    file_names = []
337
    validation_indices = pickle_load(validation_keys_file)
338
    model = load_old_model(get_last_model_path(model_file))
339
    data_file = tables.open_file(hdf5_file, "r")
340
    for index in validation_indices:
341
        if 'subject_ids' in data_file.root:
342
            case_directory = os.path.join(output_dir, data_file.root.subject_ids[index].decode('utf-8'))
343
        else:
344
            case_directory = os.path.join(output_dir, "validation_case_{}".format(index))
345
        file_names.append(
346
            run_validation_case(data_index=index, output_dir=case_directory, model=model, data_file=data_file,
347
                                training_modalities=training_modalities, overlap_factor=overlap_factor,
348
                                permute=permute, patch_shape=patch_shape, prev_truth_index=prev_truth_index,
349
                                prev_truth_size=prev_truth_size, use_augmentations=use_augmentations))
350
    data_file.close()
351
    return file_names
352
353
354
def predict(model, data, permute=False):
355
    if permute:
356
        predictions = list()
357
        for batch_index in range(data.shape[0]):
358
            predictions.append(predict_with_permutations(model, data[batch_index]))
359
        return np.asarray(predictions)
360
    else:
361
        return model.predict(data)
362
363
364
def predict_with_permutations(model, data):
365
    predictions = list()
366
    for permutation_key in generate_permutation_keys():
367
        temp_data = permute_data(data, permutation_key)[np.newaxis]
368
        predictions.append(reverse_permute_data(model.predict(temp_data)[0], permutation_key))
369
    return np.mean(predictions, axis=0)