Diff of /prod/predict_mat.py [000000] .. [ccb1dd]

Switch to unified view

a b/prod/predict_mat.py
1
import argparse
2
import json
3
import os
4
5
import numpy as np
6
7
from brats.utils import get_last_model_path
8
from fetal_net.normalize import normalize_data
9
from fetal_net.postprocess import postprocess_prediction as process_pred
10
from fetal_net.prediction import patch_wise_prediction
11
from fetal_net.training import load_old_model
12
from brats.preprocess import window_intensities_data
13
14
from scipy.io import loadmat, savemat
15
16
17
def main(input_mat_path, output_mat_path, config, overlap_factor, model_path, preprocess_method=None, norm_params=None):
18
    print(model_path)
19
    model = load_old_model(get_last_model_path(model_path))
20
    print('Loading mat from {}...'.format(input_mat_path))
21
    mat = loadmat(input_mat_path)
22
    print('Predicting mask...')
23
    data = mat['volume'].astype(np.float)
24
25
    if preprocess_method is not None:
26
        print('Applying preprocess by {}...'.format(preprocess_method))
27
        if preprocess_method == 'window_1_99':
28
            data = window_intensities_data(data)
29
        else:
30
            raise Exception('Unknown preprocess: {}'.format(preprocess_method))
31
32
    if norm_params is not None and any(norm_params.values()):
33
        data = normalize_data(data, mean=norm_params['mean'], std=norm_params['std'])
34
35
    prediction = \
36
        patch_wise_prediction(model=model,
37
                              data=np.expand_dims(data, 0),
38
                              overlap_factor=overlap_factor,
39
                              patch_shape=config["patch_shape"] + [config["patch_depth"]])
40
41
    print('Post-processing mask...')
42
    if prediction.shape[-1] > 1:
43
        prediction = prediction[..., 1]
44
    prediction = prediction.squeeze()
45
    mat['masks'][0, 9] = \
46
        process_pred(prediction, gaussian_std=0, threshold=0.2)  # .astype(np.uint8)
47
    mat['masks'][0, 8] = \
48
        process_pred(prediction, gaussian_std=1, threshold=0.5)  # .astype(np.uint8)
49
    mat['masks'][0, 7] = \
50
        process_pred(prediction, gaussian_std=0.5, threshold=0.5)  # .astype(np.uint8)
51
    print('Saving mat to {}'.format(output_mat_path))
52
    savemat(output_mat_path, mat)
53
    print('Finished.')
54
55
56
if __name__ == '__main__':
57
    parser = argparse.ArgumentParser()
58
    parser.add_argument("--config_dir", help="specifies config dir path",
59
                        type=str, required=True)
60
    parser.add_argument("--input_mat", help="specifies mat file dir path",
61
                        type=str, required=True)
62
    parser.add_argument("--output_mat", help="specifies mat file dir path",
63
                        type=str, required=True)
64
    parser.add_argument("--overlap_factor", help="specifies overlap between prediction patches",
65
                        type=float, default=0.9)
66
    parser.add_argument("--preprocess", help="what preprocess to do",
67
                        type=str, default=None)
68
    opts = parser.parse_args()
69
70
    with open(os.path.join(opts.config_dir, 'config.json'), 'r') as f:
71
        _config = json.load(f)
72
    with open(os.path.join(opts.config_dir, 'norm_params.json'), 'r') as f:
73
        _norm_params = json.load(f)
74
75
    _model_path = os.path.join(opts.config_dir, os.path.basename(_config['model_file']))
76
    main(opts.input_mat, opts.output_mat, _config, model_path=_model_path,
77
         preprocess_method=opts.preprocess, norm_params=_norm_params, overlap_factor=opts.overlap_factor)