--- a +++ b/adpkd_segmentation/inference/inference.py @@ -0,0 +1,122 @@ +from pathlib import Path +from tqdm import tqdm +from argparse import ArgumentParser + +from inference_utils import ( + load_config, + inference_to_disk, + display_volumes, + inference_to_nifti, +) + +parser = ArgumentParser() +parser.add_argument( + "--config_path", + type=str, + help="path to config file for inference pipeline", + default="checkpoints/inference.yml", +) + +parser.add_argument( + "-i", + "--inference_path", + type=str, + help="path to input dicom data (replaces path in config file)", + default=None, +) + +parser.add_argument( + "-o", + "--output_path", + type=str, + help="path to output location", + default=None, +) + + +def run_inference( + config_path="checkpoints/inference.yml", + inference_path=None, + saved_inference="saved_inference", + saved_figs="saved_figs", +): + + # %% + # Run inferences + print("Enter run inference...") + model_args = load_config( + config_path=config_path, inference_path=inference_path + ) + + if saved_inference is not None: + model_args["save_dir"] = saved_inference + # load_config initializes all objects including: + # model and datloader for InferenceDataset + + inference_to_disk(**model_args) + + # %% + # Creating figures for all inferences + + # Get all model inferences + inference_files = list(Path(saved_inference).glob("**/*")) + + # Folders are of form 'saved_inference/adpkd-segmentation/{PATIENT-ID}/{SERIES}' + folders = [ + f + for f in inference_files + if len(f.parts) >= 4 and f.parts[-4] == "saved_inference" + ] + folders = list(set(folders)) + + IDX_series = -1 + IDX_ID = -2 + + saved_folders = [ + Path(saved_figs) / f"{d.parts[IDX_ID]}_{d.parts[IDX_series]}" + for d in folders + ] + # %% + # Generate figures for all inferences + print("Creating Figures and Nifti outputs...") + for study_dir, save_dir in tqdm(list(zip(folders, saved_folders))): + try: + # Save inference figure to save_dir + display_volumes( + study_dir=study_dir, + style="pred", + plot_error=True, + skip_display=False, + save_dir=save_dir, + ) + + inference_to_nifti(inference_dir=study_dir) + + except Exception as e: + print(e) + + +if __name__ == "__main__": + args = parser.parse_args() + + config_path = Path(args.config_path) + inference_path = args.inference_path + output_path = args.output_path + + saved_inference = "saved_inference" + saved_figs = "saved_figs" + + if inference_path is not None: + inference_path = Path(inference_path) + + if output_path is not None: + # update with output folder path + saved_inference = Path(output_path) / saved_inference + saved_figs = Path(output_path) / saved_figs + + run_inference( + config_path=config_path, + inference_path=inference_path, + saved_inference=saved_inference, + saved_figs=saved_figs, + )