Diff of /data/create_data.py [000000] .. [4abb48]

Switch to side-by-side view

--- a
+++ b/data/create_data.py
@@ -0,0 +1,432 @@
+import argparse
+import dataclasses
+import json
+import os
+from enum import auto, Enum
+from pathlib import Path
+from typing import List, Any
+import random
+
+import numpy as np
+import pandas as pd
+import torch
+from omegaconf import OmegaConf
+from torch.utils.data import Dataset, DataLoader
+from tqdm import tqdm
+from transformers import AutoTokenizer
+from torch.utils.data.sampler import Sampler
+
+from data.instruct_tasks import create_direct_task_data, create_cp_task_data, create_correction_task_data, create_nle_task_data
+from local_config import VIS_ROOT, PATH_TO_MIMIC_CXR
+from model.lavis.models.blip2_models.modeling_llama_imgemb import LlamaForCausalLM
+
+
+class MyReportProcessor():
+    def __init__(self, prompt="", max_words=50, prompt_neg=""):
+        self.prompt = prompt
+        self.max_words = max_words
+        self.prompt_neg = prompt_neg
+
+    def __call__(self, findings, no_labels=False):
+        prompt = self.prompt
+
+        if no_labels:
+            findings = "no common findings"  # cannot write which findings as we don't no them
+        prompt = prompt.format(findings=findings)
+
+        return prompt
+
+    @classmethod
+    def from_config(cls, cfg=None):
+        if cfg is None:
+            cfg = OmegaConf.create()
+
+        prompt = cfg.get("prompt", "")
+        max_words = cfg.get("max_words", 50)
+
+        return cls(prompt=prompt, max_words=max_words)
+
+
+class SeparatorStyle(Enum):
+    """Different separator style."""
+    SINGLE = auto()
+    TWO = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+    """A class that keeps all conversation history."""
+    system: str
+    roles: List[str]
+    messages: List[List[str]]
+    offset: int
+    sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+    sep: str = "###"
+    sep2: str = None
+
+    # Used for gradio server
+    skip_next: bool = False
+    conv_id: Any = None
+
+    def get_prompt(self):
+        if self.sep_style == SeparatorStyle.SINGLE:
+            ret = self.system
+            for role, message in self.messages:
+                if message:
+                    ret += self.sep + " " + role + ": " + message
+                else:
+                    ret += self.sep + " " + role + ":"
+            return ret
+        elif self.sep_style == SeparatorStyle.TWO:
+            seps = [self.sep, self.sep2]
+            ret = self.system + seps[0]
+            for i, (role, message) in enumerate(self.messages):
+                if message:
+                    ret += role + ": " + message + seps[i % 2]
+                else:
+                    ret += role + ":"
+            return ret
+        else:
+            raise ValueError(f"Invalid style: {self.sep_style}")
+
+    def append_message(self, role, message):
+        self.messages.append([role, message])
+
+    def dict(self):
+        return {
+            "system": self.system,
+            "roles": self.roles,
+            "messages": self.messages,
+            "offset": self.offset,
+            "sep": self.sep,
+            "sep2": self.sep2,
+            "conv_id": self.conv_id,
+        }
+
+
+def create_conv():
+    conv = Conversation(
+        system="A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. "
+               "The assistant gives professional, detailed, and polite answers to the user's questions.",
+        roles=["USER", "ASSISTANT"],
+        messages=[],
+        offset=0,
+        sep_style=SeparatorStyle.TWO,
+        sep=" ",
+        sep2="</s>",
+    )
+    return conv
+
+
+class MIMIC_Text_Dataset(Dataset):
+    def __init__(self, split, truncate=None, prompt_type="basic", use_indication=False):
+        super().__init__()
+
+        # load csv file
+        self.split = pd.read_csv(
+            f'{PATH_TO_MIMIC_CXR}/mimic-cxr-jpg/2.0.0/mimic-cxr-2.0.0-split.csv')
+        self.reports = pd.read_csv('mimic-cxr/reports_processed/mimic_cxr_sectioned.csv')
+        # drop reports where findings are nan
+        self.reports = self.reports.dropna(subset=['findings'])
+
+        self.img_ids = {img_id: i for i, img_id in enumerate(self.reports['dicom_id'])}
+        self.chexpert = pd.read_csv(f'data/data_files/finding_chexbert_labels.csv')
+        self.chexpert_cols = ["No Finding", "Enlarged Cardiomediastinum",
+                              "Cardiomegaly", "Lung Opacity",
+                              "Lung Lesion", "Edema",
+                              "Consolidation", "Pneumonia",
+                              "Atelectasis", "Pneumothorax",
+                              "Pleural Effusion", "Pleural Other",
+                              "Fracture", "Support Devices"]
+
+        self.use_indication = use_indication
+
+        self.vis_root = VIS_ROOT
+
+        self.prompt_type = prompt_type
+
+        self.split_ids = set(self.split.loc[self.split['split'] == split]['dicom_id'])
+        self.train_ids = set(self.split.loc[self.split['split'] == 'train']['dicom_id'])
+
+        # get all dicom_ids where "split" is split
+        self.annotation = self.reports.loc[self.reports['dicom_id'].isin(self.split_ids)]
+        if truncate is not None:
+            self.annotation = self.annotation[:truncate]
+
+        self.annotation['findings'] = self.annotation['findings'].apply(lambda x: x.replace('\n', ''))
+
+        # Extract patient_id from Img_Folder (3rd part) and study_id is the name of the notefile without the pre-pending 's'
+        self.annotation['subject_id'] = self.annotation['Img_Folder'].apply(lambda x: int(x.split('/')[2].lstrip('p')))
+        self.annotation['study_id'] = self.annotation['Note_file'].apply(lambda x: int(x.lstrip('s').rstrip('.txt')))
+
+        # Merge chexpert labels with annotation dataframe
+        self.annotation = pd.merge(self.annotation, self.chexpert, how='left', left_on=['dicom_id'],
+                                   right_on=['dicom_id'])
+
+        # for every row add a string of comma-separated positive labels
+        self.annotation['positive_labels'] = self.annotation.apply(lambda x: self.convert_to_finding_labels(x[self.chexpert_cols].values,
+                                                                                                            self.chexpert_cols), axis=1)
+
+        # maybe use transforms from here: ResNet50_Weights.IMAGENET1K_V2.transforms
+        # read prompt from json
+        prompts = json.loads(Path(f"vicuna_prompts.json").read_text(encoding="UTF-8"))
+        self.text_processor = MyReportProcessor(
+            prompt=prompts[prompt_type], max_words=1000,
+            prompt_neg=prompts[prompt_type.replace("matching_examples", "neg_matching_examples")])
+
+    def convert_to_finding_labels(self, chexpert_labels, columns, label=1):
+        # Get indices where value is 1
+        indices = np.where(chexpert_labels == label)
+        # Get the corresponding column names and join them into a string
+        labels = ", ".join([columns[i] for i in indices[0]])
+        return labels
+
+    def __getitem__(self, index):
+        ann = self.annotation.iloc[index]
+        # if self.use_indication:
+        #     indication = self.indications[study_id]
+        #     if indication == "":
+        #         indication = "Indication not given."
+        caption = ann["findings"].strip()
+        chexpert_labels = ann[self.chexpert_cols].astype(float).values
+        chexpert_label_str = ann["positive_labels"]
+        dicom_id = ann["dicom_id"]
+
+        # check if all columns are in (nan, 0) -> no labels
+        no_labels = np.all((np.isnan(chexpert_labels)) | (chexpert_labels == 0) | (chexpert_labels == -1.))
+        finding_string = chexpert_label_str.lower().strip()
+
+        input_text = self.text_processor(findings=finding_string, no_labels=no_labels)
+
+        # if self.use_indication:
+        #     input_text = "Indication: " + indication + " " + input_text
+
+        # template for vicuna v1.3
+        conv = Conversation(
+            system="A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. "
+                   "The assistant gives professional, detailed, and polite answers to the user's questions.",
+            roles=["USER", "ASSISTANT"],
+            messages=[],
+            offset=0,
+            sep_style=SeparatorStyle.TWO,
+            sep=" ",
+            sep2="</s>",
+        )
+        conv.append_message(conv.roles[0], input_text)
+        conv.append_message(conv.roles[1], None)
+        prompt = conv.get_prompt()
+
+        return {
+            "text_input": prompt,
+            "text_target": caption,
+            "ig_label_string": finding_string,
+            "chexpert_labels": chexpert_labels,
+            "chexpert_cols": self.chexpert_cols,
+            "dicom": dicom_id,
+            "img_path": ann["Img_Folder"] + "/" + ann["Img_Filename"],
+        }
+
+    def __len__(self):
+        return len(self.annotation)
+
+
+class SubsetSampler(Sampler):
+    def __init__(self, indices):
+        self.indices = indices
+
+    def __iter__(self):
+        return (self.indices[i] for i in range(len(self.indices)))
+
+    def __len__(self):
+        return len(self.indices)
+
+
+def stratified_sample(df, simulated_epochs=1):
+    # We want to reduce the number of examples with no finding to 1/14th of the dataset. We achieve this easily by first seperating the dataset into 2 groups: no finding and finding.
+    # either no finding, or nothing is considered a no finding
+    no_findings_indices = df.annotation[((df.annotation['No Finding'] == 1) | ((df.annotation[df.chexpert_cols] == 1).sum(1) == 0) == 1)].index
+    finding_indices = df.annotation.index.difference(no_findings_indices)
+    no_findings_indices = no_findings_indices.tolist()
+    finding_indices = finding_indices.tolist()
+
+    # we are striving to lose as little no_finding data as possible. So instead of just reducing the number of no_finding examples, we will increase the number of finding examples. Just clone and extend dataset
+    finding_indices = finding_indices * simulated_epochs
+    # subsample the no finding examples to be 1/14th of the new dataset
+    new_dataset_size = len(finding_indices) * 14 / 13
+    new_no_finding_count = int(new_dataset_size / 14)
+    # merge considering the new dataset size
+    all_indices = finding_indices + random.sample(no_findings_indices, new_no_finding_count)
+    return all_indices
+
+
+def create_report_data_vicuna_specific_stratified(prompt_type):
+    val_dataset = MIMIC_Text_Dataset(split="train", truncate=None, prompt_type=prompt_type)
+    stratified_indices = stratified_sample(val_dataset, simulated_epochs=2)
+    sampler = SubsetSampler(stratified_indices)
+    data_loader = DataLoader(val_dataset, batch_size=200, num_workers=200, sampler=sampler)
+
+    report_jsons = []
+    for _, batch in tqdm(enumerate(data_loader)):
+        # iterate over batch elements
+        for i in range(len(batch["text_input"])):
+            text_input = batch["text_input"][i]
+            text_target = batch["text_target"][i]
+            dicom = batch["dicom"][i]
+
+            # sample random prompt for every report
+            reports_json = {
+                "instruction": text_input,
+                "input": "",
+                "output": text_target,
+                "dicom": dicom,
+            }
+            report_jsons.append(reports_json)
+
+    # Save the JSON data to a file
+    with open("data/data_files/mimic_cxr_reports_stratified.json", "w") as f:
+        json.dump(report_jsons, f, ensure_ascii=False, indent=4)
+
+
+'''
+this method saves instruct data jsons for all the different tasks we defined:
+- easy language: EL DONE
+- correction: CO DONE
+- summerization: SU DONE
+- reasoning: RE (based on MIMIC-NLE) DONE
+- region QA: RQA DONE
+- CP binary QA: CPbQA DONE
+- CP all QA: CPaQA DONE
+
+for every report we sample one task and one prompt and save the report, the question (task) and the answer generated by vicuna (or from dataset groundtruth)
+'''
+
+
+def create_report_data_vicuna_instruct_large():
+    lang_model = LlamaForCausalLM.from_pretrained("lmsys/vicuna-13b-v1.3", torch_dtype=torch.float16, device_map='auto', load_in_8bit=False)
+    tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-13b-v1.3", use_fast=False, truncation_side="left", padding_side="left")
+    tokenizer.pad_token = tokenizer.unk_token
+
+    val_dataset = MIMIC_Text_Dataset(split="train", truncate=None, prompt_type="img_matching_examples_ig2_noexamples")
+    # split in 6 portions of 1/6th each, randomly
+    split_size = len(val_dataset) // 6
+    remainder = len(val_dataset) % 6
+
+    val_dataset_EL, _, val_dataset_SU, val_dataset_EX, val_dataset_RQA, val_dataset_CPQA = torch.utils.data.random_split(val_dataset,
+                                                                                                                         [split_size + (i < remainder)
+                                                                                                                          for i in range(
+                                                                                                                             6)])  # correction is samples somewhere else
+
+    # split val_dataset_CPQA in 2
+    split_size = len(val_dataset_CPQA) // 2
+    remainder = len(val_dataset_CPQA) % 2
+    val_dataset_CPbQA, val_dataset_CPaQA = torch.utils.data.random_split(val_dataset_CPQA, [split_size + (i < remainder) for i in range(2)])
+
+    # create directory
+    if not os.path.exists("data/large_instruct_data"):
+        os.makedirs("data/large_instruct_data")
+
+    # create data
+    create_direct_task_data(lang_model, tokenizer, val_dataset_EL, task_name="EL")
+    create_direct_task_data(lang_model, tokenizer, val_dataset_SU, task_name="SU")
+    create_direct_task_data(lang_model, tokenizer, val_dataset_RQA, task_name="RQA")
+    create_cp_task_data(val_dataset_CPbQA, task_name="CPbQA")
+    create_cp_task_data(val_dataset_CPaQA, task_name="CPaQA")
+
+    create_correction_task_data(lang_model, tokenizer)
+    create_nle_task_data()
+
+
+'''
+fuse instruct data with report generation task into one dataset json
+'''
+
+
+def fuse_instruct_dataset(prompt_type="img_matching_examples_ig2_noexamples_IMG_findings"):
+    # get report generation data
+    val_dataset = MIMIC_Text_Dataset(split="train", truncate=None, prompt_type=prompt_type)
+    stratified_indices = stratified_sample(val_dataset, simulated_epochs=2)
+    sampler = SubsetSampler(stratified_indices)
+    data_loader = DataLoader(val_dataset, batch_size=200, sampler=sampler, num_workers=200)
+    report_jsons = []
+    for _, batch in tqdm(enumerate(data_loader)):
+        # iterate over batch elements
+        for i in range(len(batch["text_input"])):
+            text_input = batch["text_input"][i]
+            text_target = batch["text_target"][i]
+            dicom = batch["dicom"][i]
+
+            # sample random prompt for every report
+            reports_json = {
+                "instruction": text_input,
+                "input": "",
+                "output": text_target,
+                "dicom": dicom,
+            }
+            report_jsons.append(reports_json)
+
+    task_jsons = []
+    with open(f"vicuna_prompts.json", "r") as f:
+        prompts = json.load(f)
+    report_prompt = prompts[prompt_type]
+
+    # get instruct data
+    for task in ["EL", "RE", "CO", "SU", "RQA", "CPbQA", "CPaQA"]:
+        print("Creating data for " + task)
+        with open(f"data/large_instruct_data/instruct_large_{task}.json", "r") as f:
+            task_data = json.load(f)
+
+        for elem in tqdm(task_data):
+            report = elem["gt_report"] if task != "CO" else elem["incorrect_report"]
+
+            conv = create_conv()
+            conv.append_message(conv.roles[0], report_prompt)
+            conv.append_message(conv.roles[1], report)
+            conv.append_message(conv.roles[0], elem["task"])
+            conv.append_message(conv.roles[1], None)
+
+            instruction = conv.get_prompt()
+
+            # get elem directly from val_dataset.train_annotation with same dicom
+            orig_elem = val_dataset.annotation[val_dataset.annotation["dicom_id"] == elem["dicom"]].iloc[0]
+
+            if type(orig_elem['positive_labels']) == float and np.isnan(orig_elem['positive_labels']):
+                finding_str = "no common findings"
+            else:
+                finding_str = orig_elem['positive_labels'].lower().strip()
+            instruction = instruction.format(findings=finding_str)
+
+            task_json = {
+                "instruction": instruction,
+                "input": "",
+                "output": elem["output"].lower().strip() if task == "CPaQA" else elem["output"].strip(),
+                "dicom": elem["dicom"],
+            }
+            task_jsons.append(task_json)
+
+    # combine and shuffle report and task jsons
+    combined_jsons = report_jsons + task_jsons
+    random.shuffle(combined_jsons)
+
+    # save to json
+    with open(f"data/data_files/mimic_cxr_instruct_stratified.json", "w") as f:
+        json.dump(combined_jsons, f, indent=4)
+
+
+if __name__ == '__main__':
+    # args parser
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--mode', type=str, default='RG', help='RG or INS')
+    args = parser.parse_args()
+
+    ''' Create data to train RaDialog-RG model'''
+    if args.mode == 'RG':
+        create_report_data_vicuna_specific_stratified(prompt_type="img_matching_examples_ig2_noexamples_IMG_findings")
+
+    ''' Create data to train RaDialog-INS model'''
+    if args.mode == 'INS':
+        create_report_data_vicuna_instruct_large()
+        fuse_instruct_dataset()
+
+    # This code is meant for understanding how our instruct dataset is created.
+    # Due to randomness in the sampling and model predictions, a newly generated dataset could be slightly different.
+    # To exactly reproduce our results, please use the instruct dataset we published and use 'fuse_instruct_dataset' to merge with your MIMIC data.
\ No newline at end of file