|
a |
|
b/coderpp/clustering/utils/clustering.py |
|
|
1 |
import numpy as np |
|
|
2 |
from tqdm import tqdm |
|
|
3 |
import pickle |
|
|
4 |
import argparse |
|
|
5 |
|
|
|
6 |
# reference: https://stackoverflow.com/questions/3067529/a-set-union-find-algorithm |
|
|
7 |
class DisjointSet(object): |
|
|
8 |
def __init__(self): |
|
|
9 |
self.leader = dict() # maps a member to the group's leader |
|
|
10 |
self.group = dict() # maps a group leader to the group (which is a set) |
|
|
11 |
|
|
|
12 |
def add(self, a, b): |
|
|
13 |
leadera = self.leader.get(a) |
|
|
14 |
leaderb = self.leader.get(b) |
|
|
15 |
if leadera is not None: |
|
|
16 |
if leaderb is not None: |
|
|
17 |
if leadera == leaderb: return # nothing to do |
|
|
18 |
groupa = self.group[leadera] |
|
|
19 |
groupb = self.group[leaderb] |
|
|
20 |
if len(groupa) < len(groupb): |
|
|
21 |
a, leadera, groupa, b, leaderb, groupb = b, leaderb, groupb, a, leadera, groupa |
|
|
22 |
groupa |= groupb |
|
|
23 |
del self.group[leaderb] |
|
|
24 |
for k in groupb: |
|
|
25 |
self.leader[k] = leadera |
|
|
26 |
else: |
|
|
27 |
self.group[leadera].add(b) |
|
|
28 |
self.leader[b] = leadera |
|
|
29 |
else: |
|
|
30 |
if leaderb is not None: |
|
|
31 |
self.group[leaderb].add(a) |
|
|
32 |
self.leader[a] = leaderb |
|
|
33 |
else: |
|
|
34 |
self.leader[a] = self.leader[b] = a |
|
|
35 |
self.group[a] = set([a, b]) |
|
|
36 |
|
|
|
37 |
if __name__ == '__main__': |
|
|
38 |
parser = argparse.ArgumentParser() |
|
|
39 |
parser.add_argument( |
|
|
40 |
"--use_data_dir", |
|
|
41 |
default="../use_data/", |
|
|
42 |
type=str, |
|
|
43 |
help="Directory to indices and similarity and idx2phrase" |
|
|
44 |
) |
|
|
45 |
parser.add_argument( |
|
|
46 |
"--result_dir", |
|
|
47 |
default="../result/", |
|
|
48 |
type=str, |
|
|
49 |
help="Directory to save clustering result" |
|
|
50 |
) |
|
|
51 |
args = parser.parse_args() |
|
|
52 |
args.indices_path = args.use_data_dir + 'indices.npy' |
|
|
53 |
args.similarity_path = args.use_data_dir + 'similarity.npy' |
|
|
54 |
args.idx2phrase_path = args.use_data_dir + 'idx2phrase.pkl' |
|
|
55 |
args.result_path = args.result_dir + 'clustering_result.pkl' |
|
|
56 |
|
|
|
57 |
indices = np.load(args.indices_path) |
|
|
58 |
similarity = np.load(args.similarity_path) |
|
|
59 |
with open(args.idx2phrase_path, 'rb') as f: |
|
|
60 |
idx2phrase = pickle.load(f) |
|
|
61 |
|
|
|
62 |
ds = DisjointSet() |
|
|
63 |
for idxi in tqdm(range(indices.shape[0])): |
|
|
64 |
a = idx2phrase[idxi] |
|
|
65 |
if len(a) <= 3: |
|
|
66 |
continue |
|
|
67 |
for idxj in range(indices.shape[1]): |
|
|
68 |
if similarity[idxi, idxj] > 0.8: |
|
|
69 |
b = idx2phrase[indices[idxi][idxj]] |
|
|
70 |
if len(b) <= 3: |
|
|
71 |
continue |
|
|
72 |
ds.add(a, b) |
|
|
73 |
|
|
|
74 |
with open(args.result_path, 'wb') as f: |
|
|
75 |
pickle.dump(ds.group, f) |
|
|
76 |
print('done') |