[d9566e]: / sybil / model.py

Download this file

449 lines (377 with data), 15.8 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
from argparse import Namespace
from io import BytesIO
import os
from typing import NamedTuple, Union, Dict, List, Optional, Tuple
from urllib.request import urlopen
from zipfile import ZipFile
import torch
import numpy as np
from sybil.serie import Serie
from sybil.models.sybil import SybilNet
from sybil.models.calibrator import SimpleClassifierGroup
from sybil.utils.logging_utils import get_logger
from sybil.utils.device_utils import get_default_device, get_most_free_gpu, get_device_mem_info
# Leaving this here for a bit; these are IDs to download the models from Google Drive
NAME_TO_FILE = {
"sybil_base": {
"checkpoint": ["28a7cd44f5bcd3e6cc760b65c7e0d54d"],
"google_checkpoint_id": ["1ftYbav_BbUBkyR3HFCGnsp-h4uH1yhoz"],
"google_calibrator_id": "1F5TOtzueR-ZUvwl8Yv9Svs2NPP5El3HY",
},
"sybil_1": {
"checkpoint": ["28a7cd44f5bcd3e6cc760b65c7e0d54d"],
"google_checkpoint_id": ["1ftYbav_BbUBkyR3HFCGnsp-h4uH1yhoz"],
"google_calibrator_id": "1F5TOtzueR-ZUvwl8Yv9Svs2NPP5El3HY",
},
"sybil_2": {
"checkpoint": ["56ce1a7d241dc342982f5466c4a9d7ef"],
"google_checkpoint_id": ["1rscGi1grSxaVGzn-tqKtuAR3ipo0DWgA"],
"google_calibrator_id": "1zKLVYBaiuMOx7p--e2zabs1LbQ-XXxcZ",
},
"sybil_3": {
"checkpoint": ["624407ef8e3a2a009f9fa51f9846fe9a"],
"google_checkpoint_id": ["1DV0Ge7n9r8WAvBXyoNRPwyA7VL43csAr"],
"google_calibrator_id": "1qh4nawgE2Kjf_H97XuuTpL7XUIX7JOJn",
},
"sybil_4": {
"checkpoint": ["64a91b25f84141d32852e75a3aec7305"],
"google_checkpoint_id": ["1Acz_yzdJMpkz3PRrjXy526CjAboMEIHX"],
"google_calibrator_id": "1QIvvCYLaesPGMEiE2Up77pKL3ygDdGU2",
},
"sybil_5": {
"checkpoint": ["65fd1f04cb4c5847d86a9ed8ba31ac1a"],
"google_checkpoint_id": ["1uV58SD-Qtb6xElTzWPDWWnloH1KB_zrP"],
"google_calibrator_id": "1yDq1_A5w-fSdxzq4K2YSBRNcQQkDnH0K",
},
"sybil_ensemble": {
"checkpoint": [
"28a7cd44f5bcd3e6cc760b65c7e0d54d",
"56ce1a7d241dc342982f5466c4a9d7ef",
"624407ef8e3a2a009f9fa51f9846fe9a",
"64a91b25f84141d32852e75a3aec7305",
"65fd1f04cb4c5847d86a9ed8ba31ac1a",
],
"google_checkpoint_id": [
"1ftYbav_BbUBkyR3HFCGnsp-h4uH1yhoz",
"1rscGi1grSxaVGzn-tqKtuAR3ipo0DWgA",
"1DV0Ge7n9r8WAvBXyoNRPwyA7VL43csAr",
"1Acz_yzdJMpkz3PRrjXy526CjAboMEIHX",
"1uV58SD-Qtb6xElTzWPDWWnloH1KB_zrP",
],
"google_calibrator_id": "1FxHNo0HqXYyiUKE_k2bjatVt9e64J9Li",
},
}
CHECKPOINT_URL = os.getenv("SYBIL_CHECKPOINT_URL", "https://github.com/reginabarzilaygroup/Sybil/releases/download/v1.5.0/sybil_checkpoints.zip")
class Prediction(NamedTuple):
scores: List[List[float]]
attentions: List[Dict[str, np.ndarray]] = None
class Evaluation(NamedTuple):
auc: List[float]
c_index: float
scores: List[List[float]]
attentions: List[Dict[str, np.ndarray]] = None
def download_sybil(name, cache) -> Tuple[List[str], str]:
"""Download trained models and calibrator"""
# Create cache folder if not exists
cache = os.path.expanduser(cache)
os.makedirs(cache, exist_ok=True)
# Download models
model_files = NAME_TO_FILE[name]
checkpoints = model_files["checkpoint"]
download_calib_path = os.path.join(cache, f"{name}_simple_calibrator.json")
have_all_files = os.path.exists(download_calib_path)
download_model_paths = []
for checkpoint in checkpoints:
cur_checkpoint_path = os.path.join(cache, f"{checkpoint}.ckpt")
have_all_files &= os.path.exists(cur_checkpoint_path)
download_model_paths.append(cur_checkpoint_path)
if not have_all_files:
print(f"Downloading models to {cache}")
download_and_extract(CHECKPOINT_URL, cache)
return download_model_paths, download_calib_path
def download_and_extract(remote_url: str, local_dir: str) -> List[str]:
os.makedirs(local_dir, exist_ok=True)
resp = urlopen(remote_url)
with ZipFile(BytesIO(resp.read())) as zip_file:
all_files_and_dirs = zip_file.namelist()
zip_file.extractall(local_dir)
return all_files_and_dirs
def _torch_set_num_threads(threads) -> int:
"""
Set the number of CPU threads for torch to use.
Set to a negative number for no-op.
Set to 0 for the number of CPUs.
"""
if threads < 0:
return torch.get_num_threads()
if threads is None or threads == 0:
# I've never seen a benefit to going higher than 8 and sometimes there is a big slowdown
threads = min(8, os.cpu_count())
torch.set_num_threads(threads)
return torch.get_num_threads()
class Sybil:
def __init__(
self,
name_or_path: Union[List[str], str] = "sybil_ensemble",
cache: str = "~/.sybil/",
calibrator_path: Optional[str] = None,
device: Optional[str] = None,
):
"""Initialize a trained Sybil model for inference.
Parameters
----------
name_or_path: list or str
Alias to a provided pretrained Sybil model or path
to a sybil checkpoint.
cache: str
Directory to download model checkpoints to
calibrator_path: str
Path to calibrator pickle file corresponding with model
device: str
If provided, will run inference using this device.
By default, uses GPU with the most free memory, if available.
"""
self._logger = get_logger()
# Download if needed
if isinstance(name_or_path, str) and (name_or_path in NAME_TO_FILE):
name_or_path, calibrator_path = download_sybil(name_or_path, cache)
elif not all(os.path.exists(p) for p in name_or_path):
raise ValueError(
"No saved model or local path: {}".format(
[p for p in name_or_path if not os.path.exists(p)]
)
)
# Check calibrator path before continuing
if (calibrator_path is not None) and (not os.path.exists(calibrator_path)):
raise ValueError(f"Path not found for calibrator {calibrator_path}")
# Set device.
# If set manually, use it and stay there.
# Otherwise, pick the most free GPU now and at predict time.
self._device_flexible = True
if device is not None:
self.device = device
self._device_flexible = False
else:
self.device = get_default_device()
self.ensemble = torch.nn.ModuleList()
for path in name_or_path:
self.ensemble.append(self.load_model(path))
self.to(self.device)
if calibrator_path is not None:
self.calibrator = SimpleClassifierGroup.from_json_grouped(calibrator_path)
else:
self.calibrator = None
def load_model(self, path):
"""Load model from path.
Parameters
----------
path : str
Path to a sybil checkpoint.
Returns
-------
model
Pretrained Sybil model
"""
# Load checkpoint
checkpoint = torch.load(path, map_location="cpu")
args = checkpoint["args"]
self._max_followup = args.max_followup
self._censoring_dist = args.censoring_distribution
model = SybilNet(args)
# Remove model from param names
state_dict = {k[6:]: v for k, v in checkpoint["state_dict"].items()}
model.load_state_dict(state_dict) # type: ignore
if self.device is not None:
model.to(self.device)
# Set eval
model.eval()
self._logger.info(f"Loaded model from {path}")
return model
def _calibrate(self, scores: np.ndarray) -> np.ndarray:
"""Calibrate raw predictions
Parameters
----------
scores: np.ndarray
risk scores as numpy array
Returns
-------
np.ndarray: calibrated risk scores as numpy array
"""
if self.calibrator is None:
return scores
calibrated_scores = []
for YEAR in range(scores.shape[1]):
probs = scores[:, YEAR].reshape(-1, 1)
probs = self.calibrator["Year{}".format(YEAR + 1)].predict_proba(probs)[:, -1]
calibrated_scores.append(probs)
return np.stack(calibrated_scores, axis=1)
def _predict(
self,
model: SybilNet,
series: Union[Serie, List[Serie]],
return_attentions: bool = False,
) -> Prediction:
"""Run predictions over the given serie(s).
Parameters
----------
model: SybilNet
Instance of SybilNet
series : Union[Serie, Iterable[Serie]]
One or multiple series to run predictions for.
return_attentions : bool
If True, returns attention scores for each serie. See README for details.
Returns
-------
Prediction
Output prediction as risk scores.
"""
if isinstance(series, Serie):
series = [series]
elif not isinstance(series, list):
raise ValueError("Expected either a Serie object or list of Serie objects.")
scores: List[List[float]] = []
attentions: List[Dict[str, np.ndarray]] = [] if return_attentions else None
for serie in series:
if not isinstance(serie, Serie):
raise ValueError("Expected a list of Serie objects.")
volume = serie.get_volume()
if self.device is not None:
volume = volume.to(self.device)
with torch.no_grad():
out = model(volume)
score = out["logit"].sigmoid().squeeze(0).cpu().numpy()
scores.append(score.tolist())
if return_attentions:
attentions.append(
{
"image_attention_1": out["image_attention_1"]
.detach()
.cpu(),
"volume_attention_1": out["volume_attention_1"]
.detach()
.cpu(),
"hidden": out["hidden"]
.detach()
.cpu(),
}
)
return Prediction(scores=scores, attentions=attentions)
def predict(
self, series: Union[Serie, List[Serie]], return_attentions: bool = False, threads=0,
) -> Prediction:
"""Run predictions over the given serie(s) and ensemble
Parameters
----------
series : Union[Serie, Iterable[Serie]]
One or multiple series to run predictions for.
return_attentions : bool
If True, returns attention scores for each serie. See README for details.
threads : int
Number of CPU threads to use for PyTorch inference.
Returns
-------
Prediction
Output prediction. See details for :class:`~sybil.model.Prediction`".
"""
# Set CPU threads available to torch
num_threads = _torch_set_num_threads(threads)
self._logger.debug(f"Using {num_threads} threads for PyTorch inference")
if self._device_flexible:
self.device = self._pick_device()
self.to(self.device)
self._logger.debug(f"Beginning prediction on device: {self.device}")
scores = []
attentions_ = [] if return_attentions else None
attention_keys = None
for sybil in self.ensemble:
pred = self._predict(sybil, series, return_attentions)
scores.append(pred.scores)
if return_attentions:
attentions_.append(pred.attentions)
if attention_keys is None:
attention_keys = pred.attentions[0].keys()
scores = np.mean(np.array(scores), axis=0)
calib_scores = self._calibrate(scores).tolist()
attentions = None
if return_attentions:
attentions = []
for i in range(len(series)):
att = {}
for key in attention_keys:
att[key] = np.stack([
attentions_[j][i][key] for j in range(len(self.ensemble))
])
attentions.append(att)
return Prediction(scores=calib_scores, attentions=attentions)
def evaluate(
self, series: Union[Serie, List[Serie]], return_attentions: bool = False
) -> Evaluation:
"""Run evaluation over the given serie(s).
Parameters
----------
series : Union[Serie, List[Serie]]
One or multiple series to run evaluation for.
return_attentions : bool
If True, returns attention scores for each serie. See README for details.
Returns
-------
Evaluation
Output evaluation. See details for :class:`~sybil.model.Evaluation`.
"""
from sybil.utils.metrics import get_survival_metrics
if isinstance(series, Serie):
series = [series]
elif not isinstance(series, list):
raise ValueError(
"Expected either a Serie object or an iterable over Serie objects."
)
# Check all have labels
if not all(serie.has_label() for serie in series):
raise ValueError("All series must have a label for evaluation")
# Get scores and labels
predictions = self.predict(series, return_attentions)
scores = predictions.scores
labels = [serie.get_label(self._max_followup) for serie in series]
# Convert to format for survival metrics
input_dict = {
"probs": torch.tensor(scores),
"censors": torch.tensor([label.censor_time for label in labels]),
"golds": torch.tensor([label.y for label in labels]),
}
args = Namespace(
max_followup=self._max_followup, censoring_distribution=self._censoring_dist
)
out = get_survival_metrics(input_dict, args)
auc = [float(out[f"{i + 1}_year_auc"]) for i in range(self._max_followup)]
c_index = float(out["c_index"])
return Evaluation(auc=auc, c_index=c_index, scores=scores, attentions=predictions.attentions)
def to(self, device: str):
"""Move model to device.
Parameters
----------
device : str
Device to move model to.
"""
self.device = device
self.ensemble.to(device)
def _pick_device(self):
"""
Pick the device to run inference on.
This is based on the device with the most free memory, with a preference for remaining
on the current device.
Motivation is to enable multiprocessing without the processes needed to communicate.
"""
if not torch.cuda.is_available():
return get_default_device()
# Get size of the model in memory (approximate)
model_mem = 9*sum(p.numel() * p.element_size() for p in self.ensemble.parameters())
# Check memory available on current device.
# If it seems like we're the only thing on this GPU, stay.
free_mem, total_mem = get_device_mem_info(self.device)
cur_allocated = total_mem - free_mem
min_to_move = int(1.01 * model_mem)
if cur_allocated < min_to_move:
return self.device
else:
# Otherwise, get the most free GPU
return get_most_free_gpu()