--- a +++ b/demo.py @@ -0,0 +1,400 @@ +import argparse +import os +import random +import numpy as np +import torch +from torch.backends import cudnn + +from chexpert_train import LitIGClassifier +from local_config import JAVA_HOME, JAVA_PATH + +# Activate for deterministic demo, else comment +SEED = 16 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +cudnn.benchmark = False +cudnn.deterministic = True + +# set java path +os.environ["JAVA_HOME"] = JAVA_HOME +os.environ["PATH"] = JAVA_PATH + os.environ["PATH"] +os.environ['GRADIO_TEMP_DIR'] = os.path.join(os.getcwd(), "gradio_tmp") + +import dataclasses +import json +import time +from enum import auto, Enum +from typing import List, Any + + +import gradio as gr +from PIL import Image +from peft import PeftModelForCausalLM +from skimage import io +from torch import nn +from transformers import LlamaTokenizer +from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop, transforms + +from model.lavis import tasks +from model.lavis.common.config import Config +from model.lavis.data.ReportDataset import create_chest_xray_transform_for_inference, ExpandChannels +from model.lavis.models.blip2_models.modeling_llama_imgemb import LlamaForCausalLM + + +def parse_args(): + parser = argparse.ArgumentParser(description="Training") + + parser.add_argument("--cfg-path", required=True, help="path to configuration file.") + parser.add_argument("--local_rank", type=int, default=0, help="local rank for distributed training.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + + args = parser.parse_args() + + return args + +class SeparatorStyle(Enum): + """Different separator style.""" + SINGLE = auto() + TWO = auto() + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + + # Used for gradio server + skip_next: bool = False + conv_id: Any = None + + def get_prompt(self): + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + for role, message in self.messages: + if message: + ret += self.sep + " " + role + ": " + message + else: + ret += self.sep + " " + role + ":" + return ret + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def clear(self): + self.messages = [] + self.offset = 0 + self.skip_next = False + + def append_message(self, role, message): + self.messages.append([role, message]) + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + conv_id=self.conv_id) + + def dict(self): + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + "conv_id": self.conv_id, + } + + +cfg = Config(parse_args()) +vis_transforms = create_chest_xray_transform_for_inference(512, center_crop_size=448) +use_img = False +gen_report = True +pred_chexpert_labels = json.load(open('findings_classifier/predictions/structured_preds_chexpert_log_weighting_test_macro.json', 'r')) + +def init_blip(cfg): + task = tasks.setup_task(cfg) + model = task.build_model(cfg) + model = model.to(torch.device('cpu')) + return model + +def init_chexpert_predictor(): + ckpt_path = f"findings_classifier/checkpoints/chexpert_train/ChexpertClassifier-epoch=06-val_f1=0.36.ckpt" + chexpert_cols = ["No Finding", "Enlarged Cardiomediastinum", + "Cardiomegaly", "Lung Opacity", + "Lung Lesion", "Edema", + "Consolidation", "Pneumonia", + "Atelectasis", "Pneumothorax", + "Pleural Effusion", "Pleural Other", + "Fracture", "Support Devices"] + model = LitIGClassifier.load_from_checkpoint(ckpt_path, num_classes=14, class_names=chexpert_cols, strict=False) + model.eval() + model.cuda() + model.half() + cp_transforms = Compose([Resize(512), CenterCrop(488), ToTensor(), ExpandChannels()]) + + return model, np.asarray(model.class_names), cp_transforms + + +def remap_to_uint8(array: np.ndarray, percentiles=None) -> np.ndarray: + """Remap values in input so the output range is :math:`[0, 255]`. + + Percentiles can be used to specify the range of values to remap. + This is useful to discard outliers in the input data. + + :param array: Input array. + :param percentiles: Percentiles of the input values that will be mapped to ``0`` and ``255``. + Passing ``None`` is equivalent to using percentiles ``(0, 100)`` (but faster). + :returns: Array with ``0`` and ``255`` as minimum and maximum values. + """ + array = array.astype(float) + if percentiles is not None: + len_percentiles = len(percentiles) + if len_percentiles != 2: + message = ( + 'The value for percentiles should be a sequence of length 2,' + f' but has length {len_percentiles}' + ) + raise ValueError(message) + a, b = percentiles + if a >= b: + raise ValueError(f'Percentiles must be in ascending order, but a sequence "{percentiles}" was passed') + if a < 0 or b > 100: + raise ValueError(f'Percentiles must be in the range [0, 100], but a sequence "{percentiles}" was passed') + cutoff: np.ndarray = np.percentile(array, percentiles) + array = np.clip(array, *cutoff) + array -= array.min() + array /= array.max() + array *= 255 + return array.astype(np.uint8) + + +def load_image(path) -> Image.Image: + """Load an image from disk. + + The image values are remapped to :math:`[0, 255]` and cast to 8-bit unsigned integers. + + :param path: Path to image. + :returns: Image as ``Pillow`` ``Image``. + """ + # Although ITK supports JPEG and PNG, we use Pillow for consistency with older trained models + image = io.imread(path) + + image = remap_to_uint8(image) + return Image.fromarray(image).convert("L") + + +def init_vicuna(): + use_embs = True + + vicuna_tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.3", use_fast=False, truncation_side="left", padding_side="left") + lang_model = LlamaForCausalLM.from_pretrained("lmsys/vicuna-7b-v1.3", torch_dtype=torch.float16, device_map='auto') + vicuna_tokenizer.pad_token = vicuna_tokenizer.unk_token + + if use_embs: + lang_model.base_model.img_proj_layer = nn.Linear(768, lang_model.base_model.config.hidden_size).to(lang_model.base_model.device) + vicuna_tokenizer.add_special_tokens({"additional_special_tokens": ["<IMG>"]}) + + lang_model = PeftModelForCausalLM.from_pretrained(lang_model, + f"checkpoints/vicuna-7b-img-instruct/checkpoint-4800", + torch_dtype=torch.float16, use_ram_optimized_load=False).half() + # lang_model = PeftModelForCausalLM.from_pretrained(lang_model, f"checkpoints/vicuna-7b-img-report/checkpoint-11200", torch_dtype=torch.float16, use_ram_optimized_load=False).half() + return lang_model, vicuna_tokenizer + +blip_model = init_blip(cfg) +lang_model, vicuna_tokenizer = init_vicuna() +blip_model.eval() +lang_model.eval() + +cp_model, cp_class_names, cp_transforms = init_chexpert_predictor() + +def get_response(input_text, dicom): + global use_img, blip_model, lang_model, vicuna_tokenizer + + if input_text[-1].endswith(".png") or input_text[-1].endswith(".jpg"): + image = load_image(input_text[-1]) + cp_image = cp_transforms(image) + image = vis_transforms(image) + dicom = input_text[-1].split('/')[-1].split('.')[0] + if dicom in pred_chexpert_labels: + findings = ', '.join(pred_chexpert_labels[dicom]).lower().strip() + else: + logits = cp_model(cp_image[None].half().cuda()) + preds_probs = torch.sigmoid(logits) + preds = preds_probs > 0.5 + pred = preds[0].cpu().numpy() + findings = cp_class_names[pred].tolist() + findings = ', '.join(findings).lower().strip() + + if gen_report: + input_text = ( + f"Image information: <IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. " + "Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons.") + use_img = True + + blip_model = blip_model.to(torch.device('cuda')) + qformer_embs = blip_model.forward_image(image[None].to(torch.device('cuda')))[0].cpu().detach() + blip_model = blip_model.to(torch.device('cpu')) + # save image embedding with torch + torch.save(qformer_embs, 'current_chat_img.pt') + if not gen_report: + return None + + else: # free chat + input_text = input_text + findings = None + + '''Generate prompt given input prompt''' + conv.append_message(conv.roles[0], input_text) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + '''Call vicuna model to generate response''' + inputs = vicuna_tokenizer(prompt, return_tensors="pt") # for multiple inputs, use tokenizer.batch_encode_plus with padding=True + input_ids = inputs["input_ids"].cuda() + # lang_model = lang_model.cuda() + generation_output = lang_model.generate( + input_ids=input_ids, + dicom=[dicom] if dicom is not None else None, + use_img=use_img, + return_dict_in_generate=True, + output_scores=True, + max_new_tokens=300 + ) + # lang_model = lang_model.cpu() + + preds = vicuna_tokenizer.batch_decode(generation_output.sequences, skip_special_tokens=True) + new_pred = preds[0].split("ASSISTANT:")[-1] + # remove last message in conv + conv.messages.pop() + conv.append_message(conv.roles[1], new_pred) + return new_pred, findings + + +'''Conversation template for prompt''' +conv = Conversation( + system="A chat between a curious user and an artificial intelligence assistant." + "The assistant gives professional, detailed, and polite answers to the user's questions.", + roles=["USER", "ASSISTANT"], + messages=[], + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="</s>", +) + +# Global variable to store the DICOM string +dicom = None + + +# Function to update the global DICOM string +def set_dicom(value): + global dicom + dicom = value + + +def add_text(history, text): + history = history + [(text, None)] + return history, gr.update(value="", interactive=False) + + +def add_file(history, file): + history = history + [((file.name,), None)] + return history + + +# Function to clear the chat history +def clear_history(button_name): + global chat_history, use_img, conv + chat_history = [] + conv.clear() + use_img = False + return [] # Return empty history to the Chatbot + + +def bot(history): + # You can now access the global `dicom` variable here if needed + response, findings = get_response(history[-1][0], None) + print(response) + + # show report generation prompt if first message after image + if len(history) == 1: + input_text = f"You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons." + if findings is not None: + input_text = f"Image information: (img_tokens) Predicted Findings: {findings}. {input_text}" + history.append([input_text, None]) + + history[-1][1] = "" + if response is not None: + for character in response: + history[-1][1] += character + time.sleep(0.01) + yield history + + +if __name__ == '__main__': + with gr.Blocks() as demo: + + + chatbot = gr.Chatbot( + [], + elem_id="chatbot", + ) + + with gr.Row(): + txt = gr.Textbox( + show_label=False, + placeholder="Enter text and press enter, or upload an image", + container=False, + ) + + with gr.Row(): + btn = gr.UploadButton("📁 Upload image", file_types=["image"], scale=1) + clear_btn = gr.Button("Clear History", scale=1) + + clear_btn.click(clear_history, [chatbot], [chatbot]) + + txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( + bot, chatbot, chatbot + ) + txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False) + file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then( + bot, chatbot, chatbot + ) + + demo.queue() + demo.launch()