Diff of /src/merge_weights.py [000000] .. [014e6e]

Switch to side-by-side view

--- a
+++ b/src/merge_weights.py
@@ -0,0 +1,77 @@
+# -*- coding: utf-8 -*-
+"""
+@Time : 2023/12/12 15:09
+@Auth : Juexiao Zhou
+@File :merge_weights.py
+@IDE :PyCharm
+@Page: www.joshuachou.ink
+"""
+
+import os
+import re
+import torch
+from tqdm.cli import tqdm
+
+#path_70b = '/home/zhouj0d/Science/PID28.ABC/AutoBA/src/llama-main/llama-2-13b-chat/'
+#path_70b = '/home/zhouj0d/Science/PID28.ABC/AutoBA/src/codellama-main/CodeLlama-13b-Instruct/'
+path_70b = '/home/zhouj0d/Science/PID28.ABC/AutoBA/src/codellama-main/CodeLlama-34b-Instruct/'
+
+# Which files are merged into one
+#merge_groups = [[0,1]]
+merge_groups = [[0,1,2,3]]
+
+weights = {
+  int(fn.split('.')[1]): torch.load(f'{path_70b}{fn}', map_location=torch.device('cpu'))
+  for fn in tqdm(sorted(os.listdir(path_70b)))
+  if fn.endswith('.pth')
+}
+
+# These tensors are duplicated rather than distributed among the files
+
+not_distributed = {
+  k
+  for k in weights[0].keys()
+  #if all((weights[0][k] == weights[i][k]).min() for i in range(1,2))
+  if all((weights[0][k] == weights[i][k]).min() for i in range(1,4))
+}
+
+# What tensor dimensions should be merged, based on whether they are implemented
+# as Embedding, Row or Column Parallel.
+
+merge_dimensions ={
+  r'^layers.\d+.attention.wq.weight$': 0,
+  r'^layers.\d+.attention.wk.weight$': 0,
+  r'^layers.\d+.attention.wv.weight$': 0,
+  r'^layers.\d+.attention.wo.weight$': 1,
+
+  r'^tok_embeddings.weight$': 1,
+
+  r'^layers.\d+.feed_forward.w1.weight$': 0,
+  r'^layers.\d+.feed_forward.w2.weight$': 1,
+  r'^layers.\d+.feed_forward.w3.weight$': 0,
+  r'^output.weight$': 0
+}
+
+# Merging (or copying if not distributed)
+output_weights = {}
+for output, group in enumerate(merge_groups):
+  output_weights[output] = dict()
+  for name in tqdm(weights[group[0]], leave=False):
+    if name in not_distributed:
+      output_weights[output][name] = weights[0][name]
+    else:
+      axis = next(axis for exp, axis in merge_dimensions.items() if re.match(exp, name))
+      output_weights[output][name] = torch.cat([
+          weights[member][name]
+          for member in group
+      ], axis=axis)
+
+os.makedirs(f'{path_70b}/one-gpu/', exist_ok=True)
+with open(f'{path_70b}/params.json') as fin:
+  with open(f'{path_70b}/one-gpu/params.json', 'w') as fout:
+    fout.write(fin.read())
+
+torch.save(
+    output_weights[0],
+    f'{path_70b}/one-gpu/consolidated.00.pth'
+)
\ No newline at end of file