a b/minigpt4/conversation/conversation.py
1
import argparse
2
import time
3
from PIL import Image
4
5
import torch
6
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
7
from transformers import StoppingCriteria, StoppingCriteriaList
8
9
import dataclasses
10
from enum import auto, Enum
11
from typing import List, Tuple, Any
12
13
from minigpt4.common.registry import registry
14
15
16
class SeparatorStyle(Enum):
17
    """Different separator style."""
18
    SINGLE = auto()
19
    TWO = auto()
20
21
22
@dataclasses.dataclass
23
class Conversation:
24
    """A class that keeps all conversation history."""
25
    system: str
26
    roles: List[str]
27
    messages: List[List[str]]
28
    offset: int
29
    # system_img: List[Image.Image] = []
30
    sep_style: SeparatorStyle = SeparatorStyle.SINGLE
31
    sep: str = "###"
32
    sep2: str = None
33
34
    skip_next: bool = False
35
    conv_id: Any = None
36
37
    def get_prompt(self):
38
        if self.sep_style == SeparatorStyle.SINGLE:
39
            ret = self.system + self.sep
40
            for role, message in self.messages:
41
                if message:
42
                    ret += role + ": " + message + self.sep
43
                else:
44
                    ret += role + ":"
45
            return ret
46
        elif self.sep_style == SeparatorStyle.TWO:
47
            seps = [self.sep, self.sep2]
48
            ret = self.system + seps[0]
49
            for i, (role, message) in enumerate(self.messages):
50
                if message:
51
                    ret += role + ": " + message + seps[i % 2]
52
                else:
53
                    ret += role + ":"
54
            return ret
55
        else:
56
            raise ValueError(f"Invalid style: {self.sep_style}")
57
58
    def append_message(self, role, message):
59
        self.messages.append([role, message])
60
61
    def to_gradio_chatbot(self):
62
        ret = []
63
        for i, (role, msg) in enumerate(self.messages[self.offset:]):
64
            if i % 2 == 0:
65
                ret.append([msg, None])
66
            else:
67
                ret[-1][-1] = msg
68
        return ret
69
70
    def copy(self):
71
        return Conversation(
72
            system=self.system,
73
            # system_img=self.system_img,
74
            roles=self.roles,
75
            messages=[[x, y] for x, y in self.messages],
76
            offset=self.offset,
77
            sep_style=self.sep_style,
78
            sep=self.sep,
79
            sep2=self.sep2,
80
            conv_id=self.conv_id)
81
82
    def dict(self):
83
        return {
84
            "system": self.system,
85
            # "system_img": self.system_img,
86
            "roles": self.roles,
87
            "messages": self.messages,
88
            "offset": self.offset,
89
            "sep": self.sep,
90
            "sep2": self.sep2,
91
            "conv_id": self.conv_id,
92
        }
93
94
95
class StoppingCriteriaSub(StoppingCriteria):
96
97
    def __init__(self, stops=[], encounters=1):
98
        super().__init__()
99
        self.stops = stops
100
101
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
102
        for stop in self.stops:
103
            if torch.all((stop == input_ids[0][-len(stop):])).item():
104
                return True
105
106
        return False
107
108
109
CONV_VISION = Conversation(
110
    system="Give the following image: <Img>ImageContent</Img>. "
111
           "You will be able to see the image once I provide it to you. Please answer my questions.",
112
    roles=("Human", "Assistant"),
113
    messages=[],
114
    offset=2,
115
    sep_style=SeparatorStyle.SINGLE,
116
    sep="###",
117
)
118
119
120
121
class Chat:
122
    def __init__(self, model, vis_processor, device='cuda:0'):
123
        self.device = device
124
        self.model = model
125
        self.vis_processor = vis_processor
126
        stop_words_ids = [torch.tensor([835]).to(self.device),
127
                          torch.tensor([2277, 29937]).to(self.device)]  # '###' can be encoded in two different ways.
128
        self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
129
130
    def ask(self, text, conv):
131
        if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
132
                and conv.messages[-1][1][-6:] == '</Img>':  # last message is image.
133
            conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
134
        else:
135
            conv.append_message(conv.roles[0], text)
136
137
    def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
138
               repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
139
        conv.append_message(conv.roles[1], None)
140
        embs = self.get_context_emb(conv, img_list)
141
142
        current_max_len = embs.shape[1] + max_new_tokens
143
        if current_max_len - max_length > 0:
144
            print('Warning: The number of tokens in current conversation exceeds the max length. '
145
                  'The model will not see the contexts outside the range.')
146
        begin_idx = max(0, current_max_len - max_length)
147
148
        embs = embs[:, begin_idx:]
149
150
        outputs = self.model.llama_model.generate(
151
            inputs_embeds=embs,
152
            max_new_tokens=max_new_tokens,
153
            stopping_criteria=self.stopping_criteria,
154
            num_beams=num_beams,
155
            do_sample=True,
156
            min_length=min_length,
157
            top_p=top_p,
158
            repetition_penalty=repetition_penalty,
159
            length_penalty=length_penalty,
160
            temperature=temperature,
161
        )
162
        output_token = outputs[0]
163
        if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
164
            output_token = output_token[1:]
165
        if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it
166
            output_token = output_token[1:]
167
        output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
168
        output_text = output_text.split('###')[0]  # remove the stop sign '###'
169
        output_text = output_text.split('Assistant:')[-1].strip()
170
        conv.messages[-1][1] = output_text
171
        return output_text, output_token.cpu().numpy()
172
173
    def upload_img(self, image, conv, img_list):
174
        if isinstance(image, str):  # is a image path
175
            raw_image = Image.open(image).convert('RGB')
176
            image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
177
        elif isinstance(image, Image.Image):
178
            raw_image = image
179
            image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
180
        elif isinstance(image, torch.Tensor):
181
            if len(image.shape) == 3:
182
                image = image.unsqueeze(0)
183
            image = image.to(self.device)
184
185
        image_emb, _ = self.model.encode_img(image)
186
        img_list.append(image_emb)
187
        conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
188
        msg = "Received."
189
        # self.conv.append_message(self.conv.roles[1], msg)
190
        return msg
191
192
    def get_context_emb(self, conv, img_list):
193
        prompt = conv.get_prompt()
194
        prompt_segs = prompt.split('<ImageHere>')
195
        assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
196
        seg_tokens = [
197
            self.model.llama_tokenizer(
198
                seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
199
            # only add bos to the first seg
200
            for i, seg in enumerate(prompt_segs)
201
        ]
202
        seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
203
        mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
204
        mixed_embs = torch.cat(mixed_embs, dim=1)
205
        return mixed_embs
206
207