--- a
+++ b/src/dataset/collator.py
@@ -0,0 +1,193 @@
+import logging
+import warnings
+from typing import Dict, List
+
+import torch
+from transformers import PreTrainedTokenizer
+
+
+class InstructionTuningCollator:
+    def __init__(
+        self,
+        tokenizer: PreTrainedTokenizer,
+        sys_prompt: str = "You are an AI assistant specialized in analyzing ICU patient data.",
+        ignore_index: int = -100
+    ) -> None:
+        self.tokenizer = tokenizer
+        self.sys_prompt = sys_prompt
+        self.ignore_index = ignore_index
+        self.response_template, self.response_token_ids = self.infer_response_template()
+
+    def infer_response_template(self):
+        logging.warning("Infer response template with v2")
+        response_template, response_token_ids = self.infer_response_template_v2()
+        if response_template == "":
+            logging.warning("Infer response template with v1")
+            response_template, response_token_ids = self.infer_response_template_v1()
+        return response_template, response_token_ids
+
+    def infer_response_template_v1(self) -> (str, List[int]):
+        token = "Hi?"
+        chat = [
+            {"role": "user", "content": token},
+        ]
+        formatted_chat = self.tokenizer.apply_chat_template(
+            chat,
+            tokenize=False,
+            add_generation_prompt=True
+        )
+        response_template = formatted_chat[formatted_chat.find(token) + len(token):]
+        response_token_ids = self.tokenizer.encode(response_template, add_special_tokens=False)
+        logging.warning(f"Inferred response template: {repr(response_template)}")
+        logging.warning(f"Inferred response template token ids: {response_token_ids}")
+        return response_template, response_token_ids
+
+    def infer_response_template_v2(self) -> (str, List[int]):
+        token = "Hi?"
+        chat = [
+            {"role": "user", "content": token},
+        ]
+        formatted_chat_wo_gen = self.tokenizer.apply_chat_template(
+            chat,
+            tokenize=False,
+            add_generation_prompt=False
+        )
+        formatted_chat = self.tokenizer.apply_chat_template(
+            chat,
+            tokenize=False,
+            add_generation_prompt=True
+        )
+        formatted_chat_wo_gen = self.tokenizer.encode(formatted_chat_wo_gen, add_special_tokens=False)
+        formatted_chat = self.tokenizer.encode(formatted_chat, add_special_tokens=False)
+        response_token_ids = formatted_chat[len(formatted_chat_wo_gen):]
+        response_template = self.tokenizer.decode(response_token_ids)
+        logging.warning(f"Inferred response template: {repr(response_template)}")
+        logging.warning(f"Inferred response template token ids: {response_token_ids}")
+        return response_template, response_token_ids
+
+    def apply_chat_template(self, q_text: str, a_text: str):
+        chat = [
+            {"role": "system", "content": self.sys_prompt},
+            {"role": "user", "content": q_text},
+            {"role": "assistant", "content": a_text}
+        ]
+        formatted_chat = self.tokenizer.apply_chat_template(
+            chat,
+            tokenize=False,
+            add_generation_prompt=False
+        )
+        return formatted_chat
+
+    @staticmethod
+    def pad_tensors(tensor_list, padding_value=0):
+        max_num_events = max(tensor.shape[0] for tensor in tensor_list)
+        feature_dim = tensor_list[0].shape[1]
+        batch_size = len(tensor_list)
+
+        padded_tensor = torch.full((batch_size, max_num_events, feature_dim), padding_value, dtype=torch.float)
+        is_padding = torch.ones((batch_size, max_num_events), dtype=torch.bool)
+
+        for i, tensor in enumerate(tensor_list):
+            num_events = tensor.shape[0]
+            padded_tensor[i, :num_events, :] = tensor
+            is_padding[i, :num_events] = 0
+
+        return padded_tensor, is_padding
+
+    def mask_instruction(self, labels: torch.Tensor) -> torch.Tensor:
+        for i in range(len(labels)):
+            response_token_ids_start_idx = None
+
+            for idx in torch.where(labels[i] == self.response_token_ids[0])[0]:
+                if self.response_token_ids == labels[i][idx: idx + len(self.response_token_ids)].tolist():
+                    response_token_ids_start_idx = idx
+
+            if response_token_ids_start_idx is None:
+                warnings.warn(
+                    f"Could not find response key `{self.response_template}` in the "
+                    f'following instance: {self.tokenizer.decode(labels[i])} '
+                    f"This instance will be ignored in loss calculation. "
+                    f"Note, if this happens often, consider increasing the `max_seq_length`."
+                )
+                labels[i, :] = self.ignore_index
+            else:
+                response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids)
+                labels[i, :response_token_ids_end_idx] = self.ignore_index
+        return labels
+
+    def __call__(self, batch: List) -> Dict[str, torch.Tensor]:
+        all_text = []
+        all_events = []
+        for data in batch:
+            text = self.apply_chat_template(
+                q_text=data[0],
+                a_text=data[1],
+            )
+            all_text.append(text)
+            all_events.append(data[2])
+
+        inputs = self.tokenizer(
+            all_text,
+            return_tensors="pt",
+            padding=True,
+            truncation=True,
+            add_special_tokens=False,
+        )
+        input_ids = inputs["input_ids"]
+        pixel_values, pixel_values_is_padding = self.pad_tensors(all_events)
+        attention_mask = inputs["attention_mask"]
+        labels = self.mask_instruction(input_ids.clone())
+
+        return {
+            "input_ids": input_ids,
+            "pixel_values": pixel_values,
+            "attention_mask": attention_mask,
+            "labels": labels,
+            "pixel_values_is_padding": pixel_values_is_padding,
+        }
+
+
+if __name__ == "__main__":
+    from src.dataset.dataset import InstructionTuningDataset
+    from torch.utils.data import DataLoader
+    from src.model.init_llemr import init_llemr
+
+    # llm_pretrained_model_name_or_path = "Qwen/Qwen2-0.5B-Instruct"
+    llm_pretrained_model_name_or_path = "lmsys/vicuna-7b-v1.5"
+    device = "cuda:0"
+    llemr, tokenizer = init_llemr(llm_pretrained_model_name_or_path, hidden_size=1027)
+    llemr.to(device)
+
+    dataset = InstructionTuningDataset(split="train", source="event")
+    print(len(dataset))
+
+    collator = InstructionTuningCollator(
+        tokenizer=tokenizer,
+    )
+    loader = DataLoader(
+        dataset,
+        batch_size=8,
+        collate_fn=collator,
+    )
+    batch = next(iter(loader))
+    print(batch["input_ids"].shape)
+    print(batch["pixel_values"].shape)
+    print(batch["attention_mask"].shape)
+    print(batch["labels"].shape)
+    print(batch["pixel_values_is_padding"].shape)
+
+    for key, value in batch.items():
+        batch[key] = value.to(device)
+    with torch.no_grad():
+        outputs = llemr(**batch)
+    print(outputs.loss)
+    print(outputs.logits.shape)
+
+    llemr.train()
+    for parameters in llemr.language_model.parameters():
+        parameters.requires_grad = False
+    outputs = llemr(**batch)
+    print(outputs.loss)
+    print(outputs.logits.shape)
+    outputs.loss.backward()
+    print("Success")