[248dc9]: / scripts / convert_hf_checkpoint.py

Download this file

308 lines (275 with data), 13.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
## This script is adapted from: https://github.com/Lightning-AI/lit-gpt
## This script is used to convert the HF checkpoint to the LIT checkpoint
import gc
import json
import sys
from dataclasses import asdict
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import torch
from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from lit_gpt import Config
from lit_gpt.utils import incremental_save, lazy_load
def copy_weights_gpt_neox(
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
weight_map = {
"gpt_neox.embed_in.weight": "transformer.wte.weight",
"gpt_neox.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias",
"gpt_neox.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
"gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.attn.bias",
"gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight",
"gpt_neox.layers.{}.attention.dense.bias": "transformer.h.{}.attn.proj.bias",
"gpt_neox.layers.{}.attention.dense.weight": "transformer.h.{}.attn.proj.weight",
"gpt_neox.layers.{}.attention.rotary_emb.inv_freq": None,
"gpt_neox.layers.{}.attention.bias": None,
"gpt_neox.layers.{}.attention.masked_bias": None,
"gpt_neox.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias",
"gpt_neox.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight",
"gpt_neox.layers.{}.mlp.dense_h_to_4h.bias": "transformer.h.{}.mlp.fc.bias",
"gpt_neox.layers.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight",
"gpt_neox.layers.{}.mlp.dense_4h_to_h.bias": "transformer.h.{}.mlp.proj.bias",
"gpt_neox.layers.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight",
"gpt_neox.final_layer_norm.bias": "transformer.ln_f.bias",
"gpt_neox.final_layer_norm.weight": "transformer.ln_f.weight",
"embed_out.weight": "lm_head.weight",
}
for name, param in hf_weights.items():
if "gpt_neox.layers" in name:
from_name, number = layer_template(name, 2)
to_name = weight_map[from_name]
if to_name is None:
continue
to_name = to_name.format(number)
else:
to_name = weight_map[name]
param = load_param(param, name, dtype)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
def copy_weights_falcon(
model_name: str,
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
weight_map = {
"transformer.word_embeddings.weight": "transformer.wte.weight",
"transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight",
"transformer.h.{}.self_attention.dense.weight": "transformer.h.{}.attn.proj.weight",
"transformer.h.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight",
"transformer.h.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight",
"transformer.ln_f.bias": "transformer.ln_f.bias",
"transformer.ln_f.weight": "transformer.ln_f.weight",
"lm_head.weight": "lm_head.weight",
}
# the original model definition is different for each size
if "7b" in model_name:
weight_map.update(
{
"transformer.h.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias",
"transformer.h.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
}
)
elif "40b" in model_name or "180B" in model_name:
weight_map.update(
{
"transformer.h.{}.ln_attn.bias": "transformer.h.{}.norm_1.bias",
"transformer.h.{}.ln_attn.weight": "transformer.h.{}.norm_1.weight",
"transformer.h.{}.ln_mlp.bias": "transformer.h.{}.norm_2.bias",
"transformer.h.{}.ln_mlp.weight": "transformer.h.{}.norm_2.weight",
}
)
else:
raise NotImplementedError
for name, param in hf_weights.items():
if "transformer.h" in name:
from_name, number = layer_template(name, 2)
to_name = weight_map[from_name].format(number)
else:
to_name = weight_map[name]
param = load_param(param, name, dtype)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
def copy_weights_hf_llama(
config: Config,
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
weight_map = {
"model.embed_tokens.weight": "transformer.wte.weight",
"model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
"model.layers.{}.self_attn.q_proj.weight": None,
"model.layers.{}.self_attn.k_proj.weight": None,
"model.layers.{}.self_attn.v_proj.weight": None,
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight",
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
"model.norm.weight": "transformer.ln_f.weight",
"lm_head.weight": "lm_head.weight",
}
for name, param in hf_weights.items():
if "model.layers" in name:
from_name, number = layer_template(name, 2)
qkv = qkv_weights.setdefault(number, [None, None, None])
if "q_proj" in name:
qkv[0] = param
elif "k_proj" in name:
qkv[1] = param
elif "v_proj" in name:
qkv[2] = param
to_name = weight_map[from_name]
if to_name is None:
continue
to_name = to_name.format(number)
else:
to_name = weight_map[name]
param = load_param(param, name, dtype)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
for i, (q, k, v) in list(qkv_weights.items()):
if q is None or k is None or v is None:
# split across different .bin files
continue
q = load_param(q, f"layer {i} q", dtype)
k = load_param(k, f"layer {i} k", dtype)
v = load_param(v, f"layer {i} v", dtype)
q_per_kv = config.n_head // config.n_query_groups
qs = torch.split(q, config.head_size * q_per_kv)
ks = torch.split(k, config.head_size)
vs = torch.split(v, config.head_size)
cycled = [t for group in zip(qs, ks, vs) for t in group]
qkv = torch.cat(cycled)
state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv
del qkv_weights[i]
def copy_weights_phi(
config: Config,
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
if any(layer_name.startswith("layers.") for layer_name in hf_weights):
raise ValueError(
"You are using an outdated Phi1.5 checkpoint. "
"Please reload it as described in 'tutorials/download_phi15.md'"
)
weight_map = {
"transformer.embd.wte.weight": "transformer.wte.weight",
"transformer.h.{}.ln.bias": "transformer.h.{}.norm_1.bias",
"transformer.h.{}.ln.weight": "transformer.h.{}.norm_1.weight",
"transformer.h.{}.mixer.Wqkv.bias": "transformer.h.{}.attn.attn.bias",
"transformer.h.{}.mixer.Wqkv.weight": "transformer.h.{}.attn.attn.weight",
"transformer.h.{}.mixer.out_proj.bias": "transformer.h.{}.attn.proj.bias",
"transformer.h.{}.mixer.out_proj.weight": "transformer.h.{}.attn.proj.weight",
"transformer.h.{}.mixer.rotary_emb.inv_freq": None,
"transformer.h.{}.mlp.fc1.bias": "transformer.h.{}.mlp.fc.bias",
"transformer.h.{}.mlp.fc1.weight": "transformer.h.{}.mlp.fc.weight",
"transformer.h.{}.mlp.fc2.bias": "transformer.h.{}.mlp.proj.bias",
"transformer.h.{}.mlp.fc2.weight": "transformer.h.{}.mlp.proj.weight",
"lm_head.ln.weight": "transformer.ln_f.weight",
"lm_head.ln.bias": "transformer.ln_f.bias",
"lm_head.linear.weight": "lm_head.weight",
"lm_head.linear.bias": "lm_head.bias",
}
for name, param in hf_weights.items():
if name.startswith("transformer.h."):
from_name, number = layer_template(name, 2)
to_name = weight_map[from_name].format(number)
else:
to_name = weight_map[name]
param = load_param(param, name, dtype)
if "Wqkv" in name:
q_per_kv = config.n_head // config.n_query_groups
total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
param = param.view(total_qkv, config.n_query_groups, -1).transpose(0, 1)
param = param.reshape(config.n_embd * 3, -1)
if "bias" in name:
param = param.squeeze()
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
def layer_template(layer_name: str, idx: int) -> Tuple[str, int]:
split = layer_name.split(".")
number = int(split[idx])
split[idx] = "{}"
from_name = ".".join(split)
return from_name, number
def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor:
if hasattr(param, "_load_tensor"):
# support tensors loaded via `lazy_load()`
print(f"Loading {name!r} into RAM")
param = param._load_tensor()
if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype:
print(f"Converting {name!r} from {param.dtype} to {dtype}")
param = param.to(dtype)
return param
@torch.inference_mode()
def convert_hf_checkpoint(
*,
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
model_name: Optional[str] = None,
dtype: Optional[str] = None,
) -> None:
if model_name is None:
model_name = checkpoint_dir.name
if dtype is not None:
dtype = getattr(torch, dtype)
config = Config.from_name(model_name)
config_dict = asdict(config)
print(f"Model config {config_dict}")
with open(checkpoint_dir / "lit_config.json", "w") as json_config:
json.dump(config_dict, json_config)
if "falcon" in model_name:
copy_fn = partial(copy_weights_falcon, model_name)
elif config._mlp_class == "LLaMAMLP":
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)
elif "phi" in model_name:
copy_fn = partial(copy_weights_phi, config)
else:
copy_fn = copy_weights_gpt_neox
# initialize a new empty state dict to hold our new weights
sd = {}
# Load the json file containing weight mapping
pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json"
if pytorch_bin_map_json_path.is_file(): # not all checkpoints have this file
with open(pytorch_bin_map_json_path) as json_map:
bin_index = json.load(json_map)
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
else:
bin_files = set(checkpoint_dir.glob("*.bin"))
# some checkpoints serialize the training arguments
bin_files = {f for f in bin_files if f.name != "training_args.bin"}
if not bin_files:
raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin files")
with incremental_save(checkpoint_dir / "lit_model.pth") as saver:
# for checkpoints that split the QKV across several files, we need to keep all the bin files
# open, so we use `ExitStack` to close them all together at the end
for bin_file in sorted(bin_files):
print("Processing", bin_file)
hf_weights = lazy_load(bin_file)
copy_fn(sd, hf_weights, saver=saver, dtype=dtype)
gc.collect()
print("Saving converted checkpoint")
saver.save(sd)
if __name__ == "__main__":
from jsonargparse import CLI
CLI(convert_hf_checkpoint)