Diff of /src/dataset/collator.py [000000] .. [780764]

Switch to unified view

a b/src/dataset/collator.py
1
import logging
2
import warnings
3
from typing import Dict, List
4
5
import torch
6
from transformers import PreTrainedTokenizer
7
8
9
class InstructionTuningCollator:
10
    def __init__(
11
        self,
12
        tokenizer: PreTrainedTokenizer,
13
        sys_prompt: str = "You are an AI assistant specialized in analyzing ICU patient data.",
14
        ignore_index: int = -100
15
    ) -> None:
16
        self.tokenizer = tokenizer
17
        self.sys_prompt = sys_prompt
18
        self.ignore_index = ignore_index
19
        self.response_template, self.response_token_ids = self.infer_response_template()
20
21
    def infer_response_template(self):
22
        logging.warning("Infer response template with v2")
23
        response_template, response_token_ids = self.infer_response_template_v2()
24
        if response_template == "":
25
            logging.warning("Infer response template with v1")
26
            response_template, response_token_ids = self.infer_response_template_v1()
27
        return response_template, response_token_ids
28
29
    def infer_response_template_v1(self) -> (str, List[int]):
30
        token = "Hi?"
31
        chat = [
32
            {"role": "user", "content": token},
33
        ]
34
        formatted_chat = self.tokenizer.apply_chat_template(
35
            chat,
36
            tokenize=False,
37
            add_generation_prompt=True
38
        )
39
        response_template = formatted_chat[formatted_chat.find(token) + len(token):]
40
        response_token_ids = self.tokenizer.encode(response_template, add_special_tokens=False)
41
        logging.warning(f"Inferred response template: {repr(response_template)}")
42
        logging.warning(f"Inferred response template token ids: {response_token_ids}")
43
        return response_template, response_token_ids
44
45
    def infer_response_template_v2(self) -> (str, List[int]):
46
        token = "Hi?"
47
        chat = [
48
            {"role": "user", "content": token},
49
        ]
50
        formatted_chat_wo_gen = self.tokenizer.apply_chat_template(
51
            chat,
52
            tokenize=False,
53
            add_generation_prompt=False
54
        )
55
        formatted_chat = self.tokenizer.apply_chat_template(
56
            chat,
57
            tokenize=False,
58
            add_generation_prompt=True
59
        )
60
        formatted_chat_wo_gen = self.tokenizer.encode(formatted_chat_wo_gen, add_special_tokens=False)
61
        formatted_chat = self.tokenizer.encode(formatted_chat, add_special_tokens=False)
62
        response_token_ids = formatted_chat[len(formatted_chat_wo_gen):]
63
        response_template = self.tokenizer.decode(response_token_ids)
64
        logging.warning(f"Inferred response template: {repr(response_template)}")
65
        logging.warning(f"Inferred response template token ids: {response_token_ids}")
66
        return response_template, response_token_ids
67
68
    def apply_chat_template(self, q_text: str, a_text: str):
69
        chat = [
70
            {"role": "system", "content": self.sys_prompt},
71
            {"role": "user", "content": q_text},
72
            {"role": "assistant", "content": a_text}
73
        ]
74
        formatted_chat = self.tokenizer.apply_chat_template(
75
            chat,
76
            tokenize=False,
77
            add_generation_prompt=False
78
        )
79
        return formatted_chat
80
81
    @staticmethod
82
    def pad_tensors(tensor_list, padding_value=0):
83
        max_num_events = max(tensor.shape[0] for tensor in tensor_list)
84
        feature_dim = tensor_list[0].shape[1]
85
        batch_size = len(tensor_list)
86
87
        padded_tensor = torch.full((batch_size, max_num_events, feature_dim), padding_value, dtype=torch.float)
88
        is_padding = torch.ones((batch_size, max_num_events), dtype=torch.bool)
89
90
        for i, tensor in enumerate(tensor_list):
91
            num_events = tensor.shape[0]
92
            padded_tensor[i, :num_events, :] = tensor
93
            is_padding[i, :num_events] = 0
94
95
        return padded_tensor, is_padding
96
97
    def mask_instruction(self, labels: torch.Tensor) -> torch.Tensor:
98
        for i in range(len(labels)):
99
            response_token_ids_start_idx = None
100
101
            for idx in torch.where(labels[i] == self.response_token_ids[0])[0]:
102
                if self.response_token_ids == labels[i][idx: idx + len(self.response_token_ids)].tolist():
103
                    response_token_ids_start_idx = idx
104
105
            if response_token_ids_start_idx is None:
106
                warnings.warn(
107
                    f"Could not find response key `{self.response_template}` in the "
108
                    f'following instance: {self.tokenizer.decode(labels[i])} '
109
                    f"This instance will be ignored in loss calculation. "
110
                    f"Note, if this happens often, consider increasing the `max_seq_length`."
111
                )
112
                labels[i, :] = self.ignore_index
113
            else:
114
                response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids)
115
                labels[i, :response_token_ids_end_idx] = self.ignore_index
116
        return labels
117
118
    def __call__(self, batch: List) -> Dict[str, torch.Tensor]:
119
        all_text = []
120
        all_events = []
121
        for data in batch:
122
            text = self.apply_chat_template(
123
                q_text=data[0],
124
                a_text=data[1],
125
            )
126
            all_text.append(text)
127
            all_events.append(data[2])
128
129
        inputs = self.tokenizer(
130
            all_text,
131
            return_tensors="pt",
132
            padding=True,
133
            truncation=True,
134
            add_special_tokens=False,
135
        )
136
        input_ids = inputs["input_ids"]
137
        pixel_values, pixel_values_is_padding = self.pad_tensors(all_events)
138
        attention_mask = inputs["attention_mask"]
139
        labels = self.mask_instruction(input_ids.clone())
140
141
        return {
142
            "input_ids": input_ids,
143
            "pixel_values": pixel_values,
144
            "attention_mask": attention_mask,
145
            "labels": labels,
146
            "pixel_values_is_padding": pixel_values_is_padding,
147
        }
148
149
150
if __name__ == "__main__":
151
    from src.dataset.dataset import InstructionTuningDataset
152
    from torch.utils.data import DataLoader
153
    from src.model.init_llemr import init_llemr
154
155
    # llm_pretrained_model_name_or_path = "Qwen/Qwen2-0.5B-Instruct"
156
    llm_pretrained_model_name_or_path = "lmsys/vicuna-7b-v1.5"
157
    device = "cuda:0"
158
    llemr, tokenizer = init_llemr(llm_pretrained_model_name_or_path, hidden_size=1027)
159
    llemr.to(device)
160
161
    dataset = InstructionTuningDataset(split="train", source="event")
162
    print(len(dataset))
163
164
    collator = InstructionTuningCollator(
165
        tokenizer=tokenizer,
166
    )
167
    loader = DataLoader(
168
        dataset,
169
        batch_size=8,
170
        collate_fn=collator,
171
    )
172
    batch = next(iter(loader))
173
    print(batch["input_ids"].shape)
174
    print(batch["pixel_values"].shape)
175
    print(batch["attention_mask"].shape)
176
    print(batch["labels"].shape)
177
    print(batch["pixel_values_is_padding"].shape)
178
179
    for key, value in batch.items():
180
        batch[key] = value.to(device)
181
    with torch.no_grad():
182
        outputs = llemr(**batch)
183
    print(outputs.loss)
184
    print(outputs.logits.shape)
185
186
    llemr.train()
187
    for parameters in llemr.language_model.parameters():
188
        parameters.requires_grad = False
189
    outputs = llemr(**batch)
190
    print(outputs.loss)
191
    print(outputs.logits.shape)
192
    outputs.loss.backward()
193
    print("Success")