a b/coderpp/clustering/utils/clustering.py
1
import numpy as np
2
from tqdm import tqdm
3
import pickle
4
import argparse
5
6
# reference: https://stackoverflow.com/questions/3067529/a-set-union-find-algorithm
7
class DisjointSet(object):
8
    def __init__(self):
9
        self.leader = dict()     # maps a member to the group's leader
10
        self.group = dict()  # maps a group leader to the group (which is a set)
11
    
12
    def add(self, a, b):
13
        leadera = self.leader.get(a)   
14
        leaderb = self.leader.get(b)   
15
        if leadera is not None:
16
            if leaderb is not None:
17
                if leadera == leaderb: return # nothing to do
18
                groupa = self.group[leadera]
19
                groupb = self.group[leaderb]
20
                if len(groupa) < len(groupb):
21
                    a, leadera, groupa, b, leaderb, groupb = b, leaderb, groupb, a, leadera, groupa
22
                groupa |= groupb
23
                del self.group[leaderb]
24
                for k in groupb:
25
                    self.leader[k] = leadera
26
            else:
27
                self.group[leadera].add(b)
28
                self.leader[b] = leadera
29
        else:
30
            if leaderb is not None:
31
                self.group[leaderb].add(a)
32
                self.leader[a] = leaderb
33
            else:
34
                self.leader[a] = self.leader[b] = a
35
                self.group[a] = set([a, b])
36
37
if __name__ == '__main__':
38
    parser = argparse.ArgumentParser()
39
    parser.add_argument(
40
        "--use_data_dir",
41
        default="../use_data/",
42
        type=str,
43
        help="Directory to indices and similarity and idx2phrase"
44
    )
45
    parser.add_argument(
46
        "--result_dir",
47
        default="../result/",
48
        type=str,
49
        help="Directory to save clustering result"
50
    )
51
    args = parser.parse_args()
52
    args.indices_path = args.use_data_dir + 'indices.npy'
53
    args.similarity_path = args.use_data_dir + 'similarity.npy'
54
    args.idx2phrase_path = args.use_data_dir + 'idx2phrase.pkl'
55
    args.result_path = args.result_dir + 'clustering_result.pkl'
56
57
    indices = np.load(args.indices_path)
58
    similarity = np.load(args.similarity_path)
59
    with open(args.idx2phrase_path, 'rb') as f:
60
        idx2phrase = pickle.load(f)
61
    
62
    ds = DisjointSet()
63
    for idxi in tqdm(range(indices.shape[0])):
64
        a = idx2phrase[idxi]
65
        if len(a) <= 3:
66
            continue
67
        for idxj in range(indices.shape[1]):
68
            if similarity[idxi, idxj] > 0.8:
69
                b = idx2phrase[indices[idxi][idxj]]
70
                if len(b) <= 3:
71
                    continue
72
                ds.add(a, b)
73
    
74
    with open(args.result_path, 'wb') as f:
75
        pickle.dump(ds.group, f)
76
    print('done')