a b/lit_gpt/model.py
1
"""Full definition of a GPT NeoX Language Model, all of it in this single file.
2
3
Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
4
https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
5
"""
6
import math
7
from typing import Any, Optional, Tuple
8
9
import torch
10
import torch.nn as nn
11
from typing_extensions import Self
12
13
from lit_gpt.config import Config
14
15
16
class GPT(nn.Module):
17
    def __init__(self, config: Config) -> None:
18
        super().__init__()
19
        assert config.padded_vocab_size is not None
20
        self.config = config
21
22
        self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
23
        self.transformer = nn.ModuleDict(
24
            dict(
25
                wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
26
                h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
27
                ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
28
            )
29
        )
30
        self.max_seq_length = self.config.block_size
31
        self.mask_cache: Optional[torch.Tensor] = None
32
33
    @property
34
    def max_seq_length(self) -> int:
35
        return self._max_seq_length
36
37
    @max_seq_length.setter
38
    def max_seq_length(self, value: int) -> None:
39
        """
40
        When doing inference, the sequences used might be shorter than the model's context length.
41
        This allows setting a smaller number to avoid allocating unused memory
42
        """
43
        if value > self.config.block_size:
44
            raise ValueError(f"Cannot attend to {value}, block size is only {self.config.block_size}")
45
        self._max_seq_length = value
46
        if not hasattr(self, "cos"):
47
            # first call
48
            cos, sin = self.rope_cache()
49
            self.register_buffer("cos", cos, persistent=False)
50
            self.register_buffer("sin", sin, persistent=False)
51
        elif value != self.cos.size(0):
52
            # override
53
            self.cos, self.sin = self.rope_cache(device=self.cos.device)
54
        # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
55
        # if the kv cache is expected
56
57
    def reset_parameters(self) -> None:
58
        # Trigger resetting the rope-cache
59
        self.max_seq_length = self.config.block_size
60
61
    def _init_weights(self, module: nn.Module) -> None:
62
        """Meant to be used with `gpt.apply(gpt._init_weights)`."""
63
        if isinstance(module, nn.Linear):
64
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
65
            if module.bias is not None:
66
                torch.nn.init.zeros_(module.bias)
67
        elif isinstance(module, nn.Embedding):
68
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
69
70
    def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
71
        T = idx.size(1)
72
        if self.max_seq_length < T:
73
            raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")
74
75
        if input_pos is not None:  # use the kv cache
76
            cos = self.cos.index_select(0, input_pos)
77
            sin = self.sin.index_select(0, input_pos)
78
            if self.mask_cache is None:
79
                raise TypeError("You need to call `gpt.set_kv_cache()`")
80
            mask = self.mask_cache.index_select(2, input_pos)
81
        else:
82
            cos = self.cos[:T]
83
            sin = self.sin[:T]
84
            mask = None
85
86
        x = self.transformer.wte(idx)  # token embeddings of shape (b, t, n_embd)
87
        for block in self.transformer.h:
88
            x = block(x, cos, sin, mask, input_pos)
89
        x = self.transformer.ln_f(x)
90
        return self.lm_head(x)  # (b, t, vocab_size)
91
92
    @classmethod
93
    def from_name(cls, name: str, **kwargs: Any) -> Self:
94
        return cls(Config.from_name(name, **kwargs))
95
96
    def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]:
97
        return build_rope_cache(
98
            seq_len=self.max_seq_length,
99
            n_elem=self.config.rope_n_elem,
100
            device=device,
101
            condense_ratio=self.config.rope_condense_ratio,
102
            base=self.config.rope_base,
103
        )
104
105
    def set_kv_cache(
106
        self,
107
        batch_size: int,
108
        rope_cache_length: Optional[int] = None,
109
        device: Optional[torch.device] = None,
110
        dtype: Optional[torch.dtype] = None,
111
    ) -> None:
112
        if rope_cache_length is None:
113
            rope_cache_length = self.cos.size(-1)
114
        max_seq_length = self.max_seq_length
115
116
        # initialize the kv cache for all blocks
117
        for block in self.transformer.h:
118
            block.attn.kv_cache = block.attn.build_kv_cache(
119
                batch_size, max_seq_length, rope_cache_length, device, dtype
120
            )
121
122
        if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
123
            # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask
124
            # for the kv-cache support (only during inference), we only create it in that situation
125
            # this will be resolved by https://github.com/pytorch/pytorch/issues/96099
126
            ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
127
            self.mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0)
128
129
    def clear_kv_cache(self) -> None:
130
        self.mask_cache = None
131
        for block in self.transformer.h:
132
            block.attn.kv_cache = None
133
134
135
class Block(nn.Module):
136
    def __init__(self, config: Config) -> None:
137
        super().__init__()
138
        self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
139
        self.attn = CausalSelfAttention(config)
140
        self.norm_2 = None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps)
141
        self.mlp = config.mlp_class(config)
142
143
        self.config = config
144
145
    def forward(
146
        self,
147
        x: torch.Tensor,
148
        cos: torch.Tensor,
149
        sin: torch.Tensor,
150
        mask: Optional[torch.Tensor] = None,
151
        input_pos: Optional[torch.Tensor] = None,
152
    ) -> torch.Tensor:
153
        n_1 = self.norm_1(x)
154
        h = self.attn(n_1, cos, sin, mask, input_pos)
155
        if self.config.parallel_residual:
156
            n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
157
            x = self.mlp(n_2) + h + x
158
        else:
159
            if self.config.shared_attention_norm:
160
                raise NotImplementedError(
161
                    "No checkpoint amongst the ones we support uses this configuration"
162
                    " (non-parallel residual and shared attention norm)."
163
                )
164
            x = h + x
165
            x = self.mlp(self.norm_2(x)) + x
166
        return x
167
168
169
class CausalSelfAttention(nn.Module):
170
    def __init__(self, config: Config) -> None:
171
        super().__init__()
172
        shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
173
        # key, query, value projections for all heads, but in a batch
174
        self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
175
        # output projection
176
        self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
177
        # disabled by default
178
        self.kv_cache: Optional[KVCache] = None
179
180
        self.config = config
181
182
    def forward(
183
        self,
184
        x: torch.Tensor,
185
        cos: torch.Tensor,
186
        sin: torch.Tensor,
187
        mask: Optional[torch.Tensor] = None,
188
        input_pos: Optional[torch.Tensor] = None,
189
    ) -> torch.Tensor:
190
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)
191
192
        qkv = self.attn(x)
193
194
        # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
195
        q_per_kv = self.config.n_head // self.config.n_query_groups
196
        total_qkv = q_per_kv + 2  # each group has 1+ queries, 1 key, and 1 value
197
        qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size)
198
        qkv = qkv.permute(0, 2, 3, 1, 4)  # (B, n_query_groups, total_qkv, T, hs)
199
200
        # split batched computation into three
201
        q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
202
203
        # maybe repeat k and v if for the non multi-head attention cases
204
        # training: flash attention requires it
205
        # inference: multi-query would require a full kv cache so avoid it to limit its memory usage
206
        if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1):
207
            k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
208
            v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
209
210
        q = q.reshape(B, -1, T, self.config.head_size)  # (B, nh_q, T, hs)
211
        k = k.reshape(B, -1, T, self.config.head_size)  # (B, nh_k, T, hs)
212
        v = v.reshape(B, -1, T, self.config.head_size)  # (B, nh_v, T, hs)
213
214
        q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
215
        k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
216
        q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
217
        k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
218
219
        if input_pos is not None:
220
            if not isinstance(self.kv_cache, KVCache):
221
                raise TypeError("You need to call `gpt.set_kv_cache()`")
222
            k, v = self.kv_cache(input_pos, k, v)
223
224
        y = self.scaled_dot_product_attention(q, k, v, mask)
225
226
        y = y.reshape(B, T, C)  # re-assemble all head outputs side by side
227
228
        # output projection
229
        return self.proj(y)
230
231
    def scaled_dot_product_attention(
232
        self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
233
    ) -> torch.Tensor:
234
        scale = 1.0 / math.sqrt(self.config.head_size)
235
        y = torch.nn.functional.scaled_dot_product_attention(
236
            q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
237
        )
238
        return y.transpose(1, 2)
239
240
    def build_kv_cache(
241
        self,
242
        batch_size: int,
243
        max_seq_length: int,
244
        rope_cache_length: Optional[int] = None,
245
        device: Optional[torch.device] = None,
246
        dtype: Optional[torch.dtype] = None,
247
    ) -> "KVCache":
248
        heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
249
        v_shape = (batch_size, heads, max_seq_length, self.config.head_size)
250
        if rope_cache_length is None:
251
            if self.config.rotary_percentage != 1.0:
252
                raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value")
253
            k_shape = v_shape
254
        else:
255
            k_shape = (
256
                batch_size,
257
                heads,
258
                max_seq_length,
259
                rope_cache_length + self.config.head_size - self.config.rope_n_elem,
260
            )
261
        return KVCache(k_shape, v_shape, device=device, dtype=dtype)
262
263
264
class GptNeoxMLP(nn.Module):
265
    def __init__(self, config: Config) -> None:
266
        super().__init__()
267
        self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
268
        self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
269
270
        self.config = config
271
272
    def forward(self, x: torch.Tensor) -> torch.Tensor:
273
        x = self.fc(x)
274
        x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate)
275
        return self.proj(x)
276
277
278
class LLaMAMLP(nn.Module):
279
    def __init__(self, config: Config) -> None:
280
        super().__init__()
281
        self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
282
        self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
283
        self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
284
285
    def forward(self, x: torch.Tensor) -> torch.Tensor:
286
        x_fc_1 = self.fc_1(x)
287
        x_fc_2 = self.fc_2(x)
288
        x = torch.nn.functional.silu(x_fc_1) * x_fc_2
289
        return self.proj(x)
290
291
292
def build_rope_cache(
293
    seq_len: int, n_elem: int, device: Optional[torch.device] = None, base: int = 10000, condense_ratio: int = 1
294
) -> Tuple[torch.Tensor, torch.Tensor]:
295
    """Enhanced Transformer with Rotary Position Embedding.
296
297
    Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
298
    transformers/rope/__init__.py. MIT License:
299
    https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
300
    """
301
    # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
302
    theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
303
304
    # Create position indexes `[0, 1, ..., seq_len - 1]`
305
    seq_idx = torch.arange(seq_len, device=device) / condense_ratio
306
307
    # Calculate the product of position index and $\theta_i$
308
    idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
309
310
    return torch.cos(idx_theta), torch.sin(idx_theta)
311
312
313
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
314
    head_size = x.size(-1)
315
    x1 = x[..., : head_size // 2]  # (B, nh, T, hs/2)
316
    x2 = x[..., head_size // 2 :]  # (B, nh, T, hs/2)
317
    rotated = torch.cat((-x2, x1), dim=-1)  # (B, nh, T, hs)
318
    roped = (x * cos) + (rotated * sin)
319
    return roped.type_as(x)
320
321
322
class KVCache(nn.Module):
323
    def __init__(
324
        self,
325
        k_shape: Tuple[int, int, int, int],
326
        v_shape: Tuple[int, int, int, int],
327
        device: Optional[torch.device] = None,
328
        dtype: Optional[torch.dtype] = None,
329
    ) -> None:
330
        super().__init__()
331
        self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False)
332
        self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False)
333
334
    def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
335
        # move the buffer to the activation dtype for when AMP is used
336
        self.k = self.k.to(k.dtype)
337
        self.v = self.v.to(v.dtype)
338
        # update the cache
339
        k = self.k.index_copy_(2, input_pos, k)
340
        v = self.v.index_copy_(2, input_pos, v)
341
        return k, v
342
343
    def reset_parameters(self) -> None:
344
        torch.nn.init.zeros_(self.k)
345
        torch.nn.init.zeros_(self.v)