a b/lit_gpt/adapter_v2.py
1
"""Implementation of the paper:
2
3
LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model
4
https://arxiv.org/abs/2304.15010
5
6
Port for Lit-GPT
7
"""
8
from dataclasses import dataclass
9
from typing import Any, Dict, Optional, Tuple, Type
10
11
import torch
12
import torch.nn as nn
13
from typing_extensions import Self
14
15
import lit_gpt
16
from lit_gpt.adapter import GPT as BaseModel
17
from lit_gpt.adapter import Block as BaseBlock
18
from lit_gpt.adapter import CausalSelfAttention as BaseCausalSelfAttention
19
from lit_gpt.adapter import Config as BaseConfig
20
from lit_gpt.model import KVCache
21
from lit_gpt.utils import map_old_state_dict_weights
22
23
24
@dataclass
25
class Config(BaseConfig):
26
    @property
27
    def mlp_class(self) -> Type:
28
        return getattr(lit_gpt.adapter_v2, self._mlp_class)
29
30
31
def adapter_filter(key: str, value: Any) -> bool:
32
    adapter_substrings = (
33
        # regular adapter v1 parameters
34
        "adapter_wte",
35
        "gating_factor",
36
        # adapter v2: new bias and scale used in Linear
37
        "adapter_scale",
38
        "adapter_bias",
39
        # adapter v2: Norm parameters are now trainable
40
        "norm_1",
41
        "norm_2",
42
        "ln_f",
43
    )
44
    return any(s in key for s in adapter_substrings)
45
46
47
class AdapterV2Linear(torch.nn.Module):
48
    def __init__(self, in_features: int, out_features: int, **kwargs) -> None:
49
        super().__init__()
50
        self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
51
        self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False)
52
        self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False)
53
54
    def forward(self, x: torch.Tensor) -> torch.Tensor:
55
        return self.adapter_scale * (self.linear(x) + self.adapter_bias)
56
57
    def reset_parameters(self) -> None:
58
        nn.init.zeros_(self.adapter_bias)
59
        nn.init.ones_(self.adapter_scale)
60
61
62
class GPT(BaseModel):
63
    def __init__(self, config: Config) -> None:
64
        # Skip the parent class __init__ altogether and replace it to avoid useless allocations
65
        nn.Module.__init__(self)
66
        assert config.padded_vocab_size is not None
67
        self.config = config
68
69
        self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
70
        self.transformer = nn.ModuleDict(
71
            dict(
72
                wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
73
                h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
74
                ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
75
            )
76
        )
77
        self.max_seq_length = self.config.block_size
78
        self.mask_cache: Optional[torch.Tensor] = None
79
80
    @classmethod
81
    def from_name(cls, name: str, **kwargs: Any) -> Self:
82
        return cls(Config.from_name(name, **kwargs))
83
84
    def _init_weights(self, module: nn.Module) -> None:
85
        """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
86
        super()._init_weights(module)
87
        if isinstance(module, AdapterV2Linear):
88
            module.reset_parameters()
89
90
    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
91
        """For compatibility with base checkpoints."""
92
        mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"}
93
        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
94
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
95
96
97
class Block(BaseBlock):
98
    """The implementation is identical to `lit_gpt.model.Block` with the exception that
99
    we replace the attention layer where adaption is implemented."""
100
101
    def __init__(self, config: Config, block_idx: int) -> None:
102
        # Skip the parent class __init__ altogether and replace it to avoid useless allocations
103
        nn.Module.__init__(self)
104
        self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
105
        self.attn = CausalSelfAttention(config, block_idx)
106
        if not config.shared_attention_norm:
107
            self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
108
        self.mlp = config.mlp_class(config)
109
110
        self.config = config
111
112
113
class CausalSelfAttention(BaseCausalSelfAttention):
114
    """A modification of `lit_gpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class"""
115
116
    def __init__(self, config: Config, block_idx: int) -> None:
117
        # Skip the parent class __init__ altogether and replace it to avoid useless allocations
118
        nn.Module.__init__(self)
119
        shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
120
        # key, query, value projections for all heads, but in a batch
121
        self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias)
122
        # output projection
123
        self.proj = AdapterV2Linear(config.n_embd, config.n_embd, bias=config.bias)
124
        # disabled by default
125
        self.kv_cache: Optional[KVCache] = None
126
127
        if block_idx >= config.adapter_start_layer:
128
            # adapter embedding layer
129
            self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
130
            # gate for adaption
131
            self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
132
            # kv cache for inference
133
            self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
134
        self.block_idx = block_idx
135
136
        self.config = config
137
138
    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
139
        """For compatibility with base checkpoints."""
140
        mapping = {
141
            "attn.weight": "attn.linear.weight",
142
            "attn.bias": "attn.linear.bias",
143
            "proj.weight": "proj.linear.weight",
144
            "proj.bias": "proj.linear.bias",
145
        }
146
        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
147
        # For compatibility with older checkpoints
148
        if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
149
            state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
150
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
151
152
153
class GptNeoxMLP(lit_gpt.model.GptNeoxMLP):
154
    def __init__(self, config: Config) -> None:
155
        nn.Module.__init__(self)
156
        self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
157
        self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)
158
159
        self.config = config
160
161
    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
162
        """For compatibility with base checkpoints."""
163
        mapping = {
164
            "fc.weight": "fc.linear.weight",
165
            "fc.bias": "fc.linear.bias",
166
            "proj.weight": "proj.linear.weight",
167
            "proj.bias": "proj.linear.bias",
168
        }
169
        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
170
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
171
172
173
class LLaMAMLP(lit_gpt.model.LLaMAMLP):
174
    def __init__(self, config: Config) -> None:
175
        nn.Module.__init__(self)
176
        self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
177
        self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
178
        self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)
179
180
    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
181
        """For compatibility with base checkpoints."""
182
        mapping = {
183
            "fc_1.weight": "fc_1.linear.weight",
184
            "fc_1.bias": "fc_1.linear.bias",
185
            "fc_2.weight": "fc_2.linear.weight",
186
            "fc_2.bias": "fc_2.linear.bias",
187
            "proj.weight": "proj.linear.weight",
188
            "proj.bias": "proj.linear.bias",
189
        }
190
        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
191
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
192
193
194
def mark_only_adapter_v2_as_trainable(model: GPT) -> None:
195
    """Sets requires_grad=False for all non-adapter weights"""
196
    for name, param in model.named_parameters():
197
        param.requires_grad = adapter_filter(name, param)