Switch to unified view

a b/opengait/evaluation/metric.py
1
import torch
2
import numpy as np
3
import torch.nn.functional as F
4
5
from utils import is_tensor
6
7
8
def cuda_dist(x, y, metric='euc'):
9
    x = torch.from_numpy(x).cuda()
10
    y = torch.from_numpy(y).cuda()
11
    if metric == 'cos':
12
        x = F.normalize(x, p=2, dim=1)  # n c p
13
        y = F.normalize(y, p=2, dim=1)  # n c p
14
    num_bin = x.size(2)
15
    n_x = x.size(0)
16
    n_y = y.size(0)
17
    dist = torch.zeros(n_x, n_y).cuda()
18
    for i in range(num_bin):
19
        _x = x[:, :, i]
20
        _y = y[:, :, i]
21
        if metric == 'cos':
22
            dist += torch.matmul(_x, _y.transpose(0, 1))
23
        else:
24
            _dist = torch.sum(_x ** 2, 1).unsqueeze(1) + torch.sum(_y ** 2, 1).unsqueeze(
25
                0) - 2 * torch.matmul(_x, _y.transpose(0, 1))
26
            dist += torch.sqrt(F.relu(_dist))
27
    return 1 - dist/num_bin if metric == 'cos' else dist / num_bin
28
29
30
def mean_iou(msk1, msk2, eps=1.0e-9):
31
    if not is_tensor(msk1):
32
        msk1 = torch.from_numpy(msk1).cuda()
33
    if not is_tensor(msk2):
34
        msk2 = torch.from_numpy(msk2).cuda()
35
    n = msk1.size(0)
36
    inter = msk1 * msk2
37
    union = ((msk1 + msk2) > 0.).float()
38
    miou = inter.view(n, -1).sum(-1) / (union.view(n, -1).sum(-1) + eps)
39
    return miou
40
41
42
def compute_ACC_mAP(distmat, q_pids, g_pids, q_views=None, g_views=None, rank=1):
43
    num_q, _ = distmat.shape
44
    # indices = np.argsort(distmat, axis=1)
45
    # matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
46
47
    all_ACC = []
48
    all_AP = []
49
    num_valid_q = 0.  # number of valid query
50
    for q_idx in range(num_q):
51
        q_idx_dist = distmat[q_idx]
52
        q_idx_glabels = g_pids
53
        if q_views is not None and g_views is not None:
54
            q_idx_mask = np.isin(g_views, q_views[q_idx], invert=True) | np.isin(
55
                g_pids, q_pids[q_idx], invert=True)
56
            q_idx_dist = q_idx_dist[q_idx_mask]
57
            q_idx_glabels = q_idx_glabels[q_idx_mask]
58
59
        assert(len(q_idx_glabels) >
60
               0), "No gallery after excluding identical-view cases!"
61
        q_idx_indices = np.argsort(q_idx_dist)
62
        q_idx_matches = (q_idx_glabels[q_idx_indices]
63
                         == q_pids[q_idx]).astype(np.int32)
64
65
        # binary vector, positions with value 1 are correct matches
66
        # orig_cmc = matches[q_idx]
67
        orig_cmc = q_idx_matches
68
        cmc = orig_cmc.cumsum()
69
        cmc[cmc > 1] = 1
70
        all_ACC.append(cmc[rank-1])
71
72
        # compute average precision
73
        # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
74
        num_rel = orig_cmc.sum()
75
76
        if num_rel > 0:
77
            num_valid_q += 1.
78
            tmp_cmc = orig_cmc.cumsum()
79
            tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
80
            tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
81
            AP = tmp_cmc.sum() / num_rel
82
            all_AP.append(AP)
83
84
    # all_ACC = np.asarray(all_ACC).astype(np.float32)
85
    ACC = np.mean(all_ACC)
86
    mAP = np.mean(all_AP)
87
88
    return ACC, mAP
89
90
91
def evaluate_rank(distmat, p_lbls, g_lbls, max_rank=50):
92
    '''
93
    Copy from https://github.com/Gait3D/Gait3D-Benchmark/blob/72beab994c137b902d826f4b9f9e95b107bebd78/lib/utils/rank.py#L12-L63
94
    '''
95
    num_p, num_g = distmat.shape
96
97
    if num_g < max_rank:
98
        max_rank = num_g
99
        print('Note: number of gallery samples is quite small, got {}'.format(num_g))
100
101
    indices = np.argsort(distmat, axis=1)
102
103
    matches = (g_lbls[indices] == p_lbls[:, np.newaxis]).astype(np.int32)
104
105
    # compute cmc curve for each probe
106
    all_cmc = []
107
    all_AP = []
108
    all_INP = []
109
    num_valid_p = 0.  # number of valid probe
110
111
    for p_idx in range(num_p):
112
        # compute cmc curve
113
        # binary vector, positions with value 1 are correct matches
114
        raw_cmc = matches[p_idx]
115
        if not np.any(raw_cmc):
116
            # this condition is true when probe identity does not appear in gallery
117
            continue
118
119
        cmc = raw_cmc.cumsum()
120
121
        pos_idx = np.where(raw_cmc == 1)    # 返回坐标,此处raw_cmc为一维矩阵,所以返回相当于index
122
        max_pos_idx = np.max(pos_idx)
123
        inp = cmc[max_pos_idx] / (max_pos_idx + 1.0)
124
        all_INP.append(inp)
125
126
        cmc[cmc > 1] = 1
127
128
        all_cmc.append(cmc[:max_rank])
129
        num_valid_p += 1.
130
131
        # compute average precision
132
        # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
133
        num_rel = raw_cmc.sum()
134
        tmp_cmc = raw_cmc.cumsum()
135
        tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
136
        tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
137
        AP = tmp_cmc.sum() / num_rel
138
        all_AP.append(AP)
139
140
    assert num_valid_p > 0, 'Error: all probe identities do not appear in gallery'
141
142
    all_cmc = np.asarray(all_cmc).astype(np.float32)
143
    all_cmc = all_cmc.sum(0) / num_valid_p
144
145
    return all_cmc, all_AP, all_INP
146
147
148
def evaluate_many(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
149
    num_q, num_g = distmat.shape
150
    if num_g < max_rank:
151
        max_rank = num_g
152
        print("Note: number of gallery samples is quite small, got {}".format(num_g))
153
    indices = np.argsort(distmat, axis=1)   # 对应位置变成从小到大的序号
154
    matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(
155
        np.int32)  # 根据indices调整顺序 g_pids[indices]
156
    # print(matches)
157
158
    # compute cmc curve for each query
159
    all_cmc = []
160
    all_AP = []
161
    all_INP = []
162
    num_valid_q = 0.
163
    for q_idx in range(num_q):
164
        # get query pid and camid
165
        q_pid = q_pids[q_idx]
166
        q_camid = q_camids[q_idx]
167
168
        # remove gallery samples that have the same pid and camid with query
169
        order = indices[q_idx]
170
        remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
171
        keep = np.invert(remove)
172
173
        # compute cmc curve
174
        # binary vector, positions with value 1 are correct matches
175
        orig_cmc = matches[q_idx][keep]
176
        if not np.any(orig_cmc):
177
            # this condition is true when query identity does not appear in gallery
178
            continue
179
180
        cmc = orig_cmc.cumsum()
181
182
        pos_idx = np.where(orig_cmc == 1)
183
        max_pos_idx = np.max(pos_idx)
184
        inp = cmc[max_pos_idx] / (max_pos_idx + 1.0)
185
        all_INP.append(inp)
186
187
        cmc[cmc > 1] = 1
188
189
        all_cmc.append(cmc[:max_rank])
190
        num_valid_q += 1.
191
192
        # compute average precision
193
        # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
194
        num_rel = orig_cmc.sum()
195
        tmp_cmc = orig_cmc.cumsum()
196
        tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
197
        tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
198
        AP = tmp_cmc.sum() / num_rel
199
        all_AP.append(AP)
200
201
    assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
202
203
    all_cmc = np.asarray(all_cmc).astype(np.float32)
204
    all_cmc = all_cmc.sum(0) / num_valid_q
205
    mAP = np.mean(all_AP)
206
    mINP = np.mean(all_INP)
207
208
    return all_cmc, mAP, mINP