Diff of /Metrics.py [000000] .. [b52eda]

Switch to unified view

a b/Metrics.py
1
from torch import LongTensor, isin, argwhere
2
from numpy import ndarray, array, append
3
4
def __Accuracy__(truth: LongTensor, test: LongTensor, value: int) -> float:
5
    r"""
6
        Arguments:
7
            truth (torch.LongTensor): Ground truth segmentation.
8
            test (torch.LongTensor): GNN segmentation result.
9
            value (int): Value for which the accuracy will be returned.
10
        
11
        Returns:
12
            out (float): Segmentation accuracy for given value.
13
    """
14
    mask = argwhere(isin(truth, value))
15
    count = (test[mask] == value).sum().item()
16
    return count / mask.shape[0] if mask.shape[0] != 0 else -1
17
18
def __Calculate_Accuracy__(truth: LongTensor, test: LongTensor) -> ndarray:
19
    r"""
20
        Arguments:
21
            truth (torch.LongTensor): Ground truth segmentation.
22
            test (torch.LongTensor): GNN segmentation result.
23
        
24
        Returns:
25
            out (numpy.ndarray): Segmentation accuracies for all values.
26
    """
27
    out = array([])
28
    for i in range(0, 8):
29
        out = append(out, __Accuracy__(truth, test, i))
30
    return out
31
32
def __Average_Accuracy__(acc_arr: ndarray) -> float:
33
    return acc_arr[acc_arr > -1].sum() / acc_arr[acc_arr > -1].shape[0] \
34
        if acc_arr[acc_arr > -1].shape[0] != 0 else -1
35
36
def Accuracy_Util(truth: LongTensor, test: LongTensor) -> ndarray:
37
    r"""
38
        Arguments:
39
            truth (torch.LongTensor): Ground truth segmentation.
40
            test (torch.LongTensor): GNN segmentation result.
41
        
42
        Returns:
43
            out (numpy.ndarray): Accuracy info.
44
    """
45
    out = __Calculate_Accuracy__(truth, test)
46
    out = append(out, __Average_Accuracy__(out[1:]))
47
    out = append(out, __Average_Accuracy__(out))
48
    return out