|
a |
|
b/lit_gpt/model.py |
|
|
1 |
"""Full definition of a GPT NeoX Language Model, all of it in this single file. |
|
|
2 |
|
|
|
3 |
Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and |
|
|
4 |
https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model. |
|
|
5 |
""" |
|
|
6 |
import math |
|
|
7 |
from typing import Any, Optional, Tuple |
|
|
8 |
|
|
|
9 |
import torch |
|
|
10 |
import torch.nn as nn |
|
|
11 |
from typing_extensions import Self |
|
|
12 |
|
|
|
13 |
from lit_gpt.config import Config |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
class GPT(nn.Module): |
|
|
17 |
def __init__(self, config: Config) -> None: |
|
|
18 |
super().__init__() |
|
|
19 |
assert config.padded_vocab_size is not None |
|
|
20 |
self.config = config |
|
|
21 |
|
|
|
22 |
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) |
|
|
23 |
self.transformer = nn.ModuleDict( |
|
|
24 |
dict( |
|
|
25 |
wte=nn.Embedding(config.padded_vocab_size, config.n_embd), |
|
|
26 |
h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), |
|
|
27 |
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), |
|
|
28 |
) |
|
|
29 |
) |
|
|
30 |
self.max_seq_length = self.config.block_size |
|
|
31 |
self.mask_cache: Optional[torch.Tensor] = None |
|
|
32 |
|
|
|
33 |
@property |
|
|
34 |
def max_seq_length(self) -> int: |
|
|
35 |
return self._max_seq_length |
|
|
36 |
|
|
|
37 |
@max_seq_length.setter |
|
|
38 |
def max_seq_length(self, value: int) -> None: |
|
|
39 |
""" |
|
|
40 |
When doing inference, the sequences used might be shorter than the model's context length. |
|
|
41 |
This allows setting a smaller number to avoid allocating unused memory |
|
|
42 |
""" |
|
|
43 |
if value > self.config.block_size: |
|
|
44 |
raise ValueError(f"Cannot attend to {value}, block size is only {self.config.block_size}") |
|
|
45 |
self._max_seq_length = value |
|
|
46 |
if not hasattr(self, "cos"): |
|
|
47 |
# first call |
|
|
48 |
cos, sin = self.rope_cache() |
|
|
49 |
self.register_buffer("cos", cos, persistent=False) |
|
|
50 |
self.register_buffer("sin", sin, persistent=False) |
|
|
51 |
elif value != self.cos.size(0): |
|
|
52 |
# override |
|
|
53 |
self.cos, self.sin = self.rope_cache(device=self.cos.device) |
|
|
54 |
# the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know |
|
|
55 |
# if the kv cache is expected |
|
|
56 |
|
|
|
57 |
def reset_parameters(self) -> None: |
|
|
58 |
# Trigger resetting the rope-cache |
|
|
59 |
self.max_seq_length = self.config.block_size |
|
|
60 |
|
|
|
61 |
def _init_weights(self, module: nn.Module) -> None: |
|
|
62 |
"""Meant to be used with `gpt.apply(gpt._init_weights)`.""" |
|
|
63 |
if isinstance(module, nn.Linear): |
|
|
64 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
65 |
if module.bias is not None: |
|
|
66 |
torch.nn.init.zeros_(module.bias) |
|
|
67 |
elif isinstance(module, nn.Embedding): |
|
|
68 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
69 |
|
|
|
70 |
def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
71 |
T = idx.size(1) |
|
|
72 |
if self.max_seq_length < T: |
|
|
73 |
raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") |
|
|
74 |
|
|
|
75 |
if input_pos is not None: # use the kv cache |
|
|
76 |
cos = self.cos.index_select(0, input_pos) |
|
|
77 |
sin = self.sin.index_select(0, input_pos) |
|
|
78 |
if self.mask_cache is None: |
|
|
79 |
raise TypeError("You need to call `gpt.set_kv_cache()`") |
|
|
80 |
mask = self.mask_cache.index_select(2, input_pos) |
|
|
81 |
else: |
|
|
82 |
cos = self.cos[:T] |
|
|
83 |
sin = self.sin[:T] |
|
|
84 |
mask = None |
|
|
85 |
|
|
|
86 |
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) |
|
|
87 |
for block in self.transformer.h: |
|
|
88 |
x = block(x, cos, sin, mask, input_pos) |
|
|
89 |
x = self.transformer.ln_f(x) |
|
|
90 |
return self.lm_head(x) # (b, t, vocab_size) |
|
|
91 |
|
|
|
92 |
@classmethod |
|
|
93 |
def from_name(cls, name: str, **kwargs: Any) -> Self: |
|
|
94 |
return cls(Config.from_name(name, **kwargs)) |
|
|
95 |
|
|
|
96 |
def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
97 |
return build_rope_cache( |
|
|
98 |
seq_len=self.max_seq_length, |
|
|
99 |
n_elem=self.config.rope_n_elem, |
|
|
100 |
device=device, |
|
|
101 |
condense_ratio=self.config.rope_condense_ratio, |
|
|
102 |
base=self.config.rope_base, |
|
|
103 |
) |
|
|
104 |
|
|
|
105 |
def set_kv_cache( |
|
|
106 |
self, |
|
|
107 |
batch_size: int, |
|
|
108 |
rope_cache_length: Optional[int] = None, |
|
|
109 |
device: Optional[torch.device] = None, |
|
|
110 |
dtype: Optional[torch.dtype] = None, |
|
|
111 |
) -> None: |
|
|
112 |
if rope_cache_length is None: |
|
|
113 |
rope_cache_length = self.cos.size(-1) |
|
|
114 |
max_seq_length = self.max_seq_length |
|
|
115 |
|
|
|
116 |
# initialize the kv cache for all blocks |
|
|
117 |
for block in self.transformer.h: |
|
|
118 |
block.attn.kv_cache = block.attn.build_kv_cache( |
|
|
119 |
batch_size, max_seq_length, rope_cache_length, device, dtype |
|
|
120 |
) |
|
|
121 |
|
|
|
122 |
if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length: |
|
|
123 |
# passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask |
|
|
124 |
# for the kv-cache support (only during inference), we only create it in that situation |
|
|
125 |
# this will be resolved by https://github.com/pytorch/pytorch/issues/96099 |
|
|
126 |
ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool) |
|
|
127 |
self.mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0) |
|
|
128 |
|
|
|
129 |
def clear_kv_cache(self) -> None: |
|
|
130 |
self.mask_cache = None |
|
|
131 |
for block in self.transformer.h: |
|
|
132 |
block.attn.kv_cache = None |
|
|
133 |
|
|
|
134 |
|
|
|
135 |
class Block(nn.Module): |
|
|
136 |
def __init__(self, config: Config) -> None: |
|
|
137 |
super().__init__() |
|
|
138 |
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) |
|
|
139 |
self.attn = CausalSelfAttention(config) |
|
|
140 |
self.norm_2 = None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps) |
|
|
141 |
self.mlp = config.mlp_class(config) |
|
|
142 |
|
|
|
143 |
self.config = config |
|
|
144 |
|
|
|
145 |
def forward( |
|
|
146 |
self, |
|
|
147 |
x: torch.Tensor, |
|
|
148 |
cos: torch.Tensor, |
|
|
149 |
sin: torch.Tensor, |
|
|
150 |
mask: Optional[torch.Tensor] = None, |
|
|
151 |
input_pos: Optional[torch.Tensor] = None, |
|
|
152 |
) -> torch.Tensor: |
|
|
153 |
n_1 = self.norm_1(x) |
|
|
154 |
h = self.attn(n_1, cos, sin, mask, input_pos) |
|
|
155 |
if self.config.parallel_residual: |
|
|
156 |
n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) |
|
|
157 |
x = self.mlp(n_2) + h + x |
|
|
158 |
else: |
|
|
159 |
if self.config.shared_attention_norm: |
|
|
160 |
raise NotImplementedError( |
|
|
161 |
"No checkpoint amongst the ones we support uses this configuration" |
|
|
162 |
" (non-parallel residual and shared attention norm)." |
|
|
163 |
) |
|
|
164 |
x = h + x |
|
|
165 |
x = self.mlp(self.norm_2(x)) + x |
|
|
166 |
return x |
|
|
167 |
|
|
|
168 |
|
|
|
169 |
class CausalSelfAttention(nn.Module): |
|
|
170 |
def __init__(self, config: Config) -> None: |
|
|
171 |
super().__init__() |
|
|
172 |
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size |
|
|
173 |
# key, query, value projections for all heads, but in a batch |
|
|
174 |
self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) |
|
|
175 |
# output projection |
|
|
176 |
self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
|
|
177 |
# disabled by default |
|
|
178 |
self.kv_cache: Optional[KVCache] = None |
|
|
179 |
|
|
|
180 |
self.config = config |
|
|
181 |
|
|
|
182 |
def forward( |
|
|
183 |
self, |
|
|
184 |
x: torch.Tensor, |
|
|
185 |
cos: torch.Tensor, |
|
|
186 |
sin: torch.Tensor, |
|
|
187 |
mask: Optional[torch.Tensor] = None, |
|
|
188 |
input_pos: Optional[torch.Tensor] = None, |
|
|
189 |
) -> torch.Tensor: |
|
|
190 |
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) |
|
|
191 |
|
|
|
192 |
qkv = self.attn(x) |
|
|
193 |
|
|
|
194 |
# assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) |
|
|
195 |
q_per_kv = self.config.n_head // self.config.n_query_groups |
|
|
196 |
total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value |
|
|
197 |
qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) |
|
|
198 |
qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) |
|
|
199 |
|
|
|
200 |
# split batched computation into three |
|
|
201 |
q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) |
|
|
202 |
|
|
|
203 |
# maybe repeat k and v if for the non multi-head attention cases |
|
|
204 |
# training: flash attention requires it |
|
|
205 |
# inference: multi-query would require a full kv cache so avoid it to limit its memory usage |
|
|
206 |
if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): |
|
|
207 |
k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) |
|
|
208 |
v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) |
|
|
209 |
|
|
|
210 |
q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) |
|
|
211 |
k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) |
|
|
212 |
v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) |
|
|
213 |
|
|
|
214 |
q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) |
|
|
215 |
k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) |
|
|
216 |
q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) |
|
|
217 |
k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) |
|
|
218 |
|
|
|
219 |
if input_pos is not None: |
|
|
220 |
if not isinstance(self.kv_cache, KVCache): |
|
|
221 |
raise TypeError("You need to call `gpt.set_kv_cache()`") |
|
|
222 |
k, v = self.kv_cache(input_pos, k, v) |
|
|
223 |
|
|
|
224 |
y = self.scaled_dot_product_attention(q, k, v, mask) |
|
|
225 |
|
|
|
226 |
y = y.reshape(B, T, C) # re-assemble all head outputs side by side |
|
|
227 |
|
|
|
228 |
# output projection |
|
|
229 |
return self.proj(y) |
|
|
230 |
|
|
|
231 |
def scaled_dot_product_attention( |
|
|
232 |
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None |
|
|
233 |
) -> torch.Tensor: |
|
|
234 |
scale = 1.0 / math.sqrt(self.config.head_size) |
|
|
235 |
y = torch.nn.functional.scaled_dot_product_attention( |
|
|
236 |
q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None |
|
|
237 |
) |
|
|
238 |
return y.transpose(1, 2) |
|
|
239 |
|
|
|
240 |
def build_kv_cache( |
|
|
241 |
self, |
|
|
242 |
batch_size: int, |
|
|
243 |
max_seq_length: int, |
|
|
244 |
rope_cache_length: Optional[int] = None, |
|
|
245 |
device: Optional[torch.device] = None, |
|
|
246 |
dtype: Optional[torch.dtype] = None, |
|
|
247 |
) -> "KVCache": |
|
|
248 |
heads = 1 if self.config.n_query_groups == 1 else self.config.n_head |
|
|
249 |
v_shape = (batch_size, heads, max_seq_length, self.config.head_size) |
|
|
250 |
if rope_cache_length is None: |
|
|
251 |
if self.config.rotary_percentage != 1.0: |
|
|
252 |
raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value") |
|
|
253 |
k_shape = v_shape |
|
|
254 |
else: |
|
|
255 |
k_shape = ( |
|
|
256 |
batch_size, |
|
|
257 |
heads, |
|
|
258 |
max_seq_length, |
|
|
259 |
rope_cache_length + self.config.head_size - self.config.rope_n_elem, |
|
|
260 |
) |
|
|
261 |
return KVCache(k_shape, v_shape, device=device, dtype=dtype) |
|
|
262 |
|
|
|
263 |
|
|
|
264 |
class GptNeoxMLP(nn.Module): |
|
|
265 |
def __init__(self, config: Config) -> None: |
|
|
266 |
super().__init__() |
|
|
267 |
self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) |
|
|
268 |
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) |
|
|
269 |
|
|
|
270 |
self.config = config |
|
|
271 |
|
|
|
272 |
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
273 |
x = self.fc(x) |
|
|
274 |
x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate) |
|
|
275 |
return self.proj(x) |
|
|
276 |
|
|
|
277 |
|
|
|
278 |
class LLaMAMLP(nn.Module): |
|
|
279 |
def __init__(self, config: Config) -> None: |
|
|
280 |
super().__init__() |
|
|
281 |
self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) |
|
|
282 |
self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) |
|
|
283 |
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) |
|
|
284 |
|
|
|
285 |
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
286 |
x_fc_1 = self.fc_1(x) |
|
|
287 |
x_fc_2 = self.fc_2(x) |
|
|
288 |
x = torch.nn.functional.silu(x_fc_1) * x_fc_2 |
|
|
289 |
return self.proj(x) |
|
|
290 |
|
|
|
291 |
|
|
|
292 |
def build_rope_cache( |
|
|
293 |
seq_len: int, n_elem: int, device: Optional[torch.device] = None, base: int = 10000, condense_ratio: int = 1 |
|
|
294 |
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
295 |
"""Enhanced Transformer with Rotary Position Embedding. |
|
|
296 |
|
|
|
297 |
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ |
|
|
298 |
transformers/rope/__init__.py. MIT License: |
|
|
299 |
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. |
|
|
300 |
""" |
|
|
301 |
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ |
|
|
302 |
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) |
|
|
303 |
|
|
|
304 |
# Create position indexes `[0, 1, ..., seq_len - 1]` |
|
|
305 |
seq_idx = torch.arange(seq_len, device=device) / condense_ratio |
|
|
306 |
|
|
|
307 |
# Calculate the product of position index and $\theta_i$ |
|
|
308 |
idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) |
|
|
309 |
|
|
|
310 |
return torch.cos(idx_theta), torch.sin(idx_theta) |
|
|
311 |
|
|
|
312 |
|
|
|
313 |
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
|
|
314 |
head_size = x.size(-1) |
|
|
315 |
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) |
|
|
316 |
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) |
|
|
317 |
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) |
|
|
318 |
roped = (x * cos) + (rotated * sin) |
|
|
319 |
return roped.type_as(x) |
|
|
320 |
|
|
|
321 |
|
|
|
322 |
class KVCache(nn.Module): |
|
|
323 |
def __init__( |
|
|
324 |
self, |
|
|
325 |
k_shape: Tuple[int, int, int, int], |
|
|
326 |
v_shape: Tuple[int, int, int, int], |
|
|
327 |
device: Optional[torch.device] = None, |
|
|
328 |
dtype: Optional[torch.dtype] = None, |
|
|
329 |
) -> None: |
|
|
330 |
super().__init__() |
|
|
331 |
self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False) |
|
|
332 |
self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False) |
|
|
333 |
|
|
|
334 |
def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
335 |
# move the buffer to the activation dtype for when AMP is used |
|
|
336 |
self.k = self.k.to(k.dtype) |
|
|
337 |
self.v = self.v.to(v.dtype) |
|
|
338 |
# update the cache |
|
|
339 |
k = self.k.index_copy_(2, input_pos, k) |
|
|
340 |
v = self.v.index_copy_(2, input_pos, v) |
|
|
341 |
return k, v |
|
|
342 |
|
|
|
343 |
def reset_parameters(self) -> None: |
|
|
344 |
torch.nn.init.zeros_(self.k) |
|
|
345 |
torch.nn.init.zeros_(self.v) |