a b/sybil/model.py
1
from argparse import Namespace
2
from io import BytesIO
3
import os
4
from typing import NamedTuple, Union, Dict, List, Optional, Tuple
5
from urllib.request import urlopen
6
from zipfile import ZipFile
7
8
import torch
9
import numpy as np
10
11
from sybil.serie import Serie
12
from sybil.models.sybil import SybilNet
13
from sybil.models.calibrator import SimpleClassifierGroup
14
from sybil.utils.logging_utils import get_logger
15
from sybil.utils.device_utils import get_default_device, get_most_free_gpu, get_device_mem_info
16
17
18
# Leaving this here for a bit; these are IDs to download the models from Google Drive
19
NAME_TO_FILE = {
20
    "sybil_base": {
21
        "checkpoint": ["28a7cd44f5bcd3e6cc760b65c7e0d54d"],
22
        "google_checkpoint_id": ["1ftYbav_BbUBkyR3HFCGnsp-h4uH1yhoz"],
23
        "google_calibrator_id": "1F5TOtzueR-ZUvwl8Yv9Svs2NPP5El3HY",
24
    },
25
    "sybil_1": {
26
        "checkpoint": ["28a7cd44f5bcd3e6cc760b65c7e0d54d"],
27
        "google_checkpoint_id": ["1ftYbav_BbUBkyR3HFCGnsp-h4uH1yhoz"],
28
        "google_calibrator_id": "1F5TOtzueR-ZUvwl8Yv9Svs2NPP5El3HY",
29
    },
30
    "sybil_2": {
31
        "checkpoint": ["56ce1a7d241dc342982f5466c4a9d7ef"],
32
        "google_checkpoint_id": ["1rscGi1grSxaVGzn-tqKtuAR3ipo0DWgA"],
33
        "google_calibrator_id": "1zKLVYBaiuMOx7p--e2zabs1LbQ-XXxcZ",
34
    },
35
    "sybil_3": {
36
        "checkpoint": ["624407ef8e3a2a009f9fa51f9846fe9a"],
37
        "google_checkpoint_id": ["1DV0Ge7n9r8WAvBXyoNRPwyA7VL43csAr"],
38
        "google_calibrator_id": "1qh4nawgE2Kjf_H97XuuTpL7XUIX7JOJn",
39
    },
40
    "sybil_4": {
41
        "checkpoint": ["64a91b25f84141d32852e75a3aec7305"],
42
        "google_checkpoint_id": ["1Acz_yzdJMpkz3PRrjXy526CjAboMEIHX"],
43
        "google_calibrator_id": "1QIvvCYLaesPGMEiE2Up77pKL3ygDdGU2",
44
    },
45
    "sybil_5": {
46
        "checkpoint": ["65fd1f04cb4c5847d86a9ed8ba31ac1a"],
47
        "google_checkpoint_id": ["1uV58SD-Qtb6xElTzWPDWWnloH1KB_zrP"],
48
        "google_calibrator_id": "1yDq1_A5w-fSdxzq4K2YSBRNcQQkDnH0K",
49
    },
50
    "sybil_ensemble": {
51
        "checkpoint": [
52
            "28a7cd44f5bcd3e6cc760b65c7e0d54d",
53
            "56ce1a7d241dc342982f5466c4a9d7ef",
54
            "624407ef8e3a2a009f9fa51f9846fe9a",
55
            "64a91b25f84141d32852e75a3aec7305",
56
            "65fd1f04cb4c5847d86a9ed8ba31ac1a",
57
        ],
58
        "google_checkpoint_id": [
59
            "1ftYbav_BbUBkyR3HFCGnsp-h4uH1yhoz",
60
            "1rscGi1grSxaVGzn-tqKtuAR3ipo0DWgA",
61
            "1DV0Ge7n9r8WAvBXyoNRPwyA7VL43csAr",
62
            "1Acz_yzdJMpkz3PRrjXy526CjAboMEIHX",
63
            "1uV58SD-Qtb6xElTzWPDWWnloH1KB_zrP",
64
        ],
65
        "google_calibrator_id": "1FxHNo0HqXYyiUKE_k2bjatVt9e64J9Li",
66
    },
67
}
68
69
CHECKPOINT_URL = os.getenv("SYBIL_CHECKPOINT_URL", "https://github.com/reginabarzilaygroup/Sybil/releases/download/v1.5.0/sybil_checkpoints.zip")
70
71
72
class Prediction(NamedTuple):
73
    scores: List[List[float]]
74
    attentions: List[Dict[str, np.ndarray]] = None
75
76
77
class Evaluation(NamedTuple):
78
    auc: List[float]
79
    c_index: float
80
    scores: List[List[float]]
81
    attentions: List[Dict[str, np.ndarray]] = None
82
83
84
def download_sybil(name, cache) -> Tuple[List[str], str]:
85
    """Download trained models and calibrator"""
86
    # Create cache folder if not exists
87
    cache = os.path.expanduser(cache)
88
    os.makedirs(cache, exist_ok=True)
89
90
    # Download models
91
    model_files = NAME_TO_FILE[name]
92
    checkpoints = model_files["checkpoint"]
93
    download_calib_path = os.path.join(cache, f"{name}_simple_calibrator.json")
94
    have_all_files = os.path.exists(download_calib_path)
95
96
    download_model_paths = []
97
    for checkpoint in checkpoints:
98
        cur_checkpoint_path = os.path.join(cache, f"{checkpoint}.ckpt")
99
        have_all_files &= os.path.exists(cur_checkpoint_path)
100
        download_model_paths.append(cur_checkpoint_path)
101
102
    if not have_all_files:
103
        print(f"Downloading models to {cache}")
104
        download_and_extract(CHECKPOINT_URL, cache)
105
106
    return download_model_paths, download_calib_path
107
108
109
def download_and_extract(remote_url: str, local_dir: str) -> List[str]:
110
    os.makedirs(local_dir, exist_ok=True)
111
    resp = urlopen(remote_url)
112
    with ZipFile(BytesIO(resp.read())) as zip_file:
113
        all_files_and_dirs = zip_file.namelist()
114
        zip_file.extractall(local_dir)
115
    return all_files_and_dirs
116
117
118
def _torch_set_num_threads(threads) -> int:
119
    """
120
    Set the number of CPU threads for torch to use.
121
    Set to a negative number for no-op.
122
    Set to 0 for the number of CPUs.
123
    """
124
    if threads < 0:
125
        return torch.get_num_threads()
126
    if threads is None or threads == 0:
127
        # I've never seen a benefit to going higher than 8 and sometimes there is a big slowdown
128
        threads = min(8, os.cpu_count())
129
130
    torch.set_num_threads(threads)
131
    return torch.get_num_threads()
132
133
134
class Sybil:
135
    def __init__(
136
        self,
137
        name_or_path: Union[List[str], str] = "sybil_ensemble",
138
        cache: str = "~/.sybil/",
139
        calibrator_path: Optional[str] = None,
140
        device: Optional[str] = None,
141
    ):
142
        """Initialize a trained Sybil model for inference.
143
144
        Parameters
145
        ----------
146
        name_or_path: list or str
147
            Alias to a provided pretrained Sybil model or path
148
            to a sybil checkpoint.
149
        cache: str
150
            Directory to download model checkpoints to
151
        calibrator_path: str
152
            Path to calibrator pickle file corresponding with model
153
        device: str
154
            If provided, will run inference using this device.
155
            By default, uses GPU with the most free memory, if available.
156
157
        """
158
        self._logger = get_logger()
159
        # Download if needed
160
        if isinstance(name_or_path, str) and (name_or_path in NAME_TO_FILE):
161
            name_or_path, calibrator_path = download_sybil(name_or_path, cache)
162
163
        elif not all(os.path.exists(p) for p in name_or_path):
164
            raise ValueError(
165
                "No saved model or local path: {}".format(
166
                    [p for p in name_or_path if not os.path.exists(p)]
167
                )
168
            )
169
170
        # Check calibrator path before continuing
171
        if (calibrator_path is not None) and (not os.path.exists(calibrator_path)):
172
            raise ValueError(f"Path not found for calibrator {calibrator_path}")
173
174
        # Set device.
175
        # If set manually, use it and stay there.
176
        # Otherwise, pick the most free GPU now and at predict time.
177
        self._device_flexible = True
178
        if device is not None:
179
            self.device = device
180
            self._device_flexible = False
181
        else:
182
            self.device = get_default_device()
183
184
        self.ensemble = torch.nn.ModuleList()
185
        for path in name_or_path:
186
            self.ensemble.append(self.load_model(path))
187
        self.to(self.device)
188
189
        if calibrator_path is not None:
190
            self.calibrator = SimpleClassifierGroup.from_json_grouped(calibrator_path)
191
        else:
192
            self.calibrator = None
193
194
    def load_model(self, path):
195
        """Load model from path.
196
197
        Parameters
198
        ----------
199
        path : str
200
            Path to a sybil checkpoint.
201
202
        Returns
203
        -------
204
        model
205
            Pretrained Sybil model
206
        """
207
        # Load checkpoint
208
        checkpoint = torch.load(path, map_location="cpu")
209
        args = checkpoint["args"]
210
        self._max_followup = args.max_followup
211
        self._censoring_dist = args.censoring_distribution
212
        model = SybilNet(args)
213
214
        # Remove model from param names
215
        state_dict = {k[6:]: v for k, v in checkpoint["state_dict"].items()}
216
        model.load_state_dict(state_dict)  # type: ignore
217
        if self.device is not None:
218
            model.to(self.device)
219
220
        # Set eval
221
        model.eval()
222
        self._logger.info(f"Loaded model from {path}")
223
        return model
224
225
    def _calibrate(self, scores: np.ndarray) -> np.ndarray:
226
        """Calibrate raw predictions
227
228
        Parameters
229
        ----------
230
        scores: np.ndarray
231
            risk scores as numpy array
232
233
        Returns
234
        -------
235
            np.ndarray: calibrated risk scores as numpy array
236
        """
237
        if self.calibrator is None:
238
            return scores
239
240
        calibrated_scores = []
241
        for YEAR in range(scores.shape[1]):
242
            probs = scores[:, YEAR].reshape(-1, 1)
243
            probs = self.calibrator["Year{}".format(YEAR + 1)].predict_proba(probs)[:, -1]
244
            calibrated_scores.append(probs)
245
246
        return np.stack(calibrated_scores, axis=1)
247
248
    def _predict(
249
        self,
250
        model: SybilNet,
251
        series: Union[Serie, List[Serie]],
252
        return_attentions: bool = False,
253
    ) -> Prediction:
254
        """Run predictions over the given serie(s).
255
256
        Parameters
257
        ----------
258
        model: SybilNet
259
            Instance of SybilNet
260
        series : Union[Serie, Iterable[Serie]]
261
            One or multiple series to run predictions for.
262
        return_attentions : bool
263
            If True, returns attention scores for each serie. See README for details.
264
265
        Returns
266
        -------
267
        Prediction
268
            Output prediction as risk scores.
269
270
        """
271
        if isinstance(series, Serie):
272
            series = [series]
273
        elif not isinstance(series, list):
274
            raise ValueError("Expected either a Serie object or list of Serie objects.")
275
276
        scores: List[List[float]] = []
277
        attentions: List[Dict[str, np.ndarray]] = [] if return_attentions else None
278
        for serie in series:
279
            if not isinstance(serie, Serie):
280
                raise ValueError("Expected a list of Serie objects.")
281
282
            volume = serie.get_volume()
283
            if self.device is not None:
284
                volume = volume.to(self.device)
285
286
            with torch.no_grad():
287
                out = model(volume)
288
                score = out["logit"].sigmoid().squeeze(0).cpu().numpy()
289
                scores.append(score.tolist())
290
                if return_attentions:
291
                    attentions.append(
292
                        {
293
                            "image_attention_1": out["image_attention_1"]
294
                            .detach()
295
                            .cpu(),
296
                            "volume_attention_1": out["volume_attention_1"]
297
                            .detach()
298
                            .cpu(),
299
                            "hidden": out["hidden"]
300
                            .detach()
301
                            .cpu(),
302
                        }
303
                    )
304
305
        return Prediction(scores=scores, attentions=attentions)
306
307
    def predict(
308
        self, series: Union[Serie, List[Serie]], return_attentions: bool = False, threads=0,
309
    ) -> Prediction:
310
        """Run predictions over the given serie(s) and ensemble
311
312
        Parameters
313
        ----------
314
        series : Union[Serie, Iterable[Serie]]
315
            One or multiple series to run predictions for.
316
        return_attentions : bool
317
            If True, returns attention scores for each serie. See README for details.
318
        threads : int
319
            Number of CPU threads to use for PyTorch inference.
320
321
        Returns
322
        -------
323
        Prediction
324
            Output prediction. See details for :class:`~sybil.model.Prediction`".
325
326
        """
327
328
        # Set CPU threads available to torch
329
        num_threads = _torch_set_num_threads(threads)
330
        self._logger.debug(f"Using {num_threads} threads for PyTorch inference")
331
332
        if self._device_flexible:
333
            self.device = self._pick_device()
334
            self.to(self.device)
335
        self._logger.debug(f"Beginning prediction on device: {self.device}")
336
337
        scores = []
338
        attentions_ = [] if return_attentions else None
339
        attention_keys = None
340
        for sybil in self.ensemble:
341
            pred = self._predict(sybil, series, return_attentions)
342
            scores.append(pred.scores)
343
            if return_attentions:
344
                attentions_.append(pred.attentions)
345
                if attention_keys is None:
346
                    attention_keys = pred.attentions[0].keys()
347
348
        scores = np.mean(np.array(scores), axis=0)
349
        calib_scores = self._calibrate(scores).tolist()
350
351
        attentions = None
352
        if return_attentions:
353
            attentions = []
354
            for i in range(len(series)):
355
                att = {}
356
                for key in attention_keys:
357
                    att[key] = np.stack([
358
                        attentions_[j][i][key] for j in range(len(self.ensemble))
359
                    ])
360
                attentions.append(att)
361
362
        return Prediction(scores=calib_scores, attentions=attentions)
363
364
    def evaluate(
365
        self, series: Union[Serie, List[Serie]], return_attentions: bool = False
366
    ) -> Evaluation:
367
        """Run evaluation over the given serie(s).
368
369
        Parameters
370
        ----------
371
        series : Union[Serie, List[Serie]]
372
            One or multiple series to run evaluation for.
373
        return_attentions : bool
374
            If True, returns attention scores for each serie. See README for details.
375
376
        Returns
377
        -------
378
        Evaluation
379
            Output evaluation. See details for :class:`~sybil.model.Evaluation`.
380
381
        """
382
        from sybil.utils.metrics import get_survival_metrics
383
        if isinstance(series, Serie):
384
            series = [series]
385
        elif not isinstance(series, list):
386
            raise ValueError(
387
                "Expected either a Serie object or an iterable over Serie objects."
388
            )
389
390
        # Check all have labels
391
        if not all(serie.has_label() for serie in series):
392
            raise ValueError("All series must have a label for evaluation")
393
394
        # Get scores and labels
395
        predictions = self.predict(series, return_attentions)
396
        scores = predictions.scores
397
        labels = [serie.get_label(self._max_followup) for serie in series]
398
399
        # Convert to format for survival metrics
400
        input_dict = {
401
            "probs": torch.tensor(scores),
402
            "censors": torch.tensor([label.censor_time for label in labels]),
403
            "golds": torch.tensor([label.y for label in labels]),
404
        }
405
        args = Namespace(
406
            max_followup=self._max_followup, censoring_distribution=self._censoring_dist
407
        )
408
        out = get_survival_metrics(input_dict, args)
409
        auc = [float(out[f"{i + 1}_year_auc"]) for i in range(self._max_followup)]
410
        c_index = float(out["c_index"])
411
412
        return Evaluation(auc=auc, c_index=c_index, scores=scores, attentions=predictions.attentions)
413
414
    def to(self, device: str):
415
        """Move model to device.
416
417
        Parameters
418
        ----------
419
        device : str
420
            Device to move model to.
421
        """
422
        self.device = device
423
        self.ensemble.to(device)
424
425
    def _pick_device(self):
426
        """
427
        Pick the device to run inference on.
428
        This is based on the device with the most free memory, with a preference for remaining
429
        on the current device.
430
431
        Motivation is to enable multiprocessing without the processes needed to communicate.
432
        """
433
        if not torch.cuda.is_available():
434
            return get_default_device()
435
436
        # Get size of the model in memory (approximate)
437
        model_mem = 9*sum(p.numel() * p.element_size() for p in self.ensemble.parameters())
438
439
        # Check memory available on current device.
440
        # If it seems like we're the only thing on this GPU, stay.
441
        free_mem, total_mem = get_device_mem_info(self.device)
442
        cur_allocated = total_mem - free_mem
443
        min_to_move = int(1.01 * model_mem)
444
        if cur_allocated < min_to_move:
445
            return self.device
446
        else:
447
            # Otherwise, get the most free GPU
448
            return get_most_free_gpu()