Switch to unified view

a b/scripts/commands/SynthSeg_predict.py
1
"""
2
This script enables to launch predictions with SynthSeg from the terminal.
3
4
If you use this code, please cite one of the SynthSeg papers:
5
https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib
6
7
Copyright 2020 Benjamin Billot
8
9
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
10
compliance with the License. You may obtain a copy of the License at
11
https://www.apache.org/licenses/LICENSE-2.0
12
Unless required by applicable law or agreed to in writing, software distributed under the License is
13
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
implied. See the License for the specific language governing permissions and limitations under the
15
License.
16
"""
17
18
# python imports
19
import os
20
import sys
21
from argparse import ArgumentParser
22
23
# add main folder to python path and import ./SynthSeg/predict_synthseg.py
24
synthseg_home = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))))
25
sys.path.append(synthseg_home)
26
model_dir = os.path.join(synthseg_home, 'models')
27
labels_dir = os.path.join(synthseg_home, 'data/labels_classes_priors')
28
from SynthSeg.predict_synthseg import predict
29
30
31
# parse arguments
32
parser = ArgumentParser(description="SynthSeg", epilog='\n')
33
34
# input/outputs
35
parser.add_argument("--i", help="Image(s) to segment. Can be a path to an image or to a folder.")
36
parser.add_argument("--o", help="Segmentation output(s). Must be a folder if --i designates a folder.")
37
parser.add_argument("--parc", action="store_true", help="(optional) Whether to perform cortex parcellation.")
38
parser.add_argument("--robust", action="store_true", help="(optional) Whether to use robust predictions (slower).")
39
parser.add_argument("--fast", action="store_true", help="(optional) Bypass some postprocessing for faster predictions.")
40
parser.add_argument("--ct", action="store_true", help="(optional) Clip intensities to [0,80] for CT scans.")
41
parser.add_argument("--vol", help="(optional) Path to output CSV file with volumes (mm3) for all regions and subjects.")
42
parser.add_argument("--qc", help="(optional) Path to output CSV file with qc scores for all subjects.")
43
parser.add_argument("--post", help="(optional) Posteriors output(s). Must be a folder if --i designates a folder.")
44
parser.add_argument("--resample", help="(optional) Resampled image(s). Must be a folder if --i designates a folder.")
45
parser.add_argument("--crop", nargs='+', type=int, help="(optional) Size of 3D patches to analyse. Default is 192.")
46
parser.add_argument("--threads", type=int, default=1, help="(optional) Number of cores to be used. Default is 1.")
47
parser.add_argument("--cpu", action="store_true", help="(optional) Enforce running with CPU rather than GPU.")
48
parser.add_argument("--v1", action="store_true", help="(optional) Use SynthSeg 1.0 (updated 25/06/22).")
49
50
# check for no arguments
51
if len(sys.argv) < 2:
52
    parser.print_help()
53
    sys.exit(1)
54
55
# parse commandline
56
args = vars(parser.parse_args())
57
58
# print SynthSeg version and checks boolean params for SynthSeg-robust
59
if args['robust']:
60
    args['fast'] = True
61
    assert not args['v1'], 'The flag --v1 cannot be used with --robust since SynthSeg-robust only came out with 2.0.'
62
    version = 'SynthSeg-robust 2.0'
63
else:
64
    version = 'SynthSeg 1.0' if args['v1'] else 'SynthSeg 2.0'
65
    if args['fast']:
66
        version += ' (fast)'
67
print('\n' + version + '\n')
68
69
# enforce CPU processing if necessary
70
if args['cpu']:
71
    print('using CPU, hiding all CUDA_VISIBLE_DEVICES')
72
    os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
73
74
# limit the number of threads to be used if running on CPU
75
import tensorflow as tf
76
if args['threads'] == 1:
77
    print('using 1 thread')
78
else:
79
    print('using %s threads' % args['threads'])
80
tf.config.threading.set_inter_op_parallelism_threads(args['threads'])
81
tf.config.threading.set_intra_op_parallelism_threads(args['threads'])
82
83
# path models
84
if args['robust']:
85
    args['path_model_segmentation'] = os.path.join(model_dir, 'synthseg_robust_2.0.h5')
86
else:
87
    args['path_model_segmentation'] = os.path.join(model_dir, 'synthseg_2.0.h5')
88
args['path_model_parcellation'] = os.path.join(model_dir, 'synthseg_parc_2.0.h5')
89
args['path_model_qc'] = os.path.join(model_dir, 'synthseg_qc_2.0.h5')
90
91
# path labels
92
args['labels_segmentation'] = os.path.join(labels_dir, 'synthseg_segmentation_labels_2.0.npy')
93
args['labels_denoiser'] = os.path.join(labels_dir, 'synthseg_denoiser_labels_2.0.npy')
94
args['labels_parcellation'] = os.path.join(labels_dir, 'synthseg_parcellation_labels.npy')
95
args['labels_qc'] = os.path.join(labels_dir, 'synthseg_qc_labels_2.0.npy')
96
args['names_segmentation_labels'] = os.path.join(labels_dir, 'synthseg_segmentation_names_2.0.npy')
97
args['names_parcellation_labels'] = os.path.join(labels_dir, 'synthseg_parcellation_names.npy')
98
args['names_qc_labels'] = os.path.join(labels_dir, 'synthseg_qc_names_2.0.npy')
99
args['topology_classes'] = os.path.join(labels_dir, 'synthseg_topological_classes_2.0.npy')
100
args['n_neutral_labels'] = 19
101
102
# use previous model if needed
103
if args['v1']:
104
    args['path_model_segmentation'] = os.path.join(model_dir, 'synthseg_1.0.h5')
105
    args['labels_segmentation'] = args['labels_segmentation'].replace('_2.0.npy', '.npy')
106
    args['labels_qc'] = args['labels_qc'].replace('_2.0.npy', '.npy')
107
    args['names_segmentation_labels'] = args['names_segmentation_labels'].replace('_2.0.npy', '.npy')
108
    args['names_qc_labels'] = args['names_qc_labels'].replace('_2.0.npy', '.npy')
109
    args['topology_classes'] = args['topology_classes'].replace('_2.0.npy', '.npy')
110
    args['n_neutral_labels'] = 18
111
112
# run prediction
113
predict(path_images=args['i'],
114
        path_segmentations=args['o'],
115
        path_model_segmentation=args['path_model_segmentation'],
116
        labels_segmentation=args['labels_segmentation'],
117
        robust=args['robust'],
118
        fast=args['fast'],
119
        v1=args['v1'],
120
        do_parcellation=args['parc'],
121
        n_neutral_labels=args['n_neutral_labels'],
122
        names_segmentation=args['names_segmentation_labels'],
123
        labels_denoiser=args['labels_denoiser'],
124
        path_posteriors=args['post'],
125
        path_resampled=args['resample'],
126
        path_volumes=args['vol'],
127
        path_model_parcellation=args['path_model_parcellation'],
128
        labels_parcellation=args['labels_parcellation'],
129
        names_parcellation=args['names_parcellation_labels'],
130
        path_model_qc=args['path_model_qc'],
131
        labels_qc=args['labels_qc'],
132
        path_qc_scores=args['qc'],
133
        names_qc=args['names_qc_labels'],
134
        cropping=args['crop'],
135
        topology_classes=args['topology_classes'],
136
        ct=args['ct'])