--- a +++ b/model/lavis/data/ReportDataset.py @@ -0,0 +1,452 @@ +import dataclasses +import json +import os +import re +import time +from enum import Enum, auto +from pathlib import Path +from typing import List, Any + +from local_config import PATH_TO_MIMIC_CXR, JAVA_HOME, JAVA_PATH + +# set java path +os.environ["JAVA_HOME"] = JAVA_HOME +os.environ["PATH"] = JAVA_PATH + os.environ["PATH"] +os.environ['GRADIO_TEMP_DIR'] = os.path.join(os.getcwd(), "gradio_tmp") + +import numpy as np +import pandas as pd +import torch +from PIL import Image +from nltk import word_tokenize +from omegaconf import OmegaConf +from pycocoevalcap.bleu.bleu import Bleu +from pycocoevalcap.meteor.meteor import Meteor +from pycocoevalcap.rouge.rouge import Rouge +from skimage import io +from sklearn.metrics import classification_report, accuracy_score +from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop, transforms + +from model.lavis.processors import BaseProcessor +from model.lavis.common.registry import registry +from model.lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder +from model.lavis.datasets.datasets.base_dataset import BaseDataset +from model.lavis.datasets.datasets.caption_datasets import __DisplMixin + + +@registry.register_processor("my_blip_caption") +class MyBlipCaptionProcessor(BaseProcessor): + def __init__(self, prompt="", max_words=50): + self.prompt = prompt + self.max_words = max_words + + def __call__(self, caption): + caption = self.prompt + self.pre_caption(caption) + + return caption + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + prompt = cfg.get("prompt", "") + max_words = cfg.get("max_words", 50) + + return cls(prompt=prompt, max_words=max_words) + + def pre_caption(self, caption): + caption = re.sub( + r"([!\"()*#;~])", + " ", + caption, + ) + caption = re.sub( + r"\s{2,}", + " ", + caption, + ) + caption = caption.rstrip("\n") + caption = caption.strip(" ") + + # truncate caption + caption_words = caption.split(" ") + if len(caption_words) > self.max_words: + caption = " ".join(caption_words[: self.max_words]) + + return caption + + +class ExpandChannels: + """ + Transforms an image with one channel to an image with three channels by copying + pixel intensities of the image along the 1st dimension. + """ + + def __call__(self, data: torch.Tensor) -> torch.Tensor: + """ + :param data: Tensor of shape [1, H, W]. + :return: Tensor with channel copied three times, shape [3, H, W]. + """ + if data.shape[0] != 1: + raise ValueError(f"Expected input of shape [1, H, W], found {data.shape}") + return torch.repeat_interleave(data, 3, dim=0) + + +def create_chest_xray_transform_for_inference(resize: int, center_crop_size: int) -> Compose: + """ + Defines the image transformation pipeline for Chest-Xray datasets. + + :param resize: The size to resize the image to. Linear resampling is used. + Resizing is applied on the axis with smaller shape. + :param center_crop_size: The size to center crop the image to. Square crop is applied. + """ + + transforms = [Resize(resize), CenterCrop(center_crop_size), ToTensor(), ExpandChannels()] + return Compose(transforms) + +class SeparatorStyle(Enum): + """Different separator style.""" + SINGLE = auto() + TWO = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + + # Used for gradio server + skip_next: bool = False + conv_id: Any = None + + def get_prompt(self): + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + for role, message in self.messages: + if message: + ret += self.sep + " " + role + ": " + message + else: + ret += self.sep + " " + role + ":" + return ret + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def append_message(self, role, message): + self.messages.append([role, message]) + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + conv_id=self.conv_id) + + def dict(self): + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + "conv_id": self.conv_id, + } + +class MyReportProcessor(): + def __init__(self, prompt="", max_words=50, prompt_neg=""): + self.prompt = prompt + self.max_words = max_words + self.prompt_neg = prompt_neg + + def __call__(self, findings, no_labels=False): + prompt = self.prompt + + if no_labels: + findings = "no common findings" # cannot write which findings as we don't no them + prompt = prompt.format(findings=findings) + + return prompt + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + prompt = cfg.get("prompt", "") + max_words = cfg.get("max_words", 50) + + return cls(prompt=prompt, max_words=max_words) + + +class MIMIC_CXR_Dataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, split, cfg, ann_paths=[], truncate=None): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + # load csv file + self.split = pd.read_csv(f'{PATH_TO_MIMIC_CXR}/mimic-cxr-jpg/2.0.0/mimic-cxr-2.0.0-split.csv') + self.cur_split = split + self.reports = pd.read_csv('mimic-cxr/reports_processed/mimic_cxr_sectioned.csv') + # drop reports where findings are nan + self.reports = self.reports.dropna(subset=['findings']) + + self.use_pred_labels = True + + self.chexpert = pd.read_csv(f'data/data_files/finding_chexbert_labels.csv') + self.chexpert_cols = ["No Finding", "Enlarged Cardiomediastinum", + "Cardiomegaly", "Lung Opacity", + "Lung Lesion", "Edema", + "Consolidation", "Pneumonia", + "Atelectasis", "Pneumothorax", + "Pleural Effusion", "Pleural Other", + "Fracture", "Support Devices"] + + self.custom_epochs_per_epoch = 2 if split == 'train' and cfg.run_cfg.task != "image_text_pretrain_eval" and truncate==None else 1 + self.current_custom_epoch = 0 + self.vit_model = cfg.model_cfg['vit_model'] + self.img_size = cfg.datasets_cfg.mimic_cxr.vis_processor.train.image_size # should be 224 for coco models, 448 for biovil models + if self.vit_model == 'biovil': + self.vis_transforms = create_chest_xray_transform_for_inference(512, center_crop_size=self.img_size) + if split == 'train': + self.vis_augs = transforms.Compose([transforms.RandomAffine(degrees=30, shear=15), + transforms.ColorJitter(brightness=0.2, contrast=0.2)]) + + self.img_ids = {img_id: i for i, img_id in enumerate(self.reports['dicom_id'])} + self.id_to_dicom = {v: k for k, v in self.img_ids.items()} + self.split_ids = set(self.split.loc[self.split['split'] == split]['dicom_id']) + + # get all dicom_ids where "split" is split + self.annotation = self.reports.loc[self.reports['dicom_id'].isin(self.split_ids)] + if truncate is not None: + self.annotation = self.annotation[:truncate] + + self.annotation['findings'] = self.annotation['findings'].apply(lambda x: x.replace('\n', '')) + + # Extract patient_id from Img_Folder (3rd part) and study_id is the name of the notefile without the pre-pending 's' + self.annotation['subject_id'] = self.annotation['Img_Folder'].apply(lambda x: int(x.split('/')[2].lstrip('p'))) + self.annotation['study_id'] = self.annotation['Note_file'].apply(lambda x: int(x.lstrip('s').rstrip('.txt'))) + + # Merge chexpert labels with annotation dataframe + self.annotation = pd.merge(self.annotation, self.chexpert, how='left', left_on=['dicom_id'], right_on=['dicom_id']) + + + add_findings_in_prompt = cfg.run_cfg.get("add_findings_in_prompt", False) + self.prompt = cfg.datasets_cfg.mimic_cxr.text_processor.train.prompt if split == 'train' \ + else cfg.datasets_cfg.mimic_cxr.text_processor.eval.prompt + + self.text_processor = MyReportProcessor( + prompt=self.prompt, max_words=1000) + + self.evaluator = MIMICEvalCap(self.annotation, self.img_ids) + + def set_custom_epoch(self, custom_epoch): + self.current_custom_epoch = custom_epoch + + def remap_to_uint8(self, array: np.ndarray, percentiles=None) -> np.ndarray: + """Remap values in input so the output range is :math:`[0, 255]`. + + Percentiles can be used to specify the range of values to remap. + This is useful to discard outliers in the input data. + + :param array: Input array. + :param percentiles: Percentiles of the input values that will be mapped to ``0`` and ``255``. + Passing ``None`` is equivalent to using percentiles ``(0, 100)`` (but faster). + :returns: Array with ``0`` and ``255`` as minimum and maximum values. + """ + array = array.astype(float) + if percentiles is not None: + len_percentiles = len(percentiles) + if len_percentiles != 2: + message = ( + 'The value for percentiles should be a sequence of length 2,' + f' but has length {len_percentiles}' + ) + raise ValueError(message) + a, b = percentiles + if a >= b: + raise ValueError(f'Percentiles must be in ascending order, but a sequence "{percentiles}" was passed') + if a < 0 or b > 100: + raise ValueError(f'Percentiles must be in the range [0, 100], but a sequence "{percentiles}" was passed') + cutoff: np.ndarray = np.percentile(array, percentiles) + array = np.clip(array, *cutoff) + array -= array.min() + array /= array.max() + array *= 255 + return array.astype(np.uint8) + + def load_image(self, path) -> Image.Image: + """Load an image from disk. + + The image values are remapped to :math:`[0, 255]` and cast to 8-bit unsigned integers. + + :param path: Path to image. + :returns: Image as ``Pillow`` ``Image``. + """ + # Although ITK supports JPEG and PNG, we use Pillow for consistency with older trained models + if path.suffix in [".jpg", ".jpeg", ".png"]: + image = io.imread(path) + else: + raise ValueError(f"Image type not supported, filename was: {path}") + + image = self.remap_to_uint8(image) + return Image.fromarray(image).convert("L") + + + def __getitem__(self, index): + subset_size = len(self.annotation) // self.custom_epochs_per_epoch + start_index = self.current_custom_epoch * subset_size + actual_index = start_index + index + + ann = self.annotation.iloc[actual_index] + + image_path = os.path.join(self.vis_root, ann["Img_Folder"], ann["Img_Filename"]) + if self.vit_model == "biovil": # old version worked with smaller img and without biovil img processing + image = self.load_image(Path(image_path)) + image = self.vis_transforms(image) + + caption = ann["findings"].strip() + input_text = self.text_processor(findings=None) + + conv = Conversation( + system="A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. " + "The assistant gives professional, detailed, and polite answers to the user's questions.", + roles=["USER", "ASSISTANT"], + messages=[], + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="</s>", + ) + conv.append_message(conv.roles[0], input_text) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + if "<IMG>" in prompt: + before_img, after_img = prompt.split("<IMG>") + prompt = (before_img, after_img) + + + return { + "image": image, + "text_input": prompt, + "text_output": caption, + "image_id": self.img_ids[ann["dicom_id"]], + # "index": index, + # "string_index": str(index) + } + + def __len__(self): + return len(self.annotation) // self.custom_epochs_per_epoch + + +@registry.register_builder("mimic_cxr") +class MIMIC_CXR_Builder(BaseDatasetBuilder): + train_dataset_cls = MIMIC_CXR_Dataset + eval_dataset_cls = MIMIC_CXR_Dataset + + DATASET_CONFIG_DICT = { + "default": "defaults_report.yaml" + } + + +class MIMICEvalCap: + def __init__(self, gts, img_id_map): + + self.gts = gts + + # invert img_id_map + self.dicom_to_id = img_id_map + self.id_to_dicom = {v: k for k, v in img_id_map.items()} + + print('setting up scorers...') + self.scorers = [ + (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), + (Meteor(), "METEOR"), + (Rouge(), "ROUGE_L") + ] + + + def preprocess(self, s): + s = s.replace('\n', '') + s = s.replace('<s>', '') + s = s.replace('</s>', '') + return s + + def evaluate(self, res): + + res = {self.id_to_dicom[elem["image_id"]]: elem["caption"] for elem in res} + res_keys_set = set(res.keys()) + gts = {} + gts_img_id = {} + for _, elem in self.gts.iterrows(): + dicom_id = elem["dicom_id"] + if dicom_id in res_keys_set: + gts[dicom_id] = [elem["findings"]] + gts_img_id[self.dicom_to_id[dicom_id]] = [elem["findings"]] + + # gts = {elem["dicom_id"]: [elem["findings"]] for _, elem in self.gts.iterrows() if elem["dicom_id"] in res.keys()} + # gts_img_id = {self.dicom_to_id[elem["findings"]]: [elem["Note"]] for _, elem in self.gts.iterrows() if elem["dicom_id"] in res.keys()} + assert res.keys() == gts.keys() + # ================================================= + # Pre-process sentences + # ================================================= + print('tokenization...') + for dicom in res.keys(): + pred_text = ' '.join(word_tokenize(self.preprocess(res[dicom]))).lower() + true_text = ' '.join(word_tokenize(self.preprocess(gts[dicom][0]))).lower() + + res[dicom] = [pred_text] + gts[dicom] = [true_text] + + # ================================================= + # Compute scores + # ================================================= + final_scores = {} + for scorer, method in self.scorers: + print('computing %s score...' % (scorer.method())) + score, scores = scorer.compute_score(gts, res) + if type(method) == list: + for sc, scs, m in zip(score, scores, method): + final_scores[m] = sc + #final_scores["elem_wise_" + str(m)] = scs + print("%s: %0.3f" % (m, sc)) + else: + print("%s: %0.3f" % (method, score)) + #final_scores["elem_wise_" + str(method)] = scores + final_scores[method] = score + + final_scores['agg_metrics'] = np.mean(list({k: v for k, v in final_scores.items() if "elem_wise" not in k}.values())) + + return final_scores, gts_img_id