a b/src/train/train.py
1
import logging
2
from dataclasses import dataclass, field
3
from typing import Optional
4
5
import transformers
6
from peft import LoraConfig, get_peft_model
7
from transformers import (
8
    Trainer,
9
    TrainingArguments,
10
    AutoTokenizer,
11
    PreTrainedTokenizer
12
)
13
14
from src.dataset.collator import InstructionTuningCollator
15
from src.dataset.dataset import InstructionTuningDataset
16
from src.model.init_llemr import init_llemr
17
from src.model.modeling_llemr import LlemrForConditionalGeneration
18
from src.model.utils import find_all_linear_names
19
20
logger = logging.getLogger(__name__)
21
22
23
@dataclass
24
class ModelArguments:
25
    name_or_path: Optional[str] = field(default=None)
26
    llm_pretrained_model_name_or_path: Optional[str] = field(default="Qwen/Qwen2-0.5B-Instruct")
27
    train_type: Optional[str] = field(
28
        default="train_both",
29
        metadata={
30
            "help": """
31
            1. train_multi_modal_projector
32
            2. train_both
33
            """
34
        },
35
    )
36
    use_lora: Optional[bool] = field(default=True)
37
    lora_r: int = 32
38
    lora_alpha: int = 16
39
    lora_dropout: float = 0.05
40
    lora_bias: str = "none"
41
    vision_hidden_size: int = 768
42
43
44
@dataclass
45
class DataArguments:
46
    source: Optional[str] = field(default="note")
47
48
49
def load_model(model_args: ModelArguments):
50
    if model_args.name_or_path is not None:
51
        logging.warning(f"Load model {model_args.name_or_path} from pretrained")
52
        model = LlemrForConditionalGeneration.from_pretrained(
53
            model_args.name_or_path
54
        )
55
        tokenizer = AutoTokenizer.from_pretrained(
56
            model_args.name_or_path,
57
            padding_side="left"
58
        )
59
    else:
60
        logging.warning(f"Init model {model_args.llm_pretrained_model_name_or_path}")
61
        model, tokenizer = init_llemr(
62
            model_args.llm_pretrained_model_name_or_path, model_args.vision_hidden_size
63
        )
64
65
    assert model_args.train_type in ["train_multi_modal_projector", "train_both"]
66
    if model_args.train_type == "train_multi_modal_projector":
67
        logging.warning("Train multi_modal_projector")
68
        for param in model.language_model.parameters():
69
            param.requires_grad = False
70
    else:
71
        logging.warning("Train both")
72
73
    if model_args.use_lora:
74
        assert model_args.train_type == "train_both"
75
        logging.warning("Use Lora")
76
77
        config = LoraConfig(
78
            r=model_args.lora_r,
79
            lora_alpha=model_args.lora_alpha,
80
            target_modules=find_all_linear_names(model),
81
            lora_dropout=model_args.lora_dropout,
82
            bias=model_args.lora_bias,
83
            task_type="CAUSAL_LM",
84
            modules_to_save=["multi_modal_projector"],
85
        )
86
        model = get_peft_model(model, config)
87
88
    else:
89
        logging.warning("Not use Lora")
90
91
    return model, tokenizer
92
93
94
def load_data(data_args: DataArguments, tokenizer: PreTrainedTokenizer):
95
    train_dataset = InstructionTuningDataset(
96
        split="train",
97
        source=data_args.source,
98
    )
99
    val_dataset = InstructionTuningDataset(
100
        split="val",
101
        source=data_args.source,
102
    )
103
    collator = InstructionTuningCollator(
104
        tokenizer=tokenizer,
105
    )
106
    return train_dataset, val_dataset, collator
107
108
109
def train():
110
    parser = transformers.HfArgumentParser(
111
        (ModelArguments, DataArguments, TrainingArguments)
112
    )
113
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
114
    model, tokenizer = load_model(model_args)
115
    train_dataset, val_dataset, collator = load_data(data_args, tokenizer)
116
117
    trainer = Trainer(
118
        model=model,
119
        args=training_args,
120
        train_dataset=train_dataset,
121
        eval_dataset=val_dataset,
122
        data_collator=collator,
123
    )
124
125
    tokenizer.save_pretrained(training_args.output_dir)
126
    trainer.train()
127
128
129
if __name__ == "__main__":
130
    train()