a b/fetal/predict2.py
1
import json
2
import os
3
from glob import glob
4
from pathlib import Path
5
6
import numpy as np
7
import nibabel as nib
8
9
from brats.preprocess import window_intensities_data
10
from fetal_net.utils.utils import get_image
11
from fetal_net.postprocess import postprocess_prediction as process_pred
12
from brats.utils import get_last_model_path
13
from fetal_net.normalize import normalize_data
14
from fetal_net.prediction import run_validation_cases, patch_wise_prediction
15
import argparse
16
17
from fetal_net.training import load_old_model
18
from fetal_net.utils.cut_relevant_areas import find_bounding_box, cut_bounding_box, check_bounding_box
19
20
original_data_folder = '../Datasets/Fetus'
21
22
23
def main(pred_dir, config, split='test', overlap_factor=1, preprocess_method=None):
24
    padding = [16, 16, 8]
25
    prediction2_dir = os.path.abspath(os.path.join(config['base_dir'], 'predictions2', split))
26
    model = load_old_model(get_last_model_path(config["model_file"]))
27
    with open(os.path.join(opts.config_dir, 'norm_params.json'), 'r') as f:
28
        norm_params = json.load(f)
29
30
    for sample_folder in glob(os.path.join(pred_dir, split, '*')):
31
        mask_path = os.path.join(sample_folder, 'prediction.nii.gz')
32
        truth_path = os.path.join(sample_folder, 'truth.nii.gz')
33
34
        subject_id = Path(sample_folder).name
35
        dest_folder = os.path.join(prediction2_dir, subject_id)
36
        Path(dest_folder).mkdir(parents=True, exist_ok=True)
37
38
        truth = nib.load(truth_path)
39
        nib.save(truth, os.path.join(dest_folder, Path(truth_path).name))
40
41
        mask = nib.load(mask_path)
42
        mask = process_pred(mask.get_data(), gaussian_std=0.5, threshold=0.5)
43
        bbox_start, bbox_end = find_bounding_box(mask)
44
        check_bounding_box(mask, bbox_start, bbox_end)
45
        if padding is not None:
46
            bbox_start = np.maximum(bbox_start - padding, 0)
47
            bbox_end = np.minimum(bbox_end + padding, mask.shape)
48
        print("BBox: {}-{}".format(bbox_start, bbox_end))
49
50
        volume = nib.load(os.path.join(original_data_folder, subject_id, 'volume.nii'))
51
        orig_volume_shape = np.array(volume.get_data().shape)
52
        volume = cut_bounding_box(volume, bbox_start, bbox_end).get_data().astype(np.float)
53
54
        if preprocess_method is not None:
55
            print('Applying preprocess by {}...'.format(preprocess_method))
56
            if preprocess_method == 'window_1_99':
57
                volume = window_intensities_data(volume)
58
            else:
59
                raise Exception('Unknown preprocess: {}'.format(preprocess_method))
60
61
        if norm_params is not None and any(norm_params.values()):
62
            volume = normalize_data(volume, mean=norm_params['mean'], std=norm_params['std'])
63
64
        prediction = patch_wise_prediction(
65
            model=model, data=np.expand_dims(volume, 0),
66
            patch_shape=config["patch_shape"] + [config["patch_depth"]],
67
            overlap_factor=overlap_factor
68
        )
69
        prediction = prediction.squeeze()
70
71
        padding2 = list(zip(bbox_start, orig_volume_shape - bbox_end))
72
        print(padding2)
73
        prediction = np.pad(prediction, padding2, mode='constant', constant_values=0)
74
        assert all([s1 == s2 for s1, s2 in zip(prediction.shape, orig_volume_shape)])
75
        prediction = get_image(prediction)
76
        nib.save(prediction, os.path.join(dest_folder, Path(mask_path).name))
77
78
79
if __name__ == "__main__":
80
    parser = argparse.ArgumentParser()
81
    parser.add_argument("--pred_dir", help="specifies config dir path",
82
                        type=str, required=True)
83
    parser.add_argument("--config_dir", help="specifies config dir path",
84
                        type=str, required=True)
85
    parser.add_argument("--preprocess", help="specifies config dir path",
86
                        type=str, required=True)
87
    parser.add_argument("--split", help="What split to predict on? (test/val)",
88
                        type=str, default='test')
89
    parser.add_argument("--overlap_factor", help="specifies overlap between prediction patches",
90
                        type=float, default=1)
91
    opts = parser.parse_args()
92
93
    with open(os.path.join(opts.config_dir, 'config.json')) as f:
94
        config = json.load(f)
95
96
    main(opts.pred_dir, config, opts.split, opts.overlap_factor, opts.preprocess)