|
a |
|
b/lit_gpt/lora.py |
|
|
1 |
# Derived from https://github.com/microsoft/LoRA |
|
|
2 |
# ------------------------------------------------------------------------------------------ |
|
|
3 |
# Copyright (c) Microsoft Corporation. All rights reserved. |
|
|
4 |
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. |
|
|
5 |
# ------------------------------------------------------------------------------------------ |
|
|
6 |
|
|
|
7 |
r""" |
|
|
8 |
Low Ranking Adaptation for LLMs scheme. |
|
|
9 |
|
|
|
10 |
┌───────────────────┐ |
|
|
11 |
┆ h ┆ |
|
|
12 |
└───────────────────┘ |
|
|
13 |
▲ |
|
|
14 |
| |
|
|
15 |
+ |
|
|
16 |
/ \ |
|
|
17 |
┌─────────────────┐ ╭───────────────╮ Matrix initialization: |
|
|
18 |
┆ ┆ \ B / B = 0 |
|
|
19 |
┆ pretrained ┆ \ r*d / A = N(0, sigma^2) |
|
|
20 |
┆ weights ┆ ╰─────────╯ |
|
|
21 |
┆ ┆ | r | r - rank |
|
|
22 |
┆ W e R^(d*d) ┆ | ◀─────▶ | |
|
|
23 |
┆ ┆ ╭─────────╮ |
|
|
24 |
└─────────────────┘ / A \ |
|
|
25 |
▲ / d*r \ |
|
|
26 |
\ ╰───────────────╯ |
|
|
27 |
\ ▲ |
|
|
28 |
\ / |
|
|
29 |
\ / |
|
|
30 |
┌───────────────────┐ |
|
|
31 |
┆ x ┆ |
|
|
32 |
└───────────────────┘ |
|
|
33 |
|
|
|
34 |
With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d, |
|
|
35 |
we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates |
|
|
36 |
for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of |
|
|
37 |
course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen |
|
|
38 |
pretrained weights and thus fine-tune the model. |
|
|
39 |
|
|
|
40 |
The goal of this approach is to move weight updates into a separate matrix which is decomposed with |
|
|
41 |
two matrices of a lower rank. |
|
|
42 |
""" |
|
|
43 |
|
|
|
44 |
import math |
|
|
45 |
from dataclasses import dataclass |
|
|
46 |
from typing import Any, Dict, List, Optional, Tuple, Type, Union |
|
|
47 |
|
|
|
48 |
import torch |
|
|
49 |
import torch.nn as nn |
|
|
50 |
from torch.nn import functional as F |
|
|
51 |
from typing_extensions import Self |
|
|
52 |
|
|
|
53 |
import lit_gpt |
|
|
54 |
from lit_gpt.config import Config as BaseConfig |
|
|
55 |
from lit_gpt.model import GPT as BaseModel |
|
|
56 |
from lit_gpt.model import Block as BaseBlock |
|
|
57 |
from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention |
|
|
58 |
from lit_gpt.model import KVCache |
|
|
59 |
from lit_gpt.utils import map_old_state_dict_weights |
|
|
60 |
|
|
|
61 |
|
|
|
62 |
class LoRALayer(nn.Module): |
|
|
63 |
def __init__(self, r: int, lora_alpha: int, lora_dropout: float): |
|
|
64 |
"""Store LoRA specific attributes in a class. |
|
|
65 |
|
|
|
66 |
Args: |
|
|
67 |
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of |
|
|
68 |
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) |
|
|
69 |
lora_alpha: alpha is needed for scaling updates as alpha/r |
|
|
70 |
"This scaling helps to reduce the need to retune hyperparameters when we vary r" |
|
|
71 |
https://arxiv.org/pdf/2106.09685.pdf (section 4.1) |
|
|
72 |
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) |
|
|
73 |
""" |
|
|
74 |
super().__init__() |
|
|
75 |
assert r >= 0 |
|
|
76 |
self.r = r |
|
|
77 |
self.lora_alpha = lora_alpha |
|
|
78 |
# Optional dropout |
|
|
79 |
if lora_dropout > 0.0: |
|
|
80 |
self.lora_dropout = nn.Dropout(p=lora_dropout) |
|
|
81 |
else: |
|
|
82 |
self.lora_dropout = lambda x: x |
|
|
83 |
# Mark the weight as unmerged |
|
|
84 |
self.merged = False |
|
|
85 |
|
|
|
86 |
|
|
|
87 |
class LoRALinear(LoRALayer): |
|
|
88 |
# LoRA implemented in a dense layer |
|
|
89 |
def __init__( |
|
|
90 |
self, |
|
|
91 |
# ↓ this part is for pretrained weights |
|
|
92 |
in_features: int, |
|
|
93 |
out_features: int, |
|
|
94 |
# ↓ the remaining part is for LoRA |
|
|
95 |
r: int = 0, |
|
|
96 |
lora_alpha: int = 1, |
|
|
97 |
lora_dropout: float = 0.0, |
|
|
98 |
**kwargs, |
|
|
99 |
): |
|
|
100 |
"""LoRA wrapper around linear class. |
|
|
101 |
|
|
|
102 |
This class has three weight matrices: |
|
|
103 |
1. Pretrained weights are stored as `self.linear.weight` |
|
|
104 |
2. LoRA A matrix as `self.lora_A` |
|
|
105 |
3. LoRA B matrix as `self.lora_B` |
|
|
106 |
Only LoRA's A and B matrices are updated, pretrained weights stay frozen. |
|
|
107 |
|
|
|
108 |
Args: |
|
|
109 |
in_features: number of input features of the pretrained weights |
|
|
110 |
out_features: number of output features of the pretrained weights |
|
|
111 |
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of |
|
|
112 |
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) |
|
|
113 |
lora_alpha: alpha is needed for scaling updates as alpha/r |
|
|
114 |
"This scaling helps to reduce the need to retune hyperparameters when we vary r" |
|
|
115 |
https://arxiv.org/pdf/2106.09685.pdf (section 4.1) |
|
|
116 |
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) |
|
|
117 |
""" |
|
|
118 |
super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) |
|
|
119 |
self.linear = torch.nn.Linear(in_features, out_features, **kwargs) |
|
|
120 |
|
|
|
121 |
# Actual trainable parameters |
|
|
122 |
if r > 0: |
|
|
123 |
self.lora_A = nn.Parameter(torch.zeros((r, in_features))) |
|
|
124 |
self.lora_B = nn.Parameter(torch.zeros((out_features, r))) |
|
|
125 |
self.scaling = self.lora_alpha / self.r |
|
|
126 |
self.reset_parameters() |
|
|
127 |
|
|
|
128 |
def reset_parameters(self) -> None: |
|
|
129 |
"""Reset all the weights, even including pretrained ones.""" |
|
|
130 |
if hasattr(self, "lora_A"): |
|
|
131 |
# initialize A the same way as the default for nn.Linear and B to zero |
|
|
132 |
# Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314 |
|
|
133 |
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
|
|
134 |
nn.init.zeros_(self.lora_B) |
|
|
135 |
|
|
|
136 |
def merge(self) -> None: |
|
|
137 |
"""Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" |
|
|
138 |
if self.r > 0 and not self.merged: |
|
|
139 |
# Merge the weights and mark it |
|
|
140 |
self.linear.weight.data += (self.lora_B @ self.lora_A) * self.scaling |
|
|
141 |
self.merged = True |
|
|
142 |
|
|
|
143 |
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
144 |
# if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass; |
|
|
145 |
# otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights |
|
|
146 |
pretrained = self.linear(x) |
|
|
147 |
if self.r == 0 or self.merged: |
|
|
148 |
return pretrained |
|
|
149 |
lora = (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling |
|
|
150 |
return pretrained + lora |
|
|
151 |
|
|
|
152 |
|
|
|
153 |
class LoRAQKVLinear(LoRALinear): |
|
|
154 |
# LoRA implemented in a dense layer |
|
|
155 |
def __init__( |
|
|
156 |
self, |
|
|
157 |
# ↓ this part is for pretrained weights |
|
|
158 |
in_features: int, |
|
|
159 |
out_features: int, |
|
|
160 |
# ↓ the remaining part is for LoRA |
|
|
161 |
n_head: int, |
|
|
162 |
n_query_groups: int, |
|
|
163 |
r: int = 0, |
|
|
164 |
lora_alpha: int = 1, |
|
|
165 |
lora_dropout: float = 0.0, |
|
|
166 |
enable_lora: Union[bool, Tuple[bool, bool, bool]] = False, |
|
|
167 |
**kwargs, |
|
|
168 |
): |
|
|
169 |
"""LoRA wrapper around linear class that is used for calculation of q, k and v matrices. |
|
|
170 |
|
|
|
171 |
This class has three weight matrices: |
|
|
172 |
1. Pretrained weights are stored as `self.linear.weight` |
|
|
173 |
2. LoRA A matrix as `self.lora_A` |
|
|
174 |
3. LoRA B matrix as `self.lora_B` |
|
|
175 |
Only LoRA's A and B matrices are updated, pretrained weights stay frozen. |
|
|
176 |
|
|
|
177 |
Args: |
|
|
178 |
in_features: number of input features of the pretrained weights |
|
|
179 |
out_features: number of output features of the pretrained weights |
|
|
180 |
n_head: number of attention heads |
|
|
181 |
n_query_groups: number of query groups (see diagram in `lit_gpt/config.py`) |
|
|
182 |
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of |
|
|
183 |
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) |
|
|
184 |
lora_alpha: alpha is needed for scaling updates as alpha/r |
|
|
185 |
"This scaling helps to reduce the need to retune hyperparameters when we vary r" |
|
|
186 |
https://arxiv.org/pdf/2106.09685.pdf (section 4.1) |
|
|
187 |
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) |
|
|
188 |
enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we |
|
|
189 |
don't want to apply LoRA we can set it as False. For example if we want to apply LoRA only to `query` |
|
|
190 |
and `value` but keep `key` without weight updates we should pass `[True, False, True]` |
|
|
191 |
""" |
|
|
192 |
super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) |
|
|
193 |
self.linear = torch.nn.Linear(in_features, out_features, **kwargs) |
|
|
194 |
self.n_head = n_head |
|
|
195 |
self.n_query_groups = n_query_groups |
|
|
196 |
if isinstance(enable_lora, bool): |
|
|
197 |
enable_lora = [enable_lora] * 3 |
|
|
198 |
assert len(enable_lora) == 3 |
|
|
199 |
self.enable_lora = enable_lora |
|
|
200 |
|
|
|
201 |
# Actual trainable parameters |
|
|
202 |
# To better understand initialization let's imagine that we have such parameters: |
|
|
203 |
# ⚬ in_features: 128 (embeddings_size) |
|
|
204 |
# ⚬ out_features: 384 (3 * embedding_size) |
|
|
205 |
# ⚬ r: 2 |
|
|
206 |
# ⚬ enable_lora: [True, False, True] |
|
|
207 |
if r > 0 and any(enable_lora): |
|
|
208 |
self.lora_A = nn.Parameter(torch.zeros((r * sum(enable_lora), in_features))) # (4, 128) |
|
|
209 |
enable_q, enable_k, enable_v = enable_lora |
|
|
210 |
self.kv_embd_size = self.linear.in_features // (n_head // n_query_groups) |
|
|
211 |
# qkv_shapes will be used to split a tensor with weights correctly |
|
|
212 |
qkv_shapes = ( |
|
|
213 |
self.linear.in_features * enable_q, |
|
|
214 |
self.kv_embd_size * enable_k, |
|
|
215 |
self.kv_embd_size * enable_v, |
|
|
216 |
) |
|
|
217 |
self.qkv_shapes = [s for s in qkv_shapes if s] |
|
|
218 |
self.lora_B = nn.Parameter(torch.zeros(sum(self.qkv_shapes), r)) # (256, 2)) |
|
|
219 |
# Notes about shapes above |
|
|
220 |
# - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices; |
|
|
221 |
# 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in |
|
|
222 |
# F.linear function weights are automatically transposed. In addition conv1d requires channels to |
|
|
223 |
# be before seq length |
|
|
224 |
# - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is |
|
|
225 |
# 128*2; 2 tells to have two channels per group for group convolution |
|
|
226 |
|
|
|
227 |
# Scaling: |
|
|
228 |
# This balances the pretrained model`s knowledge and the new task-specific adaptation |
|
|
229 |
# https://lightning.ai/pages/community/tutorial/lora-llm/ |
|
|
230 |
# So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set |
|
|
231 |
# alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can |
|
|
232 |
# tune these values to your needs. This value can be even slightly greater than 1.0! |
|
|
233 |
# https://github.com/cloneofsimo/lora |
|
|
234 |
self.scaling = self.lora_alpha / self.r |
|
|
235 |
|
|
|
236 |
# Compute the indices |
|
|
237 |
# Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values, |
|
|
238 |
# but not keys, then the weights update should be: |
|
|
239 |
# |
|
|
240 |
# [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,], |
|
|
241 |
# [....................................], |
|
|
242 |
# [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]] |
|
|
243 |
# ↑ ↑ ↑ |
|
|
244 |
# ________________________________________ |
|
|
245 |
# | query | key | value | |
|
|
246 |
# ---------------------------------------- |
|
|
247 |
self.lora_ind = [] |
|
|
248 |
if enable_q: |
|
|
249 |
self.lora_ind.extend(range(0, self.linear.in_features)) |
|
|
250 |
if enable_k: |
|
|
251 |
self.lora_ind.extend(range(self.linear.in_features, self.linear.in_features + self.kv_embd_size)) |
|
|
252 |
if enable_v: |
|
|
253 |
self.lora_ind.extend(range(self.linear.in_features + self.kv_embd_size, self.linear.out_features)) |
|
|
254 |
self.reset_parameters() |
|
|
255 |
|
|
|
256 |
def zero_pad(self, x: torch.Tensor) -> torch.Tensor: |
|
|
257 |
"""Properly pad weight updates with zeros. |
|
|
258 |
|
|
|
259 |
If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys, |
|
|
260 |
then the weights update should be: |
|
|
261 |
|
|
|
262 |
[[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,], |
|
|
263 |
[....................................], |
|
|
264 |
[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]] |
|
|
265 |
↑ ↑ ↑ |
|
|
266 |
________________________________________ |
|
|
267 |
| query | key | value | |
|
|
268 |
---------------------------------------- |
|
|
269 |
|
|
|
270 |
Args: |
|
|
271 |
x: tensor with weights update that will be padded with zeros if necessary |
|
|
272 |
|
|
|
273 |
Returns: |
|
|
274 |
A tensor with weight updates and zeros for deselected q, k or v |
|
|
275 |
""" |
|
|
276 |
# we need to do zero padding only if LoRA is disabled for one of QKV matrices |
|
|
277 |
if all(self.enable_lora): |
|
|
278 |
return x |
|
|
279 |
|
|
|
280 |
# Let's image that: |
|
|
281 |
# ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size) |
|
|
282 |
# ⚬ embeddings_size: 128 |
|
|
283 |
# ⚬ self.linear.out_features: 384 (3 * embeddings_size) |
|
|
284 |
# ⚬ enable_lora: [True, False, True] |
|
|
285 |
# Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected |
|
|
286 |
# embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but |
|
|
287 |
# only for key updates (this is where self.lora_ind comes in handy) |
|
|
288 |
# Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors |
|
|
289 |
# for example when we want to merge/unmerge LoRA weights and pretrained weights |
|
|
290 |
x = x.transpose(0, 1) |
|
|
291 |
result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) # (64, 64, 384) |
|
|
292 |
result = result.view(-1, self.linear.out_features) # (4096, 384) |
|
|
293 |
result = result.index_copy( |
|
|
294 |
1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes)) |
|
|
295 |
) # (4096, 256) |
|
|
296 |
return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) # (64, 64, 384) |
|
|
297 |
|
|
|
298 |
def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: |
|
|
299 |
"""An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries. |
|
|
300 |
|
|
|
301 |
If the number of heads is equal to the number of query groups - grouped queries are disabled |
|
|
302 |
(see scheme in `lit_gpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized |
|
|
303 |
query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the |
|
|
304 |
input and weight matrices will be splitted in equally sized parts and applied separately (like having multiple |
|
|
305 |
conv layers side by side). |
|
|
306 |
|
|
|
307 |
Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually, |
|
|
308 |
apply each part of the weight matrix to the corresponding input's part and concatenate the result. |
|
|
309 |
|
|
|
310 |
Args: |
|
|
311 |
input: input matrix of shape (B, C, T) |
|
|
312 |
weight: weight matrix of shape (C_output, rank, 1). |
|
|
313 |
"C_output" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class). |
|
|
314 |
|
|
|
315 |
Returns: |
|
|
316 |
A tensor with a shape (B, C_output, T) |
|
|
317 |
|
|
|
318 |
""" |
|
|
319 |
if self.n_head == self.n_query_groups: |
|
|
320 |
return F.conv1d(input, weight, groups=sum(self.enable_lora)) # (B, C_output, T) |
|
|
321 |
|
|
|
322 |
# Notation: |
|
|
323 |
# ⚬ N: number of enabled LoRA layers (self.enable_lora) |
|
|
324 |
# ⚬ C_output': embeddings size for each LoRA layer (not equal in size) |
|
|
325 |
# ⚬ r: rank of all LoRA layers (equal in size) |
|
|
326 |
|
|
|
327 |
input_splitted = input.chunk(sum(self.enable_lora), dim=1) # N * (B, C // N, T) |
|
|
328 |
weight_splitted = weight.split(self.qkv_shapes) # N * (C_output', r, 1) |
|
|
329 |
return torch.cat( |
|
|
330 |
[F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1 # (B, C_output', T) |
|
|
331 |
) # (B, C_output, T) |
|
|
332 |
|
|
|
333 |
def merge(self) -> None: |
|
|
334 |
"""Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" |
|
|
335 |
|
|
|
336 |
# Let's assume that: |
|
|
337 |
# ⚬ self.linear.weight.data: (384, 128) or (3 * embedding_size, embedding_size) |
|
|
338 |
# ⚬ self.lora_A.data: (4, 128) |
|
|
339 |
# ⚬ self.lora_B.data: (256, 2) |
|
|
340 |
if self.r > 0 and any(self.enable_lora) and not self.merged: |
|
|
341 |
delta_w = self.conv1d( |
|
|
342 |
self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128) |
|
|
343 |
self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1) |
|
|
344 |
).squeeze( |
|
|
345 |
0 |
|
|
346 |
) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128) |
|
|
347 |
# W = W + delta_W (merge) |
|
|
348 |
self.linear.weight.data += self.zero_pad(delta_w * self.scaling) # (256, 128) after zero_pad (384, 128) |
|
|
349 |
self.merged = True |
|
|
350 |
|
|
|
351 |
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
352 |
"""Do the forward pass. |
|
|
353 |
|
|
|
354 |
If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication. |
|
|
355 |
If not, then multiply pretrained weights with input, apply LoRA on input and do summation. |
|
|
356 |
|
|
|
357 |
Args: |
|
|
358 |
x: input tensor of shape (batch_size, context_length, embedding_size) |
|
|
359 |
|
|
|
360 |
Returns: |
|
|
361 |
Output tensor of shape (batch_size, context_length, 3 * embedding_size) |
|
|
362 |
""" |
|
|
363 |
|
|
|
364 |
# Let's assume that: |
|
|
365 |
# ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size) |
|
|
366 |
# ⚬ self.linear.weight: (384, 128) or (3 * embedding_size, embedding_size) |
|
|
367 |
# ⚬ self.lora_A.data: (4, 128) |
|
|
368 |
# ⚬ self.lora_B.data: (256, 2) |
|
|
369 |
|
|
|
370 |
# if weights are merged or LoRA is disabled (r <= 0 or all `enable_lora` are False) - it's only a regular nn.Linear forward pass; |
|
|
371 |
# otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights |
|
|
372 |
pretrained = self.linear(x) |
|
|
373 |
if self.r == 0 or not any(self.enable_lora) or self.merged: |
|
|
374 |
return pretrained |
|
|
375 |
after_A = F.linear(self.lora_dropout(x), self.lora_A) # (64, 64, 128) @ (4, 128) -> (64, 64, 4) |
|
|
376 |
# For F.conv1d: |
|
|
377 |
# ⚬ input: input tensor of shape (mini-batch, in_channels, iW) |
|
|
378 |
# ⚬ weight: filters of shape (out_channels, in_channels/groups, kW) |
|
|
379 |
after_B = self.conv1d( |
|
|
380 |
after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64) |
|
|
381 |
self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1) |
|
|
382 |
).transpose( |
|
|
383 |
-2, -1 |
|
|
384 |
) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256) |
|
|
385 |
lora = self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384) |
|
|
386 |
return pretrained + lora |
|
|
387 |
|
|
|
388 |
|
|
|
389 |
def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: |
|
|
390 |
"""Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights. |
|
|
391 |
|
|
|
392 |
Args: |
|
|
393 |
model: model with LoRA layers |
|
|
394 |
bias: |
|
|
395 |
``"none"``: all bias weights will be frozen, |
|
|
396 |
``"lora_only"``: only bias weight for LoRA layers will be unfrozen, |
|
|
397 |
``"all"``: all bias weights will be unfrozen. |
|
|
398 |
|
|
|
399 |
Raises: |
|
|
400 |
NotImplementedError: if `bias` not in ["none", "lora_only", "all"] |
|
|
401 |
""" |
|
|
402 |
# freeze all layers except LoRA's |
|
|
403 |
for n, p in model.named_parameters(): |
|
|
404 |
if "lora_" not in n: |
|
|
405 |
p.requires_grad = False |
|
|
406 |
|
|
|
407 |
# depending on the `bias` value unfreeze bias weights |
|
|
408 |
if bias == "none": |
|
|
409 |
return |
|
|
410 |
if bias == "all": |
|
|
411 |
for n, p in model.named_parameters(): |
|
|
412 |
if "bias" in n: |
|
|
413 |
p.requires_grad = True |
|
|
414 |
elif bias == "lora_only": |
|
|
415 |
for m in model.modules(): |
|
|
416 |
if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None: |
|
|
417 |
m.bias.requires_grad = True |
|
|
418 |
else: |
|
|
419 |
raise NotImplementedError |
|
|
420 |
|
|
|
421 |
|
|
|
422 |
def lora_filter(key: str, value: Any) -> bool: |
|
|
423 |
return "lora_" in key |
|
|
424 |
|
|
|
425 |
|
|
|
426 |
@dataclass |
|
|
427 |
class Config(BaseConfig): |
|
|
428 |
""" |
|
|
429 |
Args: |
|
|
430 |
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of |
|
|
431 |
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) |
|
|
432 |
alpha: alpha is needed for scaling updates as alpha/r |
|
|
433 |
"This scaling helps to reduce the need to retune hyperparameters when we vary r" |
|
|
434 |
https://arxiv.org/pdf/2106.09685.pdf (section 4.1) |
|
|
435 |
dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) |
|
|
436 |
to_*: either apply LoRA to the specified weights or not |
|
|
437 |
""" |
|
|
438 |
|
|
|
439 |
r: int = 0 |
|
|
440 |
alpha: int = 1 |
|
|
441 |
dropout: float = 0.0 |
|
|
442 |
to_query: bool = False |
|
|
443 |
to_key: bool = False |
|
|
444 |
to_value: bool = False |
|
|
445 |
to_projection: bool = False |
|
|
446 |
to_mlp: bool = False |
|
|
447 |
to_head: bool = False |
|
|
448 |
|
|
|
449 |
@property |
|
|
450 |
def mlp_class(self) -> Type: |
|
|
451 |
return getattr(lit_gpt.lora, self._mlp_class) |
|
|
452 |
|
|
|
453 |
|
|
|
454 |
class GPT(BaseModel): |
|
|
455 |
def __init__(self, config: Config) -> None: |
|
|
456 |
nn.Module.__init__(self) |
|
|
457 |
assert config.padded_vocab_size is not None |
|
|
458 |
self.config = config |
|
|
459 |
|
|
|
460 |
self.lm_head = LoRALinear( |
|
|
461 |
config.n_embd, |
|
|
462 |
config.padded_vocab_size, |
|
|
463 |
bias=config.lm_head_bias, |
|
|
464 |
r=(config.r if config.to_head else 0), |
|
|
465 |
lora_alpha=config.alpha, |
|
|
466 |
lora_dropout=config.dropout, |
|
|
467 |
) |
|
|
468 |
self.transformer = nn.ModuleDict( |
|
|
469 |
dict( |
|
|
470 |
wte=nn.Embedding(config.padded_vocab_size, config.n_embd), |
|
|
471 |
h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), |
|
|
472 |
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), |
|
|
473 |
) |
|
|
474 |
) |
|
|
475 |
self.max_seq_length = self.config.block_size |
|
|
476 |
self.mask_cache: Optional[torch.Tensor] = None |
|
|
477 |
|
|
|
478 |
def forward( |
|
|
479 |
self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0 |
|
|
480 |
) -> Union[torch.Tensor, List[torch.Tensor]]: |
|
|
481 |
T = idx.size(1) |
|
|
482 |
if self.max_seq_length < T: |
|
|
483 |
raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") |
|
|
484 |
|
|
|
485 |
if input_pos is not None: # use the kv cache |
|
|
486 |
cos = self.cos.index_select(0, input_pos) |
|
|
487 |
sin = self.sin.index_select(0, input_pos) |
|
|
488 |
if self.mask_cache is None: |
|
|
489 |
raise TypeError("You need to call `gpt.set_kv_cache()`") |
|
|
490 |
mask = self.mask_cache.index_select(2, input_pos) |
|
|
491 |
else: |
|
|
492 |
cos = self.cos[:T] |
|
|
493 |
sin = self.sin[:T] |
|
|
494 |
mask = None |
|
|
495 |
|
|
|
496 |
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) |
|
|
497 |
for block in self.transformer.h: |
|
|
498 |
x = block(x, cos, sin, mask, input_pos) |
|
|
499 |
x = self.transformer.ln_f(x) |
|
|
500 |
if lm_head_chunk_size > 0: |
|
|
501 |
# chunk the lm head logits to reduce the peak memory used by autograd |
|
|
502 |
return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)] |
|
|
503 |
return self.lm_head(x) # (B, T, vocab_size) |
|
|
504 |
|
|
|
505 |
@classmethod |
|
|
506 |
def from_name(cls, name: str, **kwargs: Any) -> Self: |
|
|
507 |
return cls(Config.from_name(name, **kwargs)) |
|
|
508 |
|
|
|
509 |
def _init_weights(self, module: nn.Module) -> None: |
|
|
510 |
"""Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" |
|
|
511 |
super()._init_weights(module) |
|
|
512 |
if isinstance(module, LoRALinear): |
|
|
513 |
module.reset_parameters() |
|
|
514 |
|
|
|
515 |
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
|
|
516 |
"""For compatibility with base checkpoints.""" |
|
|
517 |
mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"} |
|
|
518 |
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
|
|
519 |
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
520 |
|
|
|
521 |
|
|
|
522 |
class Block(BaseBlock): |
|
|
523 |
def __init__(self, config: Config) -> None: |
|
|
524 |
nn.Module.__init__(self) |
|
|
525 |
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) |
|
|
526 |
self.attn = CausalSelfAttention(config) |
|
|
527 |
if not config.shared_attention_norm: |
|
|
528 |
self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) |
|
|
529 |
self.mlp = config.mlp_class(config) |
|
|
530 |
|
|
|
531 |
self.config = config |
|
|
532 |
|
|
|
533 |
|
|
|
534 |
class CausalSelfAttention(BaseCausalSelfAttention): |
|
|
535 |
def __init__(self, config: Config) -> None: |
|
|
536 |
# Skip the parent class __init__ altogether and replace it to avoid |
|
|
537 |
# useless allocations |
|
|
538 |
nn.Module.__init__(self) |
|
|
539 |
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size |
|
|
540 |
# key, query, value projections for all heads, but in a batch |
|
|
541 |
self.attn = LoRAQKVLinear( |
|
|
542 |
in_features=config.n_embd, |
|
|
543 |
out_features=shape, |
|
|
544 |
r=config.r, |
|
|
545 |
lora_alpha=config.alpha, |
|
|
546 |
lora_dropout=config.dropout, |
|
|
547 |
enable_lora=(config.to_query, config.to_key, config.to_value), |
|
|
548 |
bias=config.bias, |
|
|
549 |
# for MQA/GQA support |
|
|
550 |
n_head=config.n_head, |
|
|
551 |
n_query_groups=config.n_query_groups, |
|
|
552 |
) |
|
|
553 |
# output projection |
|
|
554 |
self.proj = LoRALinear( |
|
|
555 |
config.n_embd, |
|
|
556 |
config.n_embd, |
|
|
557 |
bias=config.bias, |
|
|
558 |
r=(config.r if config.to_projection else 0), |
|
|
559 |
lora_alpha=config.alpha, |
|
|
560 |
lora_dropout=config.dropout, |
|
|
561 |
) |
|
|
562 |
# disabled by default |
|
|
563 |
self.kv_cache: Optional[KVCache] = None |
|
|
564 |
|
|
|
565 |
self.config = config |
|
|
566 |
|
|
|
567 |
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
|
|
568 |
"""For compatibility with base checkpoints.""" |
|
|
569 |
mapping = { |
|
|
570 |
"attn.weight": "attn.linear.weight", |
|
|
571 |
"attn.bias": "attn.linear.bias", |
|
|
572 |
"proj.weight": "proj.linear.weight", |
|
|
573 |
"proj.bias": "proj.linear.bias", |
|
|
574 |
} |
|
|
575 |
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
|
|
576 |
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
577 |
|
|
|
578 |
|
|
|
579 |
class GptNeoxMLP(lit_gpt.model.GptNeoxMLP): |
|
|
580 |
def __init__(self, config: Config) -> None: |
|
|
581 |
nn.Module.__init__(self) |
|
|
582 |
self.fc = LoRALinear( |
|
|
583 |
config.n_embd, |
|
|
584 |
config.intermediate_size, |
|
|
585 |
bias=config.bias, |
|
|
586 |
r=(config.r if config.to_mlp else 0), |
|
|
587 |
lora_alpha=config.alpha, |
|
|
588 |
lora_dropout=config.dropout, |
|
|
589 |
) |
|
|
590 |
self.proj = LoRALinear( |
|
|
591 |
config.intermediate_size, |
|
|
592 |
config.n_embd, |
|
|
593 |
bias=config.bias, |
|
|
594 |
r=(config.r if config.to_mlp else 0), |
|
|
595 |
lora_alpha=config.alpha, |
|
|
596 |
lora_dropout=config.dropout, |
|
|
597 |
) |
|
|
598 |
|
|
|
599 |
self.config = config |
|
|
600 |
|
|
|
601 |
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
|
|
602 |
"""For compatibility with base checkpoints.""" |
|
|
603 |
mapping = { |
|
|
604 |
"fc.weight": "fc.linear.weight", |
|
|
605 |
"fc.bias": "fc.linear.bias", |
|
|
606 |
"proj.weight": "proj.linear.weight", |
|
|
607 |
"proj.bias": "proj.linear.bias", |
|
|
608 |
} |
|
|
609 |
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
|
|
610 |
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
611 |
|
|
|
612 |
|
|
|
613 |
class LLaMAMLP(lit_gpt.model.LLaMAMLP): |
|
|
614 |
def __init__(self, config: Config) -> None: |
|
|
615 |
nn.Module.__init__(self) |
|
|
616 |
self.fc_1 = LoRALinear( |
|
|
617 |
config.n_embd, |
|
|
618 |
config.intermediate_size, |
|
|
619 |
bias=config.bias, |
|
|
620 |
r=(config.r if config.to_mlp else 0), |
|
|
621 |
lora_alpha=config.alpha, |
|
|
622 |
lora_dropout=config.dropout, |
|
|
623 |
) |
|
|
624 |
self.fc_2 = LoRALinear( |
|
|
625 |
config.n_embd, |
|
|
626 |
config.intermediate_size, |
|
|
627 |
bias=config.bias, |
|
|
628 |
r=(config.r if config.to_mlp else 0), |
|
|
629 |
lora_alpha=config.alpha, |
|
|
630 |
lora_dropout=config.dropout, |
|
|
631 |
) |
|
|
632 |
self.proj = LoRALinear( |
|
|
633 |
config.intermediate_size, |
|
|
634 |
config.n_embd, |
|
|
635 |
bias=config.bias, |
|
|
636 |
r=(config.r if config.to_mlp else 0), |
|
|
637 |
lora_alpha=config.alpha, |
|
|
638 |
lora_dropout=config.dropout, |
|
|
639 |
) |
|
|
640 |
|
|
|
641 |
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
|
|
642 |
"""For compatibility with base checkpoints.""" |
|
|
643 |
mapping = { |
|
|
644 |
"fc_1.weight": "fc_1.linear.weight", |
|
|
645 |
"fc_1.bias": "fc_1.linear.bias", |
|
|
646 |
"fc_2.weight": "fc_2.linear.weight", |
|
|
647 |
"fc_2.bias": "fc_2.linear.bias", |
|
|
648 |
"proj.weight": "proj.linear.weight", |
|
|
649 |
"proj.bias": "proj.linear.bias", |
|
|
650 |
} |
|
|
651 |
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
|
|
652 |
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
653 |
|
|
|
654 |
|
|
|
655 |
def merge_lora_weights(model: GPT) -> None: |
|
|
656 |
"""Merge LoRA weights into the full-rank weights to speed up inference.""" |
|
|
657 |
for module in model.modules(): |
|
|
658 |
if isinstance(module, LoRALinear): |
|
|
659 |
module.merge() |