Switch to unified view

a b/metrics/meanspecificity.py
1
import torch
2
import numpy as np
3
4
5
def get_specificity(SR, GT, threshold=0.5):
6
    """
7
    cal each img specificity
8
    所有负例中被分对的概率
9
    结节在在输入图片中所占比例较少 所以该指标的值很高
10
    """
11
    # Sensitivity == Recall
12
    SR = SR > threshold
13
    GT = GT == torch.max(GT)
14
15
    # TP : True Positive
16
    # FN : False Negative
17
    # TP = ((SR == 1) + (GT == 1)) == 2
18
    # FN = ((SR == 0) + (GT == 1)) == 2
19
    # TN : True Negative
20
    # FP : False Positive
21
    TN = ((SR == 0) + (GT == 0)) == 2
22
    FP = ((SR == 1) + (GT == 0)) == 2
23
24
    SE = float(torch.sum(TN)) / (float(torch.sum(TN + FP)) + 1e-6)
25
26
    return SE
27
28
29
def meanspecificity_seg(pred, gt, spes):
30
    """
31
    :return save img' sensitivity values in sens
32
    """
33
    gt_tensor = gt
34
    gt_tensor = gt_tensor.cpu()
35
    pred[pred < 0.5] = 0
36
    pred[pred >= 0.5] = 1
37
    pred = pred.type(torch.LongTensor)
38
    # TO CPU
39
    # pred_np = pred.data.cpu().numpy()
40
    # gt = gt.data.cpu().numpy()
41
    for x in range(pred.size()[0]):
42
        spe = get_specificity(pred[x], gt_tensor[x])
43
        spes.append(spe)
44
    return spes