Diff of /src/merge_weights.py [000000] .. [014e6e]

Switch to unified view

a b/src/merge_weights.py
1
# -*- coding: utf-8 -*-
2
"""
3
@Time : 2023/12/12 15:09
4
@Auth : Juexiao Zhou
5
@File :merge_weights.py
6
@IDE :PyCharm
7
@Page: www.joshuachou.ink
8
"""
9
10
import os
11
import re
12
import torch
13
from tqdm.cli import tqdm
14
15
#path_70b = '/home/zhouj0d/Science/PID28.ABC/AutoBA/src/llama-main/llama-2-13b-chat/'
16
#path_70b = '/home/zhouj0d/Science/PID28.ABC/AutoBA/src/codellama-main/CodeLlama-13b-Instruct/'
17
path_70b = '/home/zhouj0d/Science/PID28.ABC/AutoBA/src/codellama-main/CodeLlama-34b-Instruct/'
18
19
# Which files are merged into one
20
#merge_groups = [[0,1]]
21
merge_groups = [[0,1,2,3]]
22
23
weights = {
24
  int(fn.split('.')[1]): torch.load(f'{path_70b}{fn}', map_location=torch.device('cpu'))
25
  for fn in tqdm(sorted(os.listdir(path_70b)))
26
  if fn.endswith('.pth')
27
}
28
29
# These tensors are duplicated rather than distributed among the files
30
31
not_distributed = {
32
  k
33
  for k in weights[0].keys()
34
  #if all((weights[0][k] == weights[i][k]).min() for i in range(1,2))
35
  if all((weights[0][k] == weights[i][k]).min() for i in range(1,4))
36
}
37
38
# What tensor dimensions should be merged, based on whether they are implemented
39
# as Embedding, Row or Column Parallel.
40
41
merge_dimensions ={
42
  r'^layers.\d+.attention.wq.weight$': 0,
43
  r'^layers.\d+.attention.wk.weight$': 0,
44
  r'^layers.\d+.attention.wv.weight$': 0,
45
  r'^layers.\d+.attention.wo.weight$': 1,
46
47
  r'^tok_embeddings.weight$': 1,
48
49
  r'^layers.\d+.feed_forward.w1.weight$': 0,
50
  r'^layers.\d+.feed_forward.w2.weight$': 1,
51
  r'^layers.\d+.feed_forward.w3.weight$': 0,
52
  r'^output.weight$': 0
53
}
54
55
# Merging (or copying if not distributed)
56
output_weights = {}
57
for output, group in enumerate(merge_groups):
58
  output_weights[output] = dict()
59
  for name in tqdm(weights[group[0]], leave=False):
60
    if name in not_distributed:
61
      output_weights[output][name] = weights[0][name]
62
    else:
63
      axis = next(axis for exp, axis in merge_dimensions.items() if re.match(exp, name))
64
      output_weights[output][name] = torch.cat([
65
          weights[member][name]
66
          for member in group
67
      ], axis=axis)
68
69
os.makedirs(f'{path_70b}/one-gpu/', exist_ok=True)
70
with open(f'{path_70b}/params.json') as fin:
71
  with open(f'{path_70b}/one-gpu/params.json', 'w') as fout:
72
    fout.write(fin.read())
73
74
torch.save(
75
    output_weights[0],
76
    f'{path_70b}/one-gpu/consolidated.00.pth'
77
)