Diff of /inference_v2.py [000000] .. [4be099]

Switch to unified view

a b/inference_v2.py
1
import os
2
import numpy as np
3
import torch
4
from batchgenerators.utilities.file_and_folder_operations import *
5
from nnunet.dataset_conversion.Task500_BraTS_2021 import apply_threshold_to_folder
6
from nnunet.dataset_conversion.Task032_BraTS_2018 import convert_labels_back_to_BraTS_2018_2019_convention
7
import shutil
8
9
def main():
10
    input_folder = '/input'
11
    output_folder = '/output'
12
13
    tmp_input_folder = '/tmp_input'
14
    tmp_output_folder = '/tmp_output'
15
    maybe_mkdir_p(tmp_input_folder)
16
    os.system("export RESULTS_FOLDER=/usr/local/bin/trained_models/")
17
    #convert raw data to nnunet format
18
    contrast_to_number = {'t1':'0000', 't1ce':'0001','t2':'0002','flair':'0003'}
19
    for p in subfiles(input_folder, join=False):
20
        tokens = p.split('_')
21
        patient_id = tokens[0] + "_" + tokens[1]
22
        contrast = tokens[-1].split('.')[0]
23
        shutil.copy(join(input_folder, p), join(tmp_input_folder, patient_id + "_" + contrast_to_number[contrast] + ".nii.gz"))
24
25
    #run nnunet inference
26
    tmp_output_folder_BL = join(tmp_output_folder,'raw_output_1')
27
    tmp_output_folder_BL_L_GN = join(tmp_output_folder,'raw_output_2')
28
    tmp_output_folder_ensemble = join(tmp_output_folder,'ensemble')
29
    os.system("nnUNet_predict -i {} -o {} -t 500 -m 3d_fullres -tr nnUNetTrainerV2BraTSRegions_DA4_BN_BD --save_npz".format(tmp_input_folder, tmp_output_folder_BL))
30
    os.system("nnUNet_predict -i {} -o {} -t 500 -m 3d_fullres -tr nnUNetTrainerV2BraTSRegions_DA4_BN_BD_largeUnet_Groupnorm --save_npz".format(tmp_input_folder, tmp_output_folder_BL_L_GN))
31
    os.system("nnUNet_ensemble -f {} {} -o {}".format(tmp_output_folder_BL,tmp_output_folder_BL_L_GN,tmp_output_folder_ensemble))
32
    apply_threshold_to_folder(tmp_output_folder_ensemble, join(tmp_output_folder, 'pp_output'), 200, 2)
33
    convert_labels_back_to_BraTS_2018_2019_convention(join(tmp_output_folder,'pp_output'), join(tmp_output_folder,'pp_output_converted'))
34
35
    for p in subfiles(join(tmp_output_folder,'pp_output_converted'), join=False):
36
        patient_id = p.split('_')[1].split('.')[0]
37
        shutil.copy(join(tmp_output_folder,'pp_output_converted',p),join(output_folder,patient_id + ".nii.gz"))
38
39
if __name__ == '__main__':
40
    main()