|
a |
|
b/metrics/meansensitivity.py |
|
|
1 |
import torch |
|
|
2 |
import numpy as np |
|
|
3 |
|
|
|
4 |
|
|
|
5 |
def get_sensitivity(SR, GT, threshold=0.5): |
|
|
6 |
""" |
|
|
7 |
cal each img sensitivity |
|
|
8 |
""" |
|
|
9 |
# Sensitivity == Recall |
|
|
10 |
SR = SR > threshold |
|
|
11 |
GT = GT == torch.max(GT) |
|
|
12 |
|
|
|
13 |
# TP : True Positive |
|
|
14 |
# FN : False Negative |
|
|
15 |
TP = ((SR == 1) + (GT == 1)) == 2 |
|
|
16 |
FN = ((SR == 0) + (GT == 1)) == 2 |
|
|
17 |
|
|
|
18 |
SE = float(torch.sum(TP)) / (float(torch.sum(TP + FN)) + 1e-6) |
|
|
19 |
|
|
|
20 |
return SE |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
def meansensitivity_seg(pred, gt, sens): |
|
|
24 |
""" |
|
|
25 |
:return save img' sensitivity values in sens |
|
|
26 |
""" |
|
|
27 |
gt_tensor = gt |
|
|
28 |
gt_tensor = gt_tensor.cpu() |
|
|
29 |
pred[pred < 0.5] = 0 |
|
|
30 |
pred[pred >= 0.5] = 1 |
|
|
31 |
pred = pred.type(torch.LongTensor) |
|
|
32 |
# TO CPU |
|
|
33 |
# pred_np = pred.data.cpu().numpy() |
|
|
34 |
# gt = gt.data.cpu().numpy() |
|
|
35 |
for x in range(pred.size()[0]): |
|
|
36 |
sen = get_sensitivity(pred[x], gt_tensor[x]) |
|
|
37 |
sens.append(sen) |
|
|
38 |
return sens |