|
a |
|
b/src/merge_weights.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
""" |
|
|
3 |
@Time : 2023/12/12 15:09 |
|
|
4 |
@Auth : Juexiao Zhou |
|
|
5 |
@File :merge_weights.py |
|
|
6 |
@IDE :PyCharm |
|
|
7 |
@Page: www.joshuachou.ink |
|
|
8 |
""" |
|
|
9 |
|
|
|
10 |
import os |
|
|
11 |
import re |
|
|
12 |
import torch |
|
|
13 |
from tqdm.cli import tqdm |
|
|
14 |
|
|
|
15 |
#path_70b = '/home/zhouj0d/Science/PID28.ABC/AutoBA/src/llama-main/llama-2-13b-chat/' |
|
|
16 |
#path_70b = '/home/zhouj0d/Science/PID28.ABC/AutoBA/src/codellama-main/CodeLlama-13b-Instruct/' |
|
|
17 |
path_70b = '/home/zhouj0d/Science/PID28.ABC/AutoBA/src/codellama-main/CodeLlama-34b-Instruct/' |
|
|
18 |
|
|
|
19 |
# Which files are merged into one |
|
|
20 |
#merge_groups = [[0,1]] |
|
|
21 |
merge_groups = [[0,1,2,3]] |
|
|
22 |
|
|
|
23 |
weights = { |
|
|
24 |
int(fn.split('.')[1]): torch.load(f'{path_70b}{fn}', map_location=torch.device('cpu')) |
|
|
25 |
for fn in tqdm(sorted(os.listdir(path_70b))) |
|
|
26 |
if fn.endswith('.pth') |
|
|
27 |
} |
|
|
28 |
|
|
|
29 |
# These tensors are duplicated rather than distributed among the files |
|
|
30 |
|
|
|
31 |
not_distributed = { |
|
|
32 |
k |
|
|
33 |
for k in weights[0].keys() |
|
|
34 |
#if all((weights[0][k] == weights[i][k]).min() for i in range(1,2)) |
|
|
35 |
if all((weights[0][k] == weights[i][k]).min() for i in range(1,4)) |
|
|
36 |
} |
|
|
37 |
|
|
|
38 |
# What tensor dimensions should be merged, based on whether they are implemented |
|
|
39 |
# as Embedding, Row or Column Parallel. |
|
|
40 |
|
|
|
41 |
merge_dimensions ={ |
|
|
42 |
r'^layers.\d+.attention.wq.weight$': 0, |
|
|
43 |
r'^layers.\d+.attention.wk.weight$': 0, |
|
|
44 |
r'^layers.\d+.attention.wv.weight$': 0, |
|
|
45 |
r'^layers.\d+.attention.wo.weight$': 1, |
|
|
46 |
|
|
|
47 |
r'^tok_embeddings.weight$': 1, |
|
|
48 |
|
|
|
49 |
r'^layers.\d+.feed_forward.w1.weight$': 0, |
|
|
50 |
r'^layers.\d+.feed_forward.w2.weight$': 1, |
|
|
51 |
r'^layers.\d+.feed_forward.w3.weight$': 0, |
|
|
52 |
r'^output.weight$': 0 |
|
|
53 |
} |
|
|
54 |
|
|
|
55 |
# Merging (or copying if not distributed) |
|
|
56 |
output_weights = {} |
|
|
57 |
for output, group in enumerate(merge_groups): |
|
|
58 |
output_weights[output] = dict() |
|
|
59 |
for name in tqdm(weights[group[0]], leave=False): |
|
|
60 |
if name in not_distributed: |
|
|
61 |
output_weights[output][name] = weights[0][name] |
|
|
62 |
else: |
|
|
63 |
axis = next(axis for exp, axis in merge_dimensions.items() if re.match(exp, name)) |
|
|
64 |
output_weights[output][name] = torch.cat([ |
|
|
65 |
weights[member][name] |
|
|
66 |
for member in group |
|
|
67 |
], axis=axis) |
|
|
68 |
|
|
|
69 |
os.makedirs(f'{path_70b}/one-gpu/', exist_ok=True) |
|
|
70 |
with open(f'{path_70b}/params.json') as fin: |
|
|
71 |
with open(f'{path_70b}/one-gpu/params.json', 'w') as fout: |
|
|
72 |
fout.write(fin.read()) |
|
|
73 |
|
|
|
74 |
torch.save( |
|
|
75 |
output_weights[0], |
|
|
76 |
f'{path_70b}/one-gpu/consolidated.00.pth' |
|
|
77 |
) |