a b/metric.py
1
import numpy as np
2
3
def compute_purity(y_pred, y_true):
4
        """
5
        Calculate the purity, a measurement of quality for the clustering 
6
        results.
7
        
8
        Each cluster is assigned to the class which is most frequent in the 
9
        cluster.  Using these classes, the percent accuracy is then calculated.
10
        
11
        Returns:
12
          A number between 0 and 1.  Poor clusterings have a purity close to 0 
13
          while a perfect clustering has a purity of 1.
14
15
        """
16
17
        # get the set of unique cluster ids
18
        clusters = set(y_pred)
19
20
        # find out what class is most frequent in each cluster
21
        cluster_classes = {}
22
        correct = 0
23
        for cluster in clusters:
24
            # get the indices of rows in this cluster
25
            indices = np.where(y_pred == cluster)[0]
26
27
            cluster_labels = y_true[indices]
28
            majority_label = np.argmax(np.bincount(cluster_labels))
29
            correct += np.sum(cluster_labels == majority_label)
30
        
31
        return float(correct) / len(y_pred)
32