Diff of /utils.py [000000] .. [d90d15]

Switch to unified view

a b/utils.py
1
from itertools import combinations
2
3
import numpy as np
4
import torch
5
6
7
def pdist(vectors):
8
    distance_matrix = -2 * vectors.mm(torch.t(vectors)) + vectors.pow(2).sum(dim=1).view(1, -1) + vectors.pow(2).sum(
9
        dim=1).view(-1, 1)
10
    return distance_matrix
11
12
13
class PairSelector:
14
    """
15
    Implementation should return indices of positive pairs and negative pairs that will be passed to compute
16
    Contrastive Loss
17
    return positive_pairs, negative_pairs
18
    """
19
20
    def __init__(self):
21
        pass
22
23
    def get_pairs(self, embeddings, labels):
24
        raise NotImplementedError
25
26
27
class AllPositivePairSelector(PairSelector):
28
    """
29
    Discards embeddings and generates all possible pairs given labels.
30
    If balance is True, negative pairs are a random sample to match the number of positive samples
31
    """
32
    def __init__(self, balance=True):
33
        super(AllPositivePairSelector, self).__init__()
34
        self.balance = balance
35
36
    def get_pairs(self, embeddings, labels):
37
        labels = labels.cpu().data.numpy()
38
        all_pairs = np.array(list(combinations(range(len(labels)), 2)))
39
        all_pairs = torch.LongTensor(all_pairs)
40
        positive_pairs = all_pairs[(labels[all_pairs[:, 0]] == labels[all_pairs[:, 1]]).nonzero()]
41
        negative_pairs = all_pairs[(labels[all_pairs[:, 0]] != labels[all_pairs[:, 1]]).nonzero()]
42
        if self.balance:
43
            negative_pairs = negative_pairs[torch.randperm(len(negative_pairs))[:len(positive_pairs)]]
44
45
        return positive_pairs, negative_pairs
46
47
48
class HardNegativePairSelector(PairSelector):
49
    """
50
    Creates all possible positive pairs. For negative pairs, pairs with smallest distance are taken into consideration,
51
    matching the number of positive pairs.
52
    """
53
54
    def __init__(self, cpu=True):
55
        super(HardNegativePairSelector, self).__init__()
56
        self.cpu = cpu
57
58
    def get_pairs(self, embeddings, labels):
59
        if self.cpu:
60
            embeddings = embeddings.cpu()
61
        distance_matrix = pdist(embeddings)
62
63
        labels = labels.cpu().data.numpy()
64
        all_pairs = np.array(list(combinations(range(len(labels)), 2)))
65
        all_pairs = torch.LongTensor(all_pairs)
66
        positive_pairs = all_pairs[(labels[all_pairs[:, 0]] == labels[all_pairs[:, 1]]).nonzero()]
67
        negative_pairs = all_pairs[(labels[all_pairs[:, 0]] != labels[all_pairs[:, 1]]).nonzero()]
68
69
        negative_distances = distance_matrix[negative_pairs[:, 0], negative_pairs[:, 1]]
70
        negative_distances = negative_distances.cpu().data.numpy()
71
        top_negatives = np.argpartition(negative_distances, len(positive_pairs))[:len(positive_pairs)]
72
        top_negative_pairs = negative_pairs[torch.LongTensor(top_negatives)]
73
74
        return positive_pairs, top_negative_pairs
75
76
77
class TripletSelector:
78
    """
79
    Implementation should return indices of anchors, positive and negative samples
80
    return np array of shape [N_triplets x 3]
81
    """
82
83
    def __init__(self):
84
        pass
85
86
    def get_pairs(self, embeddings, labels):
87
        raise NotImplementedError
88
89
90
class AllTripletSelector(TripletSelector):
91
    """
92
    Returns all possible triplets
93
    May be impractical in most cases
94
    """
95
96
    def __init__(self):
97
        super(AllTripletSelector, self).__init__()
98
99
    def get_triplets(self, embeddings, labels):
100
        labels = labels.cpu().data.numpy()
101
        triplets = []
102
        for label in set(labels):
103
            label_mask = (labels == label)
104
            label_indices = np.where(label_mask)[0]
105
            if len(label_indices) < 2:
106
                continue
107
            negative_indices = np.where(np.logical_not(label_mask))[0]
108
            anchor_positives = list(combinations(label_indices, 2))  # All anchor-positive pairs
109
110
            # Add all negatives for all positive pairs
111
            temp_triplets = [[anchor_positive[0], anchor_positive[1], neg_ind] for anchor_positive in anchor_positives
112
                             for neg_ind in negative_indices]
113
            triplets += temp_triplets
114
115
        return torch.LongTensor(np.array(triplets))
116
117
118
def hardest_negative(loss_values):
119
    hard_negative = np.argmax(loss_values)
120
    return hard_negative if loss_values[hard_negative] > 0 else None
121
122
123
def random_hard_negative(loss_values):
124
    hard_negatives = np.where(loss_values > 0)[0]
125
    return np.random.choice(hard_negatives) if len(hard_negatives) > 0 else None
126
127
128
def semihard_negative(loss_values, margin):
129
    semihard_negatives = np.where(np.logical_and(loss_values < margin, loss_values > 0))[0]
130
    return np.random.choice(semihard_negatives) if len(semihard_negatives) > 0 else None
131
132
133
class FunctionNegativeTripletSelector(TripletSelector):
134
    """
135
    For each positive pair, takes the hardest negative sample (with the greatest triplet loss value) to create a triplet
136
    Margin should match the margin used in triplet loss.
137
    negative_selection_fn should take array of loss_values for a given anchor-positive pair and all negative samples
138
    and return a negative index for that pair
139
    """
140
141
    def __init__(self, margin, negative_selection_fn, cpu=True):
142
        super(FunctionNegativeTripletSelector, self).__init__()
143
        self.cpu = cpu
144
        self.margin = margin
145
        self.negative_selection_fn = negative_selection_fn
146
147
    def get_triplets(self, embeddings, labels):
148
        if self.cpu:
149
            embeddings = embeddings.cpu()
150
        distance_matrix = pdist(embeddings)
151
        distance_matrix = distance_matrix.cpu()
152
153
        labels = labels.cpu().data.numpy()
154
        triplets = []
155
156
        for label in set(labels):
157
            label_mask = (labels == label)
158
            label_indices = np.where(label_mask)[0]
159
            if len(label_indices) < 2:
160
                continue
161
            negative_indices = np.where(np.logical_not(label_mask))[0]
162
            anchor_positives = list(combinations(label_indices, 2))  # All anchor-positive pairs
163
            anchor_positives = np.array(anchor_positives)
164
165
            ap_distances = distance_matrix[anchor_positives[:, 0], anchor_positives[:, 1]]
166
            for anchor_positive, ap_distance in zip(anchor_positives, ap_distances):
167
                loss_values = ap_distance - distance_matrix[torch.LongTensor(np.array([anchor_positive[0]])), torch.LongTensor(negative_indices)] + self.margin
168
                loss_values = loss_values.data.cpu().numpy()
169
                hard_negative = self.negative_selection_fn(loss_values)
170
                if hard_negative is not None:
171
                    hard_negative = negative_indices[hard_negative]
172
                    triplets.append([anchor_positive[0], anchor_positive[1], hard_negative])
173
174
        if len(triplets) == 0:
175
            triplets.append([anchor_positive[0], anchor_positive[1], negative_indices[0]])
176
177
        triplets = np.array(triplets)
178
179
        return torch.LongTensor(triplets)
180
181
182
def HardestNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin,
183
                                                                                 negative_selection_fn=hardest_negative,
184
                                                                                 cpu=cpu)
185
186
187
def RandomNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin,
188
                                                                                negative_selection_fn=random_hard_negative,
189
                                                                                cpu=cpu)
190
191
192
def SemihardNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin,
193
                                                                                  negative_selection_fn=lambda x: semihard_negative(x, margin),
194
                                                                                  cpu=cpu)