Diff of /finetune/lora.py [000000] .. [248dc9]

Switch to unified view

a b/finetune/lora.py
1
## This script is used to finetune the model on the entity extraction task using LoRA
2
## This script is adapted from the original script in the LIT repository: https://github.com/Lightning-AI/lit-gpt
3
4
5
import os
6
import sys
7
import time
8
from pathlib import Path
9
from typing import Dict, List, Literal, Optional, Tuple
10
11
import lightning as L
12
import torch
13
from lightning.fabric.loggers import CSVLogger
14
from lightning.fabric.plugins import BitsandbytesPrecision
15
from lightning.fabric.strategies import FSDPStrategy
16
from lightning.fabric.utilities import ThroughputMonitor
17
18
# support running without installing as a package
19
wd = Path(__file__).parent.parent.resolve()
20
sys.path.append(str(wd))
21
22
from generate.base import generate
23
from lit_gpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable
24
from lit_gpt.tokenizer import Tokenizer
25
from lit_gpt.utils import (
26
    check_valid_checkpoint_dir,
27
    chunked_cross_entropy,
28
    get_default_supported_precision,
29
    load_checkpoint,
30
    num_parameters,
31
)
32
from scripts.prepare_entity_extraction_data import generate_prompt
33
34
eval_interval = 100
35
save_interval = 100
36
eval_iters = 100
37
eval_max_new_tokens = 38
38
log_interval = 1
39
devices = 1
40
41
# Hyperparameters
42
learning_rate = 3e-4
43
batch_size = 16
44
micro_batch_size = 4
45
gradient_accumulation_iters = batch_size // micro_batch_size
46
assert gradient_accumulation_iters > 0
47
max_seq_length = None  # assign value to truncate
48
max_iters = 700  # train dataset size
49
weight_decay = 0.01
50
lora_r = 8
51
lora_alpha = 16
52
lora_dropout = 0.05
53
lora_query = True
54
lora_key = False
55
lora_value = True
56
lora_projection = False
57
lora_mlp = False
58
lora_head = False
59
warmup_steps = 100
60
61
hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")}
62
63
64
def setup(
65
    data_dir: Path = Path("data/entity_extraction"),
66
    checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
67
    out_dir: Path = Path("out/lora/Stable-LM/entity_extraction"),
68
    precision: Optional[str] = None,
69
    quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
70
) -> None:
71
    """
72
    This script is used to finetune the model on the entity extraction task using LoRA
73
74
    Args:
75
        data_dir: Path to the data directory containing the train.pt and test.pt files
76
        checkpoint_dir: Path to the directory containing the pre-trained model checkpoint
77
        out_dir: Path to the directory where the finetuned model will be saved
78
        precision: Precision to use for training
79
        quantize: Quantization to use for training
80
81
    Returns:
82
        None
83
    """
84
    precision = precision or get_default_supported_precision(training=True)
85
86
    plugins = None
87
    if quantize is not None and quantize.startswith("bnb."):
88
        if "mixed" in precision:
89
            raise ValueError("Quantization and mixed precision is not supported.")
90
        dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
91
        plugins = BitsandbytesPrecision(quantize[4:], dtype)
92
        precision = None
93
94
    if devices > 1:
95
        if quantize:
96
            raise NotImplementedError(
97
                "Quantization is currently not supported for multi-GPU training. Please set devices=1 when using the"
98
                " --quantize flag."
99
            )
100
        strategy = FSDPStrategy(
101
            auto_wrap_policy={Block},
102
            activation_checkpointing_policy={Block},
103
            state_dict_type="full",
104
            limit_all_gathers=True,
105
            cpu_offload=False,
106
        )
107
    else:
108
        strategy = "auto"
109
110
    logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval)
111
    fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins)
112
    fabric.print(hparams)
113
    fabric.launch(main, data_dir, checkpoint_dir, out_dir)
114
115
116
def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path) -> None:
117
    """
118
    Main function for finetuning the model on the entity extraction task using LoRA
119
120
    Args:
121
        fabric: Lightning Fabric object
122
        data_dir: Path to the data directory containing the train.pt and test.pt files
123
        checkpoint_dir: Path to the directory containing the pre-trained model checkpoint
124
        out_dir: Path to the directory where the finetuned model will be saved
125
126
    Returns:    
127
        None
128
    """
129
    check_valid_checkpoint_dir(checkpoint_dir)
130
131
    fabric.seed_everything(1337)  # same seed for every process to init model (FSDP)
132
133
    if fabric.global_rank == 0:
134
        os.makedirs(out_dir, exist_ok=True)
135
136
    train_data = torch.load(data_dir / "train.pt")
137
    val_data = torch.load(data_dir / "test.pt")
138
139
    if not any((lora_query, lora_key, lora_value, lora_projection, lora_mlp, lora_head)):
140
        fabric.print("Warning: all LoRA layers are disabled!")
141
    config = Config.from_name(
142
        name=checkpoint_dir.name,
143
        r=lora_r,
144
        alpha=lora_alpha,
145
        dropout=lora_dropout,
146
        to_query=lora_query,
147
        to_key=lora_key,
148
        to_value=lora_value,
149
        to_projection=lora_projection,
150
        to_mlp=lora_mlp,
151
        to_head=lora_head,
152
    )
153
    checkpoint_path = checkpoint_dir / "lit_model.pth"
154
    fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
155
    with fabric.init_module(empty_init=(devices > 1)):
156
        model = GPT(config)
157
    mark_only_lora_as_trainable(model)
158
159
    fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
160
    fabric.print(f"Number of non trainable parameters: {num_parameters(model, requires_grad=False):,}")
161
162
    model = fabric.setup_module(model)
163
164
    trainable_params = [p for p in model.parameters() if p.requires_grad]
165
    if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
166
        import bitsandbytes as bnb
167
168
        optimizer = bnb.optim.PagedAdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
169
    else:
170
        optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
171
    optimizer = fabric.setup_optimizers(optimizer)
172
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_iters // batch_size)
173
174
    # strict=False because missing keys due to LoRA weights not contained in state dict
175
    load_checkpoint(fabric, model, checkpoint_path, strict=False)
176
177
    fabric.seed_everything(1337 + fabric.global_rank)
178
179
    train_time = time.perf_counter()
180
    train(fabric, model, optimizer, scheduler, train_data, val_data, checkpoint_dir, out_dir)
181
    fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")
182
    if fabric.device.type == "cuda":
183
        fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
184
185
    # Save the final LoRA checkpoint at the end of training
186
    save_path = out_dir / "lit_model_lora_finetuned.pth"
187
    save_lora_checkpoint(fabric, model, save_path)
188
189
190
def train(
191
    fabric: L.Fabric,
192
    model: GPT,
193
    optimizer: torch.optim.Optimizer,
194
    scheduler: torch.optim.lr_scheduler,
195
    train_data: List[Dict],
196
    val_data: List[Dict],
197
    checkpoint_dir: Path,
198
    out_dir: Path,
199
) -> None:
200
    """
201
    Function for training the model on the entity extraction task using LoRA
202
203
    Args:
204
        fabric: Lightning Fabric object
205
        model: GPT model
206
        optimizer: Optimizer for training
207
        scheduler: Learning rate scheduler
208
        train_data: Training data
209
        val_data: Validation data
210
        checkpoint_dir: Path to the directory containing the pre-trained model checkpoint
211
        out_dir: Path to the directory where the finetuned model will be saved
212
213
    Returns:
214
        None
215
    """
216
    tokenizer = Tokenizer(checkpoint_dir)
217
    longest_seq_length, longest_seq_ix = get_longest_seq_length(train_data)
218
    model.max_seq_length = min(longest_seq_length, max_seq_length or float("inf"))
219
    fabric.print(
220
        f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is"
221
        f" {model.max_seq_length} and context length is {model.config.block_size}"
222
    )
223
224
    validate(fabric, model, val_data, tokenizer, max_iters=2)  # sanity check
225
226
    throughput = ThroughputMonitor(fabric, window_size=50)
227
    step_count = 0
228
    total_lengths = 0
229
    total_t0 = time.perf_counter()
230
231
    for iter_num in range(1, max_iters + 1):
232
        if step_count <= warmup_steps:
233
            # linear warmup
234
            lr = learning_rate * step_count / warmup_steps
235
            for param_group in optimizer.param_groups:
236
                param_group["lr"] = lr
237
238
        iter_t0 = time.perf_counter()
239
240
        input_ids, targets = get_batch(fabric, train_data, longest_seq_ix if iter_num == 1 else None)
241
242
        is_accumulating = iter_num % gradient_accumulation_iters != 0
243
        with fabric.no_backward_sync(model, enabled=is_accumulating):
244
            logits = model(input_ids, lm_head_chunk_size=128)
245
            # shift the targets such that output n predicts token n+1
246
            logits[-1] = logits[-1][..., :-1, :]
247
            loss = chunked_cross_entropy(logits, targets[..., 1:])
248
            fabric.backward(loss / gradient_accumulation_iters)
249
250
        if not is_accumulating:
251
            optimizer.step()
252
            optimizer.zero_grad()
253
            if step_count > warmup_steps:
254
                scheduler.step()
255
            step_count += 1
256
257
        total_lengths += input_ids.numel()
258
        if iter_num % log_interval == 0:
259
            loss_item = loss.item()  # expensive device-to-host synchronization
260
            t1 = time.perf_counter()
261
            throughput.update(
262
                time=t1 - total_t0, batches=iter_num, samples=iter_num * micro_batch_size, lengths=total_lengths
263
            )
264
            throughput.compute_and_log(step=iter_num)
265
            fabric.print(
266
                f"iter {iter_num} step {step_count}: loss {loss_item:.4f}, iter time:"
267
                f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}"
268
            )
269
270
        if not is_accumulating and step_count % eval_interval == 0:
271
            t0 = time.perf_counter()
272
            val_loss = validate(fabric, model, val_data, tokenizer, max_iters=eval_iters)
273
            t1 = time.perf_counter() - t0
274
            fabric.print(f"step {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f}ms")
275
            fabric.barrier()
276
        if not is_accumulating and step_count % save_interval == 0:
277
            checkpoint_path = out_dir / f"iter-{iter_num:06d}-ckpt.pth"
278
            save_lora_checkpoint(fabric, model, checkpoint_path)
279
280
281
# FSDP has issues with `inference_mode`
282
@torch.no_grad()
283
def validate(fabric: L.Fabric, model: GPT, val_data: List[Dict], tokenizer: Tokenizer, max_iters: int) -> torch.Tensor:
284
    """
285
    Function for validating the model on the entity extraction task using LoRA
286
287
    Args:
288
        fabric: Lightning Fabric object
289
        model: GPT model
290
        val_data: Validation data
291
        tokenizer: Tokenizer
292
        max_iters: Maximum number of iterations
293
294
    Returns:
295
        val_loss: Validation loss
296
    """
297
    fabric.print("Validating ...")
298
    model.eval()
299
    
300
    losses = torch.zeros(max_iters)
301
    for k in range(max_iters):
302
        input_ids, targets = get_batch(fabric, val_data)
303
        logits = model(input_ids)
304
        losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)
305
    val_loss = losses.mean()
306
307
    # produce an example:
308
    sample = {"input": "Robert Johnson\nrobert.johnson@email.com\n789 Maple Lane, Chicago, IL 60601\n555-234-5678, United States\n\nRelationship to XYZ Pharma Inc.: Patient\nReason for contacting: Adverse Event\n\nMessage: I've been on Onglyza for a while, and I've noticed that I'm experiencing frequent painful urination. Is this a known side effect?"}
309
    prompt = generate_prompt(sample)
310
    
311
    encoded = tokenizer.encode(prompt, device=fabric.device)
312
    with fabric.init_tensor():
313
        # do not set `max_seq_length=max_returned_token` because memory is not a concern here
314
        model.set_kv_cache(batch_size=1)
315
    output = generate(model, encoded, max_returned_tokens=len(encoded) + eval_max_new_tokens, temperature=0.8)
316
    model.clear_kv_cache()
317
    output = tokenizer.decode(output)
318
    fabric.print(output)
319
320
    model.train()
321
    return val_loss
322
323
324
def get_batch(
325
    fabric: L.Fabric, data: List[Dict], longest_seq_ix: Optional[int] = None
326
) -> Tuple[torch.Tensor, torch.Tensor]:
327
    """
328
    Function for getting a batch of data
329
330
    Args:   
331
        fabric: Lightning Fabric object
332
        data: Data
333
        longest_seq_ix: Index of the longest sequence
334
335
    Returns:
336
        x: Input IDs
337
        y: Targets
338
    """
339
    ix = torch.randint(len(data), (micro_batch_size,))
340
    if longest_seq_ix is not None:
341
        # force the longest sample at the beginning so potential OOMs happen right away
342
        ix[0] = longest_seq_ix
343
344
    input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
345
    labels = [data[i]["labels"].type(torch.int64) for i in ix]
346
347
    # this could be `longest_seq_length` to have a fixed size for all batches
348
    max_len = max(len(s) for s in input_ids)
349
350
    def pad_right(x, pad_id):
351
        # pad right based on the longest sequence
352
        n = max_len - len(x)
353
        return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
354
355
    x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
356
    y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
357
358
    # Truncate if needed
359
    if max_seq_length:
360
        x = x[:, :max_seq_length]
361
        y = y[:, :max_seq_length]
362
363
    if fabric.device.type == "cuda" and x.device.type == "cpu":
364
        x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
365
    else:
366
        x, y = fabric.to_device((x, y))
367
    return x, y
368
369
370
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
371
    """
372
    Function for getting the longest sequence length
373
374
    Args:
375
        data: Data
376
377
    Returns:
378
        longest_seq_length: Longest sequence length
379
        longest_seq_ix: Index of the longest sequence
380
    """
381
    # find out the minimum max_seq_length required during fine-tuning (saves memory!)
382
    lengths = [len(d["input_ids"]) for d in data]
383
    longest_seq_length = max(lengths)
384
    longest_seq_ix = lengths.index(longest_seq_length)
385
    return longest_seq_length, longest_seq_ix
386
387
388
def save_lora_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:
389
    """
390
    Function for saving the LoRA checkpoint
391
392
    Args:
393
        fabric: Lightning Fabric object
394
        model: GPT model
395
        file_path: Path to the LoRA checkpoint file
396
397
    Returns:
398
        None
399
    """
400
    fabric.print(f"Saving LoRA weights to {str(file_path)!r}")
401
    fabric.save(file_path, {"model": model}, filter={"model": lora_filter})
402
403
404
if __name__ == "__main__":
405
    torch.set_float32_matmul_precision("high")
406
407
    from jsonargparse import CLI
408
409
    CLI(setup)