Diff of /Metrics.py [000000] .. [b52eda]

Switch to side-by-side view

--- a
+++ b/Metrics.py
@@ -0,0 +1,48 @@
+from torch import LongTensor, isin, argwhere
+from numpy import ndarray, array, append
+
+def __Accuracy__(truth: LongTensor, test: LongTensor, value: int) -> float:
+    r"""
+        Arguments:
+            truth (torch.LongTensor): Ground truth segmentation.
+            test (torch.LongTensor): GNN segmentation result.
+            value (int): Value for which the accuracy will be returned.
+        
+        Returns:
+            out (float): Segmentation accuracy for given value.
+    """
+    mask = argwhere(isin(truth, value))
+    count = (test[mask] == value).sum().item()
+    return count / mask.shape[0] if mask.shape[0] != 0 else -1
+
+def __Calculate_Accuracy__(truth: LongTensor, test: LongTensor) -> ndarray:
+    r"""
+        Arguments:
+            truth (torch.LongTensor): Ground truth segmentation.
+            test (torch.LongTensor): GNN segmentation result.
+        
+        Returns:
+            out (numpy.ndarray): Segmentation accuracies for all values.
+    """
+    out = array([])
+    for i in range(0, 8):
+        out = append(out, __Accuracy__(truth, test, i))
+    return out
+
+def __Average_Accuracy__(acc_arr: ndarray) -> float:
+    return acc_arr[acc_arr > -1].sum() / acc_arr[acc_arr > -1].shape[0] \
+        if acc_arr[acc_arr > -1].shape[0] != 0 else -1
+
+def Accuracy_Util(truth: LongTensor, test: LongTensor) -> ndarray:
+    r"""
+        Arguments:
+            truth (torch.LongTensor): Ground truth segmentation.
+            test (torch.LongTensor): GNN segmentation result.
+        
+        Returns:
+            out (numpy.ndarray): Accuracy info.
+    """
+    out = __Calculate_Accuracy__(truth, test)
+    out = append(out, __Average_Accuracy__(out[1:]))
+    out = append(out, __Average_Accuracy__(out))
+    return out
\ No newline at end of file