Diff of /sybil/model.py [000000] .. [d9566e]

Switch to side-by-side view

--- a
+++ b/sybil/model.py
@@ -0,0 +1,448 @@
+from argparse import Namespace
+from io import BytesIO
+import os
+from typing import NamedTuple, Union, Dict, List, Optional, Tuple
+from urllib.request import urlopen
+from zipfile import ZipFile
+
+import torch
+import numpy as np
+
+from sybil.serie import Serie
+from sybil.models.sybil import SybilNet
+from sybil.models.calibrator import SimpleClassifierGroup
+from sybil.utils.logging_utils import get_logger
+from sybil.utils.device_utils import get_default_device, get_most_free_gpu, get_device_mem_info
+
+
+# Leaving this here for a bit; these are IDs to download the models from Google Drive
+NAME_TO_FILE = {
+    "sybil_base": {
+        "checkpoint": ["28a7cd44f5bcd3e6cc760b65c7e0d54d"],
+        "google_checkpoint_id": ["1ftYbav_BbUBkyR3HFCGnsp-h4uH1yhoz"],
+        "google_calibrator_id": "1F5TOtzueR-ZUvwl8Yv9Svs2NPP5El3HY",
+    },
+    "sybil_1": {
+        "checkpoint": ["28a7cd44f5bcd3e6cc760b65c7e0d54d"],
+        "google_checkpoint_id": ["1ftYbav_BbUBkyR3HFCGnsp-h4uH1yhoz"],
+        "google_calibrator_id": "1F5TOtzueR-ZUvwl8Yv9Svs2NPP5El3HY",
+    },
+    "sybil_2": {
+        "checkpoint": ["56ce1a7d241dc342982f5466c4a9d7ef"],
+        "google_checkpoint_id": ["1rscGi1grSxaVGzn-tqKtuAR3ipo0DWgA"],
+        "google_calibrator_id": "1zKLVYBaiuMOx7p--e2zabs1LbQ-XXxcZ",
+    },
+    "sybil_3": {
+        "checkpoint": ["624407ef8e3a2a009f9fa51f9846fe9a"],
+        "google_checkpoint_id": ["1DV0Ge7n9r8WAvBXyoNRPwyA7VL43csAr"],
+        "google_calibrator_id": "1qh4nawgE2Kjf_H97XuuTpL7XUIX7JOJn",
+    },
+    "sybil_4": {
+        "checkpoint": ["64a91b25f84141d32852e75a3aec7305"],
+        "google_checkpoint_id": ["1Acz_yzdJMpkz3PRrjXy526CjAboMEIHX"],
+        "google_calibrator_id": "1QIvvCYLaesPGMEiE2Up77pKL3ygDdGU2",
+    },
+    "sybil_5": {
+        "checkpoint": ["65fd1f04cb4c5847d86a9ed8ba31ac1a"],
+        "google_checkpoint_id": ["1uV58SD-Qtb6xElTzWPDWWnloH1KB_zrP"],
+        "google_calibrator_id": "1yDq1_A5w-fSdxzq4K2YSBRNcQQkDnH0K",
+    },
+    "sybil_ensemble": {
+        "checkpoint": [
+            "28a7cd44f5bcd3e6cc760b65c7e0d54d",
+            "56ce1a7d241dc342982f5466c4a9d7ef",
+            "624407ef8e3a2a009f9fa51f9846fe9a",
+            "64a91b25f84141d32852e75a3aec7305",
+            "65fd1f04cb4c5847d86a9ed8ba31ac1a",
+        ],
+        "google_checkpoint_id": [
+            "1ftYbav_BbUBkyR3HFCGnsp-h4uH1yhoz",
+            "1rscGi1grSxaVGzn-tqKtuAR3ipo0DWgA",
+            "1DV0Ge7n9r8WAvBXyoNRPwyA7VL43csAr",
+            "1Acz_yzdJMpkz3PRrjXy526CjAboMEIHX",
+            "1uV58SD-Qtb6xElTzWPDWWnloH1KB_zrP",
+        ],
+        "google_calibrator_id": "1FxHNo0HqXYyiUKE_k2bjatVt9e64J9Li",
+    },
+}
+
+CHECKPOINT_URL = os.getenv("SYBIL_CHECKPOINT_URL", "https://github.com/reginabarzilaygroup/Sybil/releases/download/v1.5.0/sybil_checkpoints.zip")
+
+
+class Prediction(NamedTuple):
+    scores: List[List[float]]
+    attentions: List[Dict[str, np.ndarray]] = None
+
+
+class Evaluation(NamedTuple):
+    auc: List[float]
+    c_index: float
+    scores: List[List[float]]
+    attentions: List[Dict[str, np.ndarray]] = None
+
+
+def download_sybil(name, cache) -> Tuple[List[str], str]:
+    """Download trained models and calibrator"""
+    # Create cache folder if not exists
+    cache = os.path.expanduser(cache)
+    os.makedirs(cache, exist_ok=True)
+
+    # Download models
+    model_files = NAME_TO_FILE[name]
+    checkpoints = model_files["checkpoint"]
+    download_calib_path = os.path.join(cache, f"{name}_simple_calibrator.json")
+    have_all_files = os.path.exists(download_calib_path)
+
+    download_model_paths = []
+    for checkpoint in checkpoints:
+        cur_checkpoint_path = os.path.join(cache, f"{checkpoint}.ckpt")
+        have_all_files &= os.path.exists(cur_checkpoint_path)
+        download_model_paths.append(cur_checkpoint_path)
+
+    if not have_all_files:
+        print(f"Downloading models to {cache}")
+        download_and_extract(CHECKPOINT_URL, cache)
+
+    return download_model_paths, download_calib_path
+
+
+def download_and_extract(remote_url: str, local_dir: str) -> List[str]:
+    os.makedirs(local_dir, exist_ok=True)
+    resp = urlopen(remote_url)
+    with ZipFile(BytesIO(resp.read())) as zip_file:
+        all_files_and_dirs = zip_file.namelist()
+        zip_file.extractall(local_dir)
+    return all_files_and_dirs
+
+
+def _torch_set_num_threads(threads) -> int:
+    """
+    Set the number of CPU threads for torch to use.
+    Set to a negative number for no-op.
+    Set to 0 for the number of CPUs.
+    """
+    if threads < 0:
+        return torch.get_num_threads()
+    if threads is None or threads == 0:
+        # I've never seen a benefit to going higher than 8 and sometimes there is a big slowdown
+        threads = min(8, os.cpu_count())
+
+    torch.set_num_threads(threads)
+    return torch.get_num_threads()
+
+
+class Sybil:
+    def __init__(
+        self,
+        name_or_path: Union[List[str], str] = "sybil_ensemble",
+        cache: str = "~/.sybil/",
+        calibrator_path: Optional[str] = None,
+        device: Optional[str] = None,
+    ):
+        """Initialize a trained Sybil model for inference.
+
+        Parameters
+        ----------
+        name_or_path: list or str
+            Alias to a provided pretrained Sybil model or path
+            to a sybil checkpoint.
+        cache: str
+            Directory to download model checkpoints to
+        calibrator_path: str
+            Path to calibrator pickle file corresponding with model
+        device: str
+            If provided, will run inference using this device.
+            By default, uses GPU with the most free memory, if available.
+
+        """
+        self._logger = get_logger()
+        # Download if needed
+        if isinstance(name_or_path, str) and (name_or_path in NAME_TO_FILE):
+            name_or_path, calibrator_path = download_sybil(name_or_path, cache)
+
+        elif not all(os.path.exists(p) for p in name_or_path):
+            raise ValueError(
+                "No saved model or local path: {}".format(
+                    [p for p in name_or_path if not os.path.exists(p)]
+                )
+            )
+
+        # Check calibrator path before continuing
+        if (calibrator_path is not None) and (not os.path.exists(calibrator_path)):
+            raise ValueError(f"Path not found for calibrator {calibrator_path}")
+
+        # Set device.
+        # If set manually, use it and stay there.
+        # Otherwise, pick the most free GPU now and at predict time.
+        self._device_flexible = True
+        if device is not None:
+            self.device = device
+            self._device_flexible = False
+        else:
+            self.device = get_default_device()
+
+        self.ensemble = torch.nn.ModuleList()
+        for path in name_or_path:
+            self.ensemble.append(self.load_model(path))
+        self.to(self.device)
+
+        if calibrator_path is not None:
+            self.calibrator = SimpleClassifierGroup.from_json_grouped(calibrator_path)
+        else:
+            self.calibrator = None
+
+    def load_model(self, path):
+        """Load model from path.
+
+        Parameters
+        ----------
+        path : str
+            Path to a sybil checkpoint.
+
+        Returns
+        -------
+        model
+            Pretrained Sybil model
+        """
+        # Load checkpoint
+        checkpoint = torch.load(path, map_location="cpu")
+        args = checkpoint["args"]
+        self._max_followup = args.max_followup
+        self._censoring_dist = args.censoring_distribution
+        model = SybilNet(args)
+
+        # Remove model from param names
+        state_dict = {k[6:]: v for k, v in checkpoint["state_dict"].items()}
+        model.load_state_dict(state_dict)  # type: ignore
+        if self.device is not None:
+            model.to(self.device)
+
+        # Set eval
+        model.eval()
+        self._logger.info(f"Loaded model from {path}")
+        return model
+
+    def _calibrate(self, scores: np.ndarray) -> np.ndarray:
+        """Calibrate raw predictions
+
+        Parameters
+        ----------
+        scores: np.ndarray
+            risk scores as numpy array
+
+        Returns
+        -------
+            np.ndarray: calibrated risk scores as numpy array
+        """
+        if self.calibrator is None:
+            return scores
+
+        calibrated_scores = []
+        for YEAR in range(scores.shape[1]):
+            probs = scores[:, YEAR].reshape(-1, 1)
+            probs = self.calibrator["Year{}".format(YEAR + 1)].predict_proba(probs)[:, -1]
+            calibrated_scores.append(probs)
+
+        return np.stack(calibrated_scores, axis=1)
+
+    def _predict(
+        self,
+        model: SybilNet,
+        series: Union[Serie, List[Serie]],
+        return_attentions: bool = False,
+    ) -> Prediction:
+        """Run predictions over the given serie(s).
+
+        Parameters
+        ----------
+        model: SybilNet
+            Instance of SybilNet
+        series : Union[Serie, Iterable[Serie]]
+            One or multiple series to run predictions for.
+        return_attentions : bool
+            If True, returns attention scores for each serie. See README for details.
+
+        Returns
+        -------
+        Prediction
+            Output prediction as risk scores.
+
+        """
+        if isinstance(series, Serie):
+            series = [series]
+        elif not isinstance(series, list):
+            raise ValueError("Expected either a Serie object or list of Serie objects.")
+
+        scores: List[List[float]] = []
+        attentions: List[Dict[str, np.ndarray]] = [] if return_attentions else None
+        for serie in series:
+            if not isinstance(serie, Serie):
+                raise ValueError("Expected a list of Serie objects.")
+
+            volume = serie.get_volume()
+            if self.device is not None:
+                volume = volume.to(self.device)
+
+            with torch.no_grad():
+                out = model(volume)
+                score = out["logit"].sigmoid().squeeze(0).cpu().numpy()
+                scores.append(score.tolist())
+                if return_attentions:
+                    attentions.append(
+                        {
+                            "image_attention_1": out["image_attention_1"]
+                            .detach()
+                            .cpu(),
+                            "volume_attention_1": out["volume_attention_1"]
+                            .detach()
+                            .cpu(),
+                            "hidden": out["hidden"]
+                            .detach()
+                            .cpu(),
+                        }
+                    )
+
+        return Prediction(scores=scores, attentions=attentions)
+
+    def predict(
+        self, series: Union[Serie, List[Serie]], return_attentions: bool = False, threads=0,
+    ) -> Prediction:
+        """Run predictions over the given serie(s) and ensemble
+
+        Parameters
+        ----------
+        series : Union[Serie, Iterable[Serie]]
+            One or multiple series to run predictions for.
+        return_attentions : bool
+            If True, returns attention scores for each serie. See README for details.
+        threads : int
+            Number of CPU threads to use for PyTorch inference.
+
+        Returns
+        -------
+        Prediction
+            Output prediction. See details for :class:`~sybil.model.Prediction`".
+
+        """
+
+        # Set CPU threads available to torch
+        num_threads = _torch_set_num_threads(threads)
+        self._logger.debug(f"Using {num_threads} threads for PyTorch inference")
+
+        if self._device_flexible:
+            self.device = self._pick_device()
+            self.to(self.device)
+        self._logger.debug(f"Beginning prediction on device: {self.device}")
+
+        scores = []
+        attentions_ = [] if return_attentions else None
+        attention_keys = None
+        for sybil in self.ensemble:
+            pred = self._predict(sybil, series, return_attentions)
+            scores.append(pred.scores)
+            if return_attentions:
+                attentions_.append(pred.attentions)
+                if attention_keys is None:
+                    attention_keys = pred.attentions[0].keys()
+
+        scores = np.mean(np.array(scores), axis=0)
+        calib_scores = self._calibrate(scores).tolist()
+
+        attentions = None
+        if return_attentions:
+            attentions = []
+            for i in range(len(series)):
+                att = {}
+                for key in attention_keys:
+                    att[key] = np.stack([
+                        attentions_[j][i][key] for j in range(len(self.ensemble))
+                    ])
+                attentions.append(att)
+
+        return Prediction(scores=calib_scores, attentions=attentions)
+
+    def evaluate(
+        self, series: Union[Serie, List[Serie]], return_attentions: bool = False
+    ) -> Evaluation:
+        """Run evaluation over the given serie(s).
+
+        Parameters
+        ----------
+        series : Union[Serie, List[Serie]]
+            One or multiple series to run evaluation for.
+        return_attentions : bool
+            If True, returns attention scores for each serie. See README for details.
+
+        Returns
+        -------
+        Evaluation
+            Output evaluation. See details for :class:`~sybil.model.Evaluation`.
+
+        """
+        from sybil.utils.metrics import get_survival_metrics
+        if isinstance(series, Serie):
+            series = [series]
+        elif not isinstance(series, list):
+            raise ValueError(
+                "Expected either a Serie object or an iterable over Serie objects."
+            )
+
+        # Check all have labels
+        if not all(serie.has_label() for serie in series):
+            raise ValueError("All series must have a label for evaluation")
+
+        # Get scores and labels
+        predictions = self.predict(series, return_attentions)
+        scores = predictions.scores
+        labels = [serie.get_label(self._max_followup) for serie in series]
+
+        # Convert to format for survival metrics
+        input_dict = {
+            "probs": torch.tensor(scores),
+            "censors": torch.tensor([label.censor_time for label in labels]),
+            "golds": torch.tensor([label.y for label in labels]),
+        }
+        args = Namespace(
+            max_followup=self._max_followup, censoring_distribution=self._censoring_dist
+        )
+        out = get_survival_metrics(input_dict, args)
+        auc = [float(out[f"{i + 1}_year_auc"]) for i in range(self._max_followup)]
+        c_index = float(out["c_index"])
+
+        return Evaluation(auc=auc, c_index=c_index, scores=scores, attentions=predictions.attentions)
+
+    def to(self, device: str):
+        """Move model to device.
+
+        Parameters
+        ----------
+        device : str
+            Device to move model to.
+        """
+        self.device = device
+        self.ensemble.to(device)
+
+    def _pick_device(self):
+        """
+        Pick the device to run inference on.
+        This is based on the device with the most free memory, with a preference for remaining
+        on the current device.
+
+        Motivation is to enable multiprocessing without the processes needed to communicate.
+        """
+        if not torch.cuda.is_available():
+            return get_default_device()
+
+        # Get size of the model in memory (approximate)
+        model_mem = 9*sum(p.numel() * p.element_size() for p in self.ensemble.parameters())
+
+        # Check memory available on current device.
+        # If it seems like we're the only thing on this GPU, stay.
+        free_mem, total_mem = get_device_mem_info(self.device)
+        cur_allocated = total_mem - free_mem
+        min_to_move = int(1.01 * model_mem)
+        if cur_allocated < min_to_move:
+            return self.device
+        else:
+            # Otherwise, get the most free GPU
+            return get_most_free_gpu()