[780764]: / src / model / modeling_llemr.py

Download this file

77 lines (70 with data), 3.2 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from typing import Optional, List
import torch
from transformers import LlavaForConditionalGeneration
class LlemrForConditionalGeneration(LlavaForConditionalGeneration):
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
pixel_values_is_padding: torch.BoolTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
if pixel_values is not None and pixel_values_is_padding is not None:
pixel_values = pixel_values[~pixel_values_is_padding].unsqueeze(1)
return super().forward(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
num_image_patches = image_features.shape[1]
assert num_image_patches == 1, "Only one image patch is supported."
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
assert left_padding, "Input ids should be left-padded."
(
final_embedding,
final_attention_mask,
final_labels,
position_ids
) = super()._merge_input_ids_with_image_features(
image_features=image_features,
inputs_embeds=inputs_embeds,
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
return final_embedding, final_attention_mask, final_labels, position_ids
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None,
pixel_values_is_padding=None, **kwargs
):
model_inputs = super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
attention_mask=attention_mask,
**kwargs,
)
model_inputs["pixel_values_is_padding"] = pixel_values_is_padding
return model_inputs