|
a |
|
b/scripts/convert_hf_checkpoint.py |
|
|
1 |
## This script is adapted from: https://github.com/Lightning-AI/lit-gpt |
|
|
2 |
## This script is used to convert the HF checkpoint to the LIT checkpoint |
|
|
3 |
|
|
|
4 |
import gc |
|
|
5 |
import json |
|
|
6 |
import sys |
|
|
7 |
from dataclasses import asdict |
|
|
8 |
from functools import partial |
|
|
9 |
from pathlib import Path |
|
|
10 |
from typing import Dict, List, Optional, Tuple, Union |
|
|
11 |
|
|
|
12 |
import torch |
|
|
13 |
from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor |
|
|
14 |
|
|
|
15 |
# support running without installing as a package |
|
|
16 |
wd = Path(__file__).parent.parent.resolve() |
|
|
17 |
sys.path.append(str(wd)) |
|
|
18 |
|
|
|
19 |
from lit_gpt import Config |
|
|
20 |
from lit_gpt.utils import incremental_save, lazy_load |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
def copy_weights_gpt_neox( |
|
|
24 |
state_dict: Dict[str, torch.Tensor], |
|
|
25 |
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], |
|
|
26 |
saver: Optional[incremental_save] = None, |
|
|
27 |
dtype: Optional[torch.dtype] = None, |
|
|
28 |
) -> None: |
|
|
29 |
weight_map = { |
|
|
30 |
"gpt_neox.embed_in.weight": "transformer.wte.weight", |
|
|
31 |
"gpt_neox.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", |
|
|
32 |
"gpt_neox.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", |
|
|
33 |
"gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.attn.bias", |
|
|
34 |
"gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", |
|
|
35 |
"gpt_neox.layers.{}.attention.dense.bias": "transformer.h.{}.attn.proj.bias", |
|
|
36 |
"gpt_neox.layers.{}.attention.dense.weight": "transformer.h.{}.attn.proj.weight", |
|
|
37 |
"gpt_neox.layers.{}.attention.rotary_emb.inv_freq": None, |
|
|
38 |
"gpt_neox.layers.{}.attention.bias": None, |
|
|
39 |
"gpt_neox.layers.{}.attention.masked_bias": None, |
|
|
40 |
"gpt_neox.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias", |
|
|
41 |
"gpt_neox.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", |
|
|
42 |
"gpt_neox.layers.{}.mlp.dense_h_to_4h.bias": "transformer.h.{}.mlp.fc.bias", |
|
|
43 |
"gpt_neox.layers.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", |
|
|
44 |
"gpt_neox.layers.{}.mlp.dense_4h_to_h.bias": "transformer.h.{}.mlp.proj.bias", |
|
|
45 |
"gpt_neox.layers.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", |
|
|
46 |
"gpt_neox.final_layer_norm.bias": "transformer.ln_f.bias", |
|
|
47 |
"gpt_neox.final_layer_norm.weight": "transformer.ln_f.weight", |
|
|
48 |
"embed_out.weight": "lm_head.weight", |
|
|
49 |
} |
|
|
50 |
|
|
|
51 |
for name, param in hf_weights.items(): |
|
|
52 |
if "gpt_neox.layers" in name: |
|
|
53 |
from_name, number = layer_template(name, 2) |
|
|
54 |
to_name = weight_map[from_name] |
|
|
55 |
if to_name is None: |
|
|
56 |
continue |
|
|
57 |
to_name = to_name.format(number) |
|
|
58 |
else: |
|
|
59 |
to_name = weight_map[name] |
|
|
60 |
param = load_param(param, name, dtype) |
|
|
61 |
if saver is not None: |
|
|
62 |
param = saver.store_early(param) |
|
|
63 |
state_dict[to_name] = param |
|
|
64 |
|
|
|
65 |
|
|
|
66 |
def copy_weights_falcon( |
|
|
67 |
model_name: str, |
|
|
68 |
state_dict: Dict[str, torch.Tensor], |
|
|
69 |
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], |
|
|
70 |
saver: Optional[incremental_save] = None, |
|
|
71 |
dtype: Optional[torch.dtype] = None, |
|
|
72 |
) -> None: |
|
|
73 |
weight_map = { |
|
|
74 |
"transformer.word_embeddings.weight": "transformer.wte.weight", |
|
|
75 |
"transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", |
|
|
76 |
"transformer.h.{}.self_attention.dense.weight": "transformer.h.{}.attn.proj.weight", |
|
|
77 |
"transformer.h.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", |
|
|
78 |
"transformer.h.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", |
|
|
79 |
"transformer.ln_f.bias": "transformer.ln_f.bias", |
|
|
80 |
"transformer.ln_f.weight": "transformer.ln_f.weight", |
|
|
81 |
"lm_head.weight": "lm_head.weight", |
|
|
82 |
} |
|
|
83 |
# the original model definition is different for each size |
|
|
84 |
if "7b" in model_name: |
|
|
85 |
weight_map.update( |
|
|
86 |
{ |
|
|
87 |
"transformer.h.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", |
|
|
88 |
"transformer.h.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", |
|
|
89 |
} |
|
|
90 |
) |
|
|
91 |
elif "40b" in model_name or "180B" in model_name: |
|
|
92 |
weight_map.update( |
|
|
93 |
{ |
|
|
94 |
"transformer.h.{}.ln_attn.bias": "transformer.h.{}.norm_1.bias", |
|
|
95 |
"transformer.h.{}.ln_attn.weight": "transformer.h.{}.norm_1.weight", |
|
|
96 |
"transformer.h.{}.ln_mlp.bias": "transformer.h.{}.norm_2.bias", |
|
|
97 |
"transformer.h.{}.ln_mlp.weight": "transformer.h.{}.norm_2.weight", |
|
|
98 |
} |
|
|
99 |
) |
|
|
100 |
else: |
|
|
101 |
raise NotImplementedError |
|
|
102 |
|
|
|
103 |
for name, param in hf_weights.items(): |
|
|
104 |
if "transformer.h" in name: |
|
|
105 |
from_name, number = layer_template(name, 2) |
|
|
106 |
to_name = weight_map[from_name].format(number) |
|
|
107 |
else: |
|
|
108 |
to_name = weight_map[name] |
|
|
109 |
param = load_param(param, name, dtype) |
|
|
110 |
if saver is not None: |
|
|
111 |
param = saver.store_early(param) |
|
|
112 |
state_dict[to_name] = param |
|
|
113 |
|
|
|
114 |
|
|
|
115 |
def copy_weights_hf_llama( |
|
|
116 |
config: Config, |
|
|
117 |
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], |
|
|
118 |
state_dict: Dict[str, torch.Tensor], |
|
|
119 |
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], |
|
|
120 |
saver: Optional[incremental_save] = None, |
|
|
121 |
dtype: Optional[torch.dtype] = None, |
|
|
122 |
) -> None: |
|
|
123 |
weight_map = { |
|
|
124 |
"model.embed_tokens.weight": "transformer.wte.weight", |
|
|
125 |
"model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", |
|
|
126 |
"model.layers.{}.self_attn.q_proj.weight": None, |
|
|
127 |
"model.layers.{}.self_attn.k_proj.weight": None, |
|
|
128 |
"model.layers.{}.self_attn.v_proj.weight": None, |
|
|
129 |
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", |
|
|
130 |
"model.layers.{}.self_attn.rotary_emb.inv_freq": None, |
|
|
131 |
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", |
|
|
132 |
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight", |
|
|
133 |
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight", |
|
|
134 |
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", |
|
|
135 |
"model.norm.weight": "transformer.ln_f.weight", |
|
|
136 |
"lm_head.weight": "lm_head.weight", |
|
|
137 |
} |
|
|
138 |
|
|
|
139 |
for name, param in hf_weights.items(): |
|
|
140 |
if "model.layers" in name: |
|
|
141 |
from_name, number = layer_template(name, 2) |
|
|
142 |
qkv = qkv_weights.setdefault(number, [None, None, None]) |
|
|
143 |
if "q_proj" in name: |
|
|
144 |
qkv[0] = param |
|
|
145 |
elif "k_proj" in name: |
|
|
146 |
qkv[1] = param |
|
|
147 |
elif "v_proj" in name: |
|
|
148 |
qkv[2] = param |
|
|
149 |
to_name = weight_map[from_name] |
|
|
150 |
if to_name is None: |
|
|
151 |
continue |
|
|
152 |
to_name = to_name.format(number) |
|
|
153 |
else: |
|
|
154 |
to_name = weight_map[name] |
|
|
155 |
param = load_param(param, name, dtype) |
|
|
156 |
if saver is not None: |
|
|
157 |
param = saver.store_early(param) |
|
|
158 |
state_dict[to_name] = param |
|
|
159 |
|
|
|
160 |
for i, (q, k, v) in list(qkv_weights.items()): |
|
|
161 |
if q is None or k is None or v is None: |
|
|
162 |
# split across different .bin files |
|
|
163 |
continue |
|
|
164 |
q = load_param(q, f"layer {i} q", dtype) |
|
|
165 |
k = load_param(k, f"layer {i} k", dtype) |
|
|
166 |
v = load_param(v, f"layer {i} v", dtype) |
|
|
167 |
q_per_kv = config.n_head // config.n_query_groups |
|
|
168 |
qs = torch.split(q, config.head_size * q_per_kv) |
|
|
169 |
ks = torch.split(k, config.head_size) |
|
|
170 |
vs = torch.split(v, config.head_size) |
|
|
171 |
cycled = [t for group in zip(qs, ks, vs) for t in group] |
|
|
172 |
qkv = torch.cat(cycled) |
|
|
173 |
state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv |
|
|
174 |
del qkv_weights[i] |
|
|
175 |
|
|
|
176 |
|
|
|
177 |
def copy_weights_phi( |
|
|
178 |
config: Config, |
|
|
179 |
state_dict: Dict[str, torch.Tensor], |
|
|
180 |
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], |
|
|
181 |
saver: Optional[incremental_save] = None, |
|
|
182 |
dtype: Optional[torch.dtype] = None, |
|
|
183 |
) -> None: |
|
|
184 |
if any(layer_name.startswith("layers.") for layer_name in hf_weights): |
|
|
185 |
raise ValueError( |
|
|
186 |
"You are using an outdated Phi1.5 checkpoint. " |
|
|
187 |
"Please reload it as described in 'tutorials/download_phi15.md'" |
|
|
188 |
) |
|
|
189 |
|
|
|
190 |
weight_map = { |
|
|
191 |
"transformer.embd.wte.weight": "transformer.wte.weight", |
|
|
192 |
"transformer.h.{}.ln.bias": "transformer.h.{}.norm_1.bias", |
|
|
193 |
"transformer.h.{}.ln.weight": "transformer.h.{}.norm_1.weight", |
|
|
194 |
"transformer.h.{}.mixer.Wqkv.bias": "transformer.h.{}.attn.attn.bias", |
|
|
195 |
"transformer.h.{}.mixer.Wqkv.weight": "transformer.h.{}.attn.attn.weight", |
|
|
196 |
"transformer.h.{}.mixer.out_proj.bias": "transformer.h.{}.attn.proj.bias", |
|
|
197 |
"transformer.h.{}.mixer.out_proj.weight": "transformer.h.{}.attn.proj.weight", |
|
|
198 |
"transformer.h.{}.mixer.rotary_emb.inv_freq": None, |
|
|
199 |
"transformer.h.{}.mlp.fc1.bias": "transformer.h.{}.mlp.fc.bias", |
|
|
200 |
"transformer.h.{}.mlp.fc1.weight": "transformer.h.{}.mlp.fc.weight", |
|
|
201 |
"transformer.h.{}.mlp.fc2.bias": "transformer.h.{}.mlp.proj.bias", |
|
|
202 |
"transformer.h.{}.mlp.fc2.weight": "transformer.h.{}.mlp.proj.weight", |
|
|
203 |
"lm_head.ln.weight": "transformer.ln_f.weight", |
|
|
204 |
"lm_head.ln.bias": "transformer.ln_f.bias", |
|
|
205 |
"lm_head.linear.weight": "lm_head.weight", |
|
|
206 |
"lm_head.linear.bias": "lm_head.bias", |
|
|
207 |
} |
|
|
208 |
|
|
|
209 |
for name, param in hf_weights.items(): |
|
|
210 |
if name.startswith("transformer.h."): |
|
|
211 |
from_name, number = layer_template(name, 2) |
|
|
212 |
to_name = weight_map[from_name].format(number) |
|
|
213 |
else: |
|
|
214 |
to_name = weight_map[name] |
|
|
215 |
param = load_param(param, name, dtype) |
|
|
216 |
if "Wqkv" in name: |
|
|
217 |
q_per_kv = config.n_head // config.n_query_groups |
|
|
218 |
total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value |
|
|
219 |
param = param.view(total_qkv, config.n_query_groups, -1).transpose(0, 1) |
|
|
220 |
param = param.reshape(config.n_embd * 3, -1) |
|
|
221 |
if "bias" in name: |
|
|
222 |
param = param.squeeze() |
|
|
223 |
if saver is not None: |
|
|
224 |
param = saver.store_early(param) |
|
|
225 |
state_dict[to_name] = param |
|
|
226 |
|
|
|
227 |
|
|
|
228 |
def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: |
|
|
229 |
split = layer_name.split(".") |
|
|
230 |
number = int(split[idx]) |
|
|
231 |
split[idx] = "{}" |
|
|
232 |
from_name = ".".join(split) |
|
|
233 |
return from_name, number |
|
|
234 |
|
|
|
235 |
|
|
|
236 |
def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor: |
|
|
237 |
if hasattr(param, "_load_tensor"): |
|
|
238 |
# support tensors loaded via `lazy_load()` |
|
|
239 |
print(f"Loading {name!r} into RAM") |
|
|
240 |
param = param._load_tensor() |
|
|
241 |
if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype: |
|
|
242 |
print(f"Converting {name!r} from {param.dtype} to {dtype}") |
|
|
243 |
param = param.to(dtype) |
|
|
244 |
return param |
|
|
245 |
|
|
|
246 |
|
|
|
247 |
@torch.inference_mode() |
|
|
248 |
def convert_hf_checkpoint( |
|
|
249 |
*, |
|
|
250 |
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), |
|
|
251 |
model_name: Optional[str] = None, |
|
|
252 |
dtype: Optional[str] = None, |
|
|
253 |
) -> None: |
|
|
254 |
if model_name is None: |
|
|
255 |
model_name = checkpoint_dir.name |
|
|
256 |
if dtype is not None: |
|
|
257 |
dtype = getattr(torch, dtype) |
|
|
258 |
|
|
|
259 |
config = Config.from_name(model_name) |
|
|
260 |
config_dict = asdict(config) |
|
|
261 |
print(f"Model config {config_dict}") |
|
|
262 |
with open(checkpoint_dir / "lit_config.json", "w") as json_config: |
|
|
263 |
json.dump(config_dict, json_config) |
|
|
264 |
|
|
|
265 |
if "falcon" in model_name: |
|
|
266 |
copy_fn = partial(copy_weights_falcon, model_name) |
|
|
267 |
elif config._mlp_class == "LLaMAMLP": |
|
|
268 |
# holder to reconstitute the split q, k, v |
|
|
269 |
qkv_weights = {} |
|
|
270 |
copy_fn = partial(copy_weights_hf_llama, config, qkv_weights) |
|
|
271 |
elif "phi" in model_name: |
|
|
272 |
copy_fn = partial(copy_weights_phi, config) |
|
|
273 |
else: |
|
|
274 |
copy_fn = copy_weights_gpt_neox |
|
|
275 |
|
|
|
276 |
# initialize a new empty state dict to hold our new weights |
|
|
277 |
sd = {} |
|
|
278 |
|
|
|
279 |
# Load the json file containing weight mapping |
|
|
280 |
pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json" |
|
|
281 |
if pytorch_bin_map_json_path.is_file(): # not all checkpoints have this file |
|
|
282 |
with open(pytorch_bin_map_json_path) as json_map: |
|
|
283 |
bin_index = json.load(json_map) |
|
|
284 |
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} |
|
|
285 |
else: |
|
|
286 |
bin_files = set(checkpoint_dir.glob("*.bin")) |
|
|
287 |
# some checkpoints serialize the training arguments |
|
|
288 |
bin_files = {f for f in bin_files if f.name != "training_args.bin"} |
|
|
289 |
if not bin_files: |
|
|
290 |
raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin files") |
|
|
291 |
|
|
|
292 |
with incremental_save(checkpoint_dir / "lit_model.pth") as saver: |
|
|
293 |
# for checkpoints that split the QKV across several files, we need to keep all the bin files |
|
|
294 |
# open, so we use `ExitStack` to close them all together at the end |
|
|
295 |
for bin_file in sorted(bin_files): |
|
|
296 |
print("Processing", bin_file) |
|
|
297 |
hf_weights = lazy_load(bin_file) |
|
|
298 |
copy_fn(sd, hf_weights, saver=saver, dtype=dtype) |
|
|
299 |
gc.collect() |
|
|
300 |
print("Saving converted checkpoint") |
|
|
301 |
saver.save(sd) |
|
|
302 |
|
|
|
303 |
|
|
|
304 |
if __name__ == "__main__": |
|
|
305 |
from jsonargparse import CLI |
|
|
306 |
|
|
|
307 |
CLI(convert_hf_checkpoint) |