Switch to unified view

a b/coderpp/clustering/utils/ratio_cut.py
1
import pickle
2
import os
3
import numpy as np
4
from tqdm import tqdm
5
from sklearn.cluster import spectral_clustering, KMeans
6
import torch
7
from transformers import AutoTokenizer, AutoModel, AutoConfig
8
import argparse
9
10
11
12
device = torch.device("cuda:0")
13
batch_size = 64
14
MAX_CLUSTER_COUNT = 50
15
mode = 'ratio'
16
model = AutoModel.from_pretrained('GanjinZero/coder_eng_pp')
17
tokenizer = AutoTokenizer.from_pretrained('GanjinZero/coder_eng_pp')
18
19
20
def load_pickle(file_path):
21
    with open(file_path, "rb") as f:
22
        df = pickle.load(f)
23
    return df
24
25
def get_bert_embed(phrase_list, normalize=True, summary_method="MEAN", tqdm_bar=False):
26
    global model
27
    input_ids = []
28
    for phrase in phrase_list:
29
        input_ids.append(tokenizer.encode_plus(
30
            phrase, max_length=32, add_special_tokens=True,
31
            truncation=True, padding='max_length')['input_ids'])
32
        # print(len(input_ids))
33
    model.eval()
34
    model = model.to(device)
35
36
    count = len(input_ids)
37
    now_count = 0
38
    output_list = []
39
    with torch.no_grad():
40
        if tqdm_bar:
41
            pbar = tqdm(total=count)
42
        while now_count < count:
43
            input_gpu_0 = torch.LongTensor(input_ids[now_count:min(
44
                now_count + batch_size, count)]).to(device)
45
            if summary_method == "CLS":
46
                embed = model(input_gpu_0)[1]
47
            if summary_method == "MEAN":
48
                embed = torch.mean(model(input_gpu_0)[0], dim=1)
49
            if normalize:
50
                embed_norm = torch.norm(
51
                    embed, p=2, dim=1, keepdim=True).clamp(min=1e-12)
52
                embed = embed / embed_norm
53
            if now_count % 1000000 == 0:
54
                if now_count != 0:
55
                    output_list.append(output.cpu().numpy())
56
                    del output
57
                    torch.cuda.empty_cache()
58
                output = embed
59
            else:
60
                output = torch.cat((output, embed), dim=0)
61
            if tqdm_bar:
62
                pbar.update(min(now_count + batch_size, count) - now_count)
63
            now_count = min(now_count + batch_size, count)
64
            del input_gpu_0
65
            torch.cuda.empty_cache()
66
        if tqdm_bar:
67
            pbar.close()
68
    output_list.append(output.cpu().numpy())
69
    del output
70
    torch.cuda.empty_cache()
71
    return np.mean(np.concatenate(output_list, axis=0), axis=0)
72
73
74
def re_cluster(terms_list, mode, similarity, threshold):
75
    ready = [terms_list]
76
    res = []
77
    while ready:
78
        now = ready.pop()
79
        clu0, clu1 = cut(now, mode, similarity)
80
        membed_0 = get_bert_embed(clu0)
81
        membed_1 = get_bert_embed(clu1)
82
        if np.dot(membed_0, membed_1) > threshold or len(clu0) <= 1 or len(clu1) <= 1:
83
            res.append(clu0)
84
            res.append(clu1)
85
        else:
86
            # ready.append(clu0)
87
            # ready.append(clu1)
88
            if len(clu0) <= MAX_CLUSTER_COUNT:
89
                res.append(clu0)
90
            else:
91
                ready.append(clu0)
92
            if len(clu1) <= MAX_CLUSTER_COUNT:
93
                res.append(clu1)
94
            else:
95
                ready.append(clu1)
96
        # for clu in [clu0, clu1]:
97
        #     if len(clu) <= MAX_CLUSTER_COUNT:
98
        #         res.append(clu)
99
        #     else:
100
        #         ready.append(clu)
101
    return res
102
103
def cut(terms_list, mode, similarity):
104
    if mode == 'ratio':
105
        clu0, clu1 = ratio_cut(terms_list, similarity)
106
    else:
107
        clu0, clu1 = normalize_cut(terms_list, similarity)
108
    return clu0, clu1
109
110
def get_sim(terms_list, similarity):
111
    idx = [phrase2id[x] for x in terms_list]
112
    sim = np.zeros(shape=(len(idx), len(idx)))
113
    cnt = len(idx)
114
    for i in range(cnt):
115
        for j in range(cnt):
116
            if idx[j] in indices[idx[i]]:
117
                sim[i][j] = similarity[idx[i]][np.argwhere(indices[idx[i]]==idx[j])]
118
            elif idx[i] in indices[idx[j]]:
119
                sim[i][j] = similarity[idx[j]][np.argwhere(indices[idx[j]]==idx[i])]
120
    return sim
121
122
def laplacian(matrix, normalize=False):
123
    d_val = matrix.sum(axis=0)
124
    d = np.diag(d_val)
125
    l = d - matrix
126
    if normalize:
127
        d_inverse_root_val = d_val ** (-1/2)
128
        d_inverse_root = np.diag(d_inverse_root_val)
129
        l = np.matmul(np.matmul(d_inverse_root, l), d_inverse_root)
130
    return l
131
132
def ratio_cut(terms_list, similarity):
133
    sim = get_sim(terms_list, similarity)
134
    l = laplacian(sim)
135
    u, v = np.linalg.eig(l)
136
    index = np.argsort(u.real)
137
    feat = v[:,index[0:2]].real
138
    feat_norm = np.linalg.norm(feat, ord=2, axis=1, keepdims=True)
139
    feat = feat / feat_norm
140
    cluster = KMeans(n_clusters=2).fit_predict(feat)
141
    clu0 = np.array(terms_list)[cluster==0].tolist()
142
    clu1 = np.array(terms_list)[cluster==1].tolist()
143
    return clu0, clu1
144
145
146
def normalize_cut(terms_list, similarity):
147
    sim = get_sim(terms_list, similarity)
148
    l = laplacian(sim, True)
149
    u, v = np.linalg.eig(l)
150
    index = np.argsort(u.real)
151
    feat = v[:,index[0:2]].real
152
    feat_norm = np.linalg.norm(feat, ord=2, axis=1, keepdims=True)
153
    feat = feat / feat_norm
154
    cluster = KMeans(n_clusters=2).fit_predict(feat)
155
    clu0 = np.array(terms_list)[cluster==0].tolist()
156
    clu1 = np.array(terms_list)[cluster==1].tolist()
157
    return clu0, clu1
158
159
def print_cluster_to_file(f, one_cluster_result):
160
    for idx, term in enumerate(one_cluster_result):
161
        f.write(term)
162
        if idx != len(one_cluster_result) - 1:
163
            f.write('|')
164
        else:
165
            f.write('\n')
166
167
if __name__ == '__main__':
168
    parser = argparse.ArgumentParser()
169
    parser.add_argument(
170
        "--use_data_dir",
171
        default="../use_data/",
172
        type=str,
173
        help="Directory to indices and similarity and idx2phrase"
174
    )
175
    parser.add_argument(
176
        "--result_dir",
177
        default="../result/",
178
        type=str,
179
        help="Directory to save clustering result"
180
    )
181
    args = parser.parse_args()
182
    args.indices_path = args.use_data_dir + 'indices.npy'
183
    args.similarity_path = args.use_data_dir + 'similarity.npy'
184
    args.idx2phrase_path = args.use_data_dir + 'idx2phrase.pkl'
185
    args.result_path = args.result_dir + 'clustering_result.pkl'
186
    args.phrase2idx_path = args.use_data_dir + 'phrase2idx.pkl'
187
188
189
    cluster_res = load_pickle(args.result_path)
190
    id2phrase = load_pickle(args.idx2phrase_path)
191
    phrase2id = load_pickle(args.phrase2idx_path)
192
    similarity = np.load(args.similarity_path)
193
    indices = np.load(args.indices_path)
194
195
196
    need_cluster_list = []
197
    need_cluster_length_list = []
198
    for key in tqdm(cluster_res):
199
        if len(cluster_res[key]) > MAX_CLUSTER_COUNT:
200
            need_cluster_list.append(key)
201
            need_cluster_length_list.append(len(cluster_res[key]))
202
            # break
203
204
    print(len(need_cluster_list))
205
    print(np.mean(need_cluster_length_list))
206
207
    threshold_list = [0.60]
208
    for threshold in threshold_list:
209
        print('threshold=', threshold)
210
        final_res = []
211
        for key in tqdm(cluster_res):
212
            if key not in need_cluster_list:
213
                final_res.append(list(cluster_res[key]))
214
                # print_cluster_to_file(f, list(cluster_res[key]))
215
            else:
216
                re_cluster_list = re_cluster(list(cluster_res[key]), mode, similarity, threshold)
217
                for cluster in re_cluster_list:
218
                    final_res.append(cluster)
219
                    # print_cluster_to_file(f, cluster)
220
        with open('../result/final_cluster_res.txt', 'w') as f:
221
            for cluster in tqdm(final_res):
222
                print_cluster_to_file(f, cluster)
223
                # f.write('-------')
224
                # f.write('RATIO\n')
225
                # f.write(str(re_cluster(list(cluster_res[key]), mode, similarity, threshold))+'\n')
226
227
                # f.write('NORMALIZE\n')
228
                # f.write(str(re_cluster(list(cluster_res[key]), 'normalize', similarity, threshold))+'\n')
229