a b/lit_gpt/lora.py
1
# Derived from https://github.com/microsoft/LoRA
2
#  ------------------------------------------------------------------------------------------
3
#  Copyright (c) Microsoft Corporation. All rights reserved.
4
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
5
#  ------------------------------------------------------------------------------------------
6
7
r"""
8
    Low Ranking Adaptation for LLMs scheme.
9
10
             ┌───────────────────┐
11
             ┆         h         ┆
12
             └───────────────────┘
13
14
                       |
15
                       +
16
                    /     \
17
    ┌─────────────────┐    ╭───────────────╮     Matrix initialization:
18
    ┆                 ┆     \      B      /      B = 0
19
    ┆   pretrained    ┆      \    r*d    /       A = N(0, sigma^2)
20
    ┆    weights      ┆       ╰─────────╯
21
    ┆                 ┆       |    r    |        r - rank
22
    ┆   W e R^(d*d)   ┆       | ◀─────▶ |
23
    ┆                 ┆       ╭─────────╮
24
    └─────────────────┘      /     A     \
25
              ▲             /     d*r     \
26
               \           ╰───────────────╯
27
                \                ▲
28
                 \              /
29
                  \            /
30
             ┌───────────────────┐
31
             ┆         x         ┆
32
             └───────────────────┘
33
34
With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d,
35
we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates
36
for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of
37
course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen
38
pretrained weights and thus fine-tune the model.
39
40
The goal of this approach is to move weight updates into a separate matrix which is decomposed with
41
two matrices of a lower rank.
42
"""
43
44
import math
45
from dataclasses import dataclass
46
from typing import Any, Dict, List, Optional, Tuple, Type, Union
47
48
import torch
49
import torch.nn as nn
50
from torch.nn import functional as F
51
from typing_extensions import Self
52
53
import lit_gpt
54
from lit_gpt.config import Config as BaseConfig
55
from lit_gpt.model import GPT as BaseModel
56
from lit_gpt.model import Block as BaseBlock
57
from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention
58
from lit_gpt.model import KVCache
59
from lit_gpt.utils import map_old_state_dict_weights
60
61
62
class LoRALayer(nn.Module):
63
    def __init__(self, r: int, lora_alpha: int, lora_dropout: float):
64
        """Store LoRA specific attributes in a class.
65
66
        Args:
67
            r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
68
                the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
69
            lora_alpha: alpha is needed for scaling updates as alpha/r
70
                "This scaling helps to reduce the need to retune hyperparameters when we vary r"
71
                https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
72
            lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
73
        """
74
        super().__init__()
75
        assert r >= 0
76
        self.r = r
77
        self.lora_alpha = lora_alpha
78
        # Optional dropout
79
        if lora_dropout > 0.0:
80
            self.lora_dropout = nn.Dropout(p=lora_dropout)
81
        else:
82
            self.lora_dropout = lambda x: x
83
        # Mark the weight as unmerged
84
        self.merged = False
85
86
87
class LoRALinear(LoRALayer):
88
    # LoRA implemented in a dense layer
89
    def __init__(
90
        self,
91
        # ↓ this part is for pretrained weights
92
        in_features: int,
93
        out_features: int,
94
        # ↓ the remaining part is for LoRA
95
        r: int = 0,
96
        lora_alpha: int = 1,
97
        lora_dropout: float = 0.0,
98
        **kwargs,
99
    ):
100
        """LoRA wrapper around linear class.
101
102
        This class has three weight matrices:
103
            1. Pretrained weights are stored as `self.linear.weight`
104
            2. LoRA A matrix as `self.lora_A`
105
            3. LoRA B matrix as `self.lora_B`
106
        Only LoRA's A and B matrices are updated, pretrained weights stay frozen.
107
108
        Args:
109
            in_features: number of input features of the pretrained weights
110
            out_features: number of output features of the pretrained weights
111
            r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
112
                the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
113
            lora_alpha: alpha is needed for scaling updates as alpha/r
114
                "This scaling helps to reduce the need to retune hyperparameters when we vary r"
115
                https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
116
            lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
117
        """
118
        super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
119
        self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
120
121
        # Actual trainable parameters
122
        if r > 0:
123
            self.lora_A = nn.Parameter(torch.zeros((r, in_features)))
124
            self.lora_B = nn.Parameter(torch.zeros((out_features, r)))
125
            self.scaling = self.lora_alpha / self.r
126
            self.reset_parameters()
127
128
    def reset_parameters(self) -> None:
129
        """Reset all the weights, even including pretrained ones."""
130
        if hasattr(self, "lora_A"):
131
            # initialize A the same way as the default for nn.Linear and B to zero
132
            # Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314
133
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
134
            nn.init.zeros_(self.lora_B)
135
136
    def merge(self) -> None:
137
        """Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""
138
        if self.r > 0 and not self.merged:
139
            # Merge the weights and mark it
140
            self.linear.weight.data += (self.lora_B @ self.lora_A) * self.scaling
141
            self.merged = True
142
143
    def forward(self, x: torch.Tensor) -> torch.Tensor:
144
        # if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass;
145
        # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights
146
        pretrained = self.linear(x)
147
        if self.r == 0 or self.merged:
148
            return pretrained
149
        lora = (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
150
        return pretrained + lora
151
152
153
class LoRAQKVLinear(LoRALinear):
154
    # LoRA implemented in a dense layer
155
    def __init__(
156
        self,
157
        # ↓ this part is for pretrained weights
158
        in_features: int,
159
        out_features: int,
160
        # ↓ the remaining part is for LoRA
161
        n_head: int,
162
        n_query_groups: int,
163
        r: int = 0,
164
        lora_alpha: int = 1,
165
        lora_dropout: float = 0.0,
166
        enable_lora: Union[bool, Tuple[bool, bool, bool]] = False,
167
        **kwargs,
168
    ):
169
        """LoRA wrapper around linear class that is used for calculation of q, k and v matrices.
170
171
        This class has three weight matrices:
172
            1. Pretrained weights are stored as `self.linear.weight`
173
            2. LoRA A matrix as `self.lora_A`
174
            3. LoRA B matrix as `self.lora_B`
175
        Only LoRA's A and B matrices are updated, pretrained weights stay frozen.
176
177
        Args:
178
            in_features: number of input features of the pretrained weights
179
            out_features: number of output features of the pretrained weights
180
            n_head: number of attention heads
181
            n_query_groups: number of query groups (see diagram in `lit_gpt/config.py`)
182
            r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
183
                the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
184
            lora_alpha: alpha is needed for scaling updates as alpha/r
185
                "This scaling helps to reduce the need to retune hyperparameters when we vary r"
186
                https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
187
            lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
188
            enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we
189
                don't want to apply LoRA we can set it as False. For example if we want to apply LoRA only to `query`
190
                and `value` but keep `key` without weight updates we should pass `[True, False, True]`
191
        """
192
        super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
193
        self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
194
        self.n_head = n_head
195
        self.n_query_groups = n_query_groups
196
        if isinstance(enable_lora, bool):
197
            enable_lora = [enable_lora] * 3
198
        assert len(enable_lora) == 3
199
        self.enable_lora = enable_lora
200
201
        # Actual trainable parameters
202
        # To better understand initialization let's imagine that we have such parameters:
203
        # ⚬ in_features: 128 (embeddings_size)
204
        # ⚬ out_features: 384 (3 * embedding_size)
205
        # ⚬ r: 2
206
        # ⚬ enable_lora: [True, False, True]
207
        if r > 0 and any(enable_lora):
208
            self.lora_A = nn.Parameter(torch.zeros((r * sum(enable_lora), in_features)))  # (4, 128)
209
            enable_q, enable_k, enable_v = enable_lora
210
            self.kv_embd_size = self.linear.in_features // (n_head // n_query_groups)
211
            # qkv_shapes will be used to split a tensor with weights correctly
212
            qkv_shapes = (
213
                self.linear.in_features * enable_q,
214
                self.kv_embd_size * enable_k,
215
                self.kv_embd_size * enable_v,
216
            )
217
            self.qkv_shapes = [s for s in qkv_shapes if s]
218
            self.lora_B = nn.Parameter(torch.zeros(sum(self.qkv_shapes), r))  # (256, 2))
219
            # Notes about shapes above
220
            # - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;
221
            # 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in
222
            # F.linear function weights are automatically transposed. In addition conv1d requires channels to
223
            # be before seq length
224
            # - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is
225
            # 128*2; 2 tells to have two channels per group for group convolution
226
227
            # Scaling:
228
            # This balances the pretrained model`s knowledge and the new task-specific adaptation
229
            # https://lightning.ai/pages/community/tutorial/lora-llm/
230
            # So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set
231
            # alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can
232
            # tune these values to your needs. This value can be even slightly greater than 1.0!
233
            # https://github.com/cloneofsimo/lora
234
            self.scaling = self.lora_alpha / self.r
235
236
            # Compute the indices
237
            # Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values,
238
            # but not keys, then the weights update should be:
239
            #
240
            # [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
241
            #  [....................................],
242
            #  [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
243
            #      ↑              ↑            ↑
244
            # ________________________________________
245
            # | query         | key       | value    |
246
            # ----------------------------------------
247
            self.lora_ind = []
248
            if enable_q:
249
                self.lora_ind.extend(range(0, self.linear.in_features))
250
            if enable_k:
251
                self.lora_ind.extend(range(self.linear.in_features, self.linear.in_features + self.kv_embd_size))
252
            if enable_v:
253
                self.lora_ind.extend(range(self.linear.in_features + self.kv_embd_size, self.linear.out_features))
254
            self.reset_parameters()
255
256
    def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
257
        """Properly pad weight updates with zeros.
258
259
        If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys,
260
        then the weights update should be:
261
262
        [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
263
         [....................................],
264
         [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
265
            ↑              ↑            ↑
266
        ________________________________________
267
        | query         | key       | value    |
268
        ----------------------------------------
269
270
        Args:
271
            x: tensor with weights update that will be padded with zeros if necessary
272
273
        Returns:
274
            A tensor with weight updates and zeros for deselected q, k or v
275
        """
276
        # we need to do zero padding only if LoRA is disabled for one of QKV matrices
277
        if all(self.enable_lora):
278
            return x
279
280
        # Let's image that:
281
        # ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size)
282
        # ⚬ embeddings_size: 128
283
        # ⚬ self.linear.out_features: 384 (3 * embeddings_size)
284
        # ⚬ enable_lora: [True, False, True]
285
        # Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected
286
        # embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but
287
        # only for key updates (this is where self.lora_ind comes in handy)
288
        # Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors
289
        # for example when we want to merge/unmerge LoRA weights and pretrained weights
290
        x = x.transpose(0, 1)
291
        result = x.new_zeros((*x.shape[:-1], self.linear.out_features))  # (64, 64, 384)
292
        result = result.view(-1, self.linear.out_features)  # (4096, 384)
293
        result = result.index_copy(
294
            1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes))
295
        )  # (4096, 256)
296
        return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1)  # (64, 64, 384)
297
298
    def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
299
        """An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries.
300
301
        If the number of heads is equal to the number of query groups - grouped queries are disabled
302
        (see scheme in `lit_gpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized
303
        query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the
304
        input and weight matrices will be splitted in equally sized parts and applied separately (like having multiple
305
        conv layers side by side).
306
307
        Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually,
308
        apply each part of the weight matrix to the corresponding input's part and concatenate the result.
309
310
        Args:
311
            input: input matrix of shape (B, C, T)
312
            weight: weight matrix of shape (C_output, rank, 1).
313
                "C_output" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class).
314
315
        Returns:
316
            A tensor with a shape (B, C_output, T)
317
318
        """
319
        if self.n_head == self.n_query_groups:
320
            return F.conv1d(input, weight, groups=sum(self.enable_lora))  # (B, C_output, T)
321
322
        # Notation:
323
        # ⚬ N: number of enabled LoRA layers (self.enable_lora)
324
        # ⚬ C_output': embeddings size for each LoRA layer (not equal in size)
325
        # ⚬ r: rank of all LoRA layers (equal in size)
326
327
        input_splitted = input.chunk(sum(self.enable_lora), dim=1)  # N * (B, C // N, T)
328
        weight_splitted = weight.split(self.qkv_shapes)  # N * (C_output', r, 1)
329
        return torch.cat(
330
            [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1  # (B, C_output', T)
331
        )  # (B, C_output, T)
332
333
    def merge(self) -> None:
334
        """Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""
335
336
        # Let's assume that:
337
        # ⚬ self.linear.weight.data: (384, 128) or (3 * embedding_size, embedding_size)
338
        # ⚬ self.lora_A.data: (4, 128)
339
        # ⚬ self.lora_B.data: (256, 2)
340
        if self.r > 0 and any(self.enable_lora) and not self.merged:
341
            delta_w = self.conv1d(
342
                self.lora_A.data.unsqueeze(0),  # (4, 128) -> (1, 4, 128)
343
                self.lora_B.data.unsqueeze(-1),  # (256, 2) -> (256, 2, 1)
344
            ).squeeze(
345
                0
346
            )  # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
347
            # W = W + delta_W (merge)
348
            self.linear.weight.data += self.zero_pad(delta_w * self.scaling)  # (256, 128) after zero_pad (384, 128)
349
            self.merged = True
350
351
    def forward(self, x: torch.Tensor) -> torch.Tensor:
352
        """Do the forward pass.
353
354
        If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication.
355
        If not, then multiply pretrained weights with input, apply LoRA on input and do summation.
356
357
        Args:
358
            x: input tensor of shape (batch_size, context_length, embedding_size)
359
360
        Returns:
361
            Output tensor of shape (batch_size, context_length, 3 * embedding_size)
362
        """
363
364
        # Let's assume that:
365
        # ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size)
366
        # ⚬ self.linear.weight: (384, 128) or (3 * embedding_size, embedding_size)
367
        # ⚬ self.lora_A.data: (4, 128)
368
        # ⚬ self.lora_B.data: (256, 2)
369
370
        # if weights are merged or LoRA is disabled (r <= 0 or all `enable_lora` are False) - it's only a regular nn.Linear forward pass;
371
        # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights
372
        pretrained = self.linear(x)
373
        if self.r == 0 or not any(self.enable_lora) or self.merged:
374
            return pretrained
375
        after_A = F.linear(self.lora_dropout(x), self.lora_A)  # (64, 64, 128) @ (4, 128) -> (64, 64, 4)
376
        # For F.conv1d:
377
        # ⚬ input: input tensor of shape (mini-batch, in_channels, iW)
378
        # ⚬ weight: filters of shape (out_channels, in_channels/groups, kW)
379
        after_B = self.conv1d(
380
            after_A.transpose(-2, -1),  # (64, 64, 4) -> (64, 4, 64)
381
            self.lora_B.unsqueeze(-1),  # (256, 2) -> (256, 2, 1)
382
        ).transpose(
383
            -2, -1
384
        )  # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)
385
        lora = self.zero_pad(after_B) * self.scaling  # (64, 64, 256) after zero_pad (64, 64, 384)
386
        return pretrained + lora
387
388
389
def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
390
    """Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights.
391
392
    Args:
393
        model: model with LoRA layers
394
        bias:
395
            ``"none"``: all bias weights will be frozen,
396
            ``"lora_only"``: only bias weight for LoRA layers will be unfrozen,
397
            ``"all"``: all bias weights will be unfrozen.
398
399
    Raises:
400
        NotImplementedError: if `bias` not in ["none", "lora_only", "all"]
401
    """
402
    # freeze all layers except LoRA's
403
    for n, p in model.named_parameters():
404
        if "lora_" not in n:
405
            p.requires_grad = False
406
407
    # depending on the `bias` value unfreeze bias weights
408
    if bias == "none":
409
        return
410
    if bias == "all":
411
        for n, p in model.named_parameters():
412
            if "bias" in n:
413
                p.requires_grad = True
414
    elif bias == "lora_only":
415
        for m in model.modules():
416
            if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None:
417
                m.bias.requires_grad = True
418
    else:
419
        raise NotImplementedError
420
421
422
def lora_filter(key: str, value: Any) -> bool:
423
    return "lora_" in key
424
425
426
@dataclass
427
class Config(BaseConfig):
428
    """
429
    Args:
430
        r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
431
            the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
432
        alpha: alpha is needed for scaling updates as alpha/r
433
            "This scaling helps to reduce the need to retune hyperparameters when we vary r"
434
            https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
435
        dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
436
        to_*: either apply LoRA to the specified weights or not
437
    """
438
439
    r: int = 0
440
    alpha: int = 1
441
    dropout: float = 0.0
442
    to_query: bool = False
443
    to_key: bool = False
444
    to_value: bool = False
445
    to_projection: bool = False
446
    to_mlp: bool = False
447
    to_head: bool = False
448
449
    @property
450
    def mlp_class(self) -> Type:
451
        return getattr(lit_gpt.lora, self._mlp_class)
452
453
454
class GPT(BaseModel):
455
    def __init__(self, config: Config) -> None:
456
        nn.Module.__init__(self)
457
        assert config.padded_vocab_size is not None
458
        self.config = config
459
460
        self.lm_head = LoRALinear(
461
            config.n_embd,
462
            config.padded_vocab_size,
463
            bias=config.lm_head_bias,
464
            r=(config.r if config.to_head else 0),
465
            lora_alpha=config.alpha,
466
            lora_dropout=config.dropout,
467
        )
468
        self.transformer = nn.ModuleDict(
469
            dict(
470
                wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
471
                h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
472
                ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
473
            )
474
        )
475
        self.max_seq_length = self.config.block_size
476
        self.mask_cache: Optional[torch.Tensor] = None
477
478
    def forward(
479
        self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0
480
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
481
        T = idx.size(1)
482
        if self.max_seq_length < T:
483
            raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")
484
485
        if input_pos is not None:  # use the kv cache
486
            cos = self.cos.index_select(0, input_pos)
487
            sin = self.sin.index_select(0, input_pos)
488
            if self.mask_cache is None:
489
                raise TypeError("You need to call `gpt.set_kv_cache()`")
490
            mask = self.mask_cache.index_select(2, input_pos)
491
        else:
492
            cos = self.cos[:T]
493
            sin = self.sin[:T]
494
            mask = None
495
496
        x = self.transformer.wte(idx)  # token embeddings of shape (b, t, n_embd)
497
        for block in self.transformer.h:
498
            x = block(x, cos, sin, mask, input_pos)
499
        x = self.transformer.ln_f(x)
500
        if lm_head_chunk_size > 0:
501
            # chunk the lm head logits to reduce the peak memory used by autograd
502
            return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]
503
        return self.lm_head(x)  # (B, T, vocab_size)
504
505
    @classmethod
506
    def from_name(cls, name: str, **kwargs: Any) -> Self:
507
        return cls(Config.from_name(name, **kwargs))
508
509
    def _init_weights(self, module: nn.Module) -> None:
510
        """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
511
        super()._init_weights(module)
512
        if isinstance(module, LoRALinear):
513
            module.reset_parameters()
514
515
    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
516
        """For compatibility with base checkpoints."""
517
        mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"}
518
        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
519
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
520
521
522
class Block(BaseBlock):
523
    def __init__(self, config: Config) -> None:
524
        nn.Module.__init__(self)
525
        self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
526
        self.attn = CausalSelfAttention(config)
527
        if not config.shared_attention_norm:
528
            self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
529
        self.mlp = config.mlp_class(config)
530
531
        self.config = config
532
533
534
class CausalSelfAttention(BaseCausalSelfAttention):
535
    def __init__(self, config: Config) -> None:
536
        # Skip the parent class __init__ altogether and replace it to avoid
537
        # useless allocations
538
        nn.Module.__init__(self)
539
        shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
540
        # key, query, value projections for all heads, but in a batch
541
        self.attn = LoRAQKVLinear(
542
            in_features=config.n_embd,
543
            out_features=shape,
544
            r=config.r,
545
            lora_alpha=config.alpha,
546
            lora_dropout=config.dropout,
547
            enable_lora=(config.to_query, config.to_key, config.to_value),
548
            bias=config.bias,
549
            # for MQA/GQA support
550
            n_head=config.n_head,
551
            n_query_groups=config.n_query_groups,
552
        )
553
        # output projection
554
        self.proj = LoRALinear(
555
            config.n_embd,
556
            config.n_embd,
557
            bias=config.bias,
558
            r=(config.r if config.to_projection else 0),
559
            lora_alpha=config.alpha,
560
            lora_dropout=config.dropout,
561
        )
562
        # disabled by default
563
        self.kv_cache: Optional[KVCache] = None
564
565
        self.config = config
566
567
    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
568
        """For compatibility with base checkpoints."""
569
        mapping = {
570
            "attn.weight": "attn.linear.weight",
571
            "attn.bias": "attn.linear.bias",
572
            "proj.weight": "proj.linear.weight",
573
            "proj.bias": "proj.linear.bias",
574
        }
575
        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
576
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
577
578
579
class GptNeoxMLP(lit_gpt.model.GptNeoxMLP):
580
    def __init__(self, config: Config) -> None:
581
        nn.Module.__init__(self)
582
        self.fc = LoRALinear(
583
            config.n_embd,
584
            config.intermediate_size,
585
            bias=config.bias,
586
            r=(config.r if config.to_mlp else 0),
587
            lora_alpha=config.alpha,
588
            lora_dropout=config.dropout,
589
        )
590
        self.proj = LoRALinear(
591
            config.intermediate_size,
592
            config.n_embd,
593
            bias=config.bias,
594
            r=(config.r if config.to_mlp else 0),
595
            lora_alpha=config.alpha,
596
            lora_dropout=config.dropout,
597
        )
598
599
        self.config = config
600
601
    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
602
        """For compatibility with base checkpoints."""
603
        mapping = {
604
            "fc.weight": "fc.linear.weight",
605
            "fc.bias": "fc.linear.bias",
606
            "proj.weight": "proj.linear.weight",
607
            "proj.bias": "proj.linear.bias",
608
        }
609
        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
610
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
611
612
613
class LLaMAMLP(lit_gpt.model.LLaMAMLP):
614
    def __init__(self, config: Config) -> None:
615
        nn.Module.__init__(self)
616
        self.fc_1 = LoRALinear(
617
            config.n_embd,
618
            config.intermediate_size,
619
            bias=config.bias,
620
            r=(config.r if config.to_mlp else 0),
621
            lora_alpha=config.alpha,
622
            lora_dropout=config.dropout,
623
        )
624
        self.fc_2 = LoRALinear(
625
            config.n_embd,
626
            config.intermediate_size,
627
            bias=config.bias,
628
            r=(config.r if config.to_mlp else 0),
629
            lora_alpha=config.alpha,
630
            lora_dropout=config.dropout,
631
        )
632
        self.proj = LoRALinear(
633
            config.intermediate_size,
634
            config.n_embd,
635
            bias=config.bias,
636
            r=(config.r if config.to_mlp else 0),
637
            lora_alpha=config.alpha,
638
            lora_dropout=config.dropout,
639
        )
640
641
    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
642
        """For compatibility with base checkpoints."""
643
        mapping = {
644
            "fc_1.weight": "fc_1.linear.weight",
645
            "fc_1.bias": "fc_1.linear.bias",
646
            "fc_2.weight": "fc_2.linear.weight",
647
            "fc_2.bias": "fc_2.linear.bias",
648
            "proj.weight": "proj.linear.weight",
649
            "proj.bias": "proj.linear.bias",
650
        }
651
        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
652
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
653
654
655
def merge_lora_weights(model: GPT) -> None:
656
    """Merge LoRA weights into the full-rank weights to speed up inference."""
657
    for module in model.modules():
658
        if isinstance(module, LoRALinear):
659
            module.merge()