Diff of /evaluation/metrics.py [000000] .. [92cc18]

Switch to unified view

a b/evaluation/metrics.py
1
import matplotlib.pyplot as plt
2
import torch
3
import torch.nn as nn
4
5
6
def iou(outputs: torch.Tensor, labels: torch.Tensor):
7
    SMOOTH = 1e-6
8
    outputs = (outputs.squeeze(1) > 0.5).int()
9
    labels = (labels > 0.5).int()
10
    intersection = torch.sum((outputs & labels).float())
11
    union = torch.sum((outputs | labels).float())
12
13
    iou = (intersection + SMOOTH) / (union + SMOOTH)
14
    return iou
15
16
17
def sis(new_mask, old_mask, new_seg, old_seg):
18
    def difference(mask1, mask2):
19
        return torch.round(mask1) * (1 - torch.round(mask2)) + torch.round(mask2) * (
20
                1 - torch.round(mask1))
21
22
    epsilon = 1e-5
23
    sis = torch.sum(
24
        difference(
25
            difference(new_mask, old_mask),
26
            difference(new_seg, old_seg))
27
    ) / torch.sum(torch.clamp(new_mask + old_mask + new_seg + old_seg, 0, 1) + epsilon)  # normalizing factor
28
    return sis
29
30
31
def precision(output, labels, threshold):
32
    t = (output > threshold).float()
33
    tp = torch.sum(t * labels)
34
    fp = torch.sum(t * (1 - labels))
35
    return tp / (tp + fp + 1e-5)
36
37
38
def recall(output, labels, threshold):
39
    t = (output > threshold).float()
40
    tp = torch.sum(t * labels)
41
    fn = torch.sum((1 - t) * labels)
42
    return tp / (tp + fn + 1e-5)
43
44
45
def tp_rate(output, labels, threshold):
46
    t = (output > threshold).float()
47
    tp = torch.sum(t * labels)
48
    fn = torch.sum((1 - t) * labels)
49
    return tp / (tp + fn + 1e-5)
50
51
52
def fp_rate(output, labels, threshold):
53
    t = (output > threshold).float()
54
    fp = torch.sum(t * (1 - labels))
55
    tn = torch.sum((1 - t) * (1 - labels))
56
    return fp / (fp + tn + 1e-5)
57
58
59
if __name__ == '__main__':
60
    test = torch.zeros(10, 10)
61
    test[:3, :3] = 1
62
    test2 = torch.zeros(10, 10)
63
    test2[:3, :3] = 1
64
    print(iou(test, test2))