--- a +++ b/prod/predict_mat.py @@ -0,0 +1,77 @@ +import argparse +import json +import os + +import numpy as np + +from brats.utils import get_last_model_path +from fetal_net.normalize import normalize_data +from fetal_net.postprocess import postprocess_prediction as process_pred +from fetal_net.prediction import patch_wise_prediction +from fetal_net.training import load_old_model +from brats.preprocess import window_intensities_data + +from scipy.io import loadmat, savemat + + +def main(input_mat_path, output_mat_path, config, overlap_factor, model_path, preprocess_method=None, norm_params=None): + print(model_path) + model = load_old_model(get_last_model_path(model_path)) + print('Loading mat from {}...'.format(input_mat_path)) + mat = loadmat(input_mat_path) + print('Predicting mask...') + data = mat['volume'].astype(np.float) + + if preprocess_method is not None: + print('Applying preprocess by {}...'.format(preprocess_method)) + if preprocess_method == 'window_1_99': + data = window_intensities_data(data) + else: + raise Exception('Unknown preprocess: {}'.format(preprocess_method)) + + if norm_params is not None and any(norm_params.values()): + data = normalize_data(data, mean=norm_params['mean'], std=norm_params['std']) + + prediction = \ + patch_wise_prediction(model=model, + data=np.expand_dims(data, 0), + overlap_factor=overlap_factor, + patch_shape=config["patch_shape"] + [config["patch_depth"]]) + + print('Post-processing mask...') + if prediction.shape[-1] > 1: + prediction = prediction[..., 1] + prediction = prediction.squeeze() + mat['masks'][0, 9] = \ + process_pred(prediction, gaussian_std=0, threshold=0.2) # .astype(np.uint8) + mat['masks'][0, 8] = \ + process_pred(prediction, gaussian_std=1, threshold=0.5) # .astype(np.uint8) + mat['masks'][0, 7] = \ + process_pred(prediction, gaussian_std=0.5, threshold=0.5) # .astype(np.uint8) + print('Saving mat to {}'.format(output_mat_path)) + savemat(output_mat_path, mat) + print('Finished.') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--config_dir", help="specifies config dir path", + type=str, required=True) + parser.add_argument("--input_mat", help="specifies mat file dir path", + type=str, required=True) + parser.add_argument("--output_mat", help="specifies mat file dir path", + type=str, required=True) + parser.add_argument("--overlap_factor", help="specifies overlap between prediction patches", + type=float, default=0.9) + parser.add_argument("--preprocess", help="what preprocess to do", + type=str, default=None) + opts = parser.parse_args() + + with open(os.path.join(opts.config_dir, 'config.json'), 'r') as f: + _config = json.load(f) + with open(os.path.join(opts.config_dir, 'norm_params.json'), 'r') as f: + _norm_params = json.load(f) + + _model_path = os.path.join(opts.config_dir, os.path.basename(_config['model_file'])) + main(opts.input_mat, opts.output_mat, _config, model_path=_model_path, + preprocess_method=opts.preprocess, norm_params=_norm_params, overlap_factor=opts.overlap_factor)