Diff of /prod/predict_nifti2.py [000000] .. [ccb1dd]

Switch to unified view

a b/prod/predict_nifti2.py
1
import argparse
2
import json
3
import os
4
from pathlib import Path
5
6
import nibabel as nib
7
8
import fetal_net.preprocess
9
from fetal.preprocess import window_intensities_data
10
from fetal.utils import get_last_model_path
11
from fetal_net.normalize import normalize_data
12
from fetal_net.postprocess import postprocess_prediction as process_pred
13
from fetal_net.prediction import patch_wise_prediction, predict_augment, predict_flips
14
from fetal_net.preprocess import *
15
from fetal_net.training import load_old_model
16
from fetal_net.utils.cut_relevant_areas import find_bounding_box, check_bounding_box
17
from fetal_net.utils.utils import read_img, get_image
18
19
20
def save_nifti(data, path):
21
    nifti = get_image(data)
22
    nib.save(nifti, path)
23
24
25
def secondary_prediction(mask, vol, config2, model2_path=None,
26
                         preprocess_method2=None, norm_params2=None,
27
                         overlap_factor=0.9, augment2=None, num_augment=32, return_all_preds=False):
28
    model2 = load_old_model(get_last_model_path(model2_path), config=config2)
29
    pred = mask
30
    bbox_start, bbox_end = find_bounding_box(pred)
31
    check_bounding_box(pred, bbox_start, bbox_end)
32
    padding = [16, 16, 8]
33
    if padding is not None:
34
        bbox_start = np.maximum(bbox_start - padding, 0)
35
        bbox_end = np.minimum(bbox_end + padding, mask.shape)
36
    data = vol.astype(np.float)[
37
           bbox_start[0]:bbox_end[0],
38
           bbox_start[1]:bbox_end[1],
39
           bbox_start[2]:bbox_end[2]
40
           ]
41
42
    data = preproc_and_norm(data, preprocess_method2, norm_params2)
43
44
    prediction = get_prediction(data, model2, augment=augment2, num_augments=num_augment, return_all_preds=return_all_preds,
45
                                overlap_factor=overlap_factor, config=config2)
46
47
    padding2 = list(zip(bbox_start, np.array(vol.shape) - bbox_end))
48
    if return_all_preds:
49
        padding2 = [(0, 0)] + padding2
50
    print(padding2)
51
    print(prediction.shape)
52
    prediction = np.pad(prediction, padding2, mode='constant', constant_values=0)
53
54
    return prediction
55
56
57
def preproc_and_norm(data, preprocess_method=None, norm_params=None, scale=None, preproc=None):
58
    if preprocess_method is not None:
59
        print('Applying preprocess by {}...'.format(preprocess_method))
60
        if preprocess_method == 'window_1_99':
61
            data = window_intensities_data(data)
62
        else:
63
            raise Exception('Unknown preprocess: {}'.format(preprocess_method))
64
65
    if scale is not None:
66
        data = ndimage.zoom(data, scale)
67
    if preproc is not None:
68
        preproc_func = getattr(fetal_net.preprocess, preproc)
69
        data = preproc_func(data)
70
71
    # data = normalize_data(data, mean=data.mean(), std=data.std())
72
    if norm_params is not None and any(norm_params.values()):
73
        data = normalize_data(data, mean=norm_params['mean'], std=norm_params['std'])
74
    return data
75
76
77
def get_prediction(data, model, augment, num_augments, return_all_preds, overlap_factor, config):
78
    if augment is not None:
79
        patch_shape = config["patch_shape"] + [config["patch_depth"]]
80
        if augment == 'all':
81
            prediction = predict_augment(data, model=model, overlap_factor=overlap_factor, num_augments=num_augments, patch_shape=patch_shape)
82
        elif augment == 'flip':
83
            prediction = predict_flips(data, model=model, overlap_factor=overlap_factor, patch_shape=patch_shape, config=config)
84
        else:
85
            raise ("Unknown augmentation {}".format(augment))
86
        if not return_all_preds:
87
            prediction = np.median(prediction, axis=0)
88
    else:
89
        prediction = \
90
            patch_wise_prediction(model=model,
91
                                  data=np.expand_dims(data, 0),
92
                                  overlap_factor=overlap_factor,
93
                                  patch_shape=config["patch_shape"] + [config["patch_depth"]])
94
    prediction = prediction.squeeze()
95
    return prediction
96
97
98
def main(input_path, output_path, overlap_factor,
99
         config, model_path, preprocess_method=None, norm_params=None, augment=None, num_augment=0,
100
         config2=None, model2_path=None, preprocess_method2=None, norm_params2=None, augment2=None, num_augment2=0,
101
         z_scale=None, xy_scale=None, return_all_preds=False):
102
    print(model_path)
103
    model = load_old_model(get_last_model_path(model_path), config=config)
104
    print('Loading nifti from {}...'.format(input_path))
105
    nifti = read_img(input_path)
106
    print('Predicting mask...')
107
    data = nifti.get_fdata().astype(np.float).squeeze()
108
    print('original_shape: ' + str(data.shape))
109
    scan_name = Path(input_path).name.split('.')[0]
110
111
    if (z_scale is None):
112
        z_scale = 1.0
113
    if (xy_scale is None):
114
        xy_scale = 1.0
115
    if z_scale != 1.0 or xy_scale != 1.0:
116
        data = ndimage.zoom(data, [xy_scale, xy_scale, z_scale])
117
118
    data = preproc_and_norm(data, preprocess_method, norm_params,
119
                            scale=config.get('scale_data', None),
120
                            preproc=config.get('preproc', None))
121
122
    save_nifti(data, os.path.join(output_path, scan_name + '_data.nii.gz'))
123
124
    data = np.pad(data, 3, 'constant', constant_values=data.min())
125
126
    print('Shape: ' + str(data.shape))
127
    prediction = get_prediction(data=data, model=model, augment=augment,
128
                                num_augments=num_augment, return_all_preds=return_all_preds,
129
                                overlap_factor=overlap_factor, config=config)
130
    # unpad
131
    prediction = prediction[3:-3, 3:-3, 3:-3]
132
133
    # revert to original size
134
    if config.get('scale_data', None) is not None:
135
        prediction = ndimage.zoom(prediction.squeeze(), np.divide([1, 1, 1], config.get('scale_data', None)), order=0)[..., np.newaxis]
136
137
    save_nifti(prediction, os.path.join(output_path, scan_name + '_pred.nii.gz'))
138
139
    if z_scale != 1.0 or xy_scale != 1.0:
140
        prediction = ndimage.zoom(prediction.squeeze(), [1.0 / xy_scale, 1.0 / xy_scale, 1.0 / z_scale], order=1)[..., np.newaxis]
141
142
    # if prediction.shape[-1] > 1:
143
    #    prediction = prediction[..., 1]
144
    if config2 is not None:
145
        prediction = prediction.squeeze()
146
        mask = process_pred(prediction, gaussian_std=0.5, threshold=0.5)  # .astype(np.uint8)
147
        nifti = read_img(input_path)
148
        prediction = secondary_prediction(mask, vol=nifti.get_fdata().astype(np.float),
149
                                          config2=config2, model2_path=model2_path,
150
                                          preprocess_method2=preprocess_method2, norm_params2=norm_params2,
151
                                          overlap_factor=overlap_factor, augment2=augment2, num_augment=num_augment2,
152
                                          return_all_preds=return_all_preds)
153
        save_nifti(prediction, os.path.join(output_path, scan_name + 'pred_roi.nii.gz'))
154
155
    print('Saving to {}'.format(output_path))
156
    print('Finished.')
157
158
159
def get_params(config_dir):
160
    with open(os.path.join(config_dir, 'config.json'), 'r') as f:
161
        __config = json.load(f)
162
    with open(os.path.join(config_dir, 'norm_params.json'), 'r') as f:
163
        __norm_params = json.load(f)
164
    __model_path = os.path.join(config_dir, os.path.basename(__config['model_file']))
165
    return __config, __norm_params, __model_path
166
167
168
if __name__ == '__main__':
169
    parser = argparse.ArgumentParser()
170
    parser.add_argument("--input_nii", help="specifies mat file dir path",
171
                        type=str, required=True)
172
    parser.add_argument("--output_folder", help="specifies mat file dir path",
173
                        type=str, required=True)
174
    parser.add_argument("--overlap_factor", help="specifies overlap between prediction patches",
175
                        type=float, default=0.9)
176
    parser.add_argument("--z_scale", help="specifies overlap between prediction patches",
177
                        type=float, default=1)
178
    parser.add_argument("--xy_scale", help="specifies overlap between prediction patches",
179
                        type=float, default=1)
180
    parser.add_argument("--return_all_preds", help="output std for prediction",
181
                        type=int, default=0)
182
183
    # Params for primary prediction
184
    parser.add_argument("--config_dir", help="specifies config dir path",
185
                        type=str, required=True)
186
    parser.add_argument("--preprocess", help="what preprocess to do",
187
                        type=str, required=False, default=None)
188
    parser.add_argument("--augment", help="what augment to do",
189
                        type=str, required=False, default=None)  # one of 'flip, all'
190
    parser.add_argument("--num_augment", help="what augment to do",
191
                        type=int, required=False, default=0)  # one of 'flip, all'
192
193
    # Params for secondary prediction
194
    parser.add_argument("--config2_dir", help="specifies config dir path",
195
                        type=str, required=False, default=None)
196
    parser.add_argument("--preprocess2", help="what preprocess to do",
197
                        type=str, required=False, default=None)
198
    parser.add_argument("--augment2", help="what augment to do",
199
                        type=str, required=False, default=None)  # one of 'flip, all'
200
    parser.add_argument("--num_augment2", help="what augment to do",
201
                        type=int, required=False, default=0)  # one of 'flip, all'
202
203
    opts = parser.parse_args()
204
205
    Path(opts.output_folder).mkdir(exist_ok=True)
206
207
    # 1
208
    _config, _norm_params, _model_path = get_params(opts.config_dir)
209
    # 2
210
    if opts.config2_dir is not None:
211
        _config2, _norm_params2, _model2_path = get_params(opts.config2_dir)
212
    else:
213
        _config2, _norm_params2, _model2_path = None, None, None
214
215
    main(opts.input_nii, opts.output_folder, overlap_factor=opts.overlap_factor,
216
         config=_config, model_path=_model_path, preprocess_method=opts.preprocess, norm_params=_norm_params, augment=opts.augment,
217
         num_augment=opts.num_augment,
218
         config2=_config2, model2_path=_model2_path, preprocess_method2=opts.preprocess2, norm_params2=_norm_params2, augment2=opts.augment2,
219
         num_augment2=opts.num_augment2,
220
         z_scale=opts.z_scale, xy_scale=opts.xy_scale, return_all_preds=opts.return_all_preds)