Download this file

205 lines (173 with data), 8.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from pathlib import Path
from unittest.mock import MagicMock
import torch
import torch.nn as nn
from nemo.lightning import io, teardown
from nemo.lightning.pytorch.utils import dtype_from_hf
from transformers import AutoConfig as HFAutoConfig
from transformers import AutoModel
from bionemo.amplify.model import AMPLIFYConfig
from bionemo.amplify.tokenizer import BioNeMoAMPLIFYTokenizer
from bionemo.llm.lightning import BionemoLightningModule
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
@io.model_importer(BionemoLightningModule, "hf")
class HFAMPLIFYImporter(io.ModelConnector[AutoModel, BionemoLightningModule]):
"""Converts a Hugging Face AMPLIFY model to a NeMo AMPLIFY model."""
def init(self) -> BionemoLightningModule:
"""Initialize the converted model."""
return biobert_lightning_module(self.config, tokenizer=self.tokenizer)
def apply(self, output_path: Path) -> Path:
"""Applies the transformation."""
source = AutoModel.from_pretrained(str(self), trust_remote_code=True, torch_dtype="auto")
target = self.init()
trainer = self.nemo_setup(target)
self.convert_state(source, target)
self.nemo_save(output_path, trainer)
teardown(trainer, target)
return output_path
def convert_state(self, source, target):
"""Converting HF state dict to NeMo state dict."""
mapping = {
"transformer_encoder.*.wo.weight": "encoder.layers.*.self_attention.linear_proj.weight",
"transformer_encoder.*.ffn.w12.weight": "encoder.layers.*.mlp.linear_fc1.weight",
"transformer_encoder.*.ffn.w3.weight": "encoder.layers.*.mlp.linear_fc2.weight",
"transformer_encoder.*.attention_norm.weight": "encoder.layers.*.self_attention.linear_qkv.layer_norm_weight",
"transformer_encoder.*.ffn_norm.weight": "encoder.layers.*.mlp.linear_fc1.layer_norm_weight",
"layer_norm_2.weight": "encoder.final_layernorm.weight",
}
# lm_head.bias
return io.apply_transforms(
source,
target,
mapping=mapping,
transforms=[_import_qkv_weight, _pad_embeddings, _pad_bias, _pad_output_weights],
)
@property
def tokenizer(self) -> BioNeMoAMPLIFYTokenizer:
"""We just have the one tokenizer for AMPLIFY."""
return BioNeMoAMPLIFYTokenizer()
@property
def config(self) -> AMPLIFYConfig:
"""Returns the transformed AMPLIFY config given the model tag."""
source = HFAutoConfig.from_pretrained(str(self), trust_remote_code=True)
output = AMPLIFYConfig(
num_layers=source.num_hidden_layers,
hidden_size=source.hidden_size,
ffn_hidden_size=source.intermediate_size,
position_embedding_type="rope",
num_attention_heads=source.num_attention_heads,
seq_length=source.max_length,
fp16=(dtype_from_hf(source) == torch.float16),
bf16=(dtype_from_hf(source) == torch.bfloat16),
params_dtype=dtype_from_hf(source),
)
return output
@io.state_transform(
source_key=(
"transformer_encoder.*.q.weight",
"transformer_encoder.*.k.weight",
"transformer_encoder.*.v.weight",
),
target_key="encoder.layers.*.self_attention.linear_qkv.weight",
)
def _import_qkv_weight(ctx: io.TransformCTX, query, key, value):
"""Pad the embedding layer to the new input dimension."""
concat_weights = torch.cat((query, key, value), dim=0)
input_shape = concat_weights.size()
np = ctx.target.config.num_attention_heads
# transpose weights
# [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads]
# --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads]
concat_weights = concat_weights.view(3, np, -1, query.size()[-1])
concat_weights = concat_weights.transpose(0, 1).contiguous()
concat_weights = concat_weights.view(*input_shape)
return concat_weights
@io.state_transform(
source_key="encoder.weight",
target_key="embedding.word_embeddings.weight",
)
def _pad_embeddings(ctx: io.TransformCTX, source_embed):
"""Pad the embedding layer to the new input dimension.
This allows us to pad the input vocabulary to a dimension that's better suited for tensor core utilization.
"""
nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by
hf_embedding_dimension = source_embed.size(0)
num_padding_rows = nemo_embedding_dimension - hf_embedding_dimension
padding_rows = torch.zeros(num_padding_rows, source_embed.size(1))
return torch.cat((source_embed, padding_rows), dim=0)
@io.state_transform(
source_key="decoder.weight",
target_key="output_layer.weight",
)
def _pad_output_weights(ctx: io.TransformCTX, source_embed):
"""Pad the output weight to the new input dimension since AMPLIFY does not tie output weights to input embeddings."""
nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by
hf_embedding_dimension = source_embed.size(0)
num_padding_rows = nemo_embedding_dimension - hf_embedding_dimension
padding_rows = torch.zeros(num_padding_rows, source_embed.size(1))
return torch.cat((source_embed, padding_rows), dim=0)
@io.state_transform(
source_key="decoder.bias",
target_key="output_layer.bias",
)
def _pad_bias(ctx: io.TransformCTX, source_bias):
"""Pad the embedding layer to the new input dimension."""
nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by
hf_embedding_dimension = source_bias.size(0)
output_bias = torch.zeros(nemo_embedding_dimension, dtype=source_bias.dtype, device=source_bias.device)
output_bias[:hf_embedding_dimension] = source_bias
return output_bias
class SwiGLU(nn.Module):
"""Mock SwiGLU module.
This module is a mock implementation of the SwiGLU module that is only used to ensure we can load weight matrices
correctly from the huggingface checkpoint without installing xformers in the framework container.
"""
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int | None = None,
bias: bool = True,
*,
_pack_weights: bool = True,
) -> None:
"""Create a SwiGLU module."""
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w12: nn.Linear | None = None
if _pack_weights:
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
else:
self.w12 = None
self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
self.hidden_features = hidden_features
self.out_features = out_features
self.in_features = in_features
def forward(self, x): # noqa: D102
raise NotImplementedError("This SwiGLU is a mock module and should not be used.")
def maybe_mock_xformers():
"""Optionally mock the xformers library to import amplify without the dependency."""
if "xformers" not in sys.modules:
ops_mock = MagicMock()
ops_mock.memory_efficient_attention = MagicMock()
ops_mock.SwiGLU = SwiGLU
sys.modules["xformers"] = MagicMock()
sys.modules["xformers.ops"] = ops_mock