|
a |
|
b/lit_gpt/utils.py |
|
|
1 |
"""Utility functions for training and inference.""" |
|
|
2 |
import math |
|
|
3 |
import pickle |
|
|
4 |
import sys |
|
|
5 |
from contextlib import nullcontext |
|
|
6 |
from io import BytesIO |
|
|
7 |
from pathlib import Path |
|
|
8 |
from typing import TYPE_CHECKING, ContextManager, Dict, List, Mapping, Optional, TypeVar, Union |
|
|
9 |
|
|
|
10 |
import lightning as L |
|
|
11 |
import torch |
|
|
12 |
import torch.nn as nn |
|
|
13 |
import torch.utils._device |
|
|
14 |
from lightning.fabric.strategies import FSDPStrategy |
|
|
15 |
from lightning.fabric.utilities.load import _lazy_load as lazy_load |
|
|
16 |
from torch.serialization import normalize_storage_type |
|
|
17 |
|
|
|
18 |
if TYPE_CHECKING: |
|
|
19 |
from lit_gpt import GPT |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
def find_multiple(n: int, k: int) -> int: |
|
|
23 |
assert k > 0 |
|
|
24 |
if n % k == 0: |
|
|
25 |
return n |
|
|
26 |
return n + k - (n % k) |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int: |
|
|
30 |
total = 0 |
|
|
31 |
for p in module.parameters(): |
|
|
32 |
if requires_grad is None or p.requires_grad == requires_grad: |
|
|
33 |
if hasattr(p, "quant_state"): |
|
|
34 |
# bitsandbytes 4bit layer support |
|
|
35 |
total += math.prod(p.quant_state[1]) |
|
|
36 |
else: |
|
|
37 |
total += p.numel() |
|
|
38 |
return total |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
def gptq_quantization(enabled: bool = False) -> ContextManager: |
|
|
42 |
if not enabled: |
|
|
43 |
return nullcontext() |
|
|
44 |
|
|
|
45 |
from lightning.fabric.plugins.precision.utils import _ClassReplacementContextManager |
|
|
46 |
|
|
|
47 |
from quantize.gptq import ColBlockQuantizedLinear |
|
|
48 |
|
|
|
49 |
class QuantizedLinear(ColBlockQuantizedLinear): |
|
|
50 |
def __init__(self, *args, **kwargs): |
|
|
51 |
super().__init__(*args, bits=4, tile_cols=-1, **kwargs) |
|
|
52 |
|
|
|
53 |
return _ClassReplacementContextManager({"torch.nn.Linear": QuantizedLinear}) |
|
|
54 |
|
|
|
55 |
|
|
|
56 |
def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None: |
|
|
57 |
files = { |
|
|
58 |
"lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(), |
|
|
59 |
"lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), |
|
|
60 |
"tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or ( |
|
|
61 |
checkpoint_dir / "tokenizer.model" |
|
|
62 |
).is_file(), |
|
|
63 |
"tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(), |
|
|
64 |
} |
|
|
65 |
if checkpoint_dir.is_dir(): |
|
|
66 |
if all(files.values()): |
|
|
67 |
# we're good |
|
|
68 |
return |
|
|
69 |
problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}" |
|
|
70 |
else: |
|
|
71 |
problem = " is not a checkpoint directory" |
|
|
72 |
|
|
|
73 |
# list locally available checkpoints |
|
|
74 |
available = list(Path("checkpoints").glob("*/*")) |
|
|
75 |
if available: |
|
|
76 |
options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available]) |
|
|
77 |
extra = f"\nYou have downloaded locally:{options}\n" |
|
|
78 |
else: |
|
|
79 |
extra = "" |
|
|
80 |
|
|
|
81 |
error_message = ( |
|
|
82 |
f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}." |
|
|
83 |
"\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n" |
|
|
84 |
f"{extra}\nSee all download options by running:\n python scripts/download.py" |
|
|
85 |
) |
|
|
86 |
print(error_message, file=sys.stderr) |
|
|
87 |
raise SystemExit(1) |
|
|
88 |
|
|
|
89 |
|
|
|
90 |
class SavingProxyForStorage: |
|
|
91 |
def __init__(self, obj, saver, protocol_version=5): |
|
|
92 |
self.protocol_version = protocol_version |
|
|
93 |
self.saver = saver |
|
|
94 |
if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)): |
|
|
95 |
raise TypeError(f"expected storage, not {type(obj)}") |
|
|
96 |
|
|
|
97 |
# this logic is taken from PyTorch 2.0+ torch/serialization.py |
|
|
98 |
if isinstance(obj, torch.storage.TypedStorage): |
|
|
99 |
# PT upstream wants to deprecate this eventually... |
|
|
100 |
storage = obj._untyped_storage |
|
|
101 |
storage_type_str = obj._pickle_storage_type() |
|
|
102 |
storage_type = getattr(torch, storage_type_str) |
|
|
103 |
storage_numel = obj._size() |
|
|
104 |
else: |
|
|
105 |
storage = obj |
|
|
106 |
storage_type = normalize_storage_type(type(obj)) |
|
|
107 |
storage_numel = storage.nbytes() |
|
|
108 |
|
|
|
109 |
storage_key = saver._write_storage_and_return_key(storage) |
|
|
110 |
location = torch.serialization.location_tag(storage) |
|
|
111 |
|
|
|
112 |
self.storage_info = ("storage", storage_type, storage_key, location, storage_numel) |
|
|
113 |
|
|
|
114 |
def __reduce_ex__(self, protocol_version): |
|
|
115 |
assert False, "this should be handled with out of band" |
|
|
116 |
|
|
|
117 |
|
|
|
118 |
class SavingProxyForTensor: |
|
|
119 |
def __init__(self, tensor, saver, protocol_version=5): |
|
|
120 |
self.protocol_version = protocol_version |
|
|
121 |
self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version) |
|
|
122 |
if reduce_args[0] == torch._utils._rebuild_tensor_v2: |
|
|
123 |
# for Tensors with Python attributes |
|
|
124 |
(a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args |
|
|
125 |
assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates" |
|
|
126 |
storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version) |
|
|
127 |
self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args) |
|
|
128 |
else: |
|
|
129 |
(storage, *other_reduce_args) = reduce_args |
|
|
130 |
assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates" |
|
|
131 |
storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version) |
|
|
132 |
self.reduce_args = (storage_proxy, *other_reduce_args) |
|
|
133 |
|
|
|
134 |
def __reduce_ex__(self, protocol_version): |
|
|
135 |
if protocol_version != self.protocol_version: |
|
|
136 |
raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}") |
|
|
137 |
return self.reduce_ret_fn, self.reduce_args |
|
|
138 |
|
|
|
139 |
|
|
|
140 |
class IncrementalPyTorchPickler(pickle.Pickler): |
|
|
141 |
def __init__(self, saver, *args, **kwargs): |
|
|
142 |
super().__init__(*args, **kwargs) |
|
|
143 |
self.storage_dtypes = {} |
|
|
144 |
self.saver = saver |
|
|
145 |
self.id_map = {} |
|
|
146 |
|
|
|
147 |
# this logic is taken from PyTorch 2.0+ torch/serialization.py |
|
|
148 |
def persistent_id(self, obj): |
|
|
149 |
# FIXME: the docs say that persistent_id should only return a string |
|
|
150 |
# but torch store returns tuples. This works only in the binary protocol |
|
|
151 |
# see |
|
|
152 |
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects |
|
|
153 |
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 |
|
|
154 |
if isinstance(obj, SavingProxyForStorage): |
|
|
155 |
return obj.storage_info |
|
|
156 |
|
|
|
157 |
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): |
|
|
158 |
if isinstance(obj, torch.storage.TypedStorage): |
|
|
159 |
# TODO: Once we decide to break serialization FC, this case |
|
|
160 |
# can be deleted |
|
|
161 |
storage = obj._untyped_storage |
|
|
162 |
storage_dtype = obj.dtype |
|
|
163 |
storage_type_str = obj._pickle_storage_type() |
|
|
164 |
storage_type = getattr(torch, storage_type_str) |
|
|
165 |
storage_numel = obj._size() |
|
|
166 |
|
|
|
167 |
else: |
|
|
168 |
storage = obj |
|
|
169 |
storage_dtype = torch.uint8 |
|
|
170 |
storage_type = normalize_storage_type(type(obj)) |
|
|
171 |
storage_numel = storage.nbytes() |
|
|
172 |
|
|
|
173 |
# If storage is allocated, ensure that any other saved storages |
|
|
174 |
# pointing to the same data all have the same dtype. If storage is |
|
|
175 |
# not allocated, don't perform this check |
|
|
176 |
if storage.data_ptr() != 0: |
|
|
177 |
if storage.data_ptr() in self.storage_dtypes: |
|
|
178 |
if storage_dtype != self.storage_dtypes[storage.data_ptr()]: |
|
|
179 |
raise RuntimeError( |
|
|
180 |
"Cannot save multiple tensors or storages that view the same data as different types" |
|
|
181 |
) |
|
|
182 |
else: |
|
|
183 |
self.storage_dtypes[storage.data_ptr()] = storage_dtype |
|
|
184 |
|
|
|
185 |
storage_key = self.id_map.get(storage._cdata) |
|
|
186 |
if storage_key is None: |
|
|
187 |
storage_key = self.saver._write_storage_and_return_key(storage) |
|
|
188 |
self.id_map[storage._cdata] = storage_key |
|
|
189 |
location = torch.serialization.location_tag(storage) |
|
|
190 |
|
|
|
191 |
return ("storage", storage_type, storage_key, location, storage_numel) |
|
|
192 |
|
|
|
193 |
return None |
|
|
194 |
|
|
|
195 |
|
|
|
196 |
class incremental_save: |
|
|
197 |
def __init__(self, name): |
|
|
198 |
self.name = name |
|
|
199 |
self.zipfile = torch._C.PyTorchFileWriter(str(name)) |
|
|
200 |
self.has_saved = False |
|
|
201 |
self.next_key = 0 |
|
|
202 |
|
|
|
203 |
def __enter__(self): |
|
|
204 |
return self |
|
|
205 |
|
|
|
206 |
def store_early(self, tensor): |
|
|
207 |
if isinstance(tensor, torch.Tensor): |
|
|
208 |
return SavingProxyForTensor(tensor, self) |
|
|
209 |
raise TypeError(f"can only store tensors early, not {type(tensor)}") |
|
|
210 |
|
|
|
211 |
def save(self, obj): |
|
|
212 |
if self.has_saved: |
|
|
213 |
raise RuntimeError("have already saved") |
|
|
214 |
# Write the pickle data for `obj` |
|
|
215 |
data_buf = BytesIO() |
|
|
216 |
pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5) |
|
|
217 |
pickler.dump(obj) |
|
|
218 |
data_value = data_buf.getvalue() |
|
|
219 |
self.zipfile.write_record("data.pkl", data_value, len(data_value)) |
|
|
220 |
self.has_saved = True |
|
|
221 |
|
|
|
222 |
def _write_storage_and_return_key(self, storage): |
|
|
223 |
if self.has_saved: |
|
|
224 |
raise RuntimeError("have already saved") |
|
|
225 |
key = self.next_key |
|
|
226 |
self.next_key += 1 |
|
|
227 |
name = f"data/{key}" |
|
|
228 |
if storage.device.type != "cpu": |
|
|
229 |
storage = storage.cpu() |
|
|
230 |
num_bytes = storage.nbytes() |
|
|
231 |
self.zipfile.write_record(name, storage.data_ptr(), num_bytes) |
|
|
232 |
return key |
|
|
233 |
|
|
|
234 |
def __exit__(self, type, value, traceback): |
|
|
235 |
self.zipfile.write_end_of_file() |
|
|
236 |
|
|
|
237 |
|
|
|
238 |
T = TypeVar("T") |
|
|
239 |
|
|
|
240 |
|
|
|
241 |
def chunked_cross_entropy( |
|
|
242 |
logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128 |
|
|
243 |
) -> torch.Tensor: |
|
|
244 |
# with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate |
|
|
245 |
# the memory usage in fine-tuning settings with low number of parameters. |
|
|
246 |
# as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing |
|
|
247 |
# the memory spike's magnitude |
|
|
248 |
|
|
|
249 |
# lm_head was chunked (we are fine-tuning) |
|
|
250 |
if isinstance(logits, list): |
|
|
251 |
# don't want to chunk cross entropy |
|
|
252 |
if chunk_size == 0: |
|
|
253 |
logits = torch.cat(logits, dim=1) |
|
|
254 |
logits = logits.reshape(-1, logits.size(-1)) |
|
|
255 |
targets = targets.reshape(-1) |
|
|
256 |
return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) |
|
|
257 |
|
|
|
258 |
# chunk cross entropy |
|
|
259 |
logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits] |
|
|
260 |
target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)] |
|
|
261 |
loss_chunks = [ |
|
|
262 |
torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") |
|
|
263 |
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) |
|
|
264 |
] |
|
|
265 |
non_masked_elems = (targets != -1).sum() |
|
|
266 |
mean_loss = torch.cat(loss_chunks).sum() / max(1, non_masked_elems) |
|
|
267 |
return mean_loss |
|
|
268 |
|
|
|
269 |
# no chunking at all |
|
|
270 |
logits = logits.reshape(-1, logits.size(-1)) |
|
|
271 |
targets = targets.reshape(-1) |
|
|
272 |
if chunk_size == 0: |
|
|
273 |
return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) |
|
|
274 |
|
|
|
275 |
# lm_head wasn't chunked, chunk cross entropy |
|
|
276 |
logit_chunks = logits.split(chunk_size) |
|
|
277 |
target_chunks = targets.split(chunk_size) |
|
|
278 |
loss_chunks = [ |
|
|
279 |
torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") |
|
|
280 |
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) |
|
|
281 |
] |
|
|
282 |
non_masked_elems = (targets != -1).sum() |
|
|
283 |
mean_loss = torch.cat(loss_chunks).sum() / max(1, non_masked_elems) |
|
|
284 |
return mean_loss |
|
|
285 |
|
|
|
286 |
|
|
|
287 |
def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict: |
|
|
288 |
for checkpoint_name, attribute_name in mapping.items(): |
|
|
289 |
full_checkpoint_name = prefix + checkpoint_name |
|
|
290 |
if full_checkpoint_name in state_dict: |
|
|
291 |
full_attribute_name = prefix + attribute_name |
|
|
292 |
state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name) |
|
|
293 |
return state_dict |
|
|
294 |
|
|
|
295 |
|
|
|
296 |
def get_default_supported_precision(training: bool) -> str: |
|
|
297 |
"""Return default precision that is supported by the hardware: either `bf16` or `16`. |
|
|
298 |
|
|
|
299 |
Args: |
|
|
300 |
training: `-mixed` or `-true` version of the precision to use |
|
|
301 |
|
|
|
302 |
Returns: |
|
|
303 |
default precision that is suitable for the task and is supported by the hardware |
|
|
304 |
""" |
|
|
305 |
from lightning.fabric.accelerators import MPSAccelerator |
|
|
306 |
|
|
|
307 |
if MPSAccelerator.is_available() or (torch.cuda.is_available() and not torch.cuda.is_bf16_supported()): |
|
|
308 |
return "16-mixed" if training else "16-true" |
|
|
309 |
return "bf16-mixed" if training else "bf16-true" |
|
|
310 |
|
|
|
311 |
|
|
|
312 |
def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None: |
|
|
313 |
if isinstance(fabric.strategy, FSDPStrategy): |
|
|
314 |
fabric.load_raw(checkpoint_path, model, strict=strict) |
|
|
315 |
else: |
|
|
316 |
state_dict = lazy_load(checkpoint_path) |
|
|
317 |
state_dict = state_dict.get("model", state_dict) |
|
|
318 |
model.load_state_dict(state_dict, strict=strict) |
|
|
319 |
|
|
|
320 |
|
|
|
321 |
def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int: |
|
|
322 |
flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation |
|
|
323 |
# this assumes that all samples have a fixed length equal to the block size |
|
|
324 |
# which is most likely false during finetuning |
|
|
325 |
flops_per_seq = flops_per_token * max_seq_length |
|
|
326 |
attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2)) |
|
|
327 |
return flops_per_seq + attn_flops_per_seq |
|
|
328 |
|
|
|
329 |
|
|
|
330 |
def estimate_flops(model: "GPT", training: bool) -> int: |
|
|
331 |
"""Measures estimated FLOPs for MFU. |
|
|
332 |
|
|
|
333 |
Refs: |
|
|
334 |
* https://ar5iv.labs.arxiv.org/html/2205.05198#A1 |
|
|
335 |
* https://ar5iv.labs.arxiv.org/html/2204.02311#A2 |
|
|
336 |
""" |
|
|
337 |
# using all parameters for this is a naive over estimation because not all model parameters actually contribute to |
|
|
338 |
# this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage |
|
|
339 |
# (~10%) compared to the measured FLOPs, making those lower but more realistic. |
|
|
340 |
# For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper. |
|
|
341 |
n_trainable_params = num_parameters(model, requires_grad=True) |
|
|
342 |
trainable_flops = flops_per_param( |
|
|
343 |
model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params |
|
|
344 |
) |
|
|
345 |
# forward + backward + gradients (assumes no gradient accumulation) |
|
|
346 |
ops_per_step = 3 if training else 1 |
|
|
347 |
n_frozen_params = num_parameters(model, requires_grad=False) |
|
|
348 |
frozen_flops = flops_per_param(model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params) |
|
|
349 |
# forward + backward |
|
|
350 |
frozen_ops_per_step = 2 if training else 1 |
|
|
351 |
return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops |