Diff of /medseg/uncertainty.py [000000] .. [a22922]

Switch to side-by-side view

--- a
+++ b/medseg/uncertainty.py
@@ -0,0 +1,24 @@
+import torch
+from torch.nn import Module
+from torch import Tensor
+from typing import Tuple
+from tqdm import tqdm
+
+
+def get_epistemic_uncertainty(model: Module, x: Tensor, n: int = 10) -> Tuple[Tensor, Tensor]:
+    '''
+    Estimates epistemic uncertainty with n monte carlo predictions of model on x.
+
+    Returns:
+        standard deviation uncertainty, mean prediction
+    '''
+    model = model.train()
+    with torch.no_grad():
+        uncertain_preds = [model(x).detach().cpu() for _ in tqdm(range(n), leave=False)]
+    model = model.eval()
+
+    uncertain_preds_tensor = torch.stack(uncertain_preds)
+    epistemic_uncertainty = uncertain_preds_tensor.std(dim=0)
+    mean_prediction = uncertain_preds_tensor.mean(dim=0)
+    
+    return epistemic_uncertainty, mean_prediction