|
a |
|
b/lit_gpt/adapter_v2.py |
|
|
1 |
"""Implementation of the paper: |
|
|
2 |
|
|
|
3 |
LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model |
|
|
4 |
https://arxiv.org/abs/2304.15010 |
|
|
5 |
|
|
|
6 |
Port for Lit-GPT |
|
|
7 |
""" |
|
|
8 |
from dataclasses import dataclass |
|
|
9 |
from typing import Any, Dict, Optional, Tuple, Type |
|
|
10 |
|
|
|
11 |
import torch |
|
|
12 |
import torch.nn as nn |
|
|
13 |
from typing_extensions import Self |
|
|
14 |
|
|
|
15 |
import lit_gpt |
|
|
16 |
from lit_gpt.adapter import GPT as BaseModel |
|
|
17 |
from lit_gpt.adapter import Block as BaseBlock |
|
|
18 |
from lit_gpt.adapter import CausalSelfAttention as BaseCausalSelfAttention |
|
|
19 |
from lit_gpt.adapter import Config as BaseConfig |
|
|
20 |
from lit_gpt.model import KVCache |
|
|
21 |
from lit_gpt.utils import map_old_state_dict_weights |
|
|
22 |
|
|
|
23 |
|
|
|
24 |
@dataclass |
|
|
25 |
class Config(BaseConfig): |
|
|
26 |
@property |
|
|
27 |
def mlp_class(self) -> Type: |
|
|
28 |
return getattr(lit_gpt.adapter_v2, self._mlp_class) |
|
|
29 |
|
|
|
30 |
|
|
|
31 |
def adapter_filter(key: str, value: Any) -> bool: |
|
|
32 |
adapter_substrings = ( |
|
|
33 |
# regular adapter v1 parameters |
|
|
34 |
"adapter_wte", |
|
|
35 |
"gating_factor", |
|
|
36 |
# adapter v2: new bias and scale used in Linear |
|
|
37 |
"adapter_scale", |
|
|
38 |
"adapter_bias", |
|
|
39 |
# adapter v2: Norm parameters are now trainable |
|
|
40 |
"norm_1", |
|
|
41 |
"norm_2", |
|
|
42 |
"ln_f", |
|
|
43 |
) |
|
|
44 |
return any(s in key for s in adapter_substrings) |
|
|
45 |
|
|
|
46 |
|
|
|
47 |
class AdapterV2Linear(torch.nn.Module): |
|
|
48 |
def __init__(self, in_features: int, out_features: int, **kwargs) -> None: |
|
|
49 |
super().__init__() |
|
|
50 |
self.linear = torch.nn.Linear(in_features, out_features, **kwargs) |
|
|
51 |
self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False) |
|
|
52 |
self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False) |
|
|
53 |
|
|
|
54 |
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
55 |
return self.adapter_scale * (self.linear(x) + self.adapter_bias) |
|
|
56 |
|
|
|
57 |
def reset_parameters(self) -> None: |
|
|
58 |
nn.init.zeros_(self.adapter_bias) |
|
|
59 |
nn.init.ones_(self.adapter_scale) |
|
|
60 |
|
|
|
61 |
|
|
|
62 |
class GPT(BaseModel): |
|
|
63 |
def __init__(self, config: Config) -> None: |
|
|
64 |
# Skip the parent class __init__ altogether and replace it to avoid useless allocations |
|
|
65 |
nn.Module.__init__(self) |
|
|
66 |
assert config.padded_vocab_size is not None |
|
|
67 |
self.config = config |
|
|
68 |
|
|
|
69 |
self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) |
|
|
70 |
self.transformer = nn.ModuleDict( |
|
|
71 |
dict( |
|
|
72 |
wte=nn.Embedding(config.padded_vocab_size, config.n_embd), |
|
|
73 |
h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), |
|
|
74 |
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), |
|
|
75 |
) |
|
|
76 |
) |
|
|
77 |
self.max_seq_length = self.config.block_size |
|
|
78 |
self.mask_cache: Optional[torch.Tensor] = None |
|
|
79 |
|
|
|
80 |
@classmethod |
|
|
81 |
def from_name(cls, name: str, **kwargs: Any) -> Self: |
|
|
82 |
return cls(Config.from_name(name, **kwargs)) |
|
|
83 |
|
|
|
84 |
def _init_weights(self, module: nn.Module) -> None: |
|
|
85 |
"""Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" |
|
|
86 |
super()._init_weights(module) |
|
|
87 |
if isinstance(module, AdapterV2Linear): |
|
|
88 |
module.reset_parameters() |
|
|
89 |
|
|
|
90 |
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
|
|
91 |
"""For compatibility with base checkpoints.""" |
|
|
92 |
mapping = {"lm_head.weight": "lm_head.linear.weight", "lm_head.bias": "lm_head.linear.bias"} |
|
|
93 |
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
|
|
94 |
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
95 |
|
|
|
96 |
|
|
|
97 |
class Block(BaseBlock): |
|
|
98 |
"""The implementation is identical to `lit_gpt.model.Block` with the exception that |
|
|
99 |
we replace the attention layer where adaption is implemented.""" |
|
|
100 |
|
|
|
101 |
def __init__(self, config: Config, block_idx: int) -> None: |
|
|
102 |
# Skip the parent class __init__ altogether and replace it to avoid useless allocations |
|
|
103 |
nn.Module.__init__(self) |
|
|
104 |
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) |
|
|
105 |
self.attn = CausalSelfAttention(config, block_idx) |
|
|
106 |
if not config.shared_attention_norm: |
|
|
107 |
self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) |
|
|
108 |
self.mlp = config.mlp_class(config) |
|
|
109 |
|
|
|
110 |
self.config = config |
|
|
111 |
|
|
|
112 |
|
|
|
113 |
class CausalSelfAttention(BaseCausalSelfAttention): |
|
|
114 |
"""A modification of `lit_gpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class""" |
|
|
115 |
|
|
|
116 |
def __init__(self, config: Config, block_idx: int) -> None: |
|
|
117 |
# Skip the parent class __init__ altogether and replace it to avoid useless allocations |
|
|
118 |
nn.Module.__init__(self) |
|
|
119 |
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size |
|
|
120 |
# key, query, value projections for all heads, but in a batch |
|
|
121 |
self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias) |
|
|
122 |
# output projection |
|
|
123 |
self.proj = AdapterV2Linear(config.n_embd, config.n_embd, bias=config.bias) |
|
|
124 |
# disabled by default |
|
|
125 |
self.kv_cache: Optional[KVCache] = None |
|
|
126 |
|
|
|
127 |
if block_idx >= config.adapter_start_layer: |
|
|
128 |
# adapter embedding layer |
|
|
129 |
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) |
|
|
130 |
# gate for adaption |
|
|
131 |
self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1)) |
|
|
132 |
# kv cache for inference |
|
|
133 |
self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None |
|
|
134 |
self.block_idx = block_idx |
|
|
135 |
|
|
|
136 |
self.config = config |
|
|
137 |
|
|
|
138 |
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
|
|
139 |
"""For compatibility with base checkpoints.""" |
|
|
140 |
mapping = { |
|
|
141 |
"attn.weight": "attn.linear.weight", |
|
|
142 |
"attn.bias": "attn.linear.bias", |
|
|
143 |
"proj.weight": "proj.linear.weight", |
|
|
144 |
"proj.bias": "proj.linear.bias", |
|
|
145 |
} |
|
|
146 |
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
|
|
147 |
# For compatibility with older checkpoints |
|
|
148 |
if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head: |
|
|
149 |
state_dict[key] = state_dict[key].permute(0, 2, 1, 3) |
|
|
150 |
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
151 |
|
|
|
152 |
|
|
|
153 |
class GptNeoxMLP(lit_gpt.model.GptNeoxMLP): |
|
|
154 |
def __init__(self, config: Config) -> None: |
|
|
155 |
nn.Module.__init__(self) |
|
|
156 |
self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) |
|
|
157 |
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias) |
|
|
158 |
|
|
|
159 |
self.config = config |
|
|
160 |
|
|
|
161 |
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
|
|
162 |
"""For compatibility with base checkpoints.""" |
|
|
163 |
mapping = { |
|
|
164 |
"fc.weight": "fc.linear.weight", |
|
|
165 |
"fc.bias": "fc.linear.bias", |
|
|
166 |
"proj.weight": "proj.linear.weight", |
|
|
167 |
"proj.bias": "proj.linear.bias", |
|
|
168 |
} |
|
|
169 |
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
|
|
170 |
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
171 |
|
|
|
172 |
|
|
|
173 |
class LLaMAMLP(lit_gpt.model.LLaMAMLP): |
|
|
174 |
def __init__(self, config: Config) -> None: |
|
|
175 |
nn.Module.__init__(self) |
|
|
176 |
self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) |
|
|
177 |
self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) |
|
|
178 |
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias) |
|
|
179 |
|
|
|
180 |
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
|
|
181 |
"""For compatibility with base checkpoints.""" |
|
|
182 |
mapping = { |
|
|
183 |
"fc_1.weight": "fc_1.linear.weight", |
|
|
184 |
"fc_1.bias": "fc_1.linear.bias", |
|
|
185 |
"fc_2.weight": "fc_2.linear.weight", |
|
|
186 |
"fc_2.bias": "fc_2.linear.bias", |
|
|
187 |
"proj.weight": "proj.linear.weight", |
|
|
188 |
"proj.bias": "proj.linear.bias", |
|
|
189 |
} |
|
|
190 |
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
|
|
191 |
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
192 |
|
|
|
193 |
|
|
|
194 |
def mark_only_adapter_v2_as_trainable(model: GPT) -> None: |
|
|
195 |
"""Sets requires_grad=False for all non-adapter weights""" |
|
|
196 |
for name, param in model.named_parameters(): |
|
|
197 |
param.requires_grad = adapter_filter(name, param) |