Switch to unified view

a b/src/extensions/torchmetrics.py
1
# Base Dependencies
2
# -----------------
3
import numpy as np
4
from typing import Optional
5
6
# 3rd-Party Dependencies
7
# -----------------
8
import torch
9
from torch import Tensor
10
from torchmetrics import Metric
11
from torchmetrics.utilities.checks import _input_format_classification
12
from sklearn.metrics import precision_score, recall_score, f1_score
13
14
15
def _make_binary(preds: torch.Tensor, target: torch.Tensor):
16
17
    # obtain decimal values from one-hot encoding
18
    preds2 = preds.argmax(axis=1).int()
19
    target2 = target.argmax(axis=1).int()
20
    
21
    # replace positive classes by 1
22
    preds2[preds2 != 0] = 1
23
    target2[target2 != 0] = 1
24
25
    return preds2, target2
26
27
28
class DetectionF1Score(Metric):
29
    def __init__(self, ) -> None:
30
        super().__init__()
31
        self.add_state("y_true", default=torch.Tensor([]).int())
32
        self.add_state("y_pred", default=torch.Tensor([]).int())
33
34
    def update(self, preds: torch.Tensor, target: torch.Tensor): 
35
        preds, target, datatype = _input_format_classification(preds, target)
36
        p, t = _make_binary(preds, target)
37
        self.y_pred = torch.cat((self.y_pred, p))
38
        self.y_true = torch.cat([self.y_true, t])
39
40
    def compute(self) -> torch.Tensor:
41
        """Computes f-beta over state."""
42
        score =  f1_score(y_true=self.y_true.cpu().numpy(), y_pred=self.y_pred.cpu().numpy(), average="binary")
43
        return torch.tensor(score)
44
45
46
class DetectionPrecision(Metric):
47
    def __init__(self, ) -> None:
48
        super().__init__()
49
        self.add_state("y_true", default=torch.Tensor([]).int())
50
        self.add_state("y_pred", default=torch.Tensor([]).int())
51
52
    def update(self, preds: torch.Tensor, target: torch.Tensor): 
53
        preds, target, datatype = _input_format_classification(preds, target)
54
        p, t = _make_binary(preds, target)
55
        self.y_pred = torch.cat((self.y_pred, p))
56
        self.y_true = torch.cat([self.y_true, t])
57
58
    def compute(self) -> torch.Tensor: 
59
        score = precision_score(y_true=self.y_true.cpu().numpy(), y_pred=self.y_pred.cpu().numpy(), average="binary")
60
        return torch.tensor(score)
61
62
63
class DetectionRecall(Metric):
64
    def __init__(self, ) -> None:
65
        super().__init__()
66
        self.add_state("y_true", default=torch.Tensor([]).int())
67
        self.add_state("y_pred", default=torch.Tensor([]).int())
68
69
    def update(self, preds: torch.Tensor, target: torch.Tensor): 
70
        preds, target, datatype = _input_format_classification(preds, target)
71
        p, t = _make_binary(preds, target)
72
        self.y_pred = torch.cat((self.y_pred, p))
73
        self.y_true = torch.cat([self.y_true, t])
74
75
    def compute(self) -> torch.Tensor:
76
        score = recall_score(y_true=self.y_true.cpu().numpy(), y_pred=self.y_pred.cpu().numpy(), average="binary")
77
        return torch.tensor(score)