|
a |
|
b/demo.py |
|
|
1 |
import argparse |
|
|
2 |
import os |
|
|
3 |
import random |
|
|
4 |
import numpy as np |
|
|
5 |
import torch |
|
|
6 |
from torch.backends import cudnn |
|
|
7 |
|
|
|
8 |
from chexpert_train import LitIGClassifier |
|
|
9 |
from local_config import JAVA_HOME, JAVA_PATH |
|
|
10 |
|
|
|
11 |
# Activate for deterministic demo, else comment |
|
|
12 |
SEED = 16 |
|
|
13 |
random.seed(SEED) |
|
|
14 |
np.random.seed(SEED) |
|
|
15 |
torch.manual_seed(SEED) |
|
|
16 |
cudnn.benchmark = False |
|
|
17 |
cudnn.deterministic = True |
|
|
18 |
|
|
|
19 |
# set java path |
|
|
20 |
os.environ["JAVA_HOME"] = JAVA_HOME |
|
|
21 |
os.environ["PATH"] = JAVA_PATH + os.environ["PATH"] |
|
|
22 |
os.environ['GRADIO_TEMP_DIR'] = os.path.join(os.getcwd(), "gradio_tmp") |
|
|
23 |
|
|
|
24 |
import dataclasses |
|
|
25 |
import json |
|
|
26 |
import time |
|
|
27 |
from enum import auto, Enum |
|
|
28 |
from typing import List, Any |
|
|
29 |
|
|
|
30 |
|
|
|
31 |
import gradio as gr |
|
|
32 |
from PIL import Image |
|
|
33 |
from peft import PeftModelForCausalLM |
|
|
34 |
from skimage import io |
|
|
35 |
from torch import nn |
|
|
36 |
from transformers import LlamaTokenizer |
|
|
37 |
from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop, transforms |
|
|
38 |
|
|
|
39 |
from model.lavis import tasks |
|
|
40 |
from model.lavis.common.config import Config |
|
|
41 |
from model.lavis.data.ReportDataset import create_chest_xray_transform_for_inference, ExpandChannels |
|
|
42 |
from model.lavis.models.blip2_models.modeling_llama_imgemb import LlamaForCausalLM |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
def parse_args(): |
|
|
46 |
parser = argparse.ArgumentParser(description="Training") |
|
|
47 |
|
|
|
48 |
parser.add_argument("--cfg-path", required=True, help="path to configuration file.") |
|
|
49 |
parser.add_argument("--local_rank", type=int, default=0, help="local rank for distributed training.") |
|
|
50 |
parser.add_argument( |
|
|
51 |
"--options", |
|
|
52 |
nargs="+", |
|
|
53 |
help="override some settings in the used config, the key-value pair " |
|
|
54 |
"in xxx=yyy format will be merged into config file (deprecate), " |
|
|
55 |
"change to --cfg-options instead.", |
|
|
56 |
) |
|
|
57 |
|
|
|
58 |
args = parser.parse_args() |
|
|
59 |
|
|
|
60 |
return args |
|
|
61 |
|
|
|
62 |
class SeparatorStyle(Enum): |
|
|
63 |
"""Different separator style.""" |
|
|
64 |
SINGLE = auto() |
|
|
65 |
TWO = auto() |
|
|
66 |
|
|
|
67 |
@dataclasses.dataclass |
|
|
68 |
class Conversation: |
|
|
69 |
"""A class that keeps all conversation history.""" |
|
|
70 |
system: str |
|
|
71 |
roles: List[str] |
|
|
72 |
messages: List[List[str]] |
|
|
73 |
offset: int |
|
|
74 |
sep_style: SeparatorStyle = SeparatorStyle.SINGLE |
|
|
75 |
sep: str = "###" |
|
|
76 |
sep2: str = None |
|
|
77 |
|
|
|
78 |
# Used for gradio server |
|
|
79 |
skip_next: bool = False |
|
|
80 |
conv_id: Any = None |
|
|
81 |
|
|
|
82 |
def get_prompt(self): |
|
|
83 |
if self.sep_style == SeparatorStyle.SINGLE: |
|
|
84 |
ret = self.system |
|
|
85 |
for role, message in self.messages: |
|
|
86 |
if message: |
|
|
87 |
ret += self.sep + " " + role + ": " + message |
|
|
88 |
else: |
|
|
89 |
ret += self.sep + " " + role + ":" |
|
|
90 |
return ret |
|
|
91 |
elif self.sep_style == SeparatorStyle.TWO: |
|
|
92 |
seps = [self.sep, self.sep2] |
|
|
93 |
ret = self.system + seps[0] |
|
|
94 |
for i, (role, message) in enumerate(self.messages): |
|
|
95 |
if message: |
|
|
96 |
ret += role + ": " + message + seps[i % 2] |
|
|
97 |
else: |
|
|
98 |
ret += role + ":" |
|
|
99 |
return ret |
|
|
100 |
else: |
|
|
101 |
raise ValueError(f"Invalid style: {self.sep_style}") |
|
|
102 |
|
|
|
103 |
def clear(self): |
|
|
104 |
self.messages = [] |
|
|
105 |
self.offset = 0 |
|
|
106 |
self.skip_next = False |
|
|
107 |
|
|
|
108 |
def append_message(self, role, message): |
|
|
109 |
self.messages.append([role, message]) |
|
|
110 |
|
|
|
111 |
def to_gradio_chatbot(self): |
|
|
112 |
ret = [] |
|
|
113 |
for i, (role, msg) in enumerate(self.messages[self.offset:]): |
|
|
114 |
if i % 2 == 0: |
|
|
115 |
ret.append([msg, None]) |
|
|
116 |
else: |
|
|
117 |
ret[-1][-1] = msg |
|
|
118 |
return ret |
|
|
119 |
|
|
|
120 |
def copy(self): |
|
|
121 |
return Conversation( |
|
|
122 |
system=self.system, |
|
|
123 |
roles=self.roles, |
|
|
124 |
messages=[[x, y] for x, y in self.messages], |
|
|
125 |
offset=self.offset, |
|
|
126 |
sep_style=self.sep_style, |
|
|
127 |
sep=self.sep, |
|
|
128 |
sep2=self.sep2, |
|
|
129 |
conv_id=self.conv_id) |
|
|
130 |
|
|
|
131 |
def dict(self): |
|
|
132 |
return { |
|
|
133 |
"system": self.system, |
|
|
134 |
"roles": self.roles, |
|
|
135 |
"messages": self.messages, |
|
|
136 |
"offset": self.offset, |
|
|
137 |
"sep": self.sep, |
|
|
138 |
"sep2": self.sep2, |
|
|
139 |
"conv_id": self.conv_id, |
|
|
140 |
} |
|
|
141 |
|
|
|
142 |
|
|
|
143 |
cfg = Config(parse_args()) |
|
|
144 |
vis_transforms = create_chest_xray_transform_for_inference(512, center_crop_size=448) |
|
|
145 |
use_img = False |
|
|
146 |
gen_report = True |
|
|
147 |
pred_chexpert_labels = json.load(open('findings_classifier/predictions/structured_preds_chexpert_log_weighting_test_macro.json', 'r')) |
|
|
148 |
|
|
|
149 |
def init_blip(cfg): |
|
|
150 |
task = tasks.setup_task(cfg) |
|
|
151 |
model = task.build_model(cfg) |
|
|
152 |
model = model.to(torch.device('cpu')) |
|
|
153 |
return model |
|
|
154 |
|
|
|
155 |
def init_chexpert_predictor(): |
|
|
156 |
ckpt_path = f"findings_classifier/checkpoints/chexpert_train/ChexpertClassifier-epoch=06-val_f1=0.36.ckpt" |
|
|
157 |
chexpert_cols = ["No Finding", "Enlarged Cardiomediastinum", |
|
|
158 |
"Cardiomegaly", "Lung Opacity", |
|
|
159 |
"Lung Lesion", "Edema", |
|
|
160 |
"Consolidation", "Pneumonia", |
|
|
161 |
"Atelectasis", "Pneumothorax", |
|
|
162 |
"Pleural Effusion", "Pleural Other", |
|
|
163 |
"Fracture", "Support Devices"] |
|
|
164 |
model = LitIGClassifier.load_from_checkpoint(ckpt_path, num_classes=14, class_names=chexpert_cols, strict=False) |
|
|
165 |
model.eval() |
|
|
166 |
model.cuda() |
|
|
167 |
model.half() |
|
|
168 |
cp_transforms = Compose([Resize(512), CenterCrop(488), ToTensor(), ExpandChannels()]) |
|
|
169 |
|
|
|
170 |
return model, np.asarray(model.class_names), cp_transforms |
|
|
171 |
|
|
|
172 |
|
|
|
173 |
def remap_to_uint8(array: np.ndarray, percentiles=None) -> np.ndarray: |
|
|
174 |
"""Remap values in input so the output range is :math:`[0, 255]`. |
|
|
175 |
|
|
|
176 |
Percentiles can be used to specify the range of values to remap. |
|
|
177 |
This is useful to discard outliers in the input data. |
|
|
178 |
|
|
|
179 |
:param array: Input array. |
|
|
180 |
:param percentiles: Percentiles of the input values that will be mapped to ``0`` and ``255``. |
|
|
181 |
Passing ``None`` is equivalent to using percentiles ``(0, 100)`` (but faster). |
|
|
182 |
:returns: Array with ``0`` and ``255`` as minimum and maximum values. |
|
|
183 |
""" |
|
|
184 |
array = array.astype(float) |
|
|
185 |
if percentiles is not None: |
|
|
186 |
len_percentiles = len(percentiles) |
|
|
187 |
if len_percentiles != 2: |
|
|
188 |
message = ( |
|
|
189 |
'The value for percentiles should be a sequence of length 2,' |
|
|
190 |
f' but has length {len_percentiles}' |
|
|
191 |
) |
|
|
192 |
raise ValueError(message) |
|
|
193 |
a, b = percentiles |
|
|
194 |
if a >= b: |
|
|
195 |
raise ValueError(f'Percentiles must be in ascending order, but a sequence "{percentiles}" was passed') |
|
|
196 |
if a < 0 or b > 100: |
|
|
197 |
raise ValueError(f'Percentiles must be in the range [0, 100], but a sequence "{percentiles}" was passed') |
|
|
198 |
cutoff: np.ndarray = np.percentile(array, percentiles) |
|
|
199 |
array = np.clip(array, *cutoff) |
|
|
200 |
array -= array.min() |
|
|
201 |
array /= array.max() |
|
|
202 |
array *= 255 |
|
|
203 |
return array.astype(np.uint8) |
|
|
204 |
|
|
|
205 |
|
|
|
206 |
def load_image(path) -> Image.Image: |
|
|
207 |
"""Load an image from disk. |
|
|
208 |
|
|
|
209 |
The image values are remapped to :math:`[0, 255]` and cast to 8-bit unsigned integers. |
|
|
210 |
|
|
|
211 |
:param path: Path to image. |
|
|
212 |
:returns: Image as ``Pillow`` ``Image``. |
|
|
213 |
""" |
|
|
214 |
# Although ITK supports JPEG and PNG, we use Pillow for consistency with older trained models |
|
|
215 |
image = io.imread(path) |
|
|
216 |
|
|
|
217 |
image = remap_to_uint8(image) |
|
|
218 |
return Image.fromarray(image).convert("L") |
|
|
219 |
|
|
|
220 |
|
|
|
221 |
def init_vicuna(): |
|
|
222 |
use_embs = True |
|
|
223 |
|
|
|
224 |
vicuna_tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.3", use_fast=False, truncation_side="left", padding_side="left") |
|
|
225 |
lang_model = LlamaForCausalLM.from_pretrained("lmsys/vicuna-7b-v1.3", torch_dtype=torch.float16, device_map='auto') |
|
|
226 |
vicuna_tokenizer.pad_token = vicuna_tokenizer.unk_token |
|
|
227 |
|
|
|
228 |
if use_embs: |
|
|
229 |
lang_model.base_model.img_proj_layer = nn.Linear(768, lang_model.base_model.config.hidden_size).to(lang_model.base_model.device) |
|
|
230 |
vicuna_tokenizer.add_special_tokens({"additional_special_tokens": ["<IMG>"]}) |
|
|
231 |
|
|
|
232 |
lang_model = PeftModelForCausalLM.from_pretrained(lang_model, |
|
|
233 |
f"checkpoints/vicuna-7b-img-instruct/checkpoint-4800", |
|
|
234 |
torch_dtype=torch.float16, use_ram_optimized_load=False).half() |
|
|
235 |
# lang_model = PeftModelForCausalLM.from_pretrained(lang_model, f"checkpoints/vicuna-7b-img-report/checkpoint-11200", torch_dtype=torch.float16, use_ram_optimized_load=False).half() |
|
|
236 |
return lang_model, vicuna_tokenizer |
|
|
237 |
|
|
|
238 |
blip_model = init_blip(cfg) |
|
|
239 |
lang_model, vicuna_tokenizer = init_vicuna() |
|
|
240 |
blip_model.eval() |
|
|
241 |
lang_model.eval() |
|
|
242 |
|
|
|
243 |
cp_model, cp_class_names, cp_transforms = init_chexpert_predictor() |
|
|
244 |
|
|
|
245 |
def get_response(input_text, dicom): |
|
|
246 |
global use_img, blip_model, lang_model, vicuna_tokenizer |
|
|
247 |
|
|
|
248 |
if input_text[-1].endswith(".png") or input_text[-1].endswith(".jpg"): |
|
|
249 |
image = load_image(input_text[-1]) |
|
|
250 |
cp_image = cp_transforms(image) |
|
|
251 |
image = vis_transforms(image) |
|
|
252 |
dicom = input_text[-1].split('/')[-1].split('.')[0] |
|
|
253 |
if dicom in pred_chexpert_labels: |
|
|
254 |
findings = ', '.join(pred_chexpert_labels[dicom]).lower().strip() |
|
|
255 |
else: |
|
|
256 |
logits = cp_model(cp_image[None].half().cuda()) |
|
|
257 |
preds_probs = torch.sigmoid(logits) |
|
|
258 |
preds = preds_probs > 0.5 |
|
|
259 |
pred = preds[0].cpu().numpy() |
|
|
260 |
findings = cp_class_names[pred].tolist() |
|
|
261 |
findings = ', '.join(findings).lower().strip() |
|
|
262 |
|
|
|
263 |
if gen_report: |
|
|
264 |
input_text = ( |
|
|
265 |
f"Image information: <IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. " |
|
|
266 |
"Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons.") |
|
|
267 |
use_img = True |
|
|
268 |
|
|
|
269 |
blip_model = blip_model.to(torch.device('cuda')) |
|
|
270 |
qformer_embs = blip_model.forward_image(image[None].to(torch.device('cuda')))[0].cpu().detach() |
|
|
271 |
blip_model = blip_model.to(torch.device('cpu')) |
|
|
272 |
# save image embedding with torch |
|
|
273 |
torch.save(qformer_embs, 'current_chat_img.pt') |
|
|
274 |
if not gen_report: |
|
|
275 |
return None |
|
|
276 |
|
|
|
277 |
else: # free chat |
|
|
278 |
input_text = input_text |
|
|
279 |
findings = None |
|
|
280 |
|
|
|
281 |
'''Generate prompt given input prompt''' |
|
|
282 |
conv.append_message(conv.roles[0], input_text) |
|
|
283 |
conv.append_message(conv.roles[1], None) |
|
|
284 |
prompt = conv.get_prompt() |
|
|
285 |
|
|
|
286 |
'''Call vicuna model to generate response''' |
|
|
287 |
inputs = vicuna_tokenizer(prompt, return_tensors="pt") # for multiple inputs, use tokenizer.batch_encode_plus with padding=True |
|
|
288 |
input_ids = inputs["input_ids"].cuda() |
|
|
289 |
# lang_model = lang_model.cuda() |
|
|
290 |
generation_output = lang_model.generate( |
|
|
291 |
input_ids=input_ids, |
|
|
292 |
dicom=[dicom] if dicom is not None else None, |
|
|
293 |
use_img=use_img, |
|
|
294 |
return_dict_in_generate=True, |
|
|
295 |
output_scores=True, |
|
|
296 |
max_new_tokens=300 |
|
|
297 |
) |
|
|
298 |
# lang_model = lang_model.cpu() |
|
|
299 |
|
|
|
300 |
preds = vicuna_tokenizer.batch_decode(generation_output.sequences, skip_special_tokens=True) |
|
|
301 |
new_pred = preds[0].split("ASSISTANT:")[-1] |
|
|
302 |
# remove last message in conv |
|
|
303 |
conv.messages.pop() |
|
|
304 |
conv.append_message(conv.roles[1], new_pred) |
|
|
305 |
return new_pred, findings |
|
|
306 |
|
|
|
307 |
|
|
|
308 |
'''Conversation template for prompt''' |
|
|
309 |
conv = Conversation( |
|
|
310 |
system="A chat between a curious user and an artificial intelligence assistant." |
|
|
311 |
"The assistant gives professional, detailed, and polite answers to the user's questions.", |
|
|
312 |
roles=["USER", "ASSISTANT"], |
|
|
313 |
messages=[], |
|
|
314 |
offset=0, |
|
|
315 |
sep_style=SeparatorStyle.TWO, |
|
|
316 |
sep=" ", |
|
|
317 |
sep2="</s>", |
|
|
318 |
) |
|
|
319 |
|
|
|
320 |
# Global variable to store the DICOM string |
|
|
321 |
dicom = None |
|
|
322 |
|
|
|
323 |
|
|
|
324 |
# Function to update the global DICOM string |
|
|
325 |
def set_dicom(value): |
|
|
326 |
global dicom |
|
|
327 |
dicom = value |
|
|
328 |
|
|
|
329 |
|
|
|
330 |
def add_text(history, text): |
|
|
331 |
history = history + [(text, None)] |
|
|
332 |
return history, gr.update(value="", interactive=False) |
|
|
333 |
|
|
|
334 |
|
|
|
335 |
def add_file(history, file): |
|
|
336 |
history = history + [((file.name,), None)] |
|
|
337 |
return history |
|
|
338 |
|
|
|
339 |
|
|
|
340 |
# Function to clear the chat history |
|
|
341 |
def clear_history(button_name): |
|
|
342 |
global chat_history, use_img, conv |
|
|
343 |
chat_history = [] |
|
|
344 |
conv.clear() |
|
|
345 |
use_img = False |
|
|
346 |
return [] # Return empty history to the Chatbot |
|
|
347 |
|
|
|
348 |
|
|
|
349 |
def bot(history): |
|
|
350 |
# You can now access the global `dicom` variable here if needed |
|
|
351 |
response, findings = get_response(history[-1][0], None) |
|
|
352 |
print(response) |
|
|
353 |
|
|
|
354 |
# show report generation prompt if first message after image |
|
|
355 |
if len(history) == 1: |
|
|
356 |
input_text = f"You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons." |
|
|
357 |
if findings is not None: |
|
|
358 |
input_text = f"Image information: (img_tokens) Predicted Findings: {findings}. {input_text}" |
|
|
359 |
history.append([input_text, None]) |
|
|
360 |
|
|
|
361 |
history[-1][1] = "" |
|
|
362 |
if response is not None: |
|
|
363 |
for character in response: |
|
|
364 |
history[-1][1] += character |
|
|
365 |
time.sleep(0.01) |
|
|
366 |
yield history |
|
|
367 |
|
|
|
368 |
|
|
|
369 |
if __name__ == '__main__': |
|
|
370 |
with gr.Blocks() as demo: |
|
|
371 |
|
|
|
372 |
|
|
|
373 |
chatbot = gr.Chatbot( |
|
|
374 |
[], |
|
|
375 |
elem_id="chatbot", |
|
|
376 |
) |
|
|
377 |
|
|
|
378 |
with gr.Row(): |
|
|
379 |
txt = gr.Textbox( |
|
|
380 |
show_label=False, |
|
|
381 |
placeholder="Enter text and press enter, or upload an image", |
|
|
382 |
container=False, |
|
|
383 |
) |
|
|
384 |
|
|
|
385 |
with gr.Row(): |
|
|
386 |
btn = gr.UploadButton("📁 Upload image", file_types=["image"], scale=1) |
|
|
387 |
clear_btn = gr.Button("Clear History", scale=1) |
|
|
388 |
|
|
|
389 |
clear_btn.click(clear_history, [chatbot], [chatbot]) |
|
|
390 |
|
|
|
391 |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( |
|
|
392 |
bot, chatbot, chatbot |
|
|
393 |
) |
|
|
394 |
txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False) |
|
|
395 |
file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then( |
|
|
396 |
bot, chatbot, chatbot |
|
|
397 |
) |
|
|
398 |
|
|
|
399 |
demo.queue() |
|
|
400 |
demo.launch() |