|
a |
|
b/generate/base.py |
|
|
1 |
## This file is adapted from the original file from the Lightning-AI lit-gpt repository: https://github.com/Lightning-AI/lit-gpt |
|
|
2 |
|
|
|
3 |
import sys |
|
|
4 |
import time |
|
|
5 |
from pathlib import Path |
|
|
6 |
from typing import Any, Literal, Optional |
|
|
7 |
|
|
|
8 |
import lightning as L |
|
|
9 |
import torch |
|
|
10 |
import torch._dynamo.config |
|
|
11 |
import torch._inductor.config |
|
|
12 |
from lightning.fabric.plugins import BitsandbytesPrecision |
|
|
13 |
from lightning.fabric.strategies import FSDPStrategy |
|
|
14 |
|
|
|
15 |
# support running without installing as a package |
|
|
16 |
wd = Path(__file__).parent.parent.resolve() |
|
|
17 |
sys.path.append(str(wd)) |
|
|
18 |
|
|
|
19 |
from lit_gpt import GPT, Config, Tokenizer |
|
|
20 |
from lit_gpt.model import Block |
|
|
21 |
from lit_gpt.utils import ( |
|
|
22 |
check_valid_checkpoint_dir, |
|
|
23 |
get_default_supported_precision, |
|
|
24 |
gptq_quantization, |
|
|
25 |
load_checkpoint, |
|
|
26 |
) |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor: |
|
|
30 |
if torch._dynamo.is_compiling(): |
|
|
31 |
# Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly |
|
|
32 |
distribution = torch.empty_like(probs).exponential_(1) |
|
|
33 |
return torch.argmax(probs / distribution, dim=-1, keepdim=True) |
|
|
34 |
return torch.multinomial(probs, num_samples=1) |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None) -> torch.Tensor: |
|
|
38 |
logits = logits[0, -1] |
|
|
39 |
# optionally crop the logits to only the top k options |
|
|
40 |
if top_k is not None: |
|
|
41 |
v, i = torch.topk(logits, min(top_k, logits.size(-1))) |
|
|
42 |
# do not use `torch.where` as in nanogpt because it will repeat top-k collisions |
|
|
43 |
logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v) |
|
|
44 |
# optionally scale the logits and sample from a probability distribution |
|
|
45 |
if temperature > 0.0: |
|
|
46 |
probs = torch.nn.functional.softmax(logits / temperature, dim=-1) |
|
|
47 |
return multinomial_num_samples_1(probs) |
|
|
48 |
return torch.argmax(logits, dim=-1, keepdim=True) |
|
|
49 |
|
|
|
50 |
|
|
|
51 |
def next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any) -> torch.Tensor: |
|
|
52 |
logits = model(x, input_pos) |
|
|
53 |
next = sample(logits, **kwargs) |
|
|
54 |
return next.type_as(x) |
|
|
55 |
|
|
|
56 |
|
|
|
57 |
@torch.inference_mode() |
|
|
58 |
def generate( |
|
|
59 |
model: GPT, |
|
|
60 |
prompt: torch.Tensor, |
|
|
61 |
max_returned_tokens: int, |
|
|
62 |
*, |
|
|
63 |
temperature: float = 1.0, |
|
|
64 |
top_k: Optional[int] = None, |
|
|
65 |
eos_id: Optional[int] = None, |
|
|
66 |
) -> torch.Tensor: |
|
|
67 |
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. |
|
|
68 |
|
|
|
69 |
The implementation of this function is modified from A. Karpathy's nanoGPT. |
|
|
70 |
|
|
|
71 |
Args: |
|
|
72 |
model: The model to use. |
|
|
73 |
prompt: Tensor of shape (T) with indices of the prompt sequence. |
|
|
74 |
max_returned_tokens: The maximum number of tokens to return (given plus generated). |
|
|
75 |
temperature: Scales the predicted logits by 1 / temperature. |
|
|
76 |
top_k: If specified, only sample among the tokens with the k highest probabilities. |
|
|
77 |
eos_id: If specified, stop generating any more token once the <eos> token is triggered. |
|
|
78 |
""" |
|
|
79 |
T = prompt.size(0) |
|
|
80 |
assert max_returned_tokens > T |
|
|
81 |
if model.max_seq_length < max_returned_tokens - 1: |
|
|
82 |
# rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a |
|
|
83 |
# data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do |
|
|
84 |
# not support it to avoid negatively impacting the overall speed |
|
|
85 |
raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}") |
|
|
86 |
|
|
|
87 |
device = prompt.device |
|
|
88 |
tokens = [prompt] |
|
|
89 |
input_pos = torch.tensor([T], device=device) |
|
|
90 |
token = next_token( |
|
|
91 |
model, torch.arange(0, T, device=device), prompt.view(1, -1), temperature=temperature, top_k=top_k |
|
|
92 |
).clone() |
|
|
93 |
tokens.append(token) |
|
|
94 |
for _ in range(2, max_returned_tokens - T + 1): |
|
|
95 |
token = next_token(model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k).clone() |
|
|
96 |
tokens.append(token) |
|
|
97 |
if token == eos_id: |
|
|
98 |
break |
|
|
99 |
input_pos = input_pos.add_(1) |
|
|
100 |
return torch.cat(tokens) |
|
|
101 |
|
|
|
102 |
|
|
|
103 |
def main( |
|
|
104 |
prompt: str = "What food do llamas eat?", |
|
|
105 |
*, |
|
|
106 |
num_samples: int = 1, |
|
|
107 |
max_new_tokens: int = 50, |
|
|
108 |
top_k: Optional[int] = 200, |
|
|
109 |
temperature: float = 0.8, |
|
|
110 |
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), |
|
|
111 |
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]] = None, |
|
|
112 |
strategy: str = "auto", |
|
|
113 |
devices: int = 1, |
|
|
114 |
precision: Optional[str] = None, |
|
|
115 |
compile: bool = False, |
|
|
116 |
) -> None: |
|
|
117 |
"""Generates text samples based on a pre-trained model and tokenizer. |
|
|
118 |
|
|
|
119 |
Args: |
|
|
120 |
prompt: The prompt string to use for generating the samples. |
|
|
121 |
num_samples: The number of text samples to generate. |
|
|
122 |
max_new_tokens: The number of generation steps to take. |
|
|
123 |
top_k: The number of top most probable tokens to consider in the sampling process. |
|
|
124 |
temperature: A value controlling the randomness of the sampling process. Higher values result in more random |
|
|
125 |
samples. |
|
|
126 |
checkpoint_dir: The checkpoint directory to load. |
|
|
127 |
quantize: Whether to quantize the model and using which method: |
|
|
128 |
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes |
|
|
129 |
- bnb.int8: 8-bit quantization from bitsandbytes |
|
|
130 |
- gptq.int4: 4-bit quantization from GPTQ |
|
|
131 |
for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md |
|
|
132 |
strategy: Indicates the Fabric strategy setting to use. |
|
|
133 |
devices: How many devices to use. |
|
|
134 |
precision: Indicates the Fabric precision setting to use. |
|
|
135 |
compile: Whether to compile the model. |
|
|
136 |
""" |
|
|
137 |
precision = precision or get_default_supported_precision(training=False) |
|
|
138 |
|
|
|
139 |
plugins = None |
|
|
140 |
if quantize is not None: |
|
|
141 |
if devices > 1: |
|
|
142 |
raise NotImplementedError( |
|
|
143 |
"Quantization is currently not supported for multi-GPU training. Please set devices=1 when using the" |
|
|
144 |
" --quantize flag." |
|
|
145 |
) |
|
|
146 |
if quantize.startswith("bnb."): |
|
|
147 |
if "mixed" in precision: |
|
|
148 |
raise ValueError("Quantization and mixed precision is not supported.") |
|
|
149 |
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision] |
|
|
150 |
plugins = BitsandbytesPrecision(quantize[4:], dtype) |
|
|
151 |
precision = None |
|
|
152 |
|
|
|
153 |
if strategy == "fsdp": |
|
|
154 |
strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False) |
|
|
155 |
|
|
|
156 |
fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy, plugins=plugins) |
|
|
157 |
fabric.launch() |
|
|
158 |
|
|
|
159 |
check_valid_checkpoint_dir(checkpoint_dir) |
|
|
160 |
|
|
|
161 |
config = Config.from_json(checkpoint_dir / "lit_config.json") |
|
|
162 |
|
|
|
163 |
if quantize == "gptq.int4": |
|
|
164 |
model_file = "lit_model_gptq.4bit.pth" |
|
|
165 |
if not (checkpoint_dir / model_file).is_file(): |
|
|
166 |
raise ValueError("Please run `python quantize/gptq.py` first") |
|
|
167 |
else: |
|
|
168 |
model_file = "lit_model.pth" |
|
|
169 |
checkpoint_path = checkpoint_dir / model_file |
|
|
170 |
|
|
|
171 |
tokenizer = Tokenizer(checkpoint_dir) |
|
|
172 |
encoded = tokenizer.encode(prompt, device=fabric.device) |
|
|
173 |
prompt_length = encoded.size(0) |
|
|
174 |
max_returned_tokens = prompt_length + max_new_tokens |
|
|
175 |
|
|
|
176 |
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) |
|
|
177 |
t0 = time.perf_counter() |
|
|
178 |
with fabric.init_module(empty_init=True), gptq_quantization(quantize == "gptq.int4"): |
|
|
179 |
model = GPT(config) |
|
|
180 |
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) |
|
|
181 |
with fabric.init_tensor(): |
|
|
182 |
# set the max_seq_length to limit the memory usage to what we need |
|
|
183 |
model.max_seq_length = max_returned_tokens |
|
|
184 |
# enable the kv cache |
|
|
185 |
model.set_kv_cache(batch_size=1) |
|
|
186 |
model.eval() |
|
|
187 |
|
|
|
188 |
if compile: |
|
|
189 |
torch._dynamo.config.automatic_dynamic_shapes = True |
|
|
190 |
torch._inductor.config.triton.unique_kernel_names = True |
|
|
191 |
torch._inductor.config.coordinate_descent_tuning = True |
|
|
192 |
global next_token |
|
|
193 |
next_token = torch.compile(next_token, mode="reduce-overhead") |
|
|
194 |
|
|
|
195 |
model = fabric.setup_module(model) |
|
|
196 |
|
|
|
197 |
t0 = time.perf_counter() |
|
|
198 |
load_checkpoint(fabric, model, checkpoint_path) |
|
|
199 |
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) |
|
|
200 |
|
|
|
201 |
L.seed_everything(1234) |
|
|
202 |
for i in range(num_samples): |
|
|
203 |
t0 = time.perf_counter() |
|
|
204 |
y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k) |
|
|
205 |
t = time.perf_counter() - t0 |
|
|
206 |
for block in model.transformer.h: |
|
|
207 |
block.attn.kv_cache.reset_parameters() |
|
|
208 |
fabric.print(tokenizer.decode(y)) |
|
|
209 |
tokens_generated = y.size(0) - prompt_length |
|
|
210 |
fabric.print( |
|
|
211 |
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr |
|
|
212 |
) |
|
|
213 |
if fabric.device.type == "cuda": |
|
|
214 |
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) |
|
|
215 |
|
|
|
216 |
|
|
|
217 |
if __name__ == "__main__": |
|
|
218 |
from jsonargparse import CLI |
|
|
219 |
|
|
|
220 |
torch.set_float32_matmul_precision("high") |
|
|
221 |
CLI(main) |