[96354c]: / src / uncertainty / test_time_dropout.py

Download this file

17 lines (9 with data), 560 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
from tqdm import tqdm
from src.test import predict
def ttd_uncertainty_loop(model, images, device, K=2):
prediction_labels_maps, prediction_score_vectors = [], []
for _ in tqdm(range(K), desc="Predicting.."):
prediction_four_channels, vector_prediction_scores = predict.predict(model, images, device, monte_carlo=True)
prediction_labels_maps.append(predict.get_prediction_map(prediction_four_channels))
prediction_score_vectors.append(vector_prediction_scores)
return prediction_labels_maps, prediction_score_vectors