|
a |
|
b/sybil/model.py |
|
|
1 |
from argparse import Namespace |
|
|
2 |
from io import BytesIO |
|
|
3 |
import os |
|
|
4 |
from typing import NamedTuple, Union, Dict, List, Optional, Tuple |
|
|
5 |
from urllib.request import urlopen |
|
|
6 |
from zipfile import ZipFile |
|
|
7 |
|
|
|
8 |
import torch |
|
|
9 |
import numpy as np |
|
|
10 |
|
|
|
11 |
from sybil.serie import Serie |
|
|
12 |
from sybil.models.sybil import SybilNet |
|
|
13 |
from sybil.models.calibrator import SimpleClassifierGroup |
|
|
14 |
from sybil.utils.logging_utils import get_logger |
|
|
15 |
from sybil.utils.device_utils import get_default_device, get_most_free_gpu, get_device_mem_info |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
# Leaving this here for a bit; these are IDs to download the models from Google Drive |
|
|
19 |
NAME_TO_FILE = { |
|
|
20 |
"sybil_base": { |
|
|
21 |
"checkpoint": ["28a7cd44f5bcd3e6cc760b65c7e0d54d"], |
|
|
22 |
"google_checkpoint_id": ["1ftYbav_BbUBkyR3HFCGnsp-h4uH1yhoz"], |
|
|
23 |
"google_calibrator_id": "1F5TOtzueR-ZUvwl8Yv9Svs2NPP5El3HY", |
|
|
24 |
}, |
|
|
25 |
"sybil_1": { |
|
|
26 |
"checkpoint": ["28a7cd44f5bcd3e6cc760b65c7e0d54d"], |
|
|
27 |
"google_checkpoint_id": ["1ftYbav_BbUBkyR3HFCGnsp-h4uH1yhoz"], |
|
|
28 |
"google_calibrator_id": "1F5TOtzueR-ZUvwl8Yv9Svs2NPP5El3HY", |
|
|
29 |
}, |
|
|
30 |
"sybil_2": { |
|
|
31 |
"checkpoint": ["56ce1a7d241dc342982f5466c4a9d7ef"], |
|
|
32 |
"google_checkpoint_id": ["1rscGi1grSxaVGzn-tqKtuAR3ipo0DWgA"], |
|
|
33 |
"google_calibrator_id": "1zKLVYBaiuMOx7p--e2zabs1LbQ-XXxcZ", |
|
|
34 |
}, |
|
|
35 |
"sybil_3": { |
|
|
36 |
"checkpoint": ["624407ef8e3a2a009f9fa51f9846fe9a"], |
|
|
37 |
"google_checkpoint_id": ["1DV0Ge7n9r8WAvBXyoNRPwyA7VL43csAr"], |
|
|
38 |
"google_calibrator_id": "1qh4nawgE2Kjf_H97XuuTpL7XUIX7JOJn", |
|
|
39 |
}, |
|
|
40 |
"sybil_4": { |
|
|
41 |
"checkpoint": ["64a91b25f84141d32852e75a3aec7305"], |
|
|
42 |
"google_checkpoint_id": ["1Acz_yzdJMpkz3PRrjXy526CjAboMEIHX"], |
|
|
43 |
"google_calibrator_id": "1QIvvCYLaesPGMEiE2Up77pKL3ygDdGU2", |
|
|
44 |
}, |
|
|
45 |
"sybil_5": { |
|
|
46 |
"checkpoint": ["65fd1f04cb4c5847d86a9ed8ba31ac1a"], |
|
|
47 |
"google_checkpoint_id": ["1uV58SD-Qtb6xElTzWPDWWnloH1KB_zrP"], |
|
|
48 |
"google_calibrator_id": "1yDq1_A5w-fSdxzq4K2YSBRNcQQkDnH0K", |
|
|
49 |
}, |
|
|
50 |
"sybil_ensemble": { |
|
|
51 |
"checkpoint": [ |
|
|
52 |
"28a7cd44f5bcd3e6cc760b65c7e0d54d", |
|
|
53 |
"56ce1a7d241dc342982f5466c4a9d7ef", |
|
|
54 |
"624407ef8e3a2a009f9fa51f9846fe9a", |
|
|
55 |
"64a91b25f84141d32852e75a3aec7305", |
|
|
56 |
"65fd1f04cb4c5847d86a9ed8ba31ac1a", |
|
|
57 |
], |
|
|
58 |
"google_checkpoint_id": [ |
|
|
59 |
"1ftYbav_BbUBkyR3HFCGnsp-h4uH1yhoz", |
|
|
60 |
"1rscGi1grSxaVGzn-tqKtuAR3ipo0DWgA", |
|
|
61 |
"1DV0Ge7n9r8WAvBXyoNRPwyA7VL43csAr", |
|
|
62 |
"1Acz_yzdJMpkz3PRrjXy526CjAboMEIHX", |
|
|
63 |
"1uV58SD-Qtb6xElTzWPDWWnloH1KB_zrP", |
|
|
64 |
], |
|
|
65 |
"google_calibrator_id": "1FxHNo0HqXYyiUKE_k2bjatVt9e64J9Li", |
|
|
66 |
}, |
|
|
67 |
} |
|
|
68 |
|
|
|
69 |
CHECKPOINT_URL = os.getenv("SYBIL_CHECKPOINT_URL", "https://github.com/reginabarzilaygroup/Sybil/releases/download/v1.5.0/sybil_checkpoints.zip") |
|
|
70 |
|
|
|
71 |
|
|
|
72 |
class Prediction(NamedTuple): |
|
|
73 |
scores: List[List[float]] |
|
|
74 |
attentions: List[Dict[str, np.ndarray]] = None |
|
|
75 |
|
|
|
76 |
|
|
|
77 |
class Evaluation(NamedTuple): |
|
|
78 |
auc: List[float] |
|
|
79 |
c_index: float |
|
|
80 |
scores: List[List[float]] |
|
|
81 |
attentions: List[Dict[str, np.ndarray]] = None |
|
|
82 |
|
|
|
83 |
|
|
|
84 |
def download_sybil(name, cache) -> Tuple[List[str], str]: |
|
|
85 |
"""Download trained models and calibrator""" |
|
|
86 |
# Create cache folder if not exists |
|
|
87 |
cache = os.path.expanduser(cache) |
|
|
88 |
os.makedirs(cache, exist_ok=True) |
|
|
89 |
|
|
|
90 |
# Download models |
|
|
91 |
model_files = NAME_TO_FILE[name] |
|
|
92 |
checkpoints = model_files["checkpoint"] |
|
|
93 |
download_calib_path = os.path.join(cache, f"{name}_simple_calibrator.json") |
|
|
94 |
have_all_files = os.path.exists(download_calib_path) |
|
|
95 |
|
|
|
96 |
download_model_paths = [] |
|
|
97 |
for checkpoint in checkpoints: |
|
|
98 |
cur_checkpoint_path = os.path.join(cache, f"{checkpoint}.ckpt") |
|
|
99 |
have_all_files &= os.path.exists(cur_checkpoint_path) |
|
|
100 |
download_model_paths.append(cur_checkpoint_path) |
|
|
101 |
|
|
|
102 |
if not have_all_files: |
|
|
103 |
print(f"Downloading models to {cache}") |
|
|
104 |
download_and_extract(CHECKPOINT_URL, cache) |
|
|
105 |
|
|
|
106 |
return download_model_paths, download_calib_path |
|
|
107 |
|
|
|
108 |
|
|
|
109 |
def download_and_extract(remote_url: str, local_dir: str) -> List[str]: |
|
|
110 |
os.makedirs(local_dir, exist_ok=True) |
|
|
111 |
resp = urlopen(remote_url) |
|
|
112 |
with ZipFile(BytesIO(resp.read())) as zip_file: |
|
|
113 |
all_files_and_dirs = zip_file.namelist() |
|
|
114 |
zip_file.extractall(local_dir) |
|
|
115 |
return all_files_and_dirs |
|
|
116 |
|
|
|
117 |
|
|
|
118 |
def _torch_set_num_threads(threads) -> int: |
|
|
119 |
""" |
|
|
120 |
Set the number of CPU threads for torch to use. |
|
|
121 |
Set to a negative number for no-op. |
|
|
122 |
Set to 0 for the number of CPUs. |
|
|
123 |
""" |
|
|
124 |
if threads < 0: |
|
|
125 |
return torch.get_num_threads() |
|
|
126 |
if threads is None or threads == 0: |
|
|
127 |
# I've never seen a benefit to going higher than 8 and sometimes there is a big slowdown |
|
|
128 |
threads = min(8, os.cpu_count()) |
|
|
129 |
|
|
|
130 |
torch.set_num_threads(threads) |
|
|
131 |
return torch.get_num_threads() |
|
|
132 |
|
|
|
133 |
|
|
|
134 |
class Sybil: |
|
|
135 |
def __init__( |
|
|
136 |
self, |
|
|
137 |
name_or_path: Union[List[str], str] = "sybil_ensemble", |
|
|
138 |
cache: str = "~/.sybil/", |
|
|
139 |
calibrator_path: Optional[str] = None, |
|
|
140 |
device: Optional[str] = None, |
|
|
141 |
): |
|
|
142 |
"""Initialize a trained Sybil model for inference. |
|
|
143 |
|
|
|
144 |
Parameters |
|
|
145 |
---------- |
|
|
146 |
name_or_path: list or str |
|
|
147 |
Alias to a provided pretrained Sybil model or path |
|
|
148 |
to a sybil checkpoint. |
|
|
149 |
cache: str |
|
|
150 |
Directory to download model checkpoints to |
|
|
151 |
calibrator_path: str |
|
|
152 |
Path to calibrator pickle file corresponding with model |
|
|
153 |
device: str |
|
|
154 |
If provided, will run inference using this device. |
|
|
155 |
By default, uses GPU with the most free memory, if available. |
|
|
156 |
|
|
|
157 |
""" |
|
|
158 |
self._logger = get_logger() |
|
|
159 |
# Download if needed |
|
|
160 |
if isinstance(name_or_path, str) and (name_or_path in NAME_TO_FILE): |
|
|
161 |
name_or_path, calibrator_path = download_sybil(name_or_path, cache) |
|
|
162 |
|
|
|
163 |
elif not all(os.path.exists(p) for p in name_or_path): |
|
|
164 |
raise ValueError( |
|
|
165 |
"No saved model or local path: {}".format( |
|
|
166 |
[p for p in name_or_path if not os.path.exists(p)] |
|
|
167 |
) |
|
|
168 |
) |
|
|
169 |
|
|
|
170 |
# Check calibrator path before continuing |
|
|
171 |
if (calibrator_path is not None) and (not os.path.exists(calibrator_path)): |
|
|
172 |
raise ValueError(f"Path not found for calibrator {calibrator_path}") |
|
|
173 |
|
|
|
174 |
# Set device. |
|
|
175 |
# If set manually, use it and stay there. |
|
|
176 |
# Otherwise, pick the most free GPU now and at predict time. |
|
|
177 |
self._device_flexible = True |
|
|
178 |
if device is not None: |
|
|
179 |
self.device = device |
|
|
180 |
self._device_flexible = False |
|
|
181 |
else: |
|
|
182 |
self.device = get_default_device() |
|
|
183 |
|
|
|
184 |
self.ensemble = torch.nn.ModuleList() |
|
|
185 |
for path in name_or_path: |
|
|
186 |
self.ensemble.append(self.load_model(path)) |
|
|
187 |
self.to(self.device) |
|
|
188 |
|
|
|
189 |
if calibrator_path is not None: |
|
|
190 |
self.calibrator = SimpleClassifierGroup.from_json_grouped(calibrator_path) |
|
|
191 |
else: |
|
|
192 |
self.calibrator = None |
|
|
193 |
|
|
|
194 |
def load_model(self, path): |
|
|
195 |
"""Load model from path. |
|
|
196 |
|
|
|
197 |
Parameters |
|
|
198 |
---------- |
|
|
199 |
path : str |
|
|
200 |
Path to a sybil checkpoint. |
|
|
201 |
|
|
|
202 |
Returns |
|
|
203 |
------- |
|
|
204 |
model |
|
|
205 |
Pretrained Sybil model |
|
|
206 |
""" |
|
|
207 |
# Load checkpoint |
|
|
208 |
checkpoint = torch.load(path, map_location="cpu") |
|
|
209 |
args = checkpoint["args"] |
|
|
210 |
self._max_followup = args.max_followup |
|
|
211 |
self._censoring_dist = args.censoring_distribution |
|
|
212 |
model = SybilNet(args) |
|
|
213 |
|
|
|
214 |
# Remove model from param names |
|
|
215 |
state_dict = {k[6:]: v for k, v in checkpoint["state_dict"].items()} |
|
|
216 |
model.load_state_dict(state_dict) # type: ignore |
|
|
217 |
if self.device is not None: |
|
|
218 |
model.to(self.device) |
|
|
219 |
|
|
|
220 |
# Set eval |
|
|
221 |
model.eval() |
|
|
222 |
self._logger.info(f"Loaded model from {path}") |
|
|
223 |
return model |
|
|
224 |
|
|
|
225 |
def _calibrate(self, scores: np.ndarray) -> np.ndarray: |
|
|
226 |
"""Calibrate raw predictions |
|
|
227 |
|
|
|
228 |
Parameters |
|
|
229 |
---------- |
|
|
230 |
scores: np.ndarray |
|
|
231 |
risk scores as numpy array |
|
|
232 |
|
|
|
233 |
Returns |
|
|
234 |
------- |
|
|
235 |
np.ndarray: calibrated risk scores as numpy array |
|
|
236 |
""" |
|
|
237 |
if self.calibrator is None: |
|
|
238 |
return scores |
|
|
239 |
|
|
|
240 |
calibrated_scores = [] |
|
|
241 |
for YEAR in range(scores.shape[1]): |
|
|
242 |
probs = scores[:, YEAR].reshape(-1, 1) |
|
|
243 |
probs = self.calibrator["Year{}".format(YEAR + 1)].predict_proba(probs)[:, -1] |
|
|
244 |
calibrated_scores.append(probs) |
|
|
245 |
|
|
|
246 |
return np.stack(calibrated_scores, axis=1) |
|
|
247 |
|
|
|
248 |
def _predict( |
|
|
249 |
self, |
|
|
250 |
model: SybilNet, |
|
|
251 |
series: Union[Serie, List[Serie]], |
|
|
252 |
return_attentions: bool = False, |
|
|
253 |
) -> Prediction: |
|
|
254 |
"""Run predictions over the given serie(s). |
|
|
255 |
|
|
|
256 |
Parameters |
|
|
257 |
---------- |
|
|
258 |
model: SybilNet |
|
|
259 |
Instance of SybilNet |
|
|
260 |
series : Union[Serie, Iterable[Serie]] |
|
|
261 |
One or multiple series to run predictions for. |
|
|
262 |
return_attentions : bool |
|
|
263 |
If True, returns attention scores for each serie. See README for details. |
|
|
264 |
|
|
|
265 |
Returns |
|
|
266 |
------- |
|
|
267 |
Prediction |
|
|
268 |
Output prediction as risk scores. |
|
|
269 |
|
|
|
270 |
""" |
|
|
271 |
if isinstance(series, Serie): |
|
|
272 |
series = [series] |
|
|
273 |
elif not isinstance(series, list): |
|
|
274 |
raise ValueError("Expected either a Serie object or list of Serie objects.") |
|
|
275 |
|
|
|
276 |
scores: List[List[float]] = [] |
|
|
277 |
attentions: List[Dict[str, np.ndarray]] = [] if return_attentions else None |
|
|
278 |
for serie in series: |
|
|
279 |
if not isinstance(serie, Serie): |
|
|
280 |
raise ValueError("Expected a list of Serie objects.") |
|
|
281 |
|
|
|
282 |
volume = serie.get_volume() |
|
|
283 |
if self.device is not None: |
|
|
284 |
volume = volume.to(self.device) |
|
|
285 |
|
|
|
286 |
with torch.no_grad(): |
|
|
287 |
out = model(volume) |
|
|
288 |
score = out["logit"].sigmoid().squeeze(0).cpu().numpy() |
|
|
289 |
scores.append(score.tolist()) |
|
|
290 |
if return_attentions: |
|
|
291 |
attentions.append( |
|
|
292 |
{ |
|
|
293 |
"image_attention_1": out["image_attention_1"] |
|
|
294 |
.detach() |
|
|
295 |
.cpu(), |
|
|
296 |
"volume_attention_1": out["volume_attention_1"] |
|
|
297 |
.detach() |
|
|
298 |
.cpu(), |
|
|
299 |
"hidden": out["hidden"] |
|
|
300 |
.detach() |
|
|
301 |
.cpu(), |
|
|
302 |
} |
|
|
303 |
) |
|
|
304 |
|
|
|
305 |
return Prediction(scores=scores, attentions=attentions) |
|
|
306 |
|
|
|
307 |
def predict( |
|
|
308 |
self, series: Union[Serie, List[Serie]], return_attentions: bool = False, threads=0, |
|
|
309 |
) -> Prediction: |
|
|
310 |
"""Run predictions over the given serie(s) and ensemble |
|
|
311 |
|
|
|
312 |
Parameters |
|
|
313 |
---------- |
|
|
314 |
series : Union[Serie, Iterable[Serie]] |
|
|
315 |
One or multiple series to run predictions for. |
|
|
316 |
return_attentions : bool |
|
|
317 |
If True, returns attention scores for each serie. See README for details. |
|
|
318 |
threads : int |
|
|
319 |
Number of CPU threads to use for PyTorch inference. |
|
|
320 |
|
|
|
321 |
Returns |
|
|
322 |
------- |
|
|
323 |
Prediction |
|
|
324 |
Output prediction. See details for :class:`~sybil.model.Prediction`". |
|
|
325 |
|
|
|
326 |
""" |
|
|
327 |
|
|
|
328 |
# Set CPU threads available to torch |
|
|
329 |
num_threads = _torch_set_num_threads(threads) |
|
|
330 |
self._logger.debug(f"Using {num_threads} threads for PyTorch inference") |
|
|
331 |
|
|
|
332 |
if self._device_flexible: |
|
|
333 |
self.device = self._pick_device() |
|
|
334 |
self.to(self.device) |
|
|
335 |
self._logger.debug(f"Beginning prediction on device: {self.device}") |
|
|
336 |
|
|
|
337 |
scores = [] |
|
|
338 |
attentions_ = [] if return_attentions else None |
|
|
339 |
attention_keys = None |
|
|
340 |
for sybil in self.ensemble: |
|
|
341 |
pred = self._predict(sybil, series, return_attentions) |
|
|
342 |
scores.append(pred.scores) |
|
|
343 |
if return_attentions: |
|
|
344 |
attentions_.append(pred.attentions) |
|
|
345 |
if attention_keys is None: |
|
|
346 |
attention_keys = pred.attentions[0].keys() |
|
|
347 |
|
|
|
348 |
scores = np.mean(np.array(scores), axis=0) |
|
|
349 |
calib_scores = self._calibrate(scores).tolist() |
|
|
350 |
|
|
|
351 |
attentions = None |
|
|
352 |
if return_attentions: |
|
|
353 |
attentions = [] |
|
|
354 |
for i in range(len(series)): |
|
|
355 |
att = {} |
|
|
356 |
for key in attention_keys: |
|
|
357 |
att[key] = np.stack([ |
|
|
358 |
attentions_[j][i][key] for j in range(len(self.ensemble)) |
|
|
359 |
]) |
|
|
360 |
attentions.append(att) |
|
|
361 |
|
|
|
362 |
return Prediction(scores=calib_scores, attentions=attentions) |
|
|
363 |
|
|
|
364 |
def evaluate( |
|
|
365 |
self, series: Union[Serie, List[Serie]], return_attentions: bool = False |
|
|
366 |
) -> Evaluation: |
|
|
367 |
"""Run evaluation over the given serie(s). |
|
|
368 |
|
|
|
369 |
Parameters |
|
|
370 |
---------- |
|
|
371 |
series : Union[Serie, List[Serie]] |
|
|
372 |
One or multiple series to run evaluation for. |
|
|
373 |
return_attentions : bool |
|
|
374 |
If True, returns attention scores for each serie. See README for details. |
|
|
375 |
|
|
|
376 |
Returns |
|
|
377 |
------- |
|
|
378 |
Evaluation |
|
|
379 |
Output evaluation. See details for :class:`~sybil.model.Evaluation`. |
|
|
380 |
|
|
|
381 |
""" |
|
|
382 |
from sybil.utils.metrics import get_survival_metrics |
|
|
383 |
if isinstance(series, Serie): |
|
|
384 |
series = [series] |
|
|
385 |
elif not isinstance(series, list): |
|
|
386 |
raise ValueError( |
|
|
387 |
"Expected either a Serie object or an iterable over Serie objects." |
|
|
388 |
) |
|
|
389 |
|
|
|
390 |
# Check all have labels |
|
|
391 |
if not all(serie.has_label() for serie in series): |
|
|
392 |
raise ValueError("All series must have a label for evaluation") |
|
|
393 |
|
|
|
394 |
# Get scores and labels |
|
|
395 |
predictions = self.predict(series, return_attentions) |
|
|
396 |
scores = predictions.scores |
|
|
397 |
labels = [serie.get_label(self._max_followup) for serie in series] |
|
|
398 |
|
|
|
399 |
# Convert to format for survival metrics |
|
|
400 |
input_dict = { |
|
|
401 |
"probs": torch.tensor(scores), |
|
|
402 |
"censors": torch.tensor([label.censor_time for label in labels]), |
|
|
403 |
"golds": torch.tensor([label.y for label in labels]), |
|
|
404 |
} |
|
|
405 |
args = Namespace( |
|
|
406 |
max_followup=self._max_followup, censoring_distribution=self._censoring_dist |
|
|
407 |
) |
|
|
408 |
out = get_survival_metrics(input_dict, args) |
|
|
409 |
auc = [float(out[f"{i + 1}_year_auc"]) for i in range(self._max_followup)] |
|
|
410 |
c_index = float(out["c_index"]) |
|
|
411 |
|
|
|
412 |
return Evaluation(auc=auc, c_index=c_index, scores=scores, attentions=predictions.attentions) |
|
|
413 |
|
|
|
414 |
def to(self, device: str): |
|
|
415 |
"""Move model to device. |
|
|
416 |
|
|
|
417 |
Parameters |
|
|
418 |
---------- |
|
|
419 |
device : str |
|
|
420 |
Device to move model to. |
|
|
421 |
""" |
|
|
422 |
self.device = device |
|
|
423 |
self.ensemble.to(device) |
|
|
424 |
|
|
|
425 |
def _pick_device(self): |
|
|
426 |
""" |
|
|
427 |
Pick the device to run inference on. |
|
|
428 |
This is based on the device with the most free memory, with a preference for remaining |
|
|
429 |
on the current device. |
|
|
430 |
|
|
|
431 |
Motivation is to enable multiprocessing without the processes needed to communicate. |
|
|
432 |
""" |
|
|
433 |
if not torch.cuda.is_available(): |
|
|
434 |
return get_default_device() |
|
|
435 |
|
|
|
436 |
# Get size of the model in memory (approximate) |
|
|
437 |
model_mem = 9*sum(p.numel() * p.element_size() for p in self.ensemble.parameters()) |
|
|
438 |
|
|
|
439 |
# Check memory available on current device. |
|
|
440 |
# If it seems like we're the only thing on this GPU, stay. |
|
|
441 |
free_mem, total_mem = get_device_mem_info(self.device) |
|
|
442 |
cur_allocated = total_mem - free_mem |
|
|
443 |
min_to_move = int(1.01 * model_mem) |
|
|
444 |
if cur_allocated < min_to_move: |
|
|
445 |
return self.device |
|
|
446 |
else: |
|
|
447 |
# Otherwise, get the most free GPU |
|
|
448 |
return get_most_free_gpu() |