|
a |
|
b/metric.py |
|
|
1 |
import numpy as np |
|
|
2 |
|
|
|
3 |
def compute_purity(y_pred, y_true): |
|
|
4 |
""" |
|
|
5 |
Calculate the purity, a measurement of quality for the clustering |
|
|
6 |
results. |
|
|
7 |
|
|
|
8 |
Each cluster is assigned to the class which is most frequent in the |
|
|
9 |
cluster. Using these classes, the percent accuracy is then calculated. |
|
|
10 |
|
|
|
11 |
Returns: |
|
|
12 |
A number between 0 and 1. Poor clusterings have a purity close to 0 |
|
|
13 |
while a perfect clustering has a purity of 1. |
|
|
14 |
|
|
|
15 |
""" |
|
|
16 |
|
|
|
17 |
# get the set of unique cluster ids |
|
|
18 |
clusters = set(y_pred) |
|
|
19 |
|
|
|
20 |
# find out what class is most frequent in each cluster |
|
|
21 |
cluster_classes = {} |
|
|
22 |
correct = 0 |
|
|
23 |
for cluster in clusters: |
|
|
24 |
# get the indices of rows in this cluster |
|
|
25 |
indices = np.where(y_pred == cluster)[0] |
|
|
26 |
|
|
|
27 |
cluster_labels = y_true[indices] |
|
|
28 |
majority_label = np.argmax(np.bincount(cluster_labels)) |
|
|
29 |
correct += np.sum(cluster_labels == majority_label) |
|
|
30 |
|
|
|
31 |
return float(correct) / len(y_pred) |
|
|
32 |
|