a b/scripts/convert_hf_checkpoint.py
1
## This script is adapted from: https://github.com/Lightning-AI/lit-gpt
2
## This script is used to convert the HF checkpoint to the LIT checkpoint
3
4
import gc
5
import json
6
import sys
7
from dataclasses import asdict
8
from functools import partial
9
from pathlib import Path
10
from typing import Dict, List, Optional, Tuple, Union
11
12
import torch
13
from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor
14
15
# support running without installing as a package
16
wd = Path(__file__).parent.parent.resolve()
17
sys.path.append(str(wd))
18
19
from lit_gpt import Config
20
from lit_gpt.utils import incremental_save, lazy_load
21
22
23
def copy_weights_gpt_neox(
24
    state_dict: Dict[str, torch.Tensor],
25
    hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
26
    saver: Optional[incremental_save] = None,
27
    dtype: Optional[torch.dtype] = None,
28
) -> None:
29
    weight_map = {
30
        "gpt_neox.embed_in.weight": "transformer.wte.weight",
31
        "gpt_neox.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias",
32
        "gpt_neox.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
33
        "gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.attn.bias",
34
        "gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight",
35
        "gpt_neox.layers.{}.attention.dense.bias": "transformer.h.{}.attn.proj.bias",
36
        "gpt_neox.layers.{}.attention.dense.weight": "transformer.h.{}.attn.proj.weight",
37
        "gpt_neox.layers.{}.attention.rotary_emb.inv_freq": None,
38
        "gpt_neox.layers.{}.attention.bias": None,
39
        "gpt_neox.layers.{}.attention.masked_bias": None,
40
        "gpt_neox.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias",
41
        "gpt_neox.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight",
42
        "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias": "transformer.h.{}.mlp.fc.bias",
43
        "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight",
44
        "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias": "transformer.h.{}.mlp.proj.bias",
45
        "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight",
46
        "gpt_neox.final_layer_norm.bias": "transformer.ln_f.bias",
47
        "gpt_neox.final_layer_norm.weight": "transformer.ln_f.weight",
48
        "embed_out.weight": "lm_head.weight",
49
    }
50
51
    for name, param in hf_weights.items():
52
        if "gpt_neox.layers" in name:
53
            from_name, number = layer_template(name, 2)
54
            to_name = weight_map[from_name]
55
            if to_name is None:
56
                continue
57
            to_name = to_name.format(number)
58
        else:
59
            to_name = weight_map[name]
60
        param = load_param(param, name, dtype)
61
        if saver is not None:
62
            param = saver.store_early(param)
63
        state_dict[to_name] = param
64
65
66
def copy_weights_falcon(
67
    model_name: str,
68
    state_dict: Dict[str, torch.Tensor],
69
    hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
70
    saver: Optional[incremental_save] = None,
71
    dtype: Optional[torch.dtype] = None,
72
) -> None:
73
    weight_map = {
74
        "transformer.word_embeddings.weight": "transformer.wte.weight",
75
        "transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight",
76
        "transformer.h.{}.self_attention.dense.weight": "transformer.h.{}.attn.proj.weight",
77
        "transformer.h.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight",
78
        "transformer.h.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight",
79
        "transformer.ln_f.bias": "transformer.ln_f.bias",
80
        "transformer.ln_f.weight": "transformer.ln_f.weight",
81
        "lm_head.weight": "lm_head.weight",
82
    }
83
    # the original model definition is different for each size
84
    if "7b" in model_name:
85
        weight_map.update(
86
            {
87
                "transformer.h.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias",
88
                "transformer.h.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
89
            }
90
        )
91
    elif "40b" in model_name or "180B" in model_name:
92
        weight_map.update(
93
            {
94
                "transformer.h.{}.ln_attn.bias": "transformer.h.{}.norm_1.bias",
95
                "transformer.h.{}.ln_attn.weight": "transformer.h.{}.norm_1.weight",
96
                "transformer.h.{}.ln_mlp.bias": "transformer.h.{}.norm_2.bias",
97
                "transformer.h.{}.ln_mlp.weight": "transformer.h.{}.norm_2.weight",
98
            }
99
        )
100
    else:
101
        raise NotImplementedError
102
103
    for name, param in hf_weights.items():
104
        if "transformer.h" in name:
105
            from_name, number = layer_template(name, 2)
106
            to_name = weight_map[from_name].format(number)
107
        else:
108
            to_name = weight_map[name]
109
        param = load_param(param, name, dtype)
110
        if saver is not None:
111
            param = saver.store_early(param)
112
        state_dict[to_name] = param
113
114
115
def copy_weights_hf_llama(
116
    config: Config,
117
    qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
118
    state_dict: Dict[str, torch.Tensor],
119
    hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
120
    saver: Optional[incremental_save] = None,
121
    dtype: Optional[torch.dtype] = None,
122
) -> None:
123
    weight_map = {
124
        "model.embed_tokens.weight": "transformer.wte.weight",
125
        "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
126
        "model.layers.{}.self_attn.q_proj.weight": None,
127
        "model.layers.{}.self_attn.k_proj.weight": None,
128
        "model.layers.{}.self_attn.v_proj.weight": None,
129
        "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
130
        "model.layers.{}.self_attn.rotary_emb.inv_freq": None,
131
        "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight",
132
        "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
133
        "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
134
        "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
135
        "model.norm.weight": "transformer.ln_f.weight",
136
        "lm_head.weight": "lm_head.weight",
137
    }
138
139
    for name, param in hf_weights.items():
140
        if "model.layers" in name:
141
            from_name, number = layer_template(name, 2)
142
            qkv = qkv_weights.setdefault(number, [None, None, None])
143
            if "q_proj" in name:
144
                qkv[0] = param
145
            elif "k_proj" in name:
146
                qkv[1] = param
147
            elif "v_proj" in name:
148
                qkv[2] = param
149
            to_name = weight_map[from_name]
150
            if to_name is None:
151
                continue
152
            to_name = to_name.format(number)
153
        else:
154
            to_name = weight_map[name]
155
        param = load_param(param, name, dtype)
156
        if saver is not None:
157
            param = saver.store_early(param)
158
        state_dict[to_name] = param
159
160
    for i, (q, k, v) in list(qkv_weights.items()):
161
        if q is None or k is None or v is None:
162
            # split across different .bin files
163
            continue
164
        q = load_param(q, f"layer {i} q", dtype)
165
        k = load_param(k, f"layer {i} k", dtype)
166
        v = load_param(v, f"layer {i} v", dtype)
167
        q_per_kv = config.n_head // config.n_query_groups
168
        qs = torch.split(q, config.head_size * q_per_kv)
169
        ks = torch.split(k, config.head_size)
170
        vs = torch.split(v, config.head_size)
171
        cycled = [t for group in zip(qs, ks, vs) for t in group]
172
        qkv = torch.cat(cycled)
173
        state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv
174
        del qkv_weights[i]
175
176
177
def copy_weights_phi(
178
    config: Config,
179
    state_dict: Dict[str, torch.Tensor],
180
    hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
181
    saver: Optional[incremental_save] = None,
182
    dtype: Optional[torch.dtype] = None,
183
) -> None:
184
    if any(layer_name.startswith("layers.") for layer_name in hf_weights):
185
        raise ValueError(
186
            "You are using an outdated Phi1.5 checkpoint. "
187
            "Please reload it as described in 'tutorials/download_phi15.md'"
188
        )
189
190
    weight_map = {
191
        "transformer.embd.wte.weight": "transformer.wte.weight",
192
        "transformer.h.{}.ln.bias": "transformer.h.{}.norm_1.bias",
193
        "transformer.h.{}.ln.weight": "transformer.h.{}.norm_1.weight",
194
        "transformer.h.{}.mixer.Wqkv.bias": "transformer.h.{}.attn.attn.bias",
195
        "transformer.h.{}.mixer.Wqkv.weight": "transformer.h.{}.attn.attn.weight",
196
        "transformer.h.{}.mixer.out_proj.bias": "transformer.h.{}.attn.proj.bias",
197
        "transformer.h.{}.mixer.out_proj.weight": "transformer.h.{}.attn.proj.weight",
198
        "transformer.h.{}.mixer.rotary_emb.inv_freq": None,
199
        "transformer.h.{}.mlp.fc1.bias": "transformer.h.{}.mlp.fc.bias",
200
        "transformer.h.{}.mlp.fc1.weight": "transformer.h.{}.mlp.fc.weight",
201
        "transformer.h.{}.mlp.fc2.bias": "transformer.h.{}.mlp.proj.bias",
202
        "transformer.h.{}.mlp.fc2.weight": "transformer.h.{}.mlp.proj.weight",
203
        "lm_head.ln.weight": "transformer.ln_f.weight",
204
        "lm_head.ln.bias": "transformer.ln_f.bias",
205
        "lm_head.linear.weight": "lm_head.weight",
206
        "lm_head.linear.bias": "lm_head.bias",
207
    }
208
209
    for name, param in hf_weights.items():
210
        if name.startswith("transformer.h."):
211
            from_name, number = layer_template(name, 2)
212
            to_name = weight_map[from_name].format(number)
213
        else:
214
            to_name = weight_map[name]
215
        param = load_param(param, name, dtype)
216
        if "Wqkv" in name:
217
            q_per_kv = config.n_head // config.n_query_groups
218
            total_qkv = q_per_kv + 2  # each group has 1+ queries, 1 key, and 1 value
219
            param = param.view(total_qkv, config.n_query_groups, -1).transpose(0, 1)
220
            param = param.reshape(config.n_embd * 3, -1)
221
            if "bias" in name:
222
                param = param.squeeze()
223
        if saver is not None:
224
            param = saver.store_early(param)
225
        state_dict[to_name] = param
226
227
228
def layer_template(layer_name: str, idx: int) -> Tuple[str, int]:
229
    split = layer_name.split(".")
230
    number = int(split[idx])
231
    split[idx] = "{}"
232
    from_name = ".".join(split)
233
    return from_name, number
234
235
236
def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor:
237
    if hasattr(param, "_load_tensor"):
238
        # support tensors loaded via `lazy_load()`
239
        print(f"Loading {name!r} into RAM")
240
        param = param._load_tensor()
241
    if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype:
242
        print(f"Converting {name!r} from {param.dtype} to {dtype}")
243
        param = param.to(dtype)
244
    return param
245
246
247
@torch.inference_mode()
248
def convert_hf_checkpoint(
249
    *,
250
    checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
251
    model_name: Optional[str] = None,
252
    dtype: Optional[str] = None,
253
) -> None:
254
    if model_name is None:
255
        model_name = checkpoint_dir.name
256
    if dtype is not None:
257
        dtype = getattr(torch, dtype)
258
259
    config = Config.from_name(model_name)
260
    config_dict = asdict(config)
261
    print(f"Model config {config_dict}")
262
    with open(checkpoint_dir / "lit_config.json", "w") as json_config:
263
        json.dump(config_dict, json_config)
264
265
    if "falcon" in model_name:
266
        copy_fn = partial(copy_weights_falcon, model_name)
267
    elif config._mlp_class == "LLaMAMLP":
268
        # holder to reconstitute the split q, k, v
269
        qkv_weights = {}
270
        copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)
271
    elif "phi" in model_name:
272
        copy_fn = partial(copy_weights_phi, config)
273
    else:
274
        copy_fn = copy_weights_gpt_neox
275
276
    # initialize a new empty state dict to hold our new weights
277
    sd = {}
278
279
    # Load the json file containing weight mapping
280
    pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json"
281
    if pytorch_bin_map_json_path.is_file():  # not all checkpoints have this file
282
        with open(pytorch_bin_map_json_path) as json_map:
283
            bin_index = json.load(json_map)
284
        bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
285
    else:
286
        bin_files = set(checkpoint_dir.glob("*.bin"))
287
        # some checkpoints serialize the training arguments
288
        bin_files = {f for f in bin_files if f.name != "training_args.bin"}
289
    if not bin_files:
290
        raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin files")
291
292
    with incremental_save(checkpoint_dir / "lit_model.pth") as saver:
293
        # for checkpoints that split the QKV across several files, we need to keep all the bin files
294
        # open, so we use `ExitStack` to close them all together at the end
295
        for bin_file in sorted(bin_files):
296
            print("Processing", bin_file)
297
            hf_weights = lazy_load(bin_file)
298
            copy_fn(sd, hf_weights, saver=saver, dtype=dtype)
299
        gc.collect()
300
        print("Saving converted checkpoint")
301
        saver.save(sd)
302
303
304
if __name__ == "__main__":
305
    from jsonargparse import CLI
306
307
    CLI(convert_hf_checkpoint)