--- a +++ b/sybil/model.py @@ -0,0 +1,448 @@ +from argparse import Namespace +from io import BytesIO +import os +from typing import NamedTuple, Union, Dict, List, Optional, Tuple +from urllib.request import urlopen +from zipfile import ZipFile + +import torch +import numpy as np + +from sybil.serie import Serie +from sybil.models.sybil import SybilNet +from sybil.models.calibrator import SimpleClassifierGroup +from sybil.utils.logging_utils import get_logger +from sybil.utils.device_utils import get_default_device, get_most_free_gpu, get_device_mem_info + + +# Leaving this here for a bit; these are IDs to download the models from Google Drive +NAME_TO_FILE = { + "sybil_base": { + "checkpoint": ["28a7cd44f5bcd3e6cc760b65c7e0d54d"], + "google_checkpoint_id": ["1ftYbav_BbUBkyR3HFCGnsp-h4uH1yhoz"], + "google_calibrator_id": "1F5TOtzueR-ZUvwl8Yv9Svs2NPP5El3HY", + }, + "sybil_1": { + "checkpoint": ["28a7cd44f5bcd3e6cc760b65c7e0d54d"], + "google_checkpoint_id": ["1ftYbav_BbUBkyR3HFCGnsp-h4uH1yhoz"], + "google_calibrator_id": "1F5TOtzueR-ZUvwl8Yv9Svs2NPP5El3HY", + }, + "sybil_2": { + "checkpoint": ["56ce1a7d241dc342982f5466c4a9d7ef"], + "google_checkpoint_id": ["1rscGi1grSxaVGzn-tqKtuAR3ipo0DWgA"], + "google_calibrator_id": "1zKLVYBaiuMOx7p--e2zabs1LbQ-XXxcZ", + }, + "sybil_3": { + "checkpoint": ["624407ef8e3a2a009f9fa51f9846fe9a"], + "google_checkpoint_id": ["1DV0Ge7n9r8WAvBXyoNRPwyA7VL43csAr"], + "google_calibrator_id": "1qh4nawgE2Kjf_H97XuuTpL7XUIX7JOJn", + }, + "sybil_4": { + "checkpoint": ["64a91b25f84141d32852e75a3aec7305"], + "google_checkpoint_id": ["1Acz_yzdJMpkz3PRrjXy526CjAboMEIHX"], + "google_calibrator_id": "1QIvvCYLaesPGMEiE2Up77pKL3ygDdGU2", + }, + "sybil_5": { + "checkpoint": ["65fd1f04cb4c5847d86a9ed8ba31ac1a"], + "google_checkpoint_id": ["1uV58SD-Qtb6xElTzWPDWWnloH1KB_zrP"], + "google_calibrator_id": "1yDq1_A5w-fSdxzq4K2YSBRNcQQkDnH0K", + }, + "sybil_ensemble": { + "checkpoint": [ + "28a7cd44f5bcd3e6cc760b65c7e0d54d", + "56ce1a7d241dc342982f5466c4a9d7ef", + "624407ef8e3a2a009f9fa51f9846fe9a", + "64a91b25f84141d32852e75a3aec7305", + "65fd1f04cb4c5847d86a9ed8ba31ac1a", + ], + "google_checkpoint_id": [ + "1ftYbav_BbUBkyR3HFCGnsp-h4uH1yhoz", + "1rscGi1grSxaVGzn-tqKtuAR3ipo0DWgA", + "1DV0Ge7n9r8WAvBXyoNRPwyA7VL43csAr", + "1Acz_yzdJMpkz3PRrjXy526CjAboMEIHX", + "1uV58SD-Qtb6xElTzWPDWWnloH1KB_zrP", + ], + "google_calibrator_id": "1FxHNo0HqXYyiUKE_k2bjatVt9e64J9Li", + }, +} + +CHECKPOINT_URL = os.getenv("SYBIL_CHECKPOINT_URL", "https://github.com/reginabarzilaygroup/Sybil/releases/download/v1.5.0/sybil_checkpoints.zip") + + +class Prediction(NamedTuple): + scores: List[List[float]] + attentions: List[Dict[str, np.ndarray]] = None + + +class Evaluation(NamedTuple): + auc: List[float] + c_index: float + scores: List[List[float]] + attentions: List[Dict[str, np.ndarray]] = None + + +def download_sybil(name, cache) -> Tuple[List[str], str]: + """Download trained models and calibrator""" + # Create cache folder if not exists + cache = os.path.expanduser(cache) + os.makedirs(cache, exist_ok=True) + + # Download models + model_files = NAME_TO_FILE[name] + checkpoints = model_files["checkpoint"] + download_calib_path = os.path.join(cache, f"{name}_simple_calibrator.json") + have_all_files = os.path.exists(download_calib_path) + + download_model_paths = [] + for checkpoint in checkpoints: + cur_checkpoint_path = os.path.join(cache, f"{checkpoint}.ckpt") + have_all_files &= os.path.exists(cur_checkpoint_path) + download_model_paths.append(cur_checkpoint_path) + + if not have_all_files: + print(f"Downloading models to {cache}") + download_and_extract(CHECKPOINT_URL, cache) + + return download_model_paths, download_calib_path + + +def download_and_extract(remote_url: str, local_dir: str) -> List[str]: + os.makedirs(local_dir, exist_ok=True) + resp = urlopen(remote_url) + with ZipFile(BytesIO(resp.read())) as zip_file: + all_files_and_dirs = zip_file.namelist() + zip_file.extractall(local_dir) + return all_files_and_dirs + + +def _torch_set_num_threads(threads) -> int: + """ + Set the number of CPU threads for torch to use. + Set to a negative number for no-op. + Set to 0 for the number of CPUs. + """ + if threads < 0: + return torch.get_num_threads() + if threads is None or threads == 0: + # I've never seen a benefit to going higher than 8 and sometimes there is a big slowdown + threads = min(8, os.cpu_count()) + + torch.set_num_threads(threads) + return torch.get_num_threads() + + +class Sybil: + def __init__( + self, + name_or_path: Union[List[str], str] = "sybil_ensemble", + cache: str = "~/.sybil/", + calibrator_path: Optional[str] = None, + device: Optional[str] = None, + ): + """Initialize a trained Sybil model for inference. + + Parameters + ---------- + name_or_path: list or str + Alias to a provided pretrained Sybil model or path + to a sybil checkpoint. + cache: str + Directory to download model checkpoints to + calibrator_path: str + Path to calibrator pickle file corresponding with model + device: str + If provided, will run inference using this device. + By default, uses GPU with the most free memory, if available. + + """ + self._logger = get_logger() + # Download if needed + if isinstance(name_or_path, str) and (name_or_path in NAME_TO_FILE): + name_or_path, calibrator_path = download_sybil(name_or_path, cache) + + elif not all(os.path.exists(p) for p in name_or_path): + raise ValueError( + "No saved model or local path: {}".format( + [p for p in name_or_path if not os.path.exists(p)] + ) + ) + + # Check calibrator path before continuing + if (calibrator_path is not None) and (not os.path.exists(calibrator_path)): + raise ValueError(f"Path not found for calibrator {calibrator_path}") + + # Set device. + # If set manually, use it and stay there. + # Otherwise, pick the most free GPU now and at predict time. + self._device_flexible = True + if device is not None: + self.device = device + self._device_flexible = False + else: + self.device = get_default_device() + + self.ensemble = torch.nn.ModuleList() + for path in name_or_path: + self.ensemble.append(self.load_model(path)) + self.to(self.device) + + if calibrator_path is not None: + self.calibrator = SimpleClassifierGroup.from_json_grouped(calibrator_path) + else: + self.calibrator = None + + def load_model(self, path): + """Load model from path. + + Parameters + ---------- + path : str + Path to a sybil checkpoint. + + Returns + ------- + model + Pretrained Sybil model + """ + # Load checkpoint + checkpoint = torch.load(path, map_location="cpu") + args = checkpoint["args"] + self._max_followup = args.max_followup + self._censoring_dist = args.censoring_distribution + model = SybilNet(args) + + # Remove model from param names + state_dict = {k[6:]: v for k, v in checkpoint["state_dict"].items()} + model.load_state_dict(state_dict) # type: ignore + if self.device is not None: + model.to(self.device) + + # Set eval + model.eval() + self._logger.info(f"Loaded model from {path}") + return model + + def _calibrate(self, scores: np.ndarray) -> np.ndarray: + """Calibrate raw predictions + + Parameters + ---------- + scores: np.ndarray + risk scores as numpy array + + Returns + ------- + np.ndarray: calibrated risk scores as numpy array + """ + if self.calibrator is None: + return scores + + calibrated_scores = [] + for YEAR in range(scores.shape[1]): + probs = scores[:, YEAR].reshape(-1, 1) + probs = self.calibrator["Year{}".format(YEAR + 1)].predict_proba(probs)[:, -1] + calibrated_scores.append(probs) + + return np.stack(calibrated_scores, axis=1) + + def _predict( + self, + model: SybilNet, + series: Union[Serie, List[Serie]], + return_attentions: bool = False, + ) -> Prediction: + """Run predictions over the given serie(s). + + Parameters + ---------- + model: SybilNet + Instance of SybilNet + series : Union[Serie, Iterable[Serie]] + One or multiple series to run predictions for. + return_attentions : bool + If True, returns attention scores for each serie. See README for details. + + Returns + ------- + Prediction + Output prediction as risk scores. + + """ + if isinstance(series, Serie): + series = [series] + elif not isinstance(series, list): + raise ValueError("Expected either a Serie object or list of Serie objects.") + + scores: List[List[float]] = [] + attentions: List[Dict[str, np.ndarray]] = [] if return_attentions else None + for serie in series: + if not isinstance(serie, Serie): + raise ValueError("Expected a list of Serie objects.") + + volume = serie.get_volume() + if self.device is not None: + volume = volume.to(self.device) + + with torch.no_grad(): + out = model(volume) + score = out["logit"].sigmoid().squeeze(0).cpu().numpy() + scores.append(score.tolist()) + if return_attentions: + attentions.append( + { + "image_attention_1": out["image_attention_1"] + .detach() + .cpu(), + "volume_attention_1": out["volume_attention_1"] + .detach() + .cpu(), + "hidden": out["hidden"] + .detach() + .cpu(), + } + ) + + return Prediction(scores=scores, attentions=attentions) + + def predict( + self, series: Union[Serie, List[Serie]], return_attentions: bool = False, threads=0, + ) -> Prediction: + """Run predictions over the given serie(s) and ensemble + + Parameters + ---------- + series : Union[Serie, Iterable[Serie]] + One or multiple series to run predictions for. + return_attentions : bool + If True, returns attention scores for each serie. See README for details. + threads : int + Number of CPU threads to use for PyTorch inference. + + Returns + ------- + Prediction + Output prediction. See details for :class:`~sybil.model.Prediction`". + + """ + + # Set CPU threads available to torch + num_threads = _torch_set_num_threads(threads) + self._logger.debug(f"Using {num_threads} threads for PyTorch inference") + + if self._device_flexible: + self.device = self._pick_device() + self.to(self.device) + self._logger.debug(f"Beginning prediction on device: {self.device}") + + scores = [] + attentions_ = [] if return_attentions else None + attention_keys = None + for sybil in self.ensemble: + pred = self._predict(sybil, series, return_attentions) + scores.append(pred.scores) + if return_attentions: + attentions_.append(pred.attentions) + if attention_keys is None: + attention_keys = pred.attentions[0].keys() + + scores = np.mean(np.array(scores), axis=0) + calib_scores = self._calibrate(scores).tolist() + + attentions = None + if return_attentions: + attentions = [] + for i in range(len(series)): + att = {} + for key in attention_keys: + att[key] = np.stack([ + attentions_[j][i][key] for j in range(len(self.ensemble)) + ]) + attentions.append(att) + + return Prediction(scores=calib_scores, attentions=attentions) + + def evaluate( + self, series: Union[Serie, List[Serie]], return_attentions: bool = False + ) -> Evaluation: + """Run evaluation over the given serie(s). + + Parameters + ---------- + series : Union[Serie, List[Serie]] + One or multiple series to run evaluation for. + return_attentions : bool + If True, returns attention scores for each serie. See README for details. + + Returns + ------- + Evaluation + Output evaluation. See details for :class:`~sybil.model.Evaluation`. + + """ + from sybil.utils.metrics import get_survival_metrics + if isinstance(series, Serie): + series = [series] + elif not isinstance(series, list): + raise ValueError( + "Expected either a Serie object or an iterable over Serie objects." + ) + + # Check all have labels + if not all(serie.has_label() for serie in series): + raise ValueError("All series must have a label for evaluation") + + # Get scores and labels + predictions = self.predict(series, return_attentions) + scores = predictions.scores + labels = [serie.get_label(self._max_followup) for serie in series] + + # Convert to format for survival metrics + input_dict = { + "probs": torch.tensor(scores), + "censors": torch.tensor([label.censor_time for label in labels]), + "golds": torch.tensor([label.y for label in labels]), + } + args = Namespace( + max_followup=self._max_followup, censoring_distribution=self._censoring_dist + ) + out = get_survival_metrics(input_dict, args) + auc = [float(out[f"{i + 1}_year_auc"]) for i in range(self._max_followup)] + c_index = float(out["c_index"]) + + return Evaluation(auc=auc, c_index=c_index, scores=scores, attentions=predictions.attentions) + + def to(self, device: str): + """Move model to device. + + Parameters + ---------- + device : str + Device to move model to. + """ + self.device = device + self.ensemble.to(device) + + def _pick_device(self): + """ + Pick the device to run inference on. + This is based on the device with the most free memory, with a preference for remaining + on the current device. + + Motivation is to enable multiprocessing without the processes needed to communicate. + """ + if not torch.cuda.is_available(): + return get_default_device() + + # Get size of the model in memory (approximate) + model_mem = 9*sum(p.numel() * p.element_size() for p in self.ensemble.parameters()) + + # Check memory available on current device. + # If it seems like we're the only thing on this GPU, stay. + free_mem, total_mem = get_device_mem_info(self.device) + cur_allocated = total_mem - free_mem + min_to_move = int(1.01 * model_mem) + if cur_allocated < min_to_move: + return self.device + else: + # Otherwise, get the most free GPU + return get_most_free_gpu()