[014e6e]: / src / merge_weights.py

Download this file

77 lines (64 with data), 2.3 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/12 15:09
@Auth : Juexiao Zhou
@File :merge_weights.py
@IDE :PyCharm
@Page: www.joshuachou.ink
"""
import os
import re
import torch
from tqdm.cli import tqdm
#path_70b = '/home/zhouj0d/Science/PID28.ABC/AutoBA/src/llama-main/llama-2-13b-chat/'
#path_70b = '/home/zhouj0d/Science/PID28.ABC/AutoBA/src/codellama-main/CodeLlama-13b-Instruct/'
path_70b = '/home/zhouj0d/Science/PID28.ABC/AutoBA/src/codellama-main/CodeLlama-34b-Instruct/'
# Which files are merged into one
#merge_groups = [[0,1]]
merge_groups = [[0,1,2,3]]
weights = {
int(fn.split('.')[1]): torch.load(f'{path_70b}{fn}', map_location=torch.device('cpu'))
for fn in tqdm(sorted(os.listdir(path_70b)))
if fn.endswith('.pth')
}
# These tensors are duplicated rather than distributed among the files
not_distributed = {
k
for k in weights[0].keys()
#if all((weights[0][k] == weights[i][k]).min() for i in range(1,2))
if all((weights[0][k] == weights[i][k]).min() for i in range(1,4))
}
# What tensor dimensions should be merged, based on whether they are implemented
# as Embedding, Row or Column Parallel.
merge_dimensions ={
r'^layers.\d+.attention.wq.weight$': 0,
r'^layers.\d+.attention.wk.weight$': 0,
r'^layers.\d+.attention.wv.weight$': 0,
r'^layers.\d+.attention.wo.weight$': 1,
r'^tok_embeddings.weight$': 1,
r'^layers.\d+.feed_forward.w1.weight$': 0,
r'^layers.\d+.feed_forward.w2.weight$': 1,
r'^layers.\d+.feed_forward.w3.weight$': 0,
r'^output.weight$': 0
}
# Merging (or copying if not distributed)
output_weights = {}
for output, group in enumerate(merge_groups):
output_weights[output] = dict()
for name in tqdm(weights[group[0]], leave=False):
if name in not_distributed:
output_weights[output][name] = weights[0][name]
else:
axis = next(axis for exp, axis in merge_dimensions.items() if re.match(exp, name))
output_weights[output][name] = torch.cat([
weights[member][name]
for member in group
], axis=axis)
os.makedirs(f'{path_70b}/one-gpu/', exist_ok=True)
with open(f'{path_70b}/params.json') as fin:
with open(f'{path_70b}/one-gpu/params.json', 'w') as fout:
fout.write(fin.read())
torch.save(
output_weights[0],
f'{path_70b}/one-gpu/consolidated.00.pth'
)