Switch to unified view

a b/src/llama-main/llama/model.py
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
2
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
4
import math
5
from dataclasses import dataclass
6
from typing import Optional, Tuple
7
8
import fairscale.nn.model_parallel.initialize as fs_init
9
import torch
10
import torch.nn.functional as F
11
from fairscale.nn.model_parallel.layers import (
12
    ColumnParallelLinear,
13
    ParallelEmbedding,
14
    RowParallelLinear,
15
)
16
from torch import nn
17
18
19
@dataclass
20
class ModelArgs:
21
    dim: int = 4096
22
    n_layers: int = 32
23
    n_heads: int = 32
24
    n_kv_heads: Optional[int] = None
25
    vocab_size: int = -1  # defined later by tokenizer
26
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
27
    ffn_dim_multiplier: Optional[float] = None
28
    norm_eps: float = 1e-5
29
30
    max_batch_size: int = 32
31
    max_seq_len: int = 2048
32
33
34
class RMSNorm(torch.nn.Module):
35
    def __init__(self, dim: int, eps: float = 1e-6):
36
        """
37
        Initialize the RMSNorm normalization layer.
38
39
        Args:
40
            dim (int): The dimension of the input tensor.
41
            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
42
43
        Attributes:
44
            eps (float): A small value added to the denominator for numerical stability.
45
            weight (nn.Parameter): Learnable scaling parameter.
46
47
        """
48
        super().__init__()
49
        self.eps = eps
50
        self.weight = nn.Parameter(torch.ones(dim))
51
52
    def _norm(self, x):
53
        """
54
        Apply the RMSNorm normalization to the input tensor.
55
56
        Args:
57
            x (torch.Tensor): The input tensor.
58
59
        Returns:
60
            torch.Tensor: The normalized tensor.
61
62
        """
63
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
64
65
    def forward(self, x):
66
        """
67
        Forward pass through the RMSNorm layer.
68
69
        Args:
70
            x (torch.Tensor): The input tensor.
71
72
        Returns:
73
            torch.Tensor: The output tensor after applying RMSNorm.
74
75
        """
76
        output = self._norm(x.float()).type_as(x)
77
        return output * self.weight
78
79
80
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
81
    """
82
    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
83
84
    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
85
    and the end index 'end'. The 'theta' parameter scales the frequencies.
86
    The returned tensor contains complex values in complex64 data type.
87
88
    Args:
89
        dim (int): Dimension of the frequency tensor.
90
        end (int): End index for precomputing frequencies.
91
        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
92
93
    Returns:
94
        torch.Tensor: Precomputed frequency tensor with complex exponentials.
95
96
    
97
        
98
99
    """
100
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
101
    t = torch.arange(end, device=freqs.device)  # type: ignore
102
    freqs = torch.outer(t, freqs).float()  # type: ignore
103
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
104
    return freqs_cis
105
106
107
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
108
    """
109
    Reshape frequency tensor for broadcasting it with another tensor.
110
111
    This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
112
    for the purpose of broadcasting the frequency tensor during element-wise operations.
113
114
    Args:
115
        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
116
        x (torch.Tensor): Target tensor for broadcasting compatibility.
117
118
    Returns:
119
        torch.Tensor: Reshaped frequency tensor.
120
121
    Raises:
122
        AssertionError: If the frequency tensor doesn't match the expected shape.
123
        AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
124
    """
125
    ndim = x.ndim
126
    assert 0 <= 1 < ndim
127
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
128
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
129
    return freqs_cis.view(*shape)
130
131
132
def apply_rotary_emb(
133
    xq: torch.Tensor,
134
    xk: torch.Tensor,
135
    freqs_cis: torch.Tensor,
136
) -> Tuple[torch.Tensor, torch.Tensor]:
137
    """
138
    Apply rotary embeddings to input tensors using the given frequency tensor.
139
140
    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
141
    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
142
    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
143
    returned as real tensors.
144
145
    Args:
146
        xq (torch.Tensor): Query tensor to apply rotary embeddings.
147
        xk (torch.Tensor): Key tensor to apply rotary embeddings.
148
        freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
149
150
    Returns:
151
        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
152
153
        
154
155
    """
156
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
157
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
158
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
159
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
160
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
161
    return xq_out.type_as(xq), xk_out.type_as(xk)
162
163
164
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
165
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
166
    bs, slen, n_kv_heads, head_dim = x.shape
167
    if n_rep == 1:
168
        return x
169
    return (
170
        x[:, :, :, None, :]
171
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
172
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
173
    )
174
175
176
class Attention(nn.Module):
177
    """Multi-head attention module."""
178
    def __init__(self, args: ModelArgs):
179
        """
180
        Initialize the Attention module.
181
182
        Args:
183
            args (ModelArgs): Model configuration parameters.
184
185
        Attributes:
186
            n_kv_heads (int): Number of key and value heads.
187
            n_local_heads (int): Number of local query heads.
188
            n_local_kv_heads (int): Number of local key and value heads.
189
            n_rep (int): Number of repetitions for local heads.
190
            head_dim (int): Dimension size of each attention head.
191
            wq (ColumnParallelLinear): Linear transformation for queries.
192
            wk (ColumnParallelLinear): Linear transformation for keys.
193
            wv (ColumnParallelLinear): Linear transformation for values.
194
            wo (RowParallelLinear): Linear transformation for output.
195
            cache_k (torch.Tensor): Cached keys for attention.
196
            cache_v (torch.Tensor): Cached values for attention.
197
198
        """
199
        super().__init__()
200
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
201
        model_parallel_size = fs_init.get_model_parallel_world_size()
202
        self.n_local_heads = args.n_heads // model_parallel_size
203
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
204
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
205
        self.head_dim = args.dim // args.n_heads
206
207
        self.wq = ColumnParallelLinear(
208
            args.dim,
209
            args.n_heads * self.head_dim,
210
            bias=False,
211
            gather_output=False,
212
            init_method=lambda x: x,
213
        )
214
        self.wk = ColumnParallelLinear(
215
            args.dim,
216
            self.n_kv_heads * self.head_dim,
217
            bias=False,
218
            gather_output=False,
219
            init_method=lambda x: x,
220
        )
221
        self.wv = ColumnParallelLinear(
222
            args.dim,
223
            self.n_kv_heads * self.head_dim,
224
            bias=False,
225
            gather_output=False,
226
            init_method=lambda x: x,
227
        )
228
        self.wo = RowParallelLinear(
229
            args.n_heads * self.head_dim,
230
            args.dim,
231
            bias=False,
232
            input_is_parallel=True,
233
            init_method=lambda x: x,
234
        )
235
236
        self.cache_k = torch.zeros(
237
            (
238
                args.max_batch_size,
239
                args.max_seq_len,
240
                self.n_local_kv_heads,
241
                self.head_dim,
242
            )
243
        ).cuda()
244
        self.cache_v = torch.zeros(
245
            (
246
                args.max_batch_size,
247
                args.max_seq_len,
248
                self.n_local_kv_heads,
249
                self.head_dim,
250
            )
251
        ).cuda()
252
253
    def forward(
254
        self,
255
        x: torch.Tensor,
256
        start_pos: int,
257
        freqs_cis: torch.Tensor,
258
        mask: Optional[torch.Tensor],
259
    ):
260
        """
261
        Forward pass of the attention module.
262
263
        Args:
264
            x (torch.Tensor): Input tensor.
265
            start_pos (int): Starting position for caching.
266
            freqs_cis (torch.Tensor): Precomputed frequency tensor.
267
            mask (torch.Tensor, optional): Attention mask tensor.
268
269
        Returns:
270
            torch.Tensor: Output tensor after attention.
271
272
        """
273
        bsz, seqlen, _ = x.shape
274
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
275
276
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
277
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
278
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
279
280
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
281
282
        self.cache_k = self.cache_k.to(xq)
283
        self.cache_v = self.cache_v.to(xq)
284
285
        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
286
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
287
288
        keys = self.cache_k[:bsz, : start_pos + seqlen]
289
        values = self.cache_v[:bsz, : start_pos + seqlen]
290
291
        # repeat k/v heads if n_kv_heads < n_heads
292
        keys = repeat_kv(keys, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)
293
        values = repeat_kv(values, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)
294
295
        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
296
        keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
297
        values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
298
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
299
        if mask is not None:
300
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
301
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
302
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
303
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
304
        return self.wo(output)
305
306
307
class FeedForward(nn.Module):
308
    def __init__(
309
        self,
310
        dim: int,
311
        hidden_dim: int,
312
        multiple_of: int,
313
        ffn_dim_multiplier: Optional[float],
314
    ):
315
        """
316
        Initialize the FeedForward module.
317
318
        Args:
319
            dim (int): Input dimension.
320
            hidden_dim (int): Hidden dimension of the feedforward layer.
321
            multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
322
            ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
323
324
        Attributes:
325
            w1 (ColumnParallelLinear): Linear transformation for the first layer.
326
            w2 (RowParallelLinear): Linear transformation for the second layer.
327
            w3 (ColumnParallelLinear): Linear transformation for the third layer.
328
329
        """
330
        super().__init__()
331
        hidden_dim = int(2 * hidden_dim / 3)
332
        # custom dim factor multiplier
333
        if ffn_dim_multiplier is not None:
334
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
335
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
336
337
        self.w1 = ColumnParallelLinear(
338
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
339
        )
340
        self.w2 = RowParallelLinear(
341
            hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
342
        )
343
        self.w3 = ColumnParallelLinear(
344
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
345
        )
346
347
    def forward(self, x):
348
        return self.w2(F.silu(self.w1(x)) * self.w3(x))
349
350
351
class TransformerBlock(nn.Module):
352
    def __init__(self, layer_id: int, args: ModelArgs):
353
        """
354
        Initialize a TransformerBlock.
355
356
        Args:
357
            layer_id (int): Identifier for the layer.
358
            args (ModelArgs): Model configuration parameters.
359
360
        Attributes:
361
            n_heads (int): Number of attention heads.
362
            dim (int): Dimension size of the model.
363
            head_dim (int): Dimension size of each attention head.
364
            attention (Attention): Attention module.
365
            feed_forward (FeedForward): FeedForward module.
366
            layer_id (int): Identifier for the layer.
367
            attention_norm (RMSNorm): Layer normalization for attention output.
368
            ffn_norm (RMSNorm): Layer normalization for feedforward output.
369
370
        """
371
        super().__init__()
372
        self.n_heads = args.n_heads
373
        self.dim = args.dim
374
        self.head_dim = args.dim // args.n_heads
375
        self.attention = Attention(args)
376
        self.feed_forward = FeedForward(
377
            dim=args.dim,
378
            hidden_dim=4 * args.dim,
379
            multiple_of=args.multiple_of,
380
            ffn_dim_multiplier=args.ffn_dim_multiplier,
381
        )
382
        self.layer_id = layer_id
383
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
384
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
385
386
    def forward(
387
        self,
388
        x: torch.Tensor,
389
        start_pos: int,
390
        freqs_cis: torch.Tensor,
391
        mask: Optional[torch.Tensor],
392
    ):
393
        """
394
        Perform a forward pass through the TransformerBlock.
395
396
        Args:
397
            x (torch.Tensor): Input tensor.
398
            start_pos (int): Starting position for attention caching.
399
            freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
400
            mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
401
402
        Returns:
403
            torch.Tensor: Output tensor after applying attention and feedforward layers.
404
405
        """
406
        h = x + self.attention.forward(
407
            self.attention_norm(x), start_pos, freqs_cis, mask
408
        )
409
        out = h + self.feed_forward.forward(self.ffn_norm(h))
410
        return out
411
412
413
class Transformer(nn.Module):
414
    def __init__(self, params: ModelArgs):
415
        """
416
        Initialize a Transformer model.
417
418
        Args:
419
            params (ModelArgs): Model configuration parameters.
420
421
        Attributes:
422
            params (ModelArgs): Model configuration parameters.
423
            vocab_size (int): Vocabulary size.
424
            n_layers (int): Number of layers in the model.
425
            tok_embeddings (ParallelEmbedding): Token embeddings.
426
            layers (torch.nn.ModuleList): List of Transformer blocks.
427
            norm (RMSNorm): Layer normalization for the model output.
428
            output (ColumnParallelLinear): Linear layer for final output.
429
            freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
430
431
        """
432
        super().__init__()
433
        self.params = params
434
        self.vocab_size = params.vocab_size
435
        self.n_layers = params.n_layers
436
437
        self.tok_embeddings = ParallelEmbedding(
438
            params.vocab_size, params.dim, init_method=lambda x: x
439
        )
440
441
        self.layers = torch.nn.ModuleList()
442
        for layer_id in range(params.n_layers):
443
            self.layers.append(TransformerBlock(layer_id, params))
444
445
        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
446
        self.output = ColumnParallelLinear(
447
            params.dim, params.vocab_size, bias=False, init_method=lambda x: x
448
        )
449
450
        self.freqs_cis = precompute_freqs_cis(
451
            # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096. 
452
            # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
453
            self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
454
        )
455
456
    @torch.inference_mode()
457
    def forward(self, tokens: torch.Tensor, start_pos: int):
458
        """
459
        Perform a forward pass through the Transformer model.
460
461
        Args:
462
            tokens (torch.Tensor): Input token indices.
463
            start_pos (int): Starting position for attention caching.
464
465
        Returns:
466
            torch.Tensor: Output logits after applying the Transformer model.
467
468
        """
469
        _bsz, seqlen = tokens.shape
470
        h = self.tok_embeddings(tokens)
471
        self.freqs_cis = self.freqs_cis.to(h.device)
472
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
473
474
        mask = None
475
        if seqlen > 1:
476
            mask = torch.full(
477
                (seqlen, seqlen), float("-inf"), device=tokens.device
478
            )
479
480
            mask = torch.triu(mask, diagonal=1)
481
482
            # When performing key-value caching, we compute the attention scores
483
            # only for the new sequence. Thus, the matrix of scores is of size
484
            # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
485
            # j > cache_len + i, since row i corresponds to token cache_len + i.
486
            mask = torch.hstack([
487
                torch.zeros((seqlen, start_pos), device=tokens.device),
488
                mask
489
            ]).type_as(h)
490
491
        for layer in self.layers:
492
            h = layer(h, start_pos, freqs_cis, mask)
493
        h = self.norm(h)
494
        output = self.output(h).float()
495
        return output