Diff of /demo.py [000000] .. [4abb48]

Switch to unified view

a b/demo.py
1
import argparse
2
import os
3
import random
4
import numpy as np
5
import torch
6
from torch.backends import cudnn
7
8
from chexpert_train import LitIGClassifier
9
from local_config import JAVA_HOME, JAVA_PATH
10
11
# Activate for deterministic demo, else comment
12
SEED = 16
13
random.seed(SEED)
14
np.random.seed(SEED)
15
torch.manual_seed(SEED)
16
cudnn.benchmark = False
17
cudnn.deterministic = True
18
19
# set java path
20
os.environ["JAVA_HOME"] = JAVA_HOME
21
os.environ["PATH"] = JAVA_PATH + os.environ["PATH"]
22
os.environ['GRADIO_TEMP_DIR'] = os.path.join(os.getcwd(), "gradio_tmp")
23
24
import dataclasses
25
import json
26
import time
27
from enum import auto, Enum
28
from typing import List, Any
29
30
31
import gradio as gr
32
from PIL import Image
33
from peft import PeftModelForCausalLM
34
from skimage import io
35
from torch import nn
36
from transformers import LlamaTokenizer
37
from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop, transforms
38
39
from model.lavis import tasks
40
from model.lavis.common.config import Config
41
from model.lavis.data.ReportDataset import create_chest_xray_transform_for_inference, ExpandChannels
42
from model.lavis.models.blip2_models.modeling_llama_imgemb import LlamaForCausalLM
43
44
45
def parse_args():
46
    parser = argparse.ArgumentParser(description="Training")
47
48
    parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
49
    parser.add_argument("--local_rank", type=int, default=0, help="local rank for distributed training.")
50
    parser.add_argument(
51
        "--options",
52
        nargs="+",
53
        help="override some settings in the used config, the key-value pair "
54
        "in xxx=yyy format will be merged into config file (deprecate), "
55
        "change to --cfg-options instead.",
56
    )
57
58
    args = parser.parse_args()
59
60
    return args
61
62
class SeparatorStyle(Enum):
63
    """Different separator style."""
64
    SINGLE = auto()
65
    TWO = auto()
66
67
@dataclasses.dataclass
68
class Conversation:
69
    """A class that keeps all conversation history."""
70
    system: str
71
    roles: List[str]
72
    messages: List[List[str]]
73
    offset: int
74
    sep_style: SeparatorStyle = SeparatorStyle.SINGLE
75
    sep: str = "###"
76
    sep2: str = None
77
78
    # Used for gradio server
79
    skip_next: bool = False
80
    conv_id: Any = None
81
82
    def get_prompt(self):
83
        if self.sep_style == SeparatorStyle.SINGLE:
84
            ret = self.system
85
            for role, message in self.messages:
86
                if message:
87
                    ret += self.sep + " " + role + ": " + message
88
                else:
89
                    ret += self.sep + " " + role + ":"
90
            return ret
91
        elif self.sep_style == SeparatorStyle.TWO:
92
            seps = [self.sep, self.sep2]
93
            ret = self.system + seps[0]
94
            for i, (role, message) in enumerate(self.messages):
95
                if message:
96
                    ret += role + ": " + message + seps[i % 2]
97
                else:
98
                    ret += role + ":"
99
            return ret
100
        else:
101
            raise ValueError(f"Invalid style: {self.sep_style}")
102
103
    def clear(self):
104
        self.messages = []
105
        self.offset = 0
106
        self.skip_next = False
107
108
    def append_message(self, role, message):
109
        self.messages.append([role, message])
110
111
    def to_gradio_chatbot(self):
112
        ret = []
113
        for i, (role, msg) in enumerate(self.messages[self.offset:]):
114
            if i % 2 == 0:
115
                ret.append([msg, None])
116
            else:
117
                ret[-1][-1] = msg
118
        return ret
119
120
    def copy(self):
121
        return Conversation(
122
            system=self.system,
123
            roles=self.roles,
124
            messages=[[x, y] for x, y in self.messages],
125
            offset=self.offset,
126
            sep_style=self.sep_style,
127
            sep=self.sep,
128
            sep2=self.sep2,
129
            conv_id=self.conv_id)
130
131
    def dict(self):
132
        return {
133
            "system": self.system,
134
            "roles": self.roles,
135
            "messages": self.messages,
136
            "offset": self.offset,
137
            "sep": self.sep,
138
            "sep2": self.sep2,
139
            "conv_id": self.conv_id,
140
        }
141
142
143
cfg = Config(parse_args())
144
vis_transforms = create_chest_xray_transform_for_inference(512, center_crop_size=448)
145
use_img = False
146
gen_report = True
147
pred_chexpert_labels = json.load(open('findings_classifier/predictions/structured_preds_chexpert_log_weighting_test_macro.json', 'r'))
148
149
def init_blip(cfg):
150
    task = tasks.setup_task(cfg)
151
    model = task.build_model(cfg)
152
    model = model.to(torch.device('cpu'))
153
    return model
154
155
def init_chexpert_predictor():
156
    ckpt_path = f"findings_classifier/checkpoints/chexpert_train/ChexpertClassifier-epoch=06-val_f1=0.36.ckpt"
157
    chexpert_cols = ["No Finding", "Enlarged Cardiomediastinum",
158
                          "Cardiomegaly", "Lung Opacity",
159
                          "Lung Lesion", "Edema",
160
                          "Consolidation", "Pneumonia",
161
                          "Atelectasis", "Pneumothorax",
162
                          "Pleural Effusion", "Pleural Other",
163
                          "Fracture", "Support Devices"]
164
    model = LitIGClassifier.load_from_checkpoint(ckpt_path, num_classes=14, class_names=chexpert_cols, strict=False)
165
    model.eval()
166
    model.cuda()
167
    model.half()
168
    cp_transforms = Compose([Resize(512), CenterCrop(488), ToTensor(), ExpandChannels()])
169
170
    return model, np.asarray(model.class_names), cp_transforms
171
172
173
def remap_to_uint8(array: np.ndarray, percentiles=None) -> np.ndarray:
174
    """Remap values in input so the output range is :math:`[0, 255]`.
175
176
    Percentiles can be used to specify the range of values to remap.
177
    This is useful to discard outliers in the input data.
178
179
    :param array: Input array.
180
    :param percentiles: Percentiles of the input values that will be mapped to ``0`` and ``255``.
181
        Passing ``None`` is equivalent to using percentiles ``(0, 100)`` (but faster).
182
    :returns: Array with ``0`` and ``255`` as minimum and maximum values.
183
    """
184
    array = array.astype(float)
185
    if percentiles is not None:
186
        len_percentiles = len(percentiles)
187
        if len_percentiles != 2:
188
            message = (
189
                'The value for percentiles should be a sequence of length 2,'
190
                f' but has length {len_percentiles}'
191
            )
192
            raise ValueError(message)
193
        a, b = percentiles
194
        if a >= b:
195
            raise ValueError(f'Percentiles must be in ascending order, but a sequence "{percentiles}" was passed')
196
        if a < 0 or b > 100:
197
            raise ValueError(f'Percentiles must be in the range [0, 100], but a sequence "{percentiles}" was passed')
198
        cutoff: np.ndarray = np.percentile(array, percentiles)
199
        array = np.clip(array, *cutoff)
200
    array -= array.min()
201
    array /= array.max()
202
    array *= 255
203
    return array.astype(np.uint8)
204
205
206
def load_image(path) -> Image.Image:
207
    """Load an image from disk.
208
209
    The image values are remapped to :math:`[0, 255]` and cast to 8-bit unsigned integers.
210
211
    :param path: Path to image.
212
    :returns: Image as ``Pillow`` ``Image``.
213
    """
214
    # Although ITK supports JPEG and PNG, we use Pillow for consistency with older trained models
215
    image = io.imread(path)
216
217
    image = remap_to_uint8(image)
218
    return Image.fromarray(image).convert("L")
219
220
221
def init_vicuna():
222
    use_embs = True
223
224
    vicuna_tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.3", use_fast=False, truncation_side="left", padding_side="left")
225
    lang_model = LlamaForCausalLM.from_pretrained("lmsys/vicuna-7b-v1.3", torch_dtype=torch.float16, device_map='auto')
226
    vicuna_tokenizer.pad_token = vicuna_tokenizer.unk_token
227
228
    if use_embs:
229
        lang_model.base_model.img_proj_layer = nn.Linear(768, lang_model.base_model.config.hidden_size).to(lang_model.base_model.device)
230
        vicuna_tokenizer.add_special_tokens({"additional_special_tokens": ["<IMG>"]})
231
232
    lang_model = PeftModelForCausalLM.from_pretrained(lang_model,
233
                                                      f"checkpoints/vicuna-7b-img-instruct/checkpoint-4800",
234
                                                      torch_dtype=torch.float16, use_ram_optimized_load=False).half()
235
    # lang_model = PeftModelForCausalLM.from_pretrained(lang_model, f"checkpoints/vicuna-7b-img-report/checkpoint-11200", torch_dtype=torch.float16, use_ram_optimized_load=False).half()
236
    return lang_model, vicuna_tokenizer
237
238
blip_model = init_blip(cfg)
239
lang_model, vicuna_tokenizer = init_vicuna()
240
blip_model.eval()
241
lang_model.eval()
242
243
cp_model, cp_class_names, cp_transforms = init_chexpert_predictor()
244
245
def get_response(input_text, dicom):
246
    global use_img, blip_model, lang_model, vicuna_tokenizer
247
248
    if input_text[-1].endswith(".png") or input_text[-1].endswith(".jpg"):
249
        image = load_image(input_text[-1])
250
        cp_image = cp_transforms(image)
251
        image = vis_transforms(image)
252
        dicom = input_text[-1].split('/')[-1].split('.')[0]
253
        if dicom in pred_chexpert_labels:
254
            findings = ', '.join(pred_chexpert_labels[dicom]).lower().strip()
255
        else:
256
            logits = cp_model(cp_image[None].half().cuda())
257
            preds_probs = torch.sigmoid(logits)
258
            preds = preds_probs > 0.5
259
            pred = preds[0].cpu().numpy()
260
            findings = cp_class_names[pred].tolist()
261
            findings = ', '.join(findings).lower().strip()
262
263
        if gen_report:
264
            input_text = (
265
                f"Image information: <IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. "
266
                "Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons.")
267
        use_img = True
268
269
        blip_model = blip_model.to(torch.device('cuda'))
270
        qformer_embs = blip_model.forward_image(image[None].to(torch.device('cuda')))[0].cpu().detach()
271
        blip_model = blip_model.to(torch.device('cpu'))
272
        # save image embedding with torch
273
        torch.save(qformer_embs, 'current_chat_img.pt')
274
        if not gen_report:
275
            return None
276
277
    else:  # free chat
278
        input_text = input_text
279
        findings = None
280
281
    '''Generate prompt given input prompt'''
282
    conv.append_message(conv.roles[0], input_text)
283
    conv.append_message(conv.roles[1], None)
284
    prompt = conv.get_prompt()
285
286
    '''Call vicuna model to generate response'''
287
    inputs = vicuna_tokenizer(prompt, return_tensors="pt")  # for multiple inputs, use tokenizer.batch_encode_plus with padding=True
288
    input_ids = inputs["input_ids"].cuda()
289
    # lang_model = lang_model.cuda()
290
    generation_output = lang_model.generate(
291
        input_ids=input_ids,
292
        dicom=[dicom] if dicom is not None else None,
293
        use_img=use_img,
294
        return_dict_in_generate=True,
295
        output_scores=True,
296
        max_new_tokens=300
297
    )
298
    # lang_model = lang_model.cpu()
299
300
    preds = vicuna_tokenizer.batch_decode(generation_output.sequences, skip_special_tokens=True)
301
    new_pred = preds[0].split("ASSISTANT:")[-1]
302
    # remove last message in conv
303
    conv.messages.pop()
304
    conv.append_message(conv.roles[1], new_pred)
305
    return new_pred, findings
306
307
308
'''Conversation template for prompt'''
309
conv = Conversation(
310
    system="A chat between a curious user and an artificial intelligence assistant."
311
           "The assistant gives professional, detailed, and polite answers to the user's questions.",
312
    roles=["USER", "ASSISTANT"],
313
    messages=[],
314
    offset=0,
315
    sep_style=SeparatorStyle.TWO,
316
    sep=" ",
317
    sep2="</s>",
318
)
319
320
# Global variable to store the DICOM string
321
dicom = None
322
323
324
# Function to update the global DICOM string
325
def set_dicom(value):
326
    global dicom
327
    dicom = value
328
329
330
def add_text(history, text):
331
    history = history + [(text, None)]
332
    return history, gr.update(value="", interactive=False)
333
334
335
def add_file(history, file):
336
    history = history + [((file.name,), None)]
337
    return history
338
339
340
# Function to clear the chat history
341
def clear_history(button_name):
342
    global chat_history, use_img, conv
343
    chat_history = []
344
    conv.clear()
345
    use_img = False
346
    return []  # Return empty history to the Chatbot
347
348
349
def bot(history):
350
    # You can now access the global `dicom` variable here if needed
351
    response, findings = get_response(history[-1][0], None)
352
    print(response)
353
354
    # show report generation prompt if first message after image
355
    if len(history) == 1:
356
        input_text = f"You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons."
357
        if findings is not None:
358
            input_text = f"Image information: (img_tokens) Predicted Findings: {findings}. {input_text}"
359
        history.append([input_text, None])
360
361
    history[-1][1] = ""
362
    if response is not None:
363
        for character in response:
364
            history[-1][1] += character
365
            time.sleep(0.01)
366
            yield history
367
368
369
if __name__ == '__main__':
370
    with gr.Blocks() as demo:
371
372
373
        chatbot = gr.Chatbot(
374
            [],
375
            elem_id="chatbot",
376
        )
377
378
        with gr.Row():
379
            txt = gr.Textbox(
380
                show_label=False,
381
                placeholder="Enter text and press enter, or upload an image",
382
                container=False,
383
            )
384
385
        with gr.Row():
386
            btn = gr.UploadButton("📁 Upload image", file_types=["image"], scale=1)
387
            clear_btn = gr.Button("Clear History", scale=1)
388
389
        clear_btn.click(clear_history, [chatbot], [chatbot])
390
391
        txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
392
            bot, chatbot, chatbot
393
        )
394
        txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)
395
        file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(
396
            bot, chatbot, chatbot
397
        )
398
399
    demo.queue()
400
    demo.launch()