a b/src/metrics/evaluation_metrics.py
1
from typing import Tuple
2
from medpy import metric
3
import numpy as np
4
5
6
def get_confusion_matrix(prediction: np.ndarray, reference: np.ndarray, roi_mask: np.ndarray) -> Tuple[int, int, int, int]:
7
    """
8
    Computes tp/fp/tn/fn from teh provided segmentations
9
    """
10
    assert prediction.shape == reference.shape, "'prediction' and 'reference' must have the same shape"
11
12
13
    tp = int((roi_mask*(prediction != 0) * (reference != 0)).sum()) # overlap
14
    fp = int((roi_mask*(prediction != 0) * (reference == 0)).sum())
15
    tn = int((roi_mask*(prediction == 0) * (reference == 0)).sum()) # no segmentation
16
    fn = int((roi_mask*(prediction == 0) * (reference != 0)).sum())
17
18
    return tp, fp, tn, fn
19
20
21
def dice(tp: int, fp:int, fn:int) -> float:
22
    """
23
    Dice coefficient computed using the definition of true positive (TP), false positive (FP), and false negative (FN)
24
    2TP / (2TP + FP + FN)
25
    """
26
    denominator = 2*tp + fp + fn
27
    if denominator <= 0:
28
        return 0
29
30
    return (2 * tp / denominator)
31
32
# Hausdorff
33
def hausdorff(prediction: np.ndarray, reference: np.ndarray) -> float:
34
    try:
35
        return metric.hd95(prediction, reference)
36
37
    except Exception as e:
38
        print("Error: ", e)
39
        print(f"Prediction does not contain the same label as gt. "
40
              f"Pred labels {np.unique(prediction)} GT labels {np.unique(reference)}")
41
        return 373
42
43
44
# Sensitivity: recall
45
def recall(tp, fn) -> float:
46
    """TP / (TP + FN)"""
47
    actual_positives = tp + fn
48
    if actual_positives <= 0:
49
        return 0
50
    return tp / actual_positives
51
52
# Specificity: precision
53
def precision(tp, fp) -> float:
54
    """TP/ (TP + FP)"""
55
    predicted_positives = tp + fp
56
    if predicted_positives <= 0:
57
        return 0
58
    return tp / predicted_positives
59
60
61
def fscore(tp, fp, tn, fn, beta:int=1) -> float:
62
    """(1 + b^2) * TP / ((1 + b^2) * TP + b^2 * FN + FP)"""
63
    assert beta > 0
64
65
    precision_ = precision(tn, fp)
66
    recall_ = recall(tp, fn)
67
68
    if ((beta * beta * precision_) + recall_) <= 0:
69
        return 0
70
71
    fscore = (1 + beta * beta) * precision_ * recall_ / ((beta * beta * precision_) + recall_)
72
    return fscore
73
74
75
def accuracy(tp, fp, tn, fn) -> float:
76
    """(TP + TN) / (TP + FP + FN + TN)"""
77
    if (tp + fp + tn + fn) <= 0:
78
        return 0
79
    return (tp + tn) / (tp + fp + tn + fn)
80
81