Diff of /sybil/predict.py [000000] .. [d9566e]

Switch to side-by-side view

--- a
+++ b/sybil/predict.py
@@ -0,0 +1,172 @@
+#!/usr/bin/env python
+
+__doc__ = """
+Use Sybil to run inference on a single exam.
+"""
+
+import argparse
+import json
+import os
+import pickle
+import typing
+from typing import Literal
+
+import sybil.utils.logging_utils
+import sybil.datasets.utils
+from sybil import Serie, Sybil, visualize_attentions, __version__
+
+
+def _get_parser():
+    description = __doc__ + f"\nVersion: {__version__}\n"
+    parser = argparse.ArgumentParser(description=description)
+
+    parser.add_argument(
+        "image_dir",
+        default=None,
+        help="Path to directory containing DICOM/PNG files (from a single exam) to run inference on. "
+             "Every file in the directory will be included.",
+    )
+
+    parser.add_argument(
+        "--output-dir",
+        default="sybil_result",
+        dest="output_dir",
+        help="Output directory in which to save prediction results. "
+        "Prediction will be printed to stdout as well.",
+    )
+
+    parser.add_argument(
+        "--return-attentions",
+        default=False,
+        action="store_true",
+        help="Return hidden vectors and attention scores, write them to a pickle file.",
+    )
+
+    parser.add_argument(
+        "--write-attention-images",
+        default=False,
+        action="store_true",
+        help="Generate images with attention overlap. Sets --return-attentions (if not already set).",
+    )
+
+    parser.add_argument(
+        "--file-type",
+        default="auto",
+        dest="file_type",
+        choices={"dicom", "png", "auto"},
+        help="File type of input images. "
+        "If not provided, the file type will be inferred from input extensions.",
+    )
+
+    parser.add_argument(
+        "--model-name",
+        default="sybil_ensemble",
+        dest="model_name",
+        help="Name of the model to use for prediction. Default: sybil_ensemble",
+    )
+
+    parser.add_argument("-l", "--log", "--loglevel", "--log-level",
+                        default="INFO", dest="loglevel")
+
+    parser.add_argument('--threads', type=int, default=0,
+                        help="Number of threads to use for PyTorch inference. "
+                             "Default is 0 (use all available cores)."
+                             "Set to a negative number to use Pytorch default.")
+
+    parser.add_argument("-v", "--version", action="version", version=__version__)
+
+    return parser
+
+
+def predict(
+    image_dir,
+    output_dir,
+    model_name="sybil_ensemble",
+    return_attentions=False,
+    write_attention_images=False,
+    file_type: Literal["auto", "dicom", "png"] = "auto",
+    threads: int = 0,
+):
+    logger = sybil.utils.logging_utils.get_logger()
+
+    return_attentions |= write_attention_images
+
+    input_files = os.listdir(image_dir)
+    input_files = [os.path.join(image_dir, x) for x in input_files if not x.startswith(".")]
+    input_files = [x for x in input_files if os.path.isfile(x)]
+
+    voxel_spacing = None
+    if file_type == "auto":
+        extensions = {os.path.splitext(x)[1] for x in input_files}
+        extension = extensions.pop()
+        if len(extensions) > 1:
+            raise ValueError(
+                f"Multiple file types found in {image_dir}: {','.join(extensions)}"
+            )
+
+        file_type = "dicom"
+        if extension.lower() in {".png", "png"}:
+            file_type = "png"
+            voxel_spacing = sybil.datasets.utils.VOXEL_SPACING
+            logger.debug(f"Using default voxel spacing: {voxel_spacing}")
+    assert file_type in {"dicom", "png"}
+    file_type = typing.cast(Literal["dicom", "png"], file_type)
+
+    num_files = len(input_files)
+
+    logger.debug(f"Beginning prediction using {num_files} {file_type} files from {image_dir}")
+
+    # Load a trained model
+    model = Sybil(model_name)
+
+    # Get risk scores
+    serie = Serie(input_files, voxel_spacing=voxel_spacing, file_type=file_type)
+    series = [serie]
+    prediction = model.predict(series, return_attentions=return_attentions, threads=threads)
+    prediction_scores = prediction.scores[0]
+
+    logger.debug(f"Prediction finished. Results:\n{prediction_scores}")
+
+    prediction_path = os.path.join(output_dir, "prediction_scores.json")
+    pred_dict = {"predictions": prediction.scores}
+    with open(prediction_path, "w") as f:
+        json.dump(pred_dict, f, indent=2)
+
+    series_with_attention = None
+    if return_attentions:
+        attention_path = os.path.join(output_dir, "attention_scores.pkl")
+        with open(attention_path, "wb") as f:
+            pickle.dump(prediction, f)
+
+    if write_attention_images:
+        series_with_attention = visualize_attentions(
+            series,
+            attentions=prediction.attentions,
+            save_directory=output_dir,
+            gain=3,
+        )
+
+    return pred_dict, series_with_attention
+
+
+def main():
+    args = _get_parser().parse_args()
+    sybil.utils.logging_utils.configure_logger(args.loglevel)
+
+    os.makedirs(args.output_dir, exist_ok=True)
+
+    pred_dict, series_with_attention = predict(
+        args.image_dir,
+        args.output_dir,
+        args.model_name,
+        args.return_attentions,
+        args.write_attention_images,
+        file_type=args.file_type,
+        threads=args.threads,
+    )
+
+    print(json.dumps(pred_dict, indent=2))
+
+
+if __name__ == "__main__":
+    main()