Diff of /utils/segment/metrics.py [000000] .. [190ca4]

Switch to unified view

a b/utils/segment/metrics.py
1
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
2
"""
3
Model validation metrics
4
"""
5
6
import numpy as np
7
8
from ..metrics import ap_per_class
9
10
11
def fitness(x):
12
    # Model fitness as a weighted combination of metrics
13
    w = [0.0, 0.0, 0.1, 0.9, 0.0, 0.0, 0.1, 0.9]
14
    return (x[:, :8] * w).sum(1)
15
16
17
def ap_per_class_box_and_mask(
18
        tp_m,
19
        tp_b,
20
        conf,
21
        pred_cls,
22
        target_cls,
23
        plot=False,
24
        save_dir='.',
25
        names=(),
26
):
27
    """
28
    Args:
29
        tp_b: tp of boxes.
30
        tp_m: tp of masks.
31
        other arguments see `func: ap_per_class`.
32
    """
33
    results_boxes = ap_per_class(tp_b,
34
                                 conf,
35
                                 pred_cls,
36
                                 target_cls,
37
                                 plot=plot,
38
                                 save_dir=save_dir,
39
                                 names=names,
40
                                 prefix='Box')[2:]
41
    results_masks = ap_per_class(tp_m,
42
                                 conf,
43
                                 pred_cls,
44
                                 target_cls,
45
                                 plot=plot,
46
                                 save_dir=save_dir,
47
                                 names=names,
48
                                 prefix='Mask')[2:]
49
50
    results = {
51
        'boxes': {
52
            'p': results_boxes[0],
53
            'r': results_boxes[1],
54
            'ap': results_boxes[3],
55
            'f1': results_boxes[2],
56
            'ap_class': results_boxes[4]},
57
        'masks': {
58
            'p': results_masks[0],
59
            'r': results_masks[1],
60
            'ap': results_masks[3],
61
            'f1': results_masks[2],
62
            'ap_class': results_masks[4]}}
63
    return results
64
65
66
class Metric:
67
68
    def __init__(self) -> None:
69
        self.p = []  # (nc, )
70
        self.r = []  # (nc, )
71
        self.f1 = []  # (nc, )
72
        self.all_ap = []  # (nc, 10)
73
        self.ap_class_index = []  # (nc, )
74
75
    @property
76
    def ap50(self):
77
        """AP@0.5 of all classes.
78
        Return:
79
            (nc, ) or [].
80
        """
81
        return self.all_ap[:, 0] if len(self.all_ap) else []
82
83
    @property
84
    def ap(self):
85
        """AP@0.5:0.95
86
        Return:
87
            (nc, ) or [].
88
        """
89
        return self.all_ap.mean(1) if len(self.all_ap) else []
90
91
    @property
92
    def mp(self):
93
        """mean precision of all classes.
94
        Return:
95
            float.
96
        """
97
        return self.p.mean() if len(self.p) else 0.0
98
99
    @property
100
    def mr(self):
101
        """mean recall of all classes.
102
        Return:
103
            float.
104
        """
105
        return self.r.mean() if len(self.r) else 0.0
106
107
    @property
108
    def map50(self):
109
        """Mean AP@0.5 of all classes.
110
        Return:
111
            float.
112
        """
113
        return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
114
115
    @property
116
    def map(self):
117
        """Mean AP@0.5:0.95 of all classes.
118
        Return:
119
            float.
120
        """
121
        return self.all_ap.mean() if len(self.all_ap) else 0.0
122
123
    def mean_results(self):
124
        """Mean of results, return mp, mr, map50, map"""
125
        return (self.mp, self.mr, self.map50, self.map)
126
127
    def class_result(self, i):
128
        """class-aware result, return p[i], r[i], ap50[i], ap[i]"""
129
        return (self.p[i], self.r[i], self.ap50[i], self.ap[i])
130
131
    def get_maps(self, nc):
132
        maps = np.zeros(nc) + self.map
133
        for i, c in enumerate(self.ap_class_index):
134
            maps[c] = self.ap[i]
135
        return maps
136
137
    def update(self, results):
138
        """
139
        Args:
140
            results: tuple(p, r, ap, f1, ap_class)
141
        """
142
        p, r, all_ap, f1, ap_class_index = results
143
        self.p = p
144
        self.r = r
145
        self.all_ap = all_ap
146
        self.f1 = f1
147
        self.ap_class_index = ap_class_index
148
149
150
class Metrics:
151
    """Metric for boxes and masks."""
152
153
    def __init__(self) -> None:
154
        self.metric_box = Metric()
155
        self.metric_mask = Metric()
156
157
    def update(self, results):
158
        """
159
        Args:
160
            results: Dict{'boxes': Dict{}, 'masks': Dict{}}
161
        """
162
        self.metric_box.update(list(results['boxes'].values()))
163
        self.metric_mask.update(list(results['masks'].values()))
164
165
    def mean_results(self):
166
        return self.metric_box.mean_results() + self.metric_mask.mean_results()
167
168
    def class_result(self, i):
169
        return self.metric_box.class_result(i) + self.metric_mask.class_result(i)
170
171
    def get_maps(self, nc):
172
        return self.metric_box.get_maps(nc) + self.metric_mask.get_maps(nc)
173
174
    @property
175
    def ap_class_index(self):
176
        # boxes and masks have the same ap_class_index
177
        return self.metric_box.ap_class_index
178
179
180
KEYS = [
181
    'train/box_loss',
182
    'train/seg_loss',  # train loss
183
    'train/obj_loss',
184
    'train/cls_loss',
185
    'metrics/precision(B)',
186
    'metrics/recall(B)',
187
    'metrics/mAP_0.5(B)',
188
    'metrics/mAP_0.5:0.95(B)',  # metrics
189
    'metrics/precision(M)',
190
    'metrics/recall(M)',
191
    'metrics/mAP_0.5(M)',
192
    'metrics/mAP_0.5:0.95(M)',  # metrics
193
    'val/box_loss',
194
    'val/seg_loss',  # val loss
195
    'val/obj_loss',
196
    'val/cls_loss',
197
    'x/lr0',
198
    'x/lr1',
199
    'x/lr2', ]
200
201
BEST_KEYS = [
202
    'best/epoch',
203
    'best/precision(B)',
204
    'best/recall(B)',
205
    'best/mAP_0.5(B)',
206
    'best/mAP_0.5:0.95(B)',
207
    'best/precision(M)',
208
    'best/recall(M)',
209
    'best/mAP_0.5(M)',
210
    'best/mAP_0.5:0.95(M)', ]