Switch to unified view

a b/src/model/modeling_llemr.py
1
from typing import Optional, List
2
3
import torch
4
from transformers import LlavaForConditionalGeneration
5
6
7
class LlemrForConditionalGeneration(LlavaForConditionalGeneration):
8
9
    def forward(
10
        self,
11
        input_ids: torch.LongTensor = None,
12
        pixel_values: torch.FloatTensor = None,
13
        pixel_values_is_padding: torch.BoolTensor = None,
14
        attention_mask: Optional[torch.Tensor] = None,
15
        position_ids: Optional[torch.LongTensor] = None,
16
        past_key_values: Optional[List[torch.FloatTensor]] = None,
17
        inputs_embeds: Optional[torch.FloatTensor] = None,
18
        vision_feature_layer: Optional[int] = None,
19
        vision_feature_select_strategy: Optional[str] = None,
20
        labels: Optional[torch.LongTensor] = None,
21
        use_cache: Optional[bool] = None,
22
        output_attentions: Optional[bool] = None,
23
        output_hidden_states: Optional[bool] = None,
24
        return_dict: Optional[bool] = None,
25
    ):
26
        if pixel_values is not None and pixel_values_is_padding is not None:
27
            pixel_values = pixel_values[~pixel_values_is_padding].unsqueeze(1)
28
        return super().forward(
29
            input_ids=input_ids,
30
            pixel_values=pixel_values,
31
            attention_mask=attention_mask,
32
            position_ids=position_ids,
33
            past_key_values=past_key_values,
34
            inputs_embeds=inputs_embeds,
35
            vision_feature_layer=vision_feature_layer,
36
            vision_feature_select_strategy=vision_feature_select_strategy,
37
            labels=labels,
38
            use_cache=use_cache,
39
            output_attentions=output_attentions,
40
            output_hidden_states=output_hidden_states,
41
            return_dict=return_dict,
42
        )
43
44
    def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
45
        num_image_patches = image_features.shape[1]
46
        assert num_image_patches == 1, "Only one image patch is supported."
47
        left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
48
        assert left_padding, "Input ids should be left-padded."
49
        (
50
            final_embedding,
51
            final_attention_mask,
52
            final_labels,
53
            position_ids
54
        ) = super()._merge_input_ids_with_image_features(
55
            image_features=image_features,
56
            inputs_embeds=inputs_embeds,
57
            input_ids=input_ids,
58
            attention_mask=attention_mask,
59
            labels=labels,
60
        )
61
        return final_embedding, final_attention_mask, final_labels, position_ids
62
63
    def prepare_inputs_for_generation(
64
        self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None,
65
        pixel_values_is_padding=None, **kwargs
66
    ):
67
        model_inputs = super().prepare_inputs_for_generation(
68
            input_ids=input_ids,
69
            past_key_values=past_key_values,
70
            inputs_embeds=inputs_embeds,
71
            pixel_values=pixel_values,
72
            attention_mask=attention_mask,
73
            **kwargs,
74
        )
75
        model_inputs["pixel_values_is_padding"] = pixel_values_is_padding
76
        return model_inputs