--- a +++ b/prod/predict_nifti2.py @@ -0,0 +1,220 @@ +import argparse +import json +import os +from pathlib import Path + +import nibabel as nib + +import fetal_net.preprocess +from fetal.preprocess import window_intensities_data +from fetal.utils import get_last_model_path +from fetal_net.normalize import normalize_data +from fetal_net.postprocess import postprocess_prediction as process_pred +from fetal_net.prediction import patch_wise_prediction, predict_augment, predict_flips +from fetal_net.preprocess import * +from fetal_net.training import load_old_model +from fetal_net.utils.cut_relevant_areas import find_bounding_box, check_bounding_box +from fetal_net.utils.utils import read_img, get_image + + +def save_nifti(data, path): + nifti = get_image(data) + nib.save(nifti, path) + + +def secondary_prediction(mask, vol, config2, model2_path=None, + preprocess_method2=None, norm_params2=None, + overlap_factor=0.9, augment2=None, num_augment=32, return_all_preds=False): + model2 = load_old_model(get_last_model_path(model2_path), config=config2) + pred = mask + bbox_start, bbox_end = find_bounding_box(pred) + check_bounding_box(pred, bbox_start, bbox_end) + padding = [16, 16, 8] + if padding is not None: + bbox_start = np.maximum(bbox_start - padding, 0) + bbox_end = np.minimum(bbox_end + padding, mask.shape) + data = vol.astype(np.float)[ + bbox_start[0]:bbox_end[0], + bbox_start[1]:bbox_end[1], + bbox_start[2]:bbox_end[2] + ] + + data = preproc_and_norm(data, preprocess_method2, norm_params2) + + prediction = get_prediction(data, model2, augment=augment2, num_augments=num_augment, return_all_preds=return_all_preds, + overlap_factor=overlap_factor, config=config2) + + padding2 = list(zip(bbox_start, np.array(vol.shape) - bbox_end)) + if return_all_preds: + padding2 = [(0, 0)] + padding2 + print(padding2) + print(prediction.shape) + prediction = np.pad(prediction, padding2, mode='constant', constant_values=0) + + return prediction + + +def preproc_and_norm(data, preprocess_method=None, norm_params=None, scale=None, preproc=None): + if preprocess_method is not None: + print('Applying preprocess by {}...'.format(preprocess_method)) + if preprocess_method == 'window_1_99': + data = window_intensities_data(data) + else: + raise Exception('Unknown preprocess: {}'.format(preprocess_method)) + + if scale is not None: + data = ndimage.zoom(data, scale) + if preproc is not None: + preproc_func = getattr(fetal_net.preprocess, preproc) + data = preproc_func(data) + + # data = normalize_data(data, mean=data.mean(), std=data.std()) + if norm_params is not None and any(norm_params.values()): + data = normalize_data(data, mean=norm_params['mean'], std=norm_params['std']) + return data + + +def get_prediction(data, model, augment, num_augments, return_all_preds, overlap_factor, config): + if augment is not None: + patch_shape = config["patch_shape"] + [config["patch_depth"]] + if augment == 'all': + prediction = predict_augment(data, model=model, overlap_factor=overlap_factor, num_augments=num_augments, patch_shape=patch_shape) + elif augment == 'flip': + prediction = predict_flips(data, model=model, overlap_factor=overlap_factor, patch_shape=patch_shape, config=config) + else: + raise ("Unknown augmentation {}".format(augment)) + if not return_all_preds: + prediction = np.median(prediction, axis=0) + else: + prediction = \ + patch_wise_prediction(model=model, + data=np.expand_dims(data, 0), + overlap_factor=overlap_factor, + patch_shape=config["patch_shape"] + [config["patch_depth"]]) + prediction = prediction.squeeze() + return prediction + + +def main(input_path, output_path, overlap_factor, + config, model_path, preprocess_method=None, norm_params=None, augment=None, num_augment=0, + config2=None, model2_path=None, preprocess_method2=None, norm_params2=None, augment2=None, num_augment2=0, + z_scale=None, xy_scale=None, return_all_preds=False): + print(model_path) + model = load_old_model(get_last_model_path(model_path), config=config) + print('Loading nifti from {}...'.format(input_path)) + nifti = read_img(input_path) + print('Predicting mask...') + data = nifti.get_fdata().astype(np.float).squeeze() + print('original_shape: ' + str(data.shape)) + scan_name = Path(input_path).name.split('.')[0] + + if (z_scale is None): + z_scale = 1.0 + if (xy_scale is None): + xy_scale = 1.0 + if z_scale != 1.0 or xy_scale != 1.0: + data = ndimage.zoom(data, [xy_scale, xy_scale, z_scale]) + + data = preproc_and_norm(data, preprocess_method, norm_params, + scale=config.get('scale_data', None), + preproc=config.get('preproc', None)) + + save_nifti(data, os.path.join(output_path, scan_name + '_data.nii.gz')) + + data = np.pad(data, 3, 'constant', constant_values=data.min()) + + print('Shape: ' + str(data.shape)) + prediction = get_prediction(data=data, model=model, augment=augment, + num_augments=num_augment, return_all_preds=return_all_preds, + overlap_factor=overlap_factor, config=config) + # unpad + prediction = prediction[3:-3, 3:-3, 3:-3] + + # revert to original size + if config.get('scale_data', None) is not None: + prediction = ndimage.zoom(prediction.squeeze(), np.divide([1, 1, 1], config.get('scale_data', None)), order=0)[..., np.newaxis] + + save_nifti(prediction, os.path.join(output_path, scan_name + '_pred.nii.gz')) + + if z_scale != 1.0 or xy_scale != 1.0: + prediction = ndimage.zoom(prediction.squeeze(), [1.0 / xy_scale, 1.0 / xy_scale, 1.0 / z_scale], order=1)[..., np.newaxis] + + # if prediction.shape[-1] > 1: + # prediction = prediction[..., 1] + if config2 is not None: + prediction = prediction.squeeze() + mask = process_pred(prediction, gaussian_std=0.5, threshold=0.5) # .astype(np.uint8) + nifti = read_img(input_path) + prediction = secondary_prediction(mask, vol=nifti.get_fdata().astype(np.float), + config2=config2, model2_path=model2_path, + preprocess_method2=preprocess_method2, norm_params2=norm_params2, + overlap_factor=overlap_factor, augment2=augment2, num_augment=num_augment2, + return_all_preds=return_all_preds) + save_nifti(prediction, os.path.join(output_path, scan_name + 'pred_roi.nii.gz')) + + print('Saving to {}'.format(output_path)) + print('Finished.') + + +def get_params(config_dir): + with open(os.path.join(config_dir, 'config.json'), 'r') as f: + __config = json.load(f) + with open(os.path.join(config_dir, 'norm_params.json'), 'r') as f: + __norm_params = json.load(f) + __model_path = os.path.join(config_dir, os.path.basename(__config['model_file'])) + return __config, __norm_params, __model_path + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--input_nii", help="specifies mat file dir path", + type=str, required=True) + parser.add_argument("--output_folder", help="specifies mat file dir path", + type=str, required=True) + parser.add_argument("--overlap_factor", help="specifies overlap between prediction patches", + type=float, default=0.9) + parser.add_argument("--z_scale", help="specifies overlap between prediction patches", + type=float, default=1) + parser.add_argument("--xy_scale", help="specifies overlap between prediction patches", + type=float, default=1) + parser.add_argument("--return_all_preds", help="output std for prediction", + type=int, default=0) + + # Params for primary prediction + parser.add_argument("--config_dir", help="specifies config dir path", + type=str, required=True) + parser.add_argument("--preprocess", help="what preprocess to do", + type=str, required=False, default=None) + parser.add_argument("--augment", help="what augment to do", + type=str, required=False, default=None) # one of 'flip, all' + parser.add_argument("--num_augment", help="what augment to do", + type=int, required=False, default=0) # one of 'flip, all' + + # Params for secondary prediction + parser.add_argument("--config2_dir", help="specifies config dir path", + type=str, required=False, default=None) + parser.add_argument("--preprocess2", help="what preprocess to do", + type=str, required=False, default=None) + parser.add_argument("--augment2", help="what augment to do", + type=str, required=False, default=None) # one of 'flip, all' + parser.add_argument("--num_augment2", help="what augment to do", + type=int, required=False, default=0) # one of 'flip, all' + + opts = parser.parse_args() + + Path(opts.output_folder).mkdir(exist_ok=True) + + # 1 + _config, _norm_params, _model_path = get_params(opts.config_dir) + # 2 + if opts.config2_dir is not None: + _config2, _norm_params2, _model2_path = get_params(opts.config2_dir) + else: + _config2, _norm_params2, _model2_path = None, None, None + + main(opts.input_nii, opts.output_folder, overlap_factor=opts.overlap_factor, + config=_config, model_path=_model_path, preprocess_method=opts.preprocess, norm_params=_norm_params, augment=opts.augment, + num_augment=opts.num_augment, + config2=_config2, model2_path=_model2_path, preprocess_method2=opts.preprocess2, norm_params2=_norm_params2, augment2=opts.augment2, + num_augment2=opts.num_augment2, + z_scale=opts.z_scale, xy_scale=opts.xy_scale, return_all_preds=opts.return_all_preds)