|
a |
|
b/fetal/predict.py |
|
|
1 |
import argparse |
|
|
2 |
import json |
|
|
3 |
import os |
|
|
4 |
|
|
|
5 |
from fetal_net.prediction import run_validation_cases |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
def main(config, split='test', overlap_factor=1, use_augmentations=False): |
|
|
9 |
prediction_dir = os.path.abspath(os.path.join(config['base_dir'], 'predictions', split)) |
|
|
10 |
|
|
|
11 |
indices_file = { |
|
|
12 |
"test": config["test_file"], |
|
|
13 |
"val": config["validation_file"], |
|
|
14 |
"train": config["training_file"] |
|
|
15 |
}[split] |
|
|
16 |
run_validation_cases(validation_keys_file=indices_file, |
|
|
17 |
model_file=config["model_file"], |
|
|
18 |
training_modalities=config["training_modalities"], |
|
|
19 |
hdf5_file=config["data_file"], |
|
|
20 |
output_dir=prediction_dir, |
|
|
21 |
overlap_factor=overlap_factor, |
|
|
22 |
patch_shape=config["patch_shape"] + [config["patch_depth"]], |
|
|
23 |
prev_truth_index=config["prev_truth_index"], |
|
|
24 |
prev_truth_size=config["prev_truth_size"], |
|
|
25 |
use_augmentations=use_augmentations) |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
if __name__ == "__main__": |
|
|
29 |
parser = argparse.ArgumentParser() |
|
|
30 |
parser.add_argument("--config_dir", help="specifies config dir path", |
|
|
31 |
type=str, required=True) |
|
|
32 |
parser.add_argument("--split", help="What split to predict on? (test/val)", |
|
|
33 |
type=str, default='test') |
|
|
34 |
parser.add_argument("--overlap_factor", help="specifies overlap between prediction patches", |
|
|
35 |
type=float, default=0.9) |
|
|
36 |
parser.add_argument("--use_augmentation", help="1 to use predict-time augmentations", |
|
|
37 |
type=float, default=0) |
|
|
38 |
opts = parser.parse_args() |
|
|
39 |
|
|
|
40 |
with open(os.path.join(opts.config_dir, 'config.json')) as f: |
|
|
41 |
config = json.load(f) |
|
|
42 |
|
|
|
43 |
main(config, opts.split, opts.overlap_factor, use_augmentations=opts.use_augmentation) |