--- a +++ b/lit_gpt/adapter.py @@ -0,0 +1,165 @@ +"""Implementation of the paper: + +LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention +https://arxiv.org/abs/2303.16199 + +Port for Lit-GPT +""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from typing_extensions import Self + +from lit_gpt.config import Config as BaseConfig +from lit_gpt.model import GPT as BaseModel +from lit_gpt.model import Block as BaseBlock +from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention + + +@dataclass +class Config(BaseConfig): + adapter_prompt_length: int = 10 + adapter_start_layer: int = 2 + + +class GPT(BaseModel): + """The implementation is identical to `lit_gpt.model.GPT` with the exception that + the `Block` saves the layer index and passes it down to the attention layer.""" + + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + assert config.padded_vocab_size is not None + self.config = config + + self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), + h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), + ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), + ) + ) + self.max_seq_length = self.config.block_size + self.mask_cache: Optional[torch.Tensor] = None + + def forward( + self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0 + ) -> Union[torch.Tensor, List[torch.Tensor]]: + T = idx.size(1) + if self.max_seq_length < T: + raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") + + if input_pos is not None: # use the kv cache + cos = self.cos.index_select(0, input_pos) + sin = self.sin.index_select(0, input_pos) + if self.mask_cache is None: + raise TypeError("You need to call `gpt.set_kv_cache()`") + mask = self.mask_cache.index_select(2, input_pos) + else: + cos = self.cos[:T] + sin = self.sin[:T] + mask = None + + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + for block in self.transformer.h: + x = block(x, cos, sin, mask, input_pos) + x = self.transformer.ln_f(x) + if lm_head_chunk_size > 0: + # chunk the lm head logits to reduce the peak memory used by autograd + return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)] + return self.lm_head(x) # (b, t, vocab_size) + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + return cls(Config.from_name(name, **kwargs)) + + def _init_weights(self, module: nn.Module) -> None: + """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" + super()._init_weights(module) + if isinstance(module, CausalSelfAttention): + module.reset_parameters() + + +class Block(BaseBlock): + """The implementation is identical to `lit_gpt.model.Block` with the exception that + we replace the attention layer where adaption is implemented.""" + + def __init__(self, config: Config, block_idx: int) -> None: + # Skip the parent class __init__ altogether and replace it to avoid useless allocations + nn.Module.__init__(self) + self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.attn = CausalSelfAttention(config, block_idx) + if not config.shared_attention_norm: + self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.mlp = config.mlp_class(config) + + self.config = config + + +class CausalSelfAttention(BaseCausalSelfAttention): + """A modification of `lit_gpt.model.CausalSelfAttention` that adds the attention + over the adaption prompt.""" + + def __init__(self, config: Config, block_idx: int) -> None: + super().__init__(config) + if block_idx >= config.adapter_start_layer: + # adapter embedding layer + self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) + # gate for adaption + self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1)) + # kv cache for inference + self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + self.block_idx = block_idx + + def scaled_dot_product_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + y = super().scaled_dot_product_attention(q, k, v, mask) + if self.block_idx < self.config.adapter_start_layer: + return y + + aT = self.config.adapter_prompt_length + if self.adapter_kv_cache is not None: + # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av + # are the same every call + ak, av = self.adapter_kv_cache + else: + prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd) + aqkv = self.attn(prefix) + q_per_kv = self.config.n_head // self.config.n_query_groups + aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size) + aqkv = aqkv.permute(0, 2, 3, 1, 4) + _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2) + if self.config.n_query_groups != 1: + # for MHA this is a no-op + ak = ak.repeat_interleave(q_per_kv, dim=2) + av = av.repeat_interleave(q_per_kv, dim=2) + ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs) + av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs) + self.adapter_kv_cache = (ak, av) + + T = q.size(2) + amask = torch.ones(T, aT, dtype=torch.bool, device=q.device) + ay = super().scaled_dot_product_attention(q, ak, av, amask) + return y + self.gating_factor * ay + + def reset_parameters(self) -> None: + torch.nn.init.zeros_(self.gating_factor) + + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with older checkpoints.""" + if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head: + state_dict[key] = state_dict[key].permute(0, 2, 1, 3) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +def mark_only_adapter_as_trainable(model: GPT) -> None: + """Sets `requires_grad=False` for all non-adapter weights.""" + for name, param in model.named_parameters(): + param.requires_grad = adapter_filter(name, param) + + +def adapter_filter(key: str, value: Any) -> bool: + return "adapter_wte" in key or "gating_factor" in key