Diff of /finetune/adapter_v2.py [000000] .. [248dc9]

Switch to unified view

a b/finetune/adapter_v2.py
1
## This script is used to finetune the adapter v2 model on the entity extraction task.
2
## This script is adapted from the original script in the LIT repository: https://github.com/Lightning-AI/lit-gpt
3
4
import os
5
import sys
6
import time
7
from pathlib import Path
8
from typing import Dict, List, Optional, Tuple
9
10
import lightning as L
11
import torch
12
from lightning.fabric.loggers import CSVLogger
13
from lightning.fabric.strategies import FSDPStrategy
14
from lightning.fabric.utilities import ThroughputMonitor
15
16
# support running without installing as a package
17
wd = Path(__file__).parent.parent.resolve()
18
sys.path.append(str(wd))
19
20
from generate.base import generate
21
from lit_gpt.adapter_v2 import GPT, Block, Config, adapter_filter, mark_only_adapter_v2_as_trainable
22
from lit_gpt.tokenizer import Tokenizer
23
from lit_gpt.utils import (
24
    check_valid_checkpoint_dir,
25
    chunked_cross_entropy,
26
    get_default_supported_precision,
27
    lazy_load,
28
    num_parameters,
29
)
30
from scripts.prepare_entity_extraction_data import generate_prompt
31
32
eval_interval = 100
33
save_interval = 100
34
eval_iters = 100
35
eval_max_new_tokens = 35
36
log_interval = 1
37
devices = 1
38
39
# Hyperparameters
40
learning_rate = 3e-3
41
batch_size = 8 / devices
42
micro_batch_size = 1  # set to 2 because this is fit into 12GB Vram
43
gradient_accumulation_iters = batch_size // micro_batch_size
44
assert gradient_accumulation_iters > 0
45
max_seq_length = None  # assign value to truncate
46
epoch_size = 700  # train dataset size
47
num_epochs = 5
48
max_iters = num_epochs * (epoch_size // micro_batch_size) // devices
49
weight_decay = 0.02
50
warmup_steps = 2 * (epoch_size // micro_batch_size) // devices // gradient_accumulation_iters  # 2 epochs
51
52
hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")}
53
54
55
def setup(
56
    data_dir: Path = Path("data/entity_extraction"),
57
    checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
58
    out_dir: Path = Path("out/adapter_v2/Stable-LM/entity_extraction"),
59
    precision: Optional[str] = None,
60
) -> None:
61
    """
62
    Finetune the adapter v2 model on the entity extraction task.
63
64
    Args:
65
        data_dir (Path): Path to the directory containing the dataset.
66
        checkpoint_dir (Path): Path to the directory containing the checkpoint.
67
        out_dir (Path): Path to the directory to save the finetuned model.
68
        precision (str): Precision to use for training. Defaults to None.
69
    
70
    Returns:
71
        None
72
    """
73
    precision = precision or get_default_supported_precision(training=True)
74
75
    fabric_devices = devices
76
    if fabric_devices > 1:
77
        strategy = FSDPStrategy(
78
            auto_wrap_policy={Block},
79
            activation_checkpointing_policy={Block},
80
            state_dict_type="full",
81
            limit_all_gathers=True,
82
            cpu_offload=False,
83
        )
84
    else:
85
        strategy = "auto"
86
87
    logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval)
88
    fabric = L.Fabric(devices=fabric_devices, strategy=strategy, precision=precision, loggers=logger)
89
    fabric.print(hparams)
90
    fabric.launch(main, data_dir, checkpoint_dir, out_dir)
91
92
93
def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path) -> None:
94
    """
95
    Finetune the adapter v2 model on the entity extraction task.
96
97
    Args:
98
        fabric (Fabric): Fabric object.
99
        data_dir (Path): Path to the directory containing the dataset.
100
        checkpoint_dir (Path): Path to the directory containing the checkpoint.
101
        out_dir (Path): Path to the directory to save the finetuned model.
102
103
    Returns:
104
        None
105
    """
106
    check_valid_checkpoint_dir(checkpoint_dir)
107
108
    fabric.seed_everything(1337)  # same seed for every process to init model (FSDP)
109
110
    if fabric.global_rank == 0:
111
        os.makedirs(out_dir, exist_ok=True)
112
113
    train_data = torch.load(data_dir / "train.pt")
114
    val_data = torch.load(data_dir / "test.pt")
115
116
    config = Config.from_name(name=checkpoint_dir.name)
117
    checkpoint_path = checkpoint_dir / "lit_model.pth"
118
    
119
    fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
120
    with fabric.init_module(empty_init=False):
121
        model = GPT(config)
122
    checkpoint = lazy_load(checkpoint_path)
123
    # strict=False because missing keys due to adapter weights not contained in state dict
124
    model.load_state_dict(checkpoint, strict=False)
125
126
    mark_only_adapter_v2_as_trainable(model)
127
128
    fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
129
    fabric.print(f"Number of non trainable parameters: {num_parameters(model, requires_grad=False):,}")
130
    trainable_params = [p for p in model.parameters() if p.requires_grad]
131
132
    optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
133
    model, optimizer = fabric.setup(model, optimizer)
134
135
    fabric.seed_everything(1337 + fabric.global_rank)
136
137
    train_time = time.perf_counter()
138
    train(fabric, model, optimizer, train_data, val_data, checkpoint_dir, out_dir)
139
    fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")
140
    if fabric.device.type == "cuda":
141
        fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
142
143
    # Save the final checkpoint at the end of training
144
    save_path = out_dir / "lit_model_adapter_finetuned.pth"
145
    save_adapter_v2_checkpoint(fabric, model, save_path)
146
147
148
def train(
149
    fabric: L.Fabric,
150
    model: GPT,
151
    optimizer: torch.optim.Optimizer,
152
    train_data: List[Dict],
153
    val_data: List[Dict],
154
    checkpoint_dir: Path,
155
    out_dir: Path,
156
) -> None:
157
    """
158
    Finetune the adapter v2 model on the entity extraction task. This function trains the model.
159
160
    Args:
161
        fabric (Fabric): Fabric object.
162
        model (GPT): The model to finetune.
163
        optimizer (torch.optim.Optimizer): Optimizer to use for training.
164
        train_data (List[Dict]): Training data.
165
        val_data (List[Dict]): Validation data.
166
        checkpoint_dir (Path): Path to the directory containing the checkpoint.
167
        out_dir (Path): Path to the directory to save the finetuned model.
168
169
    Returns:
170
        None
171
    """
172
    tokenizer = Tokenizer(checkpoint_dir)
173
    longest_seq_length, longest_seq_ix = get_longest_seq_length(train_data)
174
    model.max_seq_length = min(longest_seq_length, max_seq_length or float("inf"))
175
    fabric.print(
176
        f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is"
177
        f" {model.max_seq_length} and context length is {model.config.block_size}"
178
    )
179
180
    validate(fabric, model, val_data, tokenizer, max_iters=2)  # sanity check
181
182
    throughput = ThroughputMonitor(fabric, window_size=50)
183
    step_count = 0
184
    total_lengths = 0
185
    total_t0 = time.perf_counter()
186
187
    for iter_num in range(1, max_iters + 1):
188
        if step_count <= warmup_steps:
189
            # linear warmup
190
            lr = learning_rate * step_count / warmup_steps
191
            for param_group in optimizer.param_groups:
192
                param_group["lr"] = lr
193
194
        iter_t0 = time.perf_counter()
195
196
        input_ids, targets = get_batch(fabric, train_data, longest_seq_ix if iter_num == 1 else None)
197
198
        is_accumulating = iter_num % gradient_accumulation_iters != 0
199
        with fabric.no_backward_sync(model, enabled=is_accumulating):
200
            logits = model(input_ids, lm_head_chunk_size=128)
201
            # shift the targets such that output n predicts token n+1
202
            logits[-1] = logits[-1][..., :-1, :]
203
            loss = chunked_cross_entropy(logits, targets[..., 1:])
204
            fabric.backward(loss / gradient_accumulation_iters)
205
206
        if not is_accumulating:
207
            optimizer.step()
208
            optimizer.zero_grad()
209
            step_count += 1
210
211
        total_lengths += input_ids.numel()
212
        if iter_num % log_interval == 0:
213
            loss_item = loss.item()  # expensive device-to-host synchronization
214
            t1 = time.perf_counter()
215
            throughput.update(
216
                time=t1 - total_t0, batches=iter_num, samples=iter_num * micro_batch_size, lengths=total_lengths
217
            )
218
            throughput.compute_and_log(step=iter_num)
219
            fabric.print(
220
                f"iter {iter_num} step {step_count}: loss {loss_item:.4f}, iter time:"
221
                f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}"
222
            )
223
224
        if not is_accumulating and step_count % eval_interval == 0:
225
            t0 = time.perf_counter()
226
            val_loss = validate(fabric, model, val_data, tokenizer, max_iters=eval_iters)
227
            t1 = time.perf_counter() - t0
228
            fabric.print(f"step {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f}ms")
229
            fabric.barrier()
230
        if not is_accumulating and step_count % save_interval == 0:
231
            checkpoint_path = out_dir / f"iter-{iter_num:06d}-ckpt.pth"
232
            save_adapter_v2_checkpoint(fabric, model, checkpoint_path)
233
234
235
# the adapter "kv cache" cannot be initialized under `inference_mode`
236
@torch.no_grad()
237
def validate(fabric: L.Fabric, model: GPT, val_data: List[Dict], tokenizer: Tokenizer, max_iters: int) -> torch.Tensor:
238
    """
239
    Finetune the adapter v2 model on the entity extraction task. This function validates the model.
240
241
    Args:
242
        fabric (Fabric): Fabric object.
243
        model (GPT): The model to finetune.
244
        val_data (List[Dict]): Validation data.
245
        tokenizer (Tokenizer): Tokenizer to use for tokenizing the input.
246
        max_iters (int): Maximum number of iterations to run.
247
248
    Returns:
249
        torch.Tensor: Validation loss.
250
    """
251
    fabric.print("Validating ...")
252
    model.eval()
253
    
254
    losses = torch.zeros(max_iters)
255
    for k in range(max_iters):
256
        input_ids, targets = get_batch(fabric, val_data)
257
        logits = model(input_ids)
258
        losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)
259
    val_loss = losses.mean()
260
261
    # produce an example:
262
    sample = {"input": "Robert Johnson\nrobert.johnson@email.com\n789 Maple Lane, Chicago, IL 60601\n555-234-5678, United States\n\nRelationship to XYZ Pharma Inc.: Patient\nReason for contacting: Adverse Event\n\nMessage: I've been on Onglyza for a while, and I've noticed that I'm experiencing frequent painful urination. Is this a known side effect?"}
263
    prompt = generate_prompt(sample)
264
    encoded = tokenizer.encode(prompt, device=fabric.device)
265
    
266
    with fabric.init_tensor():
267
        # do not set `max_seq_length=max_returned_token` because memory is not a concern here
268
        model.set_kv_cache(batch_size=1)
269
    output = generate(model, encoded, max_returned_tokens=len(encoded) + eval_max_new_tokens, temperature=0.8)
270
    model.clear_kv_cache()
271
    output = tokenizer.decode(output)
272
    fabric.print(output)
273
274
    model.train()
275
    return val_loss
276
277
278
def get_batch(
279
    fabric: L.Fabric, data: List[Dict], longest_seq_ix: Optional[int] = None
280
) -> Tuple[torch.Tensor, torch.Tensor]:
281
    """
282
    This function gets a batch of data.
283
284
    Args:
285
        fabric (Fabric): Fabric object.
286
        data (List[Dict]): Data to get a batch from.
287
        longest_seq_ix (Optional[int]): Index of the longest sequence. Defaults to None.
288
289
    Returns:
290
        Tuple[torch.Tensor, torch.Tensor]: A batch of data.
291
    """
292
    ix = torch.randint(len(data), (micro_batch_size,))
293
    if longest_seq_ix is not None:
294
        # force the longest sample at the beginning so potential OOMs happen right away
295
        ix[0] = longest_seq_ix
296
297
    input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
298
    labels = [data[i]["labels"].type(torch.int64) for i in ix]
299
300
    # this could be `longest_seq_length` to have a fixed size for all batches
301
    max_len = max(len(s) for s in input_ids)
302
303
    def pad_right(x, pad_id):
304
        # pad right based on the longest sequence
305
        n = max_len - len(x)
306
        return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
307
308
    x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
309
    y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
310
311
    # Truncate if needed
312
    if max_seq_length:
313
        x = x[:, :max_seq_length]
314
        y = y[:, :max_seq_length]
315
316
    if fabric.device.type == "cuda" and x.device.type == "cpu":
317
        x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
318
    else:
319
        x, y = fabric.to_device((x, y))
320
    return x, y
321
322
323
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
324
    """
325
    This function gets the longest sequence length.
326
327
    Args:
328
        data (List[Dict]): Data to get the longest sequence length from.
329
330
    Returns:
331
        Tuple[int, int]: Longest sequence length and index of the longest sequence.
332
    """
333
    # find out the minimum max_seq_length required during fine-tuning (saves memory!)
334
    lengths = [len(d["input_ids"]) for d in data]
335
    longest_seq_length = max(lengths)
336
    longest_seq_ix = lengths.index(longest_seq_length)
337
    return longest_seq_length, longest_seq_ix
338
339
340
def save_adapter_v2_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:
341
    """
342
    This function saves the adapter v2 checkpoint.
343
344
    Args:
345
        fabric (Fabric): Fabric object.
346
        model (torch.nn.Module): The model to save.
347
        file_path (Path): Path to the file to save the model to.
348
349
    Returns:
350
        None
351
    """
352
    fabric.print(f"Saving adapter v2 weights to {str(file_path)!r}")
353
    fabric.save(file_path, {"model": model}, filter={"model": adapter_filter})
354
355
356
if __name__ == "__main__":
357
    torch.set_float32_matmul_precision("high")
358
359
    from jsonargparse import CLI
360
361
    CLI(setup)