Switch to side-by-side view

--- a
+++ b/adpkd_segmentation/inference/inference.py
@@ -0,0 +1,122 @@
+from pathlib import Path
+from tqdm import tqdm
+from argparse import ArgumentParser
+
+from inference_utils import (
+    load_config,
+    inference_to_disk,
+    display_volumes,
+    inference_to_nifti,
+)
+
+parser = ArgumentParser()
+parser.add_argument(
+    "--config_path",
+    type=str,
+    help="path to config file for inference pipeline",
+    default="checkpoints/inference.yml",
+)
+
+parser.add_argument(
+    "-i",
+    "--inference_path",
+    type=str,
+    help="path to input dicom data (replaces path in config file)",
+    default=None,
+)
+
+parser.add_argument(
+    "-o",
+    "--output_path",
+    type=str,
+    help="path to output location",
+    default=None,
+)
+
+
+def run_inference(
+    config_path="checkpoints/inference.yml",
+    inference_path=None,
+    saved_inference="saved_inference",
+    saved_figs="saved_figs",
+):
+
+    # %%
+    # Run inferences
+    print("Enter run inference...")
+    model_args = load_config(
+        config_path=config_path, inference_path=inference_path
+    )
+
+    if saved_inference is not None:
+        model_args["save_dir"] = saved_inference
+    # load_config initializes all objects including:
+    # model and datloader for InferenceDataset
+
+    inference_to_disk(**model_args)
+
+    # %%
+    # Creating figures for all inferences
+
+    # Get all model inferences
+    inference_files = list(Path(saved_inference).glob("**/*"))
+
+    # Folders are of form 'saved_inference/adpkd-segmentation/{PATIENT-ID}/{SERIES}'
+    folders = [
+        f
+        for f in inference_files
+        if len(f.parts) >= 4 and f.parts[-4] == "saved_inference"
+    ]
+    folders = list(set(folders))
+
+    IDX_series = -1
+    IDX_ID = -2
+
+    saved_folders = [
+        Path(saved_figs) / f"{d.parts[IDX_ID]}_{d.parts[IDX_series]}"
+        for d in folders
+    ]
+    # %%
+    # Generate figures for all inferences
+    print("Creating Figures and Nifti outputs...")
+    for study_dir, save_dir in tqdm(list(zip(folders, saved_folders))):
+        try:
+            # Save inference figure to save_dir
+            display_volumes(
+                study_dir=study_dir,
+                style="pred",
+                plot_error=True,
+                skip_display=False,
+                save_dir=save_dir,
+            )
+
+            inference_to_nifti(inference_dir=study_dir)
+
+        except Exception as e:
+            print(e)
+
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+
+    config_path = Path(args.config_path)
+    inference_path = args.inference_path
+    output_path = args.output_path
+
+    saved_inference = "saved_inference"
+    saved_figs = "saved_figs"
+
+    if inference_path is not None:
+        inference_path = Path(inference_path)
+
+    if output_path is not None:
+        # update with output folder path
+        saved_inference = Path(output_path) / saved_inference
+        saved_figs = Path(output_path) / saved_figs
+
+    run_inference(
+        config_path=config_path,
+        inference_path=inference_path,
+        saved_inference=saved_inference,
+        saved_figs=saved_figs,
+    )