|
a |
|
b/medseg/uncertainty.py |
|
|
1 |
import torch |
|
|
2 |
from torch.nn import Module |
|
|
3 |
from torch import Tensor |
|
|
4 |
from typing import Tuple |
|
|
5 |
from tqdm import tqdm |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
def get_epistemic_uncertainty(model: Module, x: Tensor, n: int = 10) -> Tuple[Tensor, Tensor]: |
|
|
9 |
''' |
|
|
10 |
Estimates epistemic uncertainty with n monte carlo predictions of model on x. |
|
|
11 |
|
|
|
12 |
Returns: |
|
|
13 |
standard deviation uncertainty, mean prediction |
|
|
14 |
''' |
|
|
15 |
model = model.train() |
|
|
16 |
with torch.no_grad(): |
|
|
17 |
uncertain_preds = [model(x).detach().cpu() for _ in tqdm(range(n), leave=False)] |
|
|
18 |
model = model.eval() |
|
|
19 |
|
|
|
20 |
uncertain_preds_tensor = torch.stack(uncertain_preds) |
|
|
21 |
epistemic_uncertainty = uncertain_preds_tensor.std(dim=0) |
|
|
22 |
mean_prediction = uncertain_preds_tensor.mean(dim=0) |
|
|
23 |
|
|
|
24 |
return epistemic_uncertainty, mean_prediction |