--- a +++ b/fetal_net/prediction.py @@ -0,0 +1,369 @@ +import itertools +import os + +import nibabel as nib +import numpy as np +import tables +from keras import Model +from scipy import ndimage +from tqdm import tqdm + +from fetal.utils import get_last_model_path +from fetal_net.utils.threaded_generator import ThreadedGenerator +from fetal_net.utils.utils import get_image, list_load, pickle_load +from .augment import permute_data, generate_permutation_keys, reverse_permute_data, contrast_augment +from .training import load_old_model +from .utils.patches import get_patch_from_3d_data + + +def flip_it(data_, axes): + for ax in axes: + data_ = np.flip(data_, ax) + return data_ + + +def predict_augment(data, model, overlap_factor, patch_shape, num_augments=32): + data_max = data.max() + data_min = data.min() + data = data.squeeze() + + order = 2 + predictions = [] + for _ in range(num_augments): + # pixel-wise augmentations + val_range = data_max - data_min + contrast_min_val = data_min + 0.10 * np.random.uniform(-1, 1) * val_range + contrast_max_val = data_max + 0.10 * np.random.uniform(-1, 1) * val_range + curr_data = contrast_augment(data, contrast_min_val, contrast_max_val) + + # spatial augmentations + rotate_factor = np.random.uniform(-30, 30) + to_flip = np.arange(0, 3)[np.random.choice([True, False], size=3)] + to_transpose = np.random.choice([True, False]) + + curr_data = flip_it(curr_data, to_flip) + + if to_transpose: + curr_data = curr_data.transpose([1, 0, 2]) + + curr_data = ndimage.rotate(curr_data, rotate_factor, order=order, reshape=False) + + curr_prediction = patch_wise_prediction(model=model, data=curr_data[np.newaxis, ...], overlap_factor=overlap_factor, patch_shape=patch_shape).squeeze() + + curr_prediction = ndimage.rotate(curr_prediction, -rotate_factor) + + if to_transpose: + curr_prediction = curr_prediction.transpose([1, 0, 2]) + + curr_prediction = flip_it(curr_prediction, to_flip) + predictions += [curr_prediction.squeeze()] + + res = np.stack(predictions, axis=0) + return res + + +def predict_flips(data, model, overlap_factor, config): + def powerset(iterable): + "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)" + s = list(iterable) + return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(0, len(s) + 1)) + + def predict_it(data_, axes=()): + data_ = flip_it(data_, axes) + curr_pred = \ + patch_wise_prediction(model=model, + data=np.expand_dims(data_.squeeze(), 0), + overlap_factor=overlap_factor, + patch_shape=config["patch_shape"] + [config["patch_depth"]]).squeeze() + curr_pred = flip_it(curr_pred, axes) + return curr_pred + + predictions = [] + for axes in powerset([0, 1, 2]): + predictions += [predict_it(data, axes).squeeze()] + + return predictions + + +def get_set_of_patch_indices_full(start, stop, step): + indices = [] + for start_i, stop_i, step_i in zip(start, stop, step): + indices_i = list(range(start_i, stop_i + 1, step_i)) + if stop_i % step_i > 0: + indices_i += [stop_i] + indices += [indices_i] + return np.array(list(itertools.product(*indices))) + + +def batch_iterator(indices, batch_size, data_0, patch_shape, truth_0, prev_truth_index, truth_patch_shape): + i = 0 + while i < len(indices): + batch = [] + curr_indices = [] + while len(batch) < batch_size and i < len(indices): + curr_index = indices[i] + patch = get_patch_from_3d_data(data_0, patch_shape=patch_shape, patch_index=curr_index) + if truth_0 is not None: + truth_index = list(curr_index[:2]) + [curr_index[2] + prev_truth_index] + truth_patch = get_patch_from_3d_data(truth_0, patch_shape=truth_patch_shape, + patch_index=truth_index) + patch = np.concatenate([patch, truth_patch], axis=-1) + batch.append(patch) + curr_indices.append(curr_index) + i += 1 + yield [batch, curr_indices] + # print('Finished! {}-{}'.format(i, len(indices))) + + +def patch_wise_prediction(model: Model, data, patch_shape, overlap_factor=0, batch_size=5, + permute=False, truth_data=None, prev_truth_index=None, prev_truth_size=None): + """ + :param truth_data: + :param permute: + :param overlap_factor: + :param batch_size: + :param model: + :param data: + :return: + """ + is3d = np.sum(np.array(model.output_shape[1:]) > 1) > 2 + + if is3d: + prediction_shape = model.output_shape[-3:] + else: + prediction_shape = model.output_shape[-3:-1] + (1,) # patch_shape[-3:-1] #[64,64]# + min_overlap = np.subtract(patch_shape, prediction_shape) + max_overlap = np.subtract(patch_shape, (1, 1, 1)) + overlap = min_overlap + (overlap_factor * (max_overlap - min_overlap)).astype(np.int) + data_0 = np.pad(data[0], + [(np.ceil(_ / 2).astype(int), np.floor(_ / 2).astype(int)) for _ in + np.subtract(patch_shape, prediction_shape)], + mode='constant', constant_values=np.percentile(data[0], q=1)) + pad_for_fit = [(np.ceil(_ / 2).astype(int), np.floor(_ / 2).astype(int)) for _ in + np.maximum(np.subtract(patch_shape, data_0.shape), 0)] + data_0 = np.pad(data_0, + [_ for _ in pad_for_fit], + 'constant', constant_values=np.percentile(data_0, q=1)) + + if truth_data is not None: + truth_0 = np.pad(truth_data[0], + [(np.ceil(_ / 2).astype(int), np.floor(_ / 2).astype(int)) for _ in + np.subtract(patch_shape, prediction_shape)], + mode='constant', constant_values=0) + truth_0 = np.pad(truth_0, [_ for _ in pad_for_fit], + 'constant', constant_values=0) + + truth_patch_shape = list(patch_shape[:2]) + [prev_truth_size] + else: + truth_0 = None + truth_patch_shape = None + + indices = get_set_of_patch_indices_full((0, 0, 0), + np.subtract(data_0.shape, patch_shape), + np.subtract(patch_shape, overlap)) + + b_iter = batch_iterator(indices, batch_size, data_0, patch_shape, + truth_0, prev_truth_index, truth_patch_shape) + tb_iter = iter(ThreadedGenerator(b_iter, queue_maxsize=50)) + + data_shape = list(data.shape[-3:] + np.sum(pad_for_fit, -1)) + if is3d: + data_shape += [model.output_shape[1]] + else: + data_shape += [model.output_shape[-1]] + predicted_output = np.zeros(data_shape) + predicted_count = np.zeros(data_shape, dtype=np.int16) + with tqdm(total=len(indices)) as pbar: + for [curr_batch, batch_indices] in tb_iter: + curr_batch = np.asarray(curr_batch) + if is3d: + curr_batch = np.expand_dims(curr_batch, 1) + prediction = predict(model, curr_batch, permute=permute) + + if is3d: + prediction = prediction.transpose([0, 2, 3, 4, 1]) + else: + prediction = np.expand_dims(prediction, -2) + + for predicted_patch, predicted_index in zip(prediction, batch_indices): + # predictions.append(predicted_patch) + x, y, z = predicted_index + x_len, y_len, z_len = predicted_patch.shape[:-1] + predicted_output[x:x + x_len, y:y + y_len, z:z + z_len, :] += predicted_patch + predicted_count[x:x + x_len, y:y + y_len, z:z + z_len] += 1 + pbar.update(batch_size) + + assert np.all(predicted_count > 0), 'Found zeros in count' + + if np.sum(pad_for_fit) > 0: + # must be a better way :\ + x_pad, y_pad, z_pad = [[None if p2[0] == 0 else p2[0], + None if p2[1] == 0 else -p2[1]] for p2 in pad_for_fit] + predicted_count = predicted_count[x_pad[0]: x_pad[1], + y_pad[0]: y_pad[1], + z_pad[0]: z_pad[1]] + predicted_output = predicted_output[x_pad[0]: x_pad[1], + y_pad[0]: y_pad[1], + z_pad[0]: z_pad[1]] + + assert np.array_equal(predicted_count.shape[:-1], data[0].shape), 'prediction shape wrong' + return predicted_output / predicted_count + # return reconstruct_from_patches(predictions, patch_indices=indices, data_shape=data_shape) + + +def get_prediction_labels(prediction, threshold=0.5, labels=None): + n_samples = prediction.shape[0] + label_arrays = [] + for sample_number in range(n_samples): + label_data = np.argmax(prediction[sample_number], axis=1) + label_data[np.max(prediction[sample_number], axis=0) < threshold] = 0 + if labels: + for value in np.unique(label_data).tolist()[1:]: + label_data[label_data == value] = labels[value - 1] + label_arrays.append(np.array(label_data, dtype=np.uint8)) + return label_arrays + + +def get_test_indices(testing_file): + return pickle_load(testing_file) + + +def predict_from_data_file(model, open_data_file, index): + return model.predict(open_data_file.root.data[index]) + + +def predict_and_get_image(model, data, affine): + return nib.Nifti1Image(model.predict(data)[0, 0], affine) + + +def predict_from_data_file_and_get_image(model, open_data_file, index): + return predict_and_get_image(model, open_data_file.root.data[index], open_data_file.root.affine) + + +def predict_from_data_file_and_write_image(model, open_data_file, index, out_file): + image = predict_from_data_file_and_get_image(model, open_data_file, index) + image.to_filename(out_file) + + +def prediction_to_image(prediction, label_map=False, threshold=0.5, labels=None): + if prediction.shape[0] == 1: + data = prediction[0] + if label_map: + label_map_data = np.zeros(prediction[0, 0].shape, np.int8) + if labels: + label = labels[0] + else: + label = 1 + label_map_data[data > threshold] = label + data = label_map_data + elif prediction.shape[1] > 1: + if label_map: + label_map_data = get_prediction_labels(prediction, threshold=threshold, labels=labels) + data = label_map_data[0] + else: + return multi_class_prediction(prediction) + else: + raise RuntimeError("Invalid prediction array shape: {0}".format(prediction.shape)) + return get_image(data) + + +def multi_class_prediction(prediction, affine): + prediction_images = [] + for i in range(prediction.shape[1]): + prediction_images.append(get_image(prediction[0, i])) + return prediction_images + + +def run_validation_case(data_index, output_dir, model, data_file, training_modalities, patch_shape, + overlap_factor=0, permute=False, prev_truth_index=None, prev_truth_size=None, + use_augmentations=False): + """ + Runs a test case and writes predicted images to file. + :param data_index: Index from of the list of test cases to get an image prediction from. + :param output_dir: Where to write prediction images. + :param output_label_map: If True, will write out a single image with one or more labels. Otherwise outputs + the (sigmoid) prediction values from the model. + :param threshold: If output_label_map is set to True, this threshold defines the value above which is + considered a positive result and will be assigned a label. + :param labels: + :param training_modalities: + :param data_file: + :param model: + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + test_data = np.asarray([data_file.root.data[data_index]]) + if prev_truth_index is not None: + test_truth_data = np.asarray([data_file.root.truth[data_index]]) + else: + test_truth_data = None + + for i, modality in enumerate(training_modalities): + image = get_image(test_data[i]) + image.to_filename(os.path.join(output_dir, "data_{0}.nii.gz".format(modality))) + + test_truth = get_image(data_file.root.truth[data_index]) + test_truth.to_filename(os.path.join(output_dir, "truth.nii.gz")) + + if patch_shape == test_data.shape[-3:]: + prediction = predict(model, test_data, permute=permute) + else: + if use_augmentations: + prediction = predict_augment(data=test_data, model=model, overlap_factor=overlap_factor, + patch_shape=patch_shape) + else: + prediction = \ + patch_wise_prediction(model=model, data=test_data, overlap_factor=overlap_factor, + patch_shape=patch_shape, + truth_data=test_truth_data, prev_truth_index=prev_truth_index, + prev_truth_size=prev_truth_size)[np.newaxis] + + prediction = prediction.squeeze() + prediction_image = get_image(prediction) + if isinstance(prediction_image, list): + for i, image in enumerate(prediction_image): + image.to_filename(os.path.join(output_dir, "prediction_{0}.nii.gz".format(i + 1))) + else: + filename = os.path.join(output_dir, "prediction.nii.gz") + prediction_image.to_filename(filename) + return filename + + +def run_validation_cases(validation_keys_file, model_file, training_modalities, hdf5_file, patch_shape, + output_dir=".", overlap_factor=0, permute=False, + prev_truth_index=None, prev_truth_size=None, use_augmentations=False): + file_names = [] + validation_indices = pickle_load(validation_keys_file) + model = load_old_model(get_last_model_path(model_file)) + data_file = tables.open_file(hdf5_file, "r") + for index in validation_indices: + if 'subject_ids' in data_file.root: + case_directory = os.path.join(output_dir, data_file.root.subject_ids[index].decode('utf-8')) + else: + case_directory = os.path.join(output_dir, "validation_case_{}".format(index)) + file_names.append( + run_validation_case(data_index=index, output_dir=case_directory, model=model, data_file=data_file, + training_modalities=training_modalities, overlap_factor=overlap_factor, + permute=permute, patch_shape=patch_shape, prev_truth_index=prev_truth_index, + prev_truth_size=prev_truth_size, use_augmentations=use_augmentations)) + data_file.close() + return file_names + + +def predict(model, data, permute=False): + if permute: + predictions = list() + for batch_index in range(data.shape[0]): + predictions.append(predict_with_permutations(model, data[batch_index])) + return np.asarray(predictions) + else: + return model.predict(data) + + +def predict_with_permutations(model, data): + predictions = list() + for permutation_key in generate_permutation_keys(): + temp_data = permute_data(data, permutation_key)[np.newaxis] + predictions.append(reverse_permute_data(model.predict(temp_data)[0], permutation_key)) + return np.mean(predictions, axis=0) \ No newline at end of file