|
a |
|
b/inference_v2.py |
|
|
1 |
import os |
|
|
2 |
import numpy as np |
|
|
3 |
import torch |
|
|
4 |
from batchgenerators.utilities.file_and_folder_operations import * |
|
|
5 |
from nnunet.dataset_conversion.Task500_BraTS_2021 import apply_threshold_to_folder |
|
|
6 |
from nnunet.dataset_conversion.Task032_BraTS_2018 import convert_labels_back_to_BraTS_2018_2019_convention |
|
|
7 |
import shutil |
|
|
8 |
|
|
|
9 |
def main(): |
|
|
10 |
input_folder = '/input' |
|
|
11 |
output_folder = '/output' |
|
|
12 |
|
|
|
13 |
tmp_input_folder = '/tmp_input' |
|
|
14 |
tmp_output_folder = '/tmp_output' |
|
|
15 |
maybe_mkdir_p(tmp_input_folder) |
|
|
16 |
os.system("export RESULTS_FOLDER=/usr/local/bin/trained_models/") |
|
|
17 |
#convert raw data to nnunet format |
|
|
18 |
contrast_to_number = {'t1':'0000', 't1ce':'0001','t2':'0002','flair':'0003'} |
|
|
19 |
for p in subfiles(input_folder, join=False): |
|
|
20 |
tokens = p.split('_') |
|
|
21 |
patient_id = tokens[0] + "_" + tokens[1] |
|
|
22 |
contrast = tokens[-1].split('.')[0] |
|
|
23 |
shutil.copy(join(input_folder, p), join(tmp_input_folder, patient_id + "_" + contrast_to_number[contrast] + ".nii.gz")) |
|
|
24 |
|
|
|
25 |
#run nnunet inference |
|
|
26 |
tmp_output_folder_BL = join(tmp_output_folder,'raw_output_1') |
|
|
27 |
tmp_output_folder_BL_L_GN = join(tmp_output_folder,'raw_output_2') |
|
|
28 |
tmp_output_folder_ensemble = join(tmp_output_folder,'ensemble') |
|
|
29 |
os.system("nnUNet_predict -i {} -o {} -t 500 -m 3d_fullres -tr nnUNetTrainerV2BraTSRegions_DA4_BN_BD --save_npz".format(tmp_input_folder, tmp_output_folder_BL)) |
|
|
30 |
os.system("nnUNet_predict -i {} -o {} -t 500 -m 3d_fullres -tr nnUNetTrainerV2BraTSRegions_DA4_BN_BD_largeUnet_Groupnorm --save_npz".format(tmp_input_folder, tmp_output_folder_BL_L_GN)) |
|
|
31 |
os.system("nnUNet_ensemble -f {} {} -o {}".format(tmp_output_folder_BL,tmp_output_folder_BL_L_GN,tmp_output_folder_ensemble)) |
|
|
32 |
apply_threshold_to_folder(tmp_output_folder_ensemble, join(tmp_output_folder, 'pp_output'), 200, 2) |
|
|
33 |
convert_labels_back_to_BraTS_2018_2019_convention(join(tmp_output_folder,'pp_output'), join(tmp_output_folder,'pp_output_converted')) |
|
|
34 |
|
|
|
35 |
for p in subfiles(join(tmp_output_folder,'pp_output_converted'), join=False): |
|
|
36 |
patient_id = p.split('_')[1].split('.')[0] |
|
|
37 |
shutil.copy(join(tmp_output_folder,'pp_output_converted',p),join(output_folder,patient_id + ".nii.gz")) |
|
|
38 |
|
|
|
39 |
if __name__ == '__main__': |
|
|
40 |
main() |