|
a |
|
b/prod/predict_nifti2.py |
|
|
1 |
import argparse |
|
|
2 |
import json |
|
|
3 |
import os |
|
|
4 |
from pathlib import Path |
|
|
5 |
|
|
|
6 |
import nibabel as nib |
|
|
7 |
|
|
|
8 |
import fetal_net.preprocess |
|
|
9 |
from fetal.preprocess import window_intensities_data |
|
|
10 |
from fetal.utils import get_last_model_path |
|
|
11 |
from fetal_net.normalize import normalize_data |
|
|
12 |
from fetal_net.postprocess import postprocess_prediction as process_pred |
|
|
13 |
from fetal_net.prediction import patch_wise_prediction, predict_augment, predict_flips |
|
|
14 |
from fetal_net.preprocess import * |
|
|
15 |
from fetal_net.training import load_old_model |
|
|
16 |
from fetal_net.utils.cut_relevant_areas import find_bounding_box, check_bounding_box |
|
|
17 |
from fetal_net.utils.utils import read_img, get_image |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
def save_nifti(data, path): |
|
|
21 |
nifti = get_image(data) |
|
|
22 |
nib.save(nifti, path) |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
def secondary_prediction(mask, vol, config2, model2_path=None, |
|
|
26 |
preprocess_method2=None, norm_params2=None, |
|
|
27 |
overlap_factor=0.9, augment2=None, num_augment=32, return_all_preds=False): |
|
|
28 |
model2 = load_old_model(get_last_model_path(model2_path), config=config2) |
|
|
29 |
pred = mask |
|
|
30 |
bbox_start, bbox_end = find_bounding_box(pred) |
|
|
31 |
check_bounding_box(pred, bbox_start, bbox_end) |
|
|
32 |
padding = [16, 16, 8] |
|
|
33 |
if padding is not None: |
|
|
34 |
bbox_start = np.maximum(bbox_start - padding, 0) |
|
|
35 |
bbox_end = np.minimum(bbox_end + padding, mask.shape) |
|
|
36 |
data = vol.astype(np.float)[ |
|
|
37 |
bbox_start[0]:bbox_end[0], |
|
|
38 |
bbox_start[1]:bbox_end[1], |
|
|
39 |
bbox_start[2]:bbox_end[2] |
|
|
40 |
] |
|
|
41 |
|
|
|
42 |
data = preproc_and_norm(data, preprocess_method2, norm_params2) |
|
|
43 |
|
|
|
44 |
prediction = get_prediction(data, model2, augment=augment2, num_augments=num_augment, return_all_preds=return_all_preds, |
|
|
45 |
overlap_factor=overlap_factor, config=config2) |
|
|
46 |
|
|
|
47 |
padding2 = list(zip(bbox_start, np.array(vol.shape) - bbox_end)) |
|
|
48 |
if return_all_preds: |
|
|
49 |
padding2 = [(0, 0)] + padding2 |
|
|
50 |
print(padding2) |
|
|
51 |
print(prediction.shape) |
|
|
52 |
prediction = np.pad(prediction, padding2, mode='constant', constant_values=0) |
|
|
53 |
|
|
|
54 |
return prediction |
|
|
55 |
|
|
|
56 |
|
|
|
57 |
def preproc_and_norm(data, preprocess_method=None, norm_params=None, scale=None, preproc=None): |
|
|
58 |
if preprocess_method is not None: |
|
|
59 |
print('Applying preprocess by {}...'.format(preprocess_method)) |
|
|
60 |
if preprocess_method == 'window_1_99': |
|
|
61 |
data = window_intensities_data(data) |
|
|
62 |
else: |
|
|
63 |
raise Exception('Unknown preprocess: {}'.format(preprocess_method)) |
|
|
64 |
|
|
|
65 |
if scale is not None: |
|
|
66 |
data = ndimage.zoom(data, scale) |
|
|
67 |
if preproc is not None: |
|
|
68 |
preproc_func = getattr(fetal_net.preprocess, preproc) |
|
|
69 |
data = preproc_func(data) |
|
|
70 |
|
|
|
71 |
# data = normalize_data(data, mean=data.mean(), std=data.std()) |
|
|
72 |
if norm_params is not None and any(norm_params.values()): |
|
|
73 |
data = normalize_data(data, mean=norm_params['mean'], std=norm_params['std']) |
|
|
74 |
return data |
|
|
75 |
|
|
|
76 |
|
|
|
77 |
def get_prediction(data, model, augment, num_augments, return_all_preds, overlap_factor, config): |
|
|
78 |
if augment is not None: |
|
|
79 |
patch_shape = config["patch_shape"] + [config["patch_depth"]] |
|
|
80 |
if augment == 'all': |
|
|
81 |
prediction = predict_augment(data, model=model, overlap_factor=overlap_factor, num_augments=num_augments, patch_shape=patch_shape) |
|
|
82 |
elif augment == 'flip': |
|
|
83 |
prediction = predict_flips(data, model=model, overlap_factor=overlap_factor, patch_shape=patch_shape, config=config) |
|
|
84 |
else: |
|
|
85 |
raise ("Unknown augmentation {}".format(augment)) |
|
|
86 |
if not return_all_preds: |
|
|
87 |
prediction = np.median(prediction, axis=0) |
|
|
88 |
else: |
|
|
89 |
prediction = \ |
|
|
90 |
patch_wise_prediction(model=model, |
|
|
91 |
data=np.expand_dims(data, 0), |
|
|
92 |
overlap_factor=overlap_factor, |
|
|
93 |
patch_shape=config["patch_shape"] + [config["patch_depth"]]) |
|
|
94 |
prediction = prediction.squeeze() |
|
|
95 |
return prediction |
|
|
96 |
|
|
|
97 |
|
|
|
98 |
def main(input_path, output_path, overlap_factor, |
|
|
99 |
config, model_path, preprocess_method=None, norm_params=None, augment=None, num_augment=0, |
|
|
100 |
config2=None, model2_path=None, preprocess_method2=None, norm_params2=None, augment2=None, num_augment2=0, |
|
|
101 |
z_scale=None, xy_scale=None, return_all_preds=False): |
|
|
102 |
print(model_path) |
|
|
103 |
model = load_old_model(get_last_model_path(model_path), config=config) |
|
|
104 |
print('Loading nifti from {}...'.format(input_path)) |
|
|
105 |
nifti = read_img(input_path) |
|
|
106 |
print('Predicting mask...') |
|
|
107 |
data = nifti.get_fdata().astype(np.float).squeeze() |
|
|
108 |
print('original_shape: ' + str(data.shape)) |
|
|
109 |
scan_name = Path(input_path).name.split('.')[0] |
|
|
110 |
|
|
|
111 |
if (z_scale is None): |
|
|
112 |
z_scale = 1.0 |
|
|
113 |
if (xy_scale is None): |
|
|
114 |
xy_scale = 1.0 |
|
|
115 |
if z_scale != 1.0 or xy_scale != 1.0: |
|
|
116 |
data = ndimage.zoom(data, [xy_scale, xy_scale, z_scale]) |
|
|
117 |
|
|
|
118 |
data = preproc_and_norm(data, preprocess_method, norm_params, |
|
|
119 |
scale=config.get('scale_data', None), |
|
|
120 |
preproc=config.get('preproc', None)) |
|
|
121 |
|
|
|
122 |
save_nifti(data, os.path.join(output_path, scan_name + '_data.nii.gz')) |
|
|
123 |
|
|
|
124 |
data = np.pad(data, 3, 'constant', constant_values=data.min()) |
|
|
125 |
|
|
|
126 |
print('Shape: ' + str(data.shape)) |
|
|
127 |
prediction = get_prediction(data=data, model=model, augment=augment, |
|
|
128 |
num_augments=num_augment, return_all_preds=return_all_preds, |
|
|
129 |
overlap_factor=overlap_factor, config=config) |
|
|
130 |
# unpad |
|
|
131 |
prediction = prediction[3:-3, 3:-3, 3:-3] |
|
|
132 |
|
|
|
133 |
# revert to original size |
|
|
134 |
if config.get('scale_data', None) is not None: |
|
|
135 |
prediction = ndimage.zoom(prediction.squeeze(), np.divide([1, 1, 1], config.get('scale_data', None)), order=0)[..., np.newaxis] |
|
|
136 |
|
|
|
137 |
save_nifti(prediction, os.path.join(output_path, scan_name + '_pred.nii.gz')) |
|
|
138 |
|
|
|
139 |
if z_scale != 1.0 or xy_scale != 1.0: |
|
|
140 |
prediction = ndimage.zoom(prediction.squeeze(), [1.0 / xy_scale, 1.0 / xy_scale, 1.0 / z_scale], order=1)[..., np.newaxis] |
|
|
141 |
|
|
|
142 |
# if prediction.shape[-1] > 1: |
|
|
143 |
# prediction = prediction[..., 1] |
|
|
144 |
if config2 is not None: |
|
|
145 |
prediction = prediction.squeeze() |
|
|
146 |
mask = process_pred(prediction, gaussian_std=0.5, threshold=0.5) # .astype(np.uint8) |
|
|
147 |
nifti = read_img(input_path) |
|
|
148 |
prediction = secondary_prediction(mask, vol=nifti.get_fdata().astype(np.float), |
|
|
149 |
config2=config2, model2_path=model2_path, |
|
|
150 |
preprocess_method2=preprocess_method2, norm_params2=norm_params2, |
|
|
151 |
overlap_factor=overlap_factor, augment2=augment2, num_augment=num_augment2, |
|
|
152 |
return_all_preds=return_all_preds) |
|
|
153 |
save_nifti(prediction, os.path.join(output_path, scan_name + 'pred_roi.nii.gz')) |
|
|
154 |
|
|
|
155 |
print('Saving to {}'.format(output_path)) |
|
|
156 |
print('Finished.') |
|
|
157 |
|
|
|
158 |
|
|
|
159 |
def get_params(config_dir): |
|
|
160 |
with open(os.path.join(config_dir, 'config.json'), 'r') as f: |
|
|
161 |
__config = json.load(f) |
|
|
162 |
with open(os.path.join(config_dir, 'norm_params.json'), 'r') as f: |
|
|
163 |
__norm_params = json.load(f) |
|
|
164 |
__model_path = os.path.join(config_dir, os.path.basename(__config['model_file'])) |
|
|
165 |
return __config, __norm_params, __model_path |
|
|
166 |
|
|
|
167 |
|
|
|
168 |
if __name__ == '__main__': |
|
|
169 |
parser = argparse.ArgumentParser() |
|
|
170 |
parser.add_argument("--input_nii", help="specifies mat file dir path", |
|
|
171 |
type=str, required=True) |
|
|
172 |
parser.add_argument("--output_folder", help="specifies mat file dir path", |
|
|
173 |
type=str, required=True) |
|
|
174 |
parser.add_argument("--overlap_factor", help="specifies overlap between prediction patches", |
|
|
175 |
type=float, default=0.9) |
|
|
176 |
parser.add_argument("--z_scale", help="specifies overlap between prediction patches", |
|
|
177 |
type=float, default=1) |
|
|
178 |
parser.add_argument("--xy_scale", help="specifies overlap between prediction patches", |
|
|
179 |
type=float, default=1) |
|
|
180 |
parser.add_argument("--return_all_preds", help="output std for prediction", |
|
|
181 |
type=int, default=0) |
|
|
182 |
|
|
|
183 |
# Params for primary prediction |
|
|
184 |
parser.add_argument("--config_dir", help="specifies config dir path", |
|
|
185 |
type=str, required=True) |
|
|
186 |
parser.add_argument("--preprocess", help="what preprocess to do", |
|
|
187 |
type=str, required=False, default=None) |
|
|
188 |
parser.add_argument("--augment", help="what augment to do", |
|
|
189 |
type=str, required=False, default=None) # one of 'flip, all' |
|
|
190 |
parser.add_argument("--num_augment", help="what augment to do", |
|
|
191 |
type=int, required=False, default=0) # one of 'flip, all' |
|
|
192 |
|
|
|
193 |
# Params for secondary prediction |
|
|
194 |
parser.add_argument("--config2_dir", help="specifies config dir path", |
|
|
195 |
type=str, required=False, default=None) |
|
|
196 |
parser.add_argument("--preprocess2", help="what preprocess to do", |
|
|
197 |
type=str, required=False, default=None) |
|
|
198 |
parser.add_argument("--augment2", help="what augment to do", |
|
|
199 |
type=str, required=False, default=None) # one of 'flip, all' |
|
|
200 |
parser.add_argument("--num_augment2", help="what augment to do", |
|
|
201 |
type=int, required=False, default=0) # one of 'flip, all' |
|
|
202 |
|
|
|
203 |
opts = parser.parse_args() |
|
|
204 |
|
|
|
205 |
Path(opts.output_folder).mkdir(exist_ok=True) |
|
|
206 |
|
|
|
207 |
# 1 |
|
|
208 |
_config, _norm_params, _model_path = get_params(opts.config_dir) |
|
|
209 |
# 2 |
|
|
210 |
if opts.config2_dir is not None: |
|
|
211 |
_config2, _norm_params2, _model2_path = get_params(opts.config2_dir) |
|
|
212 |
else: |
|
|
213 |
_config2, _norm_params2, _model2_path = None, None, None |
|
|
214 |
|
|
|
215 |
main(opts.input_nii, opts.output_folder, overlap_factor=opts.overlap_factor, |
|
|
216 |
config=_config, model_path=_model_path, preprocess_method=opts.preprocess, norm_params=_norm_params, augment=opts.augment, |
|
|
217 |
num_augment=opts.num_augment, |
|
|
218 |
config2=_config2, model2_path=_model2_path, preprocess_method2=opts.preprocess2, norm_params2=_norm_params2, augment2=opts.augment2, |
|
|
219 |
num_augment2=opts.num_augment2, |
|
|
220 |
z_scale=opts.z_scale, xy_scale=opts.xy_scale, return_all_preds=opts.return_all_preds) |