a b/scripts/download.py
1
## This script is adapted from https://github.com/Lightning-AI/lit-gpt
2
## This script is used to download checkpoints from the HuggingFace Hub.
3
4
import os
5
import sys
6
from pathlib import Path
7
from typing import Optional
8
9
import torch
10
from lightning_utilities.core.imports import RequirementCache
11
12
# support running without installing as a package
13
wd = Path(__file__).parent.parent.resolve()
14
sys.path.append(str(wd))
15
16
_SAFETENSORS_AVAILABLE = RequirementCache("safetensors")
17
18
19
def download_from_hub(
20
    repo_id: Optional[str] = None,
21
    access_token: Optional[str] = os.getenv("HF_TOKEN"),
22
    from_safetensors: bool = False,
23
    tokenizer_only: bool = False,
24
    checkpoint_dir: Path = Path("checkpoints"),
25
) -> None:
26
    if repo_id is None:
27
        from lit_gpt.config import configs
28
29
        options = [f"{config['hf_config']['org']}/{config['hf_config']['name']}" for config in configs]
30
        print("Please specify --repo_id <repo_id>. Available values:")
31
        print("\n".join(options))
32
        return
33
34
    from huggingface_hub import snapshot_download
35
36
    if ("meta-llama" in repo_id or "falcon-180" in repo_id) and not access_token:
37
        raise ValueError(
38
            f"{repo_id} requires authentication, please set the `HF_TOKEN=your_token` environment"
39
            " variable or pass --access_token=your_token. You can find your token by visiting"
40
            " https://huggingface.co/settings/tokens"
41
        )
42
43
    download_files = ["tokenizer*", "generation_config.json"]
44
    if not tokenizer_only:
45
        if from_safetensors:
46
            if not _SAFETENSORS_AVAILABLE:
47
                raise ModuleNotFoundError(str(_SAFETENSORS_AVAILABLE))
48
            download_files.append("*.safetensors")
49
        else:
50
            # covers `.bin` files and `.bin.index.json`
51
            download_files.append("*.bin*")
52
    elif from_safetensors:
53
        raise ValueError("`--from_safetensors=True` won't have an effect with `--tokenizer_only=True`")
54
55
    directory = checkpoint_dir / repo_id
56
    snapshot_download(
57
        repo_id,
58
        local_dir=directory,
59
        local_dir_use_symlinks=False,
60
        resume_download=True,
61
        allow_patterns=download_files,
62
        token=access_token,
63
    )
64
65
    # convert safetensors to PyTorch binaries
66
    if from_safetensors:
67
        from safetensors import SafetensorError
68
        from safetensors.torch import load_file as safetensors_load
69
70
        print("Converting .safetensor files to PyTorch binaries (.bin)")
71
        for safetensor_path in directory.glob("*.safetensors"):
72
            bin_path = safetensor_path.with_suffix(".bin")
73
            try:
74
                result = safetensors_load(safetensor_path)
75
            except SafetensorError as e:
76
                raise RuntimeError(f"{safetensor_path} is likely corrupted. Please try to re-download it.") from e
77
            print(f"{safetensor_path} --> {bin_path}")
78
            torch.save(result, bin_path)
79
            os.remove(safetensor_path)
80
81
82
if __name__ == "__main__":
83
    from jsonargparse import CLI
84
85
    CLI(download_from_hub)