Diff of /medseg/uncertainty.py [000000] .. [a22922]

Switch to unified view

a b/medseg/uncertainty.py
1
import torch
2
from torch.nn import Module
3
from torch import Tensor
4
from typing import Tuple
5
from tqdm import tqdm
6
7
8
def get_epistemic_uncertainty(model: Module, x: Tensor, n: int = 10) -> Tuple[Tensor, Tensor]:
9
    '''
10
    Estimates epistemic uncertainty with n monte carlo predictions of model on x.
11
12
    Returns:
13
        standard deviation uncertainty, mean prediction
14
    '''
15
    model = model.train()
16
    with torch.no_grad():
17
        uncertain_preds = [model(x).detach().cpu() for _ in tqdm(range(n), leave=False)]
18
    model = model.eval()
19
20
    uncertain_preds_tensor = torch.stack(uncertain_preds)
21
    epistemic_uncertainty = uncertain_preds_tensor.std(dim=0)
22
    mean_prediction = uncertain_preds_tensor.mean(dim=0)
23
    
24
    return epistemic_uncertainty, mean_prediction