Diff of /scripts/download.py [000000] .. [248dc9]

Switch to side-by-side view

--- a
+++ b/scripts/download.py
@@ -0,0 +1,85 @@
+## This script is adapted from https://github.com/Lightning-AI/lit-gpt
+## This script is used to download checkpoints from the HuggingFace Hub.
+
+import os
+import sys
+from pathlib import Path
+from typing import Optional
+
+import torch
+from lightning_utilities.core.imports import RequirementCache
+
+# support running without installing as a package
+wd = Path(__file__).parent.parent.resolve()
+sys.path.append(str(wd))
+
+_SAFETENSORS_AVAILABLE = RequirementCache("safetensors")
+
+
+def download_from_hub(
+    repo_id: Optional[str] = None,
+    access_token: Optional[str] = os.getenv("HF_TOKEN"),
+    from_safetensors: bool = False,
+    tokenizer_only: bool = False,
+    checkpoint_dir: Path = Path("checkpoints"),
+) -> None:
+    if repo_id is None:
+        from lit_gpt.config import configs
+
+        options = [f"{config['hf_config']['org']}/{config['hf_config']['name']}" for config in configs]
+        print("Please specify --repo_id <repo_id>. Available values:")
+        print("\n".join(options))
+        return
+
+    from huggingface_hub import snapshot_download
+
+    if ("meta-llama" in repo_id or "falcon-180" in repo_id) and not access_token:
+        raise ValueError(
+            f"{repo_id} requires authentication, please set the `HF_TOKEN=your_token` environment"
+            " variable or pass --access_token=your_token. You can find your token by visiting"
+            " https://huggingface.co/settings/tokens"
+        )
+
+    download_files = ["tokenizer*", "generation_config.json"]
+    if not tokenizer_only:
+        if from_safetensors:
+            if not _SAFETENSORS_AVAILABLE:
+                raise ModuleNotFoundError(str(_SAFETENSORS_AVAILABLE))
+            download_files.append("*.safetensors")
+        else:
+            # covers `.bin` files and `.bin.index.json`
+            download_files.append("*.bin*")
+    elif from_safetensors:
+        raise ValueError("`--from_safetensors=True` won't have an effect with `--tokenizer_only=True`")
+
+    directory = checkpoint_dir / repo_id
+    snapshot_download(
+        repo_id,
+        local_dir=directory,
+        local_dir_use_symlinks=False,
+        resume_download=True,
+        allow_patterns=download_files,
+        token=access_token,
+    )
+
+    # convert safetensors to PyTorch binaries
+    if from_safetensors:
+        from safetensors import SafetensorError
+        from safetensors.torch import load_file as safetensors_load
+
+        print("Converting .safetensor files to PyTorch binaries (.bin)")
+        for safetensor_path in directory.glob("*.safetensors"):
+            bin_path = safetensor_path.with_suffix(".bin")
+            try:
+                result = safetensors_load(safetensor_path)
+            except SafetensorError as e:
+                raise RuntimeError(f"{safetensor_path} is likely corrupted. Please try to re-download it.") from e
+            print(f"{safetensor_path} --> {bin_path}")
+            torch.save(result, bin_path)
+            os.remove(safetensor_path)
+
+
+if __name__ == "__main__":
+    from jsonargparse import CLI
+
+    CLI(download_from_hub)