Switch to unified view

a b/metrics/meansensitivity.py
1
import torch
2
import numpy as np
3
4
5
def get_sensitivity(SR, GT, threshold=0.5):
6
    """
7
    cal each img sensitivity
8
    """
9
    # Sensitivity == Recall
10
    SR = SR > threshold
11
    GT = GT == torch.max(GT)
12
13
    # TP : True Positive
14
    # FN : False Negative
15
    TP = ((SR == 1) + (GT == 1)) == 2
16
    FN = ((SR == 0) + (GT == 1)) == 2
17
18
    SE = float(torch.sum(TP)) / (float(torch.sum(TP + FN)) + 1e-6)
19
20
    return SE
21
22
23
def meansensitivity_seg(pred, gt, sens):
24
    """
25
    :return save img' sensitivity values in sens
26
    """
27
    gt_tensor = gt
28
    gt_tensor = gt_tensor.cpu()
29
    pred[pred < 0.5] = 0
30
    pred[pred >= 0.5] = 1
31
    pred = pred.type(torch.LongTensor)
32
    # TO CPU
33
    # pred_np = pred.data.cpu().numpy()
34
    # gt = gt.data.cpu().numpy()
35
    for x in range(pred.size()[0]):
36
        sen = get_sensitivity(pred[x], gt_tensor[x])
37
        sens.append(sen)
38
    return sens