a b/lit_gpt/utils.py
1
"""Utility functions for training and inference."""
2
import math
3
import pickle
4
import sys
5
from contextlib import nullcontext
6
from io import BytesIO
7
from pathlib import Path
8
from typing import TYPE_CHECKING, ContextManager, Dict, List, Mapping, Optional, TypeVar, Union
9
10
import lightning as L
11
import torch
12
import torch.nn as nn
13
import torch.utils._device
14
from lightning.fabric.strategies import FSDPStrategy
15
from lightning.fabric.utilities.load import _lazy_load as lazy_load
16
from torch.serialization import normalize_storage_type
17
18
if TYPE_CHECKING:
19
    from lit_gpt import GPT
20
21
22
def find_multiple(n: int, k: int) -> int:
23
    assert k > 0
24
    if n % k == 0:
25
        return n
26
    return n + k - (n % k)
27
28
29
def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
30
    total = 0
31
    for p in module.parameters():
32
        if requires_grad is None or p.requires_grad == requires_grad:
33
            if hasattr(p, "quant_state"):
34
                # bitsandbytes 4bit layer support
35
                total += math.prod(p.quant_state[1])
36
            else:
37
                total += p.numel()
38
    return total
39
40
41
def gptq_quantization(enabled: bool = False) -> ContextManager:
42
    if not enabled:
43
        return nullcontext()
44
45
    from lightning.fabric.plugins.precision.utils import _ClassReplacementContextManager
46
47
    from quantize.gptq import ColBlockQuantizedLinear
48
49
    class QuantizedLinear(ColBlockQuantizedLinear):
50
        def __init__(self, *args, **kwargs):
51
            super().__init__(*args, bits=4, tile_cols=-1, **kwargs)
52
53
    return _ClassReplacementContextManager({"torch.nn.Linear": QuantizedLinear})
54
55
56
def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None:
57
    files = {
58
        "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(),
59
        "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(),
60
        "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or (
61
            checkpoint_dir / "tokenizer.model"
62
        ).is_file(),
63
        "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
64
    }
65
    if checkpoint_dir.is_dir():
66
        if all(files.values()):
67
            # we're good
68
            return
69
        problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}"
70
    else:
71
        problem = " is not a checkpoint directory"
72
73
    # list locally available checkpoints
74
    available = list(Path("checkpoints").glob("*/*"))
75
    if available:
76
        options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available])
77
        extra = f"\nYou have downloaded locally:{options}\n"
78
    else:
79
        extra = ""
80
81
    error_message = (
82
        f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
83
        "\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n"
84
        f"{extra}\nSee all download options by running:\n python scripts/download.py"
85
    )
86
    print(error_message, file=sys.stderr)
87
    raise SystemExit(1)
88
89
90
class SavingProxyForStorage:
91
    def __init__(self, obj, saver, protocol_version=5):
92
        self.protocol_version = protocol_version
93
        self.saver = saver
94
        if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
95
            raise TypeError(f"expected storage, not {type(obj)}")
96
97
        # this logic is taken from PyTorch 2.0+ torch/serialization.py
98
        if isinstance(obj, torch.storage.TypedStorage):
99
            # PT upstream wants to deprecate this eventually...
100
            storage = obj._untyped_storage
101
            storage_type_str = obj._pickle_storage_type()
102
            storage_type = getattr(torch, storage_type_str)
103
            storage_numel = obj._size()
104
        else:
105
            storage = obj
106
            storage_type = normalize_storage_type(type(obj))
107
            storage_numel = storage.nbytes()
108
109
        storage_key = saver._write_storage_and_return_key(storage)
110
        location = torch.serialization.location_tag(storage)
111
112
        self.storage_info = ("storage", storage_type, storage_key, location, storage_numel)
113
114
    def __reduce_ex__(self, protocol_version):
115
        assert False, "this should be handled with out of band"
116
117
118
class SavingProxyForTensor:
119
    def __init__(self, tensor, saver, protocol_version=5):
120
        self.protocol_version = protocol_version
121
        self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version)
122
        if reduce_args[0] == torch._utils._rebuild_tensor_v2:
123
            # for Tensors with Python attributes
124
            (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
125
            assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
126
            storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)
127
            self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
128
        else:
129
            (storage, *other_reduce_args) = reduce_args
130
            assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
131
            storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)
132
            self.reduce_args = (storage_proxy, *other_reduce_args)
133
134
    def __reduce_ex__(self, protocol_version):
135
        if protocol_version != self.protocol_version:
136
            raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}")
137
        return self.reduce_ret_fn, self.reduce_args
138
139
140
class IncrementalPyTorchPickler(pickle.Pickler):
141
    def __init__(self, saver, *args, **kwargs):
142
        super().__init__(*args, **kwargs)
143
        self.storage_dtypes = {}
144
        self.saver = saver
145
        self.id_map = {}
146
147
    # this logic is taken from PyTorch 2.0+ torch/serialization.py
148
    def persistent_id(self, obj):
149
        # FIXME: the docs say that persistent_id should only return a string
150
        # but torch store returns tuples. This works only in the binary protocol
151
        # see
152
        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
153
        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
154
        if isinstance(obj, SavingProxyForStorage):
155
            return obj.storage_info
156
157
        if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
158
            if isinstance(obj, torch.storage.TypedStorage):
159
                # TODO: Once we decide to break serialization FC, this case
160
                # can be deleted
161
                storage = obj._untyped_storage
162
                storage_dtype = obj.dtype
163
                storage_type_str = obj._pickle_storage_type()
164
                storage_type = getattr(torch, storage_type_str)
165
                storage_numel = obj._size()
166
167
            else:
168
                storage = obj
169
                storage_dtype = torch.uint8
170
                storage_type = normalize_storage_type(type(obj))
171
                storage_numel = storage.nbytes()
172
173
            # If storage is allocated, ensure that any other saved storages
174
            # pointing to the same data all have the same dtype. If storage is
175
            # not allocated, don't perform this check
176
            if storage.data_ptr() != 0:
177
                if storage.data_ptr() in self.storage_dtypes:
178
                    if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
179
                        raise RuntimeError(
180
                            "Cannot save multiple tensors or storages that view the same data as different types"
181
                        )
182
                else:
183
                    self.storage_dtypes[storage.data_ptr()] = storage_dtype
184
185
            storage_key = self.id_map.get(storage._cdata)
186
            if storage_key is None:
187
                storage_key = self.saver._write_storage_and_return_key(storage)
188
                self.id_map[storage._cdata] = storage_key
189
            location = torch.serialization.location_tag(storage)
190
191
            return ("storage", storage_type, storage_key, location, storage_numel)
192
193
        return None
194
195
196
class incremental_save:
197
    def __init__(self, name):
198
        self.name = name
199
        self.zipfile = torch._C.PyTorchFileWriter(str(name))
200
        self.has_saved = False
201
        self.next_key = 0
202
203
    def __enter__(self):
204
        return self
205
206
    def store_early(self, tensor):
207
        if isinstance(tensor, torch.Tensor):
208
            return SavingProxyForTensor(tensor, self)
209
        raise TypeError(f"can only store tensors early, not {type(tensor)}")
210
211
    def save(self, obj):
212
        if self.has_saved:
213
            raise RuntimeError("have already saved")
214
        # Write the pickle data for `obj`
215
        data_buf = BytesIO()
216
        pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
217
        pickler.dump(obj)
218
        data_value = data_buf.getvalue()
219
        self.zipfile.write_record("data.pkl", data_value, len(data_value))
220
        self.has_saved = True
221
222
    def _write_storage_and_return_key(self, storage):
223
        if self.has_saved:
224
            raise RuntimeError("have already saved")
225
        key = self.next_key
226
        self.next_key += 1
227
        name = f"data/{key}"
228
        if storage.device.type != "cpu":
229
            storage = storage.cpu()
230
        num_bytes = storage.nbytes()
231
        self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
232
        return key
233
234
    def __exit__(self, type, value, traceback):
235
        self.zipfile.write_end_of_file()
236
237
238
T = TypeVar("T")
239
240
241
def chunked_cross_entropy(
242
    logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128
243
) -> torch.Tensor:
244
    # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
245
    # the memory usage in fine-tuning settings with low number of parameters.
246
    # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing
247
    # the memory spike's magnitude
248
249
    # lm_head was chunked (we are fine-tuning)
250
    if isinstance(logits, list):
251
        # don't want to chunk cross entropy
252
        if chunk_size == 0:
253
            logits = torch.cat(logits, dim=1)
254
            logits = logits.reshape(-1, logits.size(-1))
255
            targets = targets.reshape(-1)
256
            return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
257
258
        # chunk cross entropy
259
        logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits]
260
        target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)]
261
        loss_chunks = [
262
            torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none")
263
            for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
264
        ]
265
        non_masked_elems = (targets != -1).sum()
266
        mean_loss = torch.cat(loss_chunks).sum() / max(1, non_masked_elems)
267
        return mean_loss
268
269
    # no chunking at all
270
    logits = logits.reshape(-1, logits.size(-1))
271
    targets = targets.reshape(-1)
272
    if chunk_size == 0:
273
        return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
274
275
    # lm_head wasn't chunked, chunk cross entropy
276
    logit_chunks = logits.split(chunk_size)
277
    target_chunks = targets.split(chunk_size)
278
    loss_chunks = [
279
        torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none")
280
        for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
281
    ]
282
    non_masked_elems = (targets != -1).sum()
283
    mean_loss = torch.cat(loss_chunks).sum() / max(1, non_masked_elems)
284
    return mean_loss
285
286
287
def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
288
    for checkpoint_name, attribute_name in mapping.items():
289
        full_checkpoint_name = prefix + checkpoint_name
290
        if full_checkpoint_name in state_dict:
291
            full_attribute_name = prefix + attribute_name
292
            state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)
293
    return state_dict
294
295
296
def get_default_supported_precision(training: bool) -> str:
297
    """Return default precision that is supported by the hardware: either `bf16` or `16`.
298
299
    Args:
300
        training: `-mixed` or `-true` version of the precision to use
301
302
    Returns:
303
        default precision that is suitable for the task and is supported by the hardware
304
    """
305
    from lightning.fabric.accelerators import MPSAccelerator
306
307
    if MPSAccelerator.is_available() or (torch.cuda.is_available() and not torch.cuda.is_bf16_supported()):
308
        return "16-mixed" if training else "16-true"
309
    return "bf16-mixed" if training else "bf16-true"
310
311
312
def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None:
313
    if isinstance(fabric.strategy, FSDPStrategy):
314
        fabric.load_raw(checkpoint_path, model, strict=strict)
315
    else:
316
        state_dict = lazy_load(checkpoint_path)
317
        state_dict = state_dict.get("model", state_dict)
318
        model.load_state_dict(state_dict, strict=strict)
319
320
321
def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int:
322
    flops_per_token = 2 * n_params  # each parameter is used for a MAC (2 FLOPS) per network operation
323
    # this assumes that all samples have a fixed length equal to the block size
324
    # which is most likely false during finetuning
325
    flops_per_seq = flops_per_token * max_seq_length
326
    attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
327
    return flops_per_seq + attn_flops_per_seq
328
329
330
def estimate_flops(model: "GPT", training: bool) -> int:
331
    """Measures estimated FLOPs for MFU.
332
333
    Refs:
334
        * https://ar5iv.labs.arxiv.org/html/2205.05198#A1
335
        * https://ar5iv.labs.arxiv.org/html/2204.02311#A2
336
    """
337
    # using all parameters for this is a naive over estimation because not all model parameters actually contribute to
338
    # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
339
    # (~10%) compared to the measured FLOPs, making those lower but more realistic.
340
    # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
341
    n_trainable_params = num_parameters(model, requires_grad=True)
342
    trainable_flops = flops_per_param(
343
        model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params
344
    )
345
    # forward + backward + gradients (assumes no gradient accumulation)
346
    ops_per_step = 3 if training else 1
347
    n_frozen_params = num_parameters(model, requires_grad=False)
348
    frozen_flops = flops_per_param(model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params)
349
    # forward + backward
350
    frozen_ops_per_step = 2 if training else 1
351
    return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops