--- a +++ b/scripts/download.py @@ -0,0 +1,85 @@ +## This script is adapted from https://github.com/Lightning-AI/lit-gpt +## This script is used to download checkpoints from the HuggingFace Hub. + +import os +import sys +from pathlib import Path +from typing import Optional + +import torch +from lightning_utilities.core.imports import RequirementCache + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +_SAFETENSORS_AVAILABLE = RequirementCache("safetensors") + + +def download_from_hub( + repo_id: Optional[str] = None, + access_token: Optional[str] = os.getenv("HF_TOKEN"), + from_safetensors: bool = False, + tokenizer_only: bool = False, + checkpoint_dir: Path = Path("checkpoints"), +) -> None: + if repo_id is None: + from lit_gpt.config import configs + + options = [f"{config['hf_config']['org']}/{config['hf_config']['name']}" for config in configs] + print("Please specify --repo_id <repo_id>. Available values:") + print("\n".join(options)) + return + + from huggingface_hub import snapshot_download + + if ("meta-llama" in repo_id or "falcon-180" in repo_id) and not access_token: + raise ValueError( + f"{repo_id} requires authentication, please set the `HF_TOKEN=your_token` environment" + " variable or pass --access_token=your_token. You can find your token by visiting" + " https://huggingface.co/settings/tokens" + ) + + download_files = ["tokenizer*", "generation_config.json"] + if not tokenizer_only: + if from_safetensors: + if not _SAFETENSORS_AVAILABLE: + raise ModuleNotFoundError(str(_SAFETENSORS_AVAILABLE)) + download_files.append("*.safetensors") + else: + # covers `.bin` files and `.bin.index.json` + download_files.append("*.bin*") + elif from_safetensors: + raise ValueError("`--from_safetensors=True` won't have an effect with `--tokenizer_only=True`") + + directory = checkpoint_dir / repo_id + snapshot_download( + repo_id, + local_dir=directory, + local_dir_use_symlinks=False, + resume_download=True, + allow_patterns=download_files, + token=access_token, + ) + + # convert safetensors to PyTorch binaries + if from_safetensors: + from safetensors import SafetensorError + from safetensors.torch import load_file as safetensors_load + + print("Converting .safetensor files to PyTorch binaries (.bin)") + for safetensor_path in directory.glob("*.safetensors"): + bin_path = safetensor_path.with_suffix(".bin") + try: + result = safetensors_load(safetensor_path) + except SafetensorError as e: + raise RuntimeError(f"{safetensor_path} is likely corrupted. Please try to re-download it.") from e + print(f"{safetensor_path} --> {bin_path}") + torch.save(result, bin_path) + os.remove(safetensor_path) + + +if __name__ == "__main__": + from jsonargparse import CLI + + CLI(download_from_hub)