|
a |
|
b/fetal/predict2.py |
|
|
1 |
import json |
|
|
2 |
import os |
|
|
3 |
from glob import glob |
|
|
4 |
from pathlib import Path |
|
|
5 |
|
|
|
6 |
import numpy as np |
|
|
7 |
import nibabel as nib |
|
|
8 |
|
|
|
9 |
from brats.preprocess import window_intensities_data |
|
|
10 |
from fetal_net.utils.utils import get_image |
|
|
11 |
from fetal_net.postprocess import postprocess_prediction as process_pred |
|
|
12 |
from brats.utils import get_last_model_path |
|
|
13 |
from fetal_net.normalize import normalize_data |
|
|
14 |
from fetal_net.prediction import run_validation_cases, patch_wise_prediction |
|
|
15 |
import argparse |
|
|
16 |
|
|
|
17 |
from fetal_net.training import load_old_model |
|
|
18 |
from fetal_net.utils.cut_relevant_areas import find_bounding_box, cut_bounding_box, check_bounding_box |
|
|
19 |
|
|
|
20 |
original_data_folder = '../Datasets/Fetus' |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
def main(pred_dir, config, split='test', overlap_factor=1, preprocess_method=None): |
|
|
24 |
padding = [16, 16, 8] |
|
|
25 |
prediction2_dir = os.path.abspath(os.path.join(config['base_dir'], 'predictions2', split)) |
|
|
26 |
model = load_old_model(get_last_model_path(config["model_file"])) |
|
|
27 |
with open(os.path.join(opts.config_dir, 'norm_params.json'), 'r') as f: |
|
|
28 |
norm_params = json.load(f) |
|
|
29 |
|
|
|
30 |
for sample_folder in glob(os.path.join(pred_dir, split, '*')): |
|
|
31 |
mask_path = os.path.join(sample_folder, 'prediction.nii.gz') |
|
|
32 |
truth_path = os.path.join(sample_folder, 'truth.nii.gz') |
|
|
33 |
|
|
|
34 |
subject_id = Path(sample_folder).name |
|
|
35 |
dest_folder = os.path.join(prediction2_dir, subject_id) |
|
|
36 |
Path(dest_folder).mkdir(parents=True, exist_ok=True) |
|
|
37 |
|
|
|
38 |
truth = nib.load(truth_path) |
|
|
39 |
nib.save(truth, os.path.join(dest_folder, Path(truth_path).name)) |
|
|
40 |
|
|
|
41 |
mask = nib.load(mask_path) |
|
|
42 |
mask = process_pred(mask.get_data(), gaussian_std=0.5, threshold=0.5) |
|
|
43 |
bbox_start, bbox_end = find_bounding_box(mask) |
|
|
44 |
check_bounding_box(mask, bbox_start, bbox_end) |
|
|
45 |
if padding is not None: |
|
|
46 |
bbox_start = np.maximum(bbox_start - padding, 0) |
|
|
47 |
bbox_end = np.minimum(bbox_end + padding, mask.shape) |
|
|
48 |
print("BBox: {}-{}".format(bbox_start, bbox_end)) |
|
|
49 |
|
|
|
50 |
volume = nib.load(os.path.join(original_data_folder, subject_id, 'volume.nii')) |
|
|
51 |
orig_volume_shape = np.array(volume.get_data().shape) |
|
|
52 |
volume = cut_bounding_box(volume, bbox_start, bbox_end).get_data().astype(np.float) |
|
|
53 |
|
|
|
54 |
if preprocess_method is not None: |
|
|
55 |
print('Applying preprocess by {}...'.format(preprocess_method)) |
|
|
56 |
if preprocess_method == 'window_1_99': |
|
|
57 |
volume = window_intensities_data(volume) |
|
|
58 |
else: |
|
|
59 |
raise Exception('Unknown preprocess: {}'.format(preprocess_method)) |
|
|
60 |
|
|
|
61 |
if norm_params is not None and any(norm_params.values()): |
|
|
62 |
volume = normalize_data(volume, mean=norm_params['mean'], std=norm_params['std']) |
|
|
63 |
|
|
|
64 |
prediction = patch_wise_prediction( |
|
|
65 |
model=model, data=np.expand_dims(volume, 0), |
|
|
66 |
patch_shape=config["patch_shape"] + [config["patch_depth"]], |
|
|
67 |
overlap_factor=overlap_factor |
|
|
68 |
) |
|
|
69 |
prediction = prediction.squeeze() |
|
|
70 |
|
|
|
71 |
padding2 = list(zip(bbox_start, orig_volume_shape - bbox_end)) |
|
|
72 |
print(padding2) |
|
|
73 |
prediction = np.pad(prediction, padding2, mode='constant', constant_values=0) |
|
|
74 |
assert all([s1 == s2 for s1, s2 in zip(prediction.shape, orig_volume_shape)]) |
|
|
75 |
prediction = get_image(prediction) |
|
|
76 |
nib.save(prediction, os.path.join(dest_folder, Path(mask_path).name)) |
|
|
77 |
|
|
|
78 |
|
|
|
79 |
if __name__ == "__main__": |
|
|
80 |
parser = argparse.ArgumentParser() |
|
|
81 |
parser.add_argument("--pred_dir", help="specifies config dir path", |
|
|
82 |
type=str, required=True) |
|
|
83 |
parser.add_argument("--config_dir", help="specifies config dir path", |
|
|
84 |
type=str, required=True) |
|
|
85 |
parser.add_argument("--preprocess", help="specifies config dir path", |
|
|
86 |
type=str, required=True) |
|
|
87 |
parser.add_argument("--split", help="What split to predict on? (test/val)", |
|
|
88 |
type=str, default='test') |
|
|
89 |
parser.add_argument("--overlap_factor", help="specifies overlap between prediction patches", |
|
|
90 |
type=float, default=1) |
|
|
91 |
opts = parser.parse_args() |
|
|
92 |
|
|
|
93 |
with open(os.path.join(opts.config_dir, 'config.json')) as f: |
|
|
94 |
config = json.load(f) |
|
|
95 |
|
|
|
96 |
main(opts.pred_dir, config, opts.split, opts.overlap_factor, opts.preprocess) |