|
a |
|
b/coderpp/clustering/utils/ratio_cut.py |
|
|
1 |
import pickle |
|
|
2 |
import os |
|
|
3 |
import numpy as np |
|
|
4 |
from tqdm import tqdm |
|
|
5 |
from sklearn.cluster import spectral_clustering, KMeans |
|
|
6 |
import torch |
|
|
7 |
from transformers import AutoTokenizer, AutoModel, AutoConfig |
|
|
8 |
import argparse |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
|
|
|
12 |
device = torch.device("cuda:0") |
|
|
13 |
batch_size = 64 |
|
|
14 |
MAX_CLUSTER_COUNT = 50 |
|
|
15 |
mode = 'ratio' |
|
|
16 |
model = AutoModel.from_pretrained('GanjinZero/coder_eng_pp') |
|
|
17 |
tokenizer = AutoTokenizer.from_pretrained('GanjinZero/coder_eng_pp') |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
def load_pickle(file_path): |
|
|
21 |
with open(file_path, "rb") as f: |
|
|
22 |
df = pickle.load(f) |
|
|
23 |
return df |
|
|
24 |
|
|
|
25 |
def get_bert_embed(phrase_list, normalize=True, summary_method="MEAN", tqdm_bar=False): |
|
|
26 |
global model |
|
|
27 |
input_ids = [] |
|
|
28 |
for phrase in phrase_list: |
|
|
29 |
input_ids.append(tokenizer.encode_plus( |
|
|
30 |
phrase, max_length=32, add_special_tokens=True, |
|
|
31 |
truncation=True, padding='max_length')['input_ids']) |
|
|
32 |
# print(len(input_ids)) |
|
|
33 |
model.eval() |
|
|
34 |
model = model.to(device) |
|
|
35 |
|
|
|
36 |
count = len(input_ids) |
|
|
37 |
now_count = 0 |
|
|
38 |
output_list = [] |
|
|
39 |
with torch.no_grad(): |
|
|
40 |
if tqdm_bar: |
|
|
41 |
pbar = tqdm(total=count) |
|
|
42 |
while now_count < count: |
|
|
43 |
input_gpu_0 = torch.LongTensor(input_ids[now_count:min( |
|
|
44 |
now_count + batch_size, count)]).to(device) |
|
|
45 |
if summary_method == "CLS": |
|
|
46 |
embed = model(input_gpu_0)[1] |
|
|
47 |
if summary_method == "MEAN": |
|
|
48 |
embed = torch.mean(model(input_gpu_0)[0], dim=1) |
|
|
49 |
if normalize: |
|
|
50 |
embed_norm = torch.norm( |
|
|
51 |
embed, p=2, dim=1, keepdim=True).clamp(min=1e-12) |
|
|
52 |
embed = embed / embed_norm |
|
|
53 |
if now_count % 1000000 == 0: |
|
|
54 |
if now_count != 0: |
|
|
55 |
output_list.append(output.cpu().numpy()) |
|
|
56 |
del output |
|
|
57 |
torch.cuda.empty_cache() |
|
|
58 |
output = embed |
|
|
59 |
else: |
|
|
60 |
output = torch.cat((output, embed), dim=0) |
|
|
61 |
if tqdm_bar: |
|
|
62 |
pbar.update(min(now_count + batch_size, count) - now_count) |
|
|
63 |
now_count = min(now_count + batch_size, count) |
|
|
64 |
del input_gpu_0 |
|
|
65 |
torch.cuda.empty_cache() |
|
|
66 |
if tqdm_bar: |
|
|
67 |
pbar.close() |
|
|
68 |
output_list.append(output.cpu().numpy()) |
|
|
69 |
del output |
|
|
70 |
torch.cuda.empty_cache() |
|
|
71 |
return np.mean(np.concatenate(output_list, axis=0), axis=0) |
|
|
72 |
|
|
|
73 |
|
|
|
74 |
def re_cluster(terms_list, mode, similarity, threshold): |
|
|
75 |
ready = [terms_list] |
|
|
76 |
res = [] |
|
|
77 |
while ready: |
|
|
78 |
now = ready.pop() |
|
|
79 |
clu0, clu1 = cut(now, mode, similarity) |
|
|
80 |
membed_0 = get_bert_embed(clu0) |
|
|
81 |
membed_1 = get_bert_embed(clu1) |
|
|
82 |
if np.dot(membed_0, membed_1) > threshold or len(clu0) <= 1 or len(clu1) <= 1: |
|
|
83 |
res.append(clu0) |
|
|
84 |
res.append(clu1) |
|
|
85 |
else: |
|
|
86 |
# ready.append(clu0) |
|
|
87 |
# ready.append(clu1) |
|
|
88 |
if len(clu0) <= MAX_CLUSTER_COUNT: |
|
|
89 |
res.append(clu0) |
|
|
90 |
else: |
|
|
91 |
ready.append(clu0) |
|
|
92 |
if len(clu1) <= MAX_CLUSTER_COUNT: |
|
|
93 |
res.append(clu1) |
|
|
94 |
else: |
|
|
95 |
ready.append(clu1) |
|
|
96 |
# for clu in [clu0, clu1]: |
|
|
97 |
# if len(clu) <= MAX_CLUSTER_COUNT: |
|
|
98 |
# res.append(clu) |
|
|
99 |
# else: |
|
|
100 |
# ready.append(clu) |
|
|
101 |
return res |
|
|
102 |
|
|
|
103 |
def cut(terms_list, mode, similarity): |
|
|
104 |
if mode == 'ratio': |
|
|
105 |
clu0, clu1 = ratio_cut(terms_list, similarity) |
|
|
106 |
else: |
|
|
107 |
clu0, clu1 = normalize_cut(terms_list, similarity) |
|
|
108 |
return clu0, clu1 |
|
|
109 |
|
|
|
110 |
def get_sim(terms_list, similarity): |
|
|
111 |
idx = [phrase2id[x] for x in terms_list] |
|
|
112 |
sim = np.zeros(shape=(len(idx), len(idx))) |
|
|
113 |
cnt = len(idx) |
|
|
114 |
for i in range(cnt): |
|
|
115 |
for j in range(cnt): |
|
|
116 |
if idx[j] in indices[idx[i]]: |
|
|
117 |
sim[i][j] = similarity[idx[i]][np.argwhere(indices[idx[i]]==idx[j])] |
|
|
118 |
elif idx[i] in indices[idx[j]]: |
|
|
119 |
sim[i][j] = similarity[idx[j]][np.argwhere(indices[idx[j]]==idx[i])] |
|
|
120 |
return sim |
|
|
121 |
|
|
|
122 |
def laplacian(matrix, normalize=False): |
|
|
123 |
d_val = matrix.sum(axis=0) |
|
|
124 |
d = np.diag(d_val) |
|
|
125 |
l = d - matrix |
|
|
126 |
if normalize: |
|
|
127 |
d_inverse_root_val = d_val ** (-1/2) |
|
|
128 |
d_inverse_root = np.diag(d_inverse_root_val) |
|
|
129 |
l = np.matmul(np.matmul(d_inverse_root, l), d_inverse_root) |
|
|
130 |
return l |
|
|
131 |
|
|
|
132 |
def ratio_cut(terms_list, similarity): |
|
|
133 |
sim = get_sim(terms_list, similarity) |
|
|
134 |
l = laplacian(sim) |
|
|
135 |
u, v = np.linalg.eig(l) |
|
|
136 |
index = np.argsort(u.real) |
|
|
137 |
feat = v[:,index[0:2]].real |
|
|
138 |
feat_norm = np.linalg.norm(feat, ord=2, axis=1, keepdims=True) |
|
|
139 |
feat = feat / feat_norm |
|
|
140 |
cluster = KMeans(n_clusters=2).fit_predict(feat) |
|
|
141 |
clu0 = np.array(terms_list)[cluster==0].tolist() |
|
|
142 |
clu1 = np.array(terms_list)[cluster==1].tolist() |
|
|
143 |
return clu0, clu1 |
|
|
144 |
|
|
|
145 |
|
|
|
146 |
def normalize_cut(terms_list, similarity): |
|
|
147 |
sim = get_sim(terms_list, similarity) |
|
|
148 |
l = laplacian(sim, True) |
|
|
149 |
u, v = np.linalg.eig(l) |
|
|
150 |
index = np.argsort(u.real) |
|
|
151 |
feat = v[:,index[0:2]].real |
|
|
152 |
feat_norm = np.linalg.norm(feat, ord=2, axis=1, keepdims=True) |
|
|
153 |
feat = feat / feat_norm |
|
|
154 |
cluster = KMeans(n_clusters=2).fit_predict(feat) |
|
|
155 |
clu0 = np.array(terms_list)[cluster==0].tolist() |
|
|
156 |
clu1 = np.array(terms_list)[cluster==1].tolist() |
|
|
157 |
return clu0, clu1 |
|
|
158 |
|
|
|
159 |
def print_cluster_to_file(f, one_cluster_result): |
|
|
160 |
for idx, term in enumerate(one_cluster_result): |
|
|
161 |
f.write(term) |
|
|
162 |
if idx != len(one_cluster_result) - 1: |
|
|
163 |
f.write('|') |
|
|
164 |
else: |
|
|
165 |
f.write('\n') |
|
|
166 |
|
|
|
167 |
if __name__ == '__main__': |
|
|
168 |
parser = argparse.ArgumentParser() |
|
|
169 |
parser.add_argument( |
|
|
170 |
"--use_data_dir", |
|
|
171 |
default="../use_data/", |
|
|
172 |
type=str, |
|
|
173 |
help="Directory to indices and similarity and idx2phrase" |
|
|
174 |
) |
|
|
175 |
parser.add_argument( |
|
|
176 |
"--result_dir", |
|
|
177 |
default="../result/", |
|
|
178 |
type=str, |
|
|
179 |
help="Directory to save clustering result" |
|
|
180 |
) |
|
|
181 |
args = parser.parse_args() |
|
|
182 |
args.indices_path = args.use_data_dir + 'indices.npy' |
|
|
183 |
args.similarity_path = args.use_data_dir + 'similarity.npy' |
|
|
184 |
args.idx2phrase_path = args.use_data_dir + 'idx2phrase.pkl' |
|
|
185 |
args.result_path = args.result_dir + 'clustering_result.pkl' |
|
|
186 |
args.phrase2idx_path = args.use_data_dir + 'phrase2idx.pkl' |
|
|
187 |
|
|
|
188 |
|
|
|
189 |
cluster_res = load_pickle(args.result_path) |
|
|
190 |
id2phrase = load_pickle(args.idx2phrase_path) |
|
|
191 |
phrase2id = load_pickle(args.phrase2idx_path) |
|
|
192 |
similarity = np.load(args.similarity_path) |
|
|
193 |
indices = np.load(args.indices_path) |
|
|
194 |
|
|
|
195 |
|
|
|
196 |
need_cluster_list = [] |
|
|
197 |
need_cluster_length_list = [] |
|
|
198 |
for key in tqdm(cluster_res): |
|
|
199 |
if len(cluster_res[key]) > MAX_CLUSTER_COUNT: |
|
|
200 |
need_cluster_list.append(key) |
|
|
201 |
need_cluster_length_list.append(len(cluster_res[key])) |
|
|
202 |
# break |
|
|
203 |
|
|
|
204 |
print(len(need_cluster_list)) |
|
|
205 |
print(np.mean(need_cluster_length_list)) |
|
|
206 |
|
|
|
207 |
threshold_list = [0.60] |
|
|
208 |
for threshold in threshold_list: |
|
|
209 |
print('threshold=', threshold) |
|
|
210 |
final_res = [] |
|
|
211 |
for key in tqdm(cluster_res): |
|
|
212 |
if key not in need_cluster_list: |
|
|
213 |
final_res.append(list(cluster_res[key])) |
|
|
214 |
# print_cluster_to_file(f, list(cluster_res[key])) |
|
|
215 |
else: |
|
|
216 |
re_cluster_list = re_cluster(list(cluster_res[key]), mode, similarity, threshold) |
|
|
217 |
for cluster in re_cluster_list: |
|
|
218 |
final_res.append(cluster) |
|
|
219 |
# print_cluster_to_file(f, cluster) |
|
|
220 |
with open('../result/final_cluster_res.txt', 'w') as f: |
|
|
221 |
for cluster in tqdm(final_res): |
|
|
222 |
print_cluster_to_file(f, cluster) |
|
|
223 |
# f.write('-------') |
|
|
224 |
# f.write('RATIO\n') |
|
|
225 |
# f.write(str(re_cluster(list(cluster_res[key]), mode, similarity, threshold))+'\n') |
|
|
226 |
|
|
|
227 |
# f.write('NORMALIZE\n') |
|
|
228 |
# f.write(str(re_cluster(list(cluster_res[key]), 'normalize', similarity, threshold))+'\n') |
|
|
229 |
|