a b/prod/predict_mat2.py
1
import argparse
2
import json
3
import os
4
5
import numpy as np
6
7
from brats.utils import get_last_model_path
8
from fetal_net.normalize import normalize_data
9
from fetal_net.postprocess import postprocess_prediction as process_pred
10
from fetal_net.prediction import patch_wise_prediction
11
from fetal_net.training import load_old_model
12
from brats.preprocess import window_intensities_data
13
14
from scipy.io import loadmat, savemat
15
16
from fetal_net.utils.cut_relevant_areas import find_bounding_box, check_bounding_box
17
18
19
def secondary_prediction(mask, vol, config2, model2_path=None,
20
                         preprocess_method2=None, norm_params2=None,
21
                         overlap_factor=0.9):
22
    model2 = load_old_model(get_last_model_path(model2_path))
23
    pred = mask
24
    bbox_start, bbox_end = find_bounding_box(pred)
25
    check_bounding_box(pred, bbox_start, bbox_end)
26
    padding = [16, 16, 8]
27
    if padding is not None:
28
        bbox_start = np.maximum(bbox_start - padding, 0)
29
        bbox_end = np.minimum(bbox_end + padding, mask.shape)
30
    data = vol.astype(np.float)[
31
           bbox_start[0]:bbox_end[0],
32
           bbox_start[1]:bbox_end[1],
33
           bbox_start[2]:bbox_end[2]
34
           ]
35
36
    data = preproc_and_norm(data, preprocess_method2, norm_params2)
37
38
    prediction = \
39
        patch_wise_prediction(model=model2,
40
                              data=np.expand_dims(data, 0),
41
                              overlap_factor=overlap_factor,
42
                              patch_shape=config2["patch_shape"] + [config2["patch_depth"]])
43
    prediction = prediction.squeeze()
44
    padding2 = list(zip(bbox_start, np.array(vol.shape) - bbox_end))
45
    print(padding2)
46
    print(prediction.shape)
47
    prediction = np.pad(prediction, padding2, mode='constant', constant_values=0)
48
    return prediction
49
50
51
def preproc_and_norm(data, preprocess_method, norm_params):
52
    if preprocess_method is not None:
53
        print('Applying preprocess by {}...'.format(preprocess_method))
54
        if preprocess_method == 'window_1_99':
55
            data = window_intensities_data(data)
56
        else:
57
            raise Exception('Unknown preprocess: {}'.format(preprocess_method))
58
59
    if norm_params is not None and any(norm_params.values()):
60
        data = normalize_data(data, mean=norm_params['mean'], std=norm_params['std'])
61
    return data
62
63
64
def main(input_mat_path, output_mat_path, overlap_factor,
65
         config, model_path, preprocess_method=None, norm_params=None,
66
         config2=None, model2_path=None, preprocess_method2=None, norm_params2=None):
67
    print(model_path)
68
    model = load_old_model(get_last_model_path(model_path))
69
    print('Loading mat from {}...'.format(input_mat_path))
70
    mat = loadmat(input_mat_path)
71
    print('Predicting mask...')
72
    data = mat['volume'].astype(np.float)
73
74
    data = preproc_and_norm(data, preprocess_method, norm_params)
75
76
    prediction = \
77
        patch_wise_prediction(model=model,
78
                              data=np.expand_dims(data, 0),
79
                              overlap_factor=overlap_factor,
80
                              patch_shape=config["patch_shape"] + [config["patch_depth"]])
81
82
    print('Post-processing mask...')
83
    if prediction.shape[-1] > 1:
84
        prediction = prediction[..., 1]
85
    prediction = prediction.squeeze()
86
    print("Storing prediction in [7-9], 7 should be the best...")
87
    mat['masks'][0, 9] = \
88
        process_pred(prediction, gaussian_std=0, threshold=0.2)  # .astype(np.uint8)
89
    mat['masks'][0, 8] = \
90
        process_pred(prediction, gaussian_std=1, threshold=0.5)  # .astype(np.uint8)
91
    mat['masks'][0, 7] = \
92
        process_pred(prediction, gaussian_std=0.5, threshold=0.5)  # .astype(np.uint8)
93
94
    if config2 is not None:
95
        print('Making secondary prediction... [6]')
96
        prediction = secondary_prediction(mat['masks'][0, 7], vol=mat['volume'].astype(np.float),
97
                                              config2=config2, model2_path=model2_path,
98
                                              preprocess_method2=preprocess_method2, norm_params2=norm_params2,
99
                                              overlap_factor=0.9)
100
        mat['masks'][0, 6] = \
101
            process_pred(prediction, gaussian_std=0, threshold=0.2)  # .astype(np.uint8)
102
        mat['masks'][0, 5] = \
103
            process_pred(prediction, gaussian_std=1, threshold=0.5)  # .astype(np.uint8)
104
        mat['masks'][0, 4] = \
105
            process_pred(prediction, gaussian_std=0.5, threshold=0.5)  # .astype(np.uint8)
106
107
108
    print('Saving mat to {}'.format(output_mat_path))
109
    savemat(output_mat_path, mat)
110
    print('Finished.')
111
112
113
def get_params(config_dir):
114
    with open(os.path.join(config_dir, 'config.json'), 'r') as f:
115
        __config = json.load(f)
116
    with open(os.path.join(config_dir, 'norm_params.json'), 'r') as f:
117
        __norm_params = json.load(f)
118
    __model_path = os.path.join(config_dir, os.path.basename(__config['model_file']))
119
    return __config, __norm_params, __model_path
120
121
122
if __name__ == '__main__':
123
    parser = argparse.ArgumentParser()
124
    parser.add_argument("--input_mat", help="specifies mat file dir path",
125
                        type=str, required=True)
126
    parser.add_argument("--output_mat", help="specifies mat file dir path",
127
                        type=str, required=True)
128
    parser.add_argument("--overlap_factor", help="specifies overlap between prediction patches",
129
                        type=float, default=0.9)
130
131
    # Params for primary prediction
132
    parser.add_argument("--config_dir", help="specifies config dir path",
133
                        type=str, required=True)
134
    parser.add_argument("--preprocess", help="what preprocess to do",
135
                        type=str, required=False, default=None)
136
137
    # Params for secondary prediction
138
    parser.add_argument("--config2_dir", help="specifies config dir path",
139
                        type=str, required=False, default=None)
140
    parser.add_argument("--preprocess2", help="what preprocess to do",
141
                        type=str, required=False, default=None)
142
143
    opts = parser.parse_args()
144
145
    # 1
146
    _config, _norm_params, _model_path = get_params(opts.config_dir)
147
    # 2
148
    if opts.config2_dir is not None:
149
        _config2, _norm_params2, _model2_path = get_params(opts.config2_dir)
150
    else:
151
        _config2, _norm_params2, _model2_path = None, None, None
152
153
    main(opts.input_mat, opts.output_mat, overlap_factor=opts.overlap_factor,
154
         config=_config, model_path=_model_path, preprocess_method=opts.preprocess, norm_params=_norm_params,
155
         config2=_config2, model2_path=_model2_path, preprocess_method2=opts.preprocess2, norm_params2=_norm_params2)