Diff of /lit_gpt/adapter.py [000000] .. [248dc9]

Switch to side-by-side view

--- a
+++ b/lit_gpt/adapter.py
@@ -0,0 +1,165 @@
+"""Implementation of the paper:
+
+LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
+https://arxiv.org/abs/2303.16199
+
+Port for Lit-GPT
+"""
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from typing_extensions import Self
+
+from lit_gpt.config import Config as BaseConfig
+from lit_gpt.model import GPT as BaseModel
+from lit_gpt.model import Block as BaseBlock
+from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention
+
+
+@dataclass
+class Config(BaseConfig):
+    adapter_prompt_length: int = 10
+    adapter_start_layer: int = 2
+
+
+class GPT(BaseModel):
+    """The implementation is identical to `lit_gpt.model.GPT` with the exception that
+    the `Block` saves the layer index and passes it down to the attention layer."""
+
+    def __init__(self, config: Config) -> None:
+        nn.Module.__init__(self)
+        assert config.padded_vocab_size is not None
+        self.config = config
+
+        self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
+        self.transformer = nn.ModuleDict(
+            dict(
+                wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
+                h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
+                ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
+            )
+        )
+        self.max_seq_length = self.config.block_size
+        self.mask_cache: Optional[torch.Tensor] = None
+
+    def forward(
+        self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0
+    ) -> Union[torch.Tensor, List[torch.Tensor]]:
+        T = idx.size(1)
+        if self.max_seq_length < T:
+            raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")
+
+        if input_pos is not None:  # use the kv cache
+            cos = self.cos.index_select(0, input_pos)
+            sin = self.sin.index_select(0, input_pos)
+            if self.mask_cache is None:
+                raise TypeError("You need to call `gpt.set_kv_cache()`")
+            mask = self.mask_cache.index_select(2, input_pos)
+        else:
+            cos = self.cos[:T]
+            sin = self.sin[:T]
+            mask = None
+
+        x = self.transformer.wte(idx)  # token embeddings of shape (b, t, n_embd)
+        for block in self.transformer.h:
+            x = block(x, cos, sin, mask, input_pos)
+        x = self.transformer.ln_f(x)
+        if lm_head_chunk_size > 0:
+            # chunk the lm head logits to reduce the peak memory used by autograd
+            return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]
+        return self.lm_head(x)  # (b, t, vocab_size)
+
+    @classmethod
+    def from_name(cls, name: str, **kwargs: Any) -> Self:
+        return cls(Config.from_name(name, **kwargs))
+
+    def _init_weights(self, module: nn.Module) -> None:
+        """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
+        super()._init_weights(module)
+        if isinstance(module, CausalSelfAttention):
+            module.reset_parameters()
+
+
+class Block(BaseBlock):
+    """The implementation is identical to `lit_gpt.model.Block` with the exception that
+    we replace the attention layer where adaption is implemented."""
+
+    def __init__(self, config: Config, block_idx: int) -> None:
+        # Skip the parent class __init__ altogether and replace it to avoid useless allocations
+        nn.Module.__init__(self)
+        self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
+        self.attn = CausalSelfAttention(config, block_idx)
+        if not config.shared_attention_norm:
+            self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
+        self.mlp = config.mlp_class(config)
+
+        self.config = config
+
+
+class CausalSelfAttention(BaseCausalSelfAttention):
+    """A modification of `lit_gpt.model.CausalSelfAttention` that adds the attention
+    over the adaption prompt."""
+
+    def __init__(self, config: Config, block_idx: int) -> None:
+        super().__init__(config)
+        if block_idx >= config.adapter_start_layer:
+            # adapter embedding layer
+            self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
+            # gate for adaption
+            self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
+            # kv cache for inference
+            self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
+        self.block_idx = block_idx
+
+    def scaled_dot_product_attention(
+        self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
+    ) -> torch.Tensor:
+        y = super().scaled_dot_product_attention(q, k, v, mask)
+        if self.block_idx < self.config.adapter_start_layer:
+            return y
+
+        aT = self.config.adapter_prompt_length
+        if self.adapter_kv_cache is not None:
+            # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
+            # are the same every call
+            ak, av = self.adapter_kv_cache
+        else:
+            prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd)
+            aqkv = self.attn(prefix)
+            q_per_kv = self.config.n_head // self.config.n_query_groups
+            aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
+            aqkv = aqkv.permute(0, 2, 3, 1, 4)
+            _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)
+            if self.config.n_query_groups != 1:
+                # for MHA this is a no-op
+                ak = ak.repeat_interleave(q_per_kv, dim=2)
+                av = av.repeat_interleave(q_per_kv, dim=2)
+            ak = ak.view(1, -1, aT, self.config.head_size)  # (1, nh_ak, aT, hs)
+            av = av.view(1, -1, aT, self.config.head_size)  # (1, nh_av, aT, hs)
+            self.adapter_kv_cache = (ak, av)
+
+        T = q.size(2)
+        amask = torch.ones(T, aT, dtype=torch.bool, device=q.device)
+        ay = super().scaled_dot_product_attention(q, ak, av, amask)
+        return y + self.gating_factor * ay
+
+    def reset_parameters(self) -> None:
+        torch.nn.init.zeros_(self.gating_factor)
+
+    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
+        """For compatibility with older checkpoints."""
+        if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
+            state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
+        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
+
+
+def mark_only_adapter_as_trainable(model: GPT) -> None:
+    """Sets `requires_grad=False` for all non-adapter weights."""
+    for name, param in model.named_parameters():
+        param.requires_grad = adapter_filter(name, param)
+
+
+def adapter_filter(key: str, value: Any) -> bool:
+    return "adapter_wte" in key or "gating_factor" in key