Diff of /EGFR/utils.py [000000] .. [d90d15]

Switch to side-by-side view

--- a
+++ b/EGFR/utils.py
@@ -0,0 +1,194 @@
+from itertools import combinations
+
+import numpy as np
+import torch
+
+
+def pdist(vectors):
+    distance_matrix = -2 * vectors.mm(torch.t(vectors)) + vectors.pow(2).sum(dim=1).view(1, -1) + vectors.pow(2).sum(
+        dim=1).view(-1, 1)
+    return distance_matrix
+
+
+class PairSelector:
+    """
+    Implementation should return indices of positive pairs and negative pairs that will be passed to compute
+    Contrastive Loss
+    return positive_pairs, negative_pairs
+    """
+
+    def __init__(self):
+        pass
+
+    def get_pairs(self, embeddings, labels):
+        raise NotImplementedError
+
+
+class AllPositivePairSelector(PairSelector):
+    """
+    Discards embeddings and generates all possible pairs given labels.
+    If balance is True, negative pairs are a random sample to match the number of positive samples
+    """
+    def __init__(self, balance=True):
+        super(AllPositivePairSelector, self).__init__()
+        self.balance = balance
+
+    def get_pairs(self, embeddings, labels):
+        labels = labels.cpu().data.numpy()
+        all_pairs = np.array(list(combinations(range(len(labels)), 2)))
+        all_pairs = torch.LongTensor(all_pairs)
+        positive_pairs = all_pairs[(labels[all_pairs[:, 0]] == labels[all_pairs[:, 1]]).nonzero()]
+        negative_pairs = all_pairs[(labels[all_pairs[:, 0]] != labels[all_pairs[:, 1]]).nonzero()]
+        if self.balance:
+            negative_pairs = negative_pairs[torch.randperm(len(negative_pairs))[:len(positive_pairs)]]
+
+        return positive_pairs, negative_pairs
+
+
+class HardNegativePairSelector(PairSelector):
+    """
+    Creates all possible positive pairs. For negative pairs, pairs with smallest distance are taken into consideration,
+    matching the number of positive pairs.
+    """
+
+    def __init__(self, cpu=True):
+        super(HardNegativePairSelector, self).__init__()
+        self.cpu = cpu
+
+    def get_pairs(self, embeddings, labels):
+        if self.cpu:
+            embeddings = embeddings.cpu()
+        distance_matrix = pdist(embeddings)
+
+        labels = labels.cpu().data.numpy()
+        all_pairs = np.array(list(combinations(range(len(labels)), 2)))
+        all_pairs = torch.LongTensor(all_pairs)
+        positive_pairs = all_pairs[(labels[all_pairs[:, 0]] == labels[all_pairs[:, 1]]).nonzero()]
+        negative_pairs = all_pairs[(labels[all_pairs[:, 0]] != labels[all_pairs[:, 1]]).nonzero()]
+
+        negative_distances = distance_matrix[negative_pairs[:, 0], negative_pairs[:, 1]]
+        negative_distances = negative_distances.cpu().data.numpy()
+        top_negatives = np.argpartition(negative_distances, len(positive_pairs))[:len(positive_pairs)]
+        top_negative_pairs = negative_pairs[torch.LongTensor(top_negatives)]
+
+        return positive_pairs, top_negative_pairs
+
+
+class TripletSelector:
+    """
+    Implementation should return indices of anchors, positive and negative samples
+    return np array of shape [N_triplets x 3]
+    """
+
+    def __init__(self):
+        pass
+
+    def get_pairs(self, embeddings, labels):
+        raise NotImplementedError
+
+
+class AllTripletSelector(TripletSelector):
+    """
+    Returns all possible triplets
+    May be impractical in most cases
+    """
+
+    def __init__(self):
+        super(AllTripletSelector, self).__init__()
+
+    def get_triplets(self, embeddings, labels):
+        labels = labels.cpu().data.numpy()
+        triplets = []
+        for label in set(labels):
+            label_mask = (labels == label)
+            label_indices = np.where(label_mask)[0]
+            if len(label_indices) < 2:
+                continue
+            negative_indices = np.where(np.logical_not(label_mask))[0]
+            anchor_positives = list(combinations(label_indices, 2))  # All anchor-positive pairs
+
+            # Add all negatives for all positive pairs
+            temp_triplets = [[anchor_positive[0], anchor_positive[1], neg_ind] for anchor_positive in anchor_positives
+                             for neg_ind in negative_indices]
+            triplets += temp_triplets
+
+        return torch.LongTensor(np.array(triplets))
+
+
+def hardest_negative(loss_values):
+    hard_negative = np.argmax(loss_values)
+    return hard_negative if loss_values[hard_negative] > 0 else None
+
+
+def random_hard_negative(loss_values):
+    hard_negatives = np.where(loss_values > 0)[0]
+    return np.random.choice(hard_negatives) if len(hard_negatives) > 0 else None
+
+
+def semihard_negative(loss_values, margin):
+    semihard_negatives = np.where(np.logical_and(loss_values < margin, loss_values > 0))[0]
+    return np.random.choice(semihard_negatives) if len(semihard_negatives) > 0 else None
+
+
+class FunctionNegativeTripletSelector(TripletSelector):
+    """
+    For each positive pair, takes the hardest negative sample (with the greatest triplet loss value) to create a triplet
+    Margin should match the margin used in triplet loss.
+    negative_selection_fn should take array of loss_values for a given anchor-positive pair and all negative samples
+    and return a negative index for that pair
+    """
+
+    def __init__(self, margin, negative_selection_fn, cpu=True):
+        super(FunctionNegativeTripletSelector, self).__init__()
+        self.cpu = cpu
+        self.margin = margin
+        self.negative_selection_fn = negative_selection_fn
+
+    def get_triplets(self, embeddings, labels):
+        if self.cpu:
+            embeddings = embeddings.cpu()
+        distance_matrix = pdist(embeddings)
+        distance_matrix = distance_matrix.cpu()
+
+        labels = labels.cpu().data.numpy()
+        triplets = []
+
+        for label in set(labels):
+            label_mask = (labels == label)
+            label_indices = np.where(label_mask)[0]
+            if len(label_indices) < 2:
+                continue
+            negative_indices = np.where(np.logical_not(label_mask))[0]
+            anchor_positives = list(combinations(label_indices, 2))  # All anchor-positive pairs
+            anchor_positives = np.array(anchor_positives)
+
+            ap_distances = distance_matrix[anchor_positives[:, 0], anchor_positives[:, 1]]
+            for anchor_positive, ap_distance in zip(anchor_positives, ap_distances):
+                loss_values = ap_distance - distance_matrix[torch.LongTensor(np.array([anchor_positive[0]])), torch.LongTensor(negative_indices)] + self.margin
+                loss_values = loss_values.data.cpu().numpy()
+                hard_negative = self.negative_selection_fn(loss_values)
+                if hard_negative is not None:
+                    hard_negative = negative_indices[hard_negative]
+                    triplets.append([anchor_positive[0], anchor_positive[1], hard_negative])
+
+        if len(triplets) == 0:
+            triplets.append([anchor_positive[0], anchor_positive[1], negative_indices[0]])
+
+        triplets = np.array(triplets)
+
+        return torch.LongTensor(triplets)
+
+
+def HardestNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin,
+                                                                                 negative_selection_fn=hardest_negative,
+                                                                                 cpu=cpu)
+
+
+def RandomNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin,
+                                                                                negative_selection_fn=random_hard_negative,
+                                                                                cpu=cpu)
+
+
+def SemihardNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin,
+                                                                                  negative_selection_fn=lambda x: semihard_negative(x, margin),
+                                                                                  cpu=cpu)
\ No newline at end of file