--- a +++ b/data/instruct_tasks.py @@ -0,0 +1,360 @@ +import dataclasses +import json +import random +from enum import Enum, auto +from pathlib import Path +from typing import List, Any + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from local_config import PATH_TO_MIMIC_NLE + + +class SeparatorStyle(Enum): + """Different separator style.""" + SINGLE = auto() + TWO = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + + # Used for gradio server + skip_next: bool = False + conv_id: Any = None + + def get_prompt(self): + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + for role, message in self.messages: + if message: + ret += self.sep + " " + role + ": " + message + else: + ret += self.sep + " " + role + ":" + return ret + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def append_message(self, role, message): + self.messages.append([role, message]) + + def dict(self): + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + "conv_id": self.conv_id, + } + + +def create_conv(): + conv = Conversation( + system="A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. " + "The assistant gives professional, detailed, and polite answers to the user's questions.", + roles=["USER", "ASSISTANT"], + messages=[], + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="</s>", + ) + return conv + + +def create_direct_task_data(lang_model, tokenizer, val_dataset, task_name): + prompts = pd.read_csv(f"data/instruct_prompts/{task_name}_prompts.csv")["instruction"].tolist() + data_loader = DataLoader(val_dataset, batch_size=12, shuffle=False, num_workers=0) + report_jsons = [] + print("Dataloader len: ", len(data_loader)) + for _, batch in tqdm(enumerate(data_loader)): + + # Create prompts for every report + # sample batchsize questions from EL_prompts + batch_prompts = random.choices(prompts, k=len(batch["text_input"])) + batch_instructions = [] + for text_target, prompt in zip(batch["text_target"], batch_prompts): + conv = create_conv() + conv.append_message(conv.roles[0], "Report: " + text_target + "\n" + prompt) + conv.append_message(conv.roles[1], None) + batch_instructions.append(conv.get_prompt()) + + inputs = tokenizer.batch_encode_plus(batch_instructions, return_tensors="pt", padding=True) + input_ids = inputs["input_ids"].to(torch.device("cuda")) + + # generate answers with no-lora vicuna + generation_output = lang_model.generate( + input_ids=input_ids, + dicom=None, + return_dict_in_generate=True, + output_scores=True, + max_new_tokens=256 + ) + preds = tokenizer.batch_decode(generation_output.sequences, skip_special_tokens=True) + preds = [p.split("ASSISTANT:")[1] for idx, p in enumerate(preds)] + + # iterate over batch elements + for i in range(len(batch["text_input"])): + text_target = batch["text_target"][i] # GT report + task_prompt = batch_prompts[i] + task_instruction = batch_instructions[i] + answer = preds[i] + dicom = batch["dicom"][i] + + # sample random prompt for every report + reports_json = { + "gt_report": text_target, + "task": task_prompt, + "instruction": task_instruction, + "input": "", + "output": answer, + "dicom": dicom, + "task_type": task_name + } + report_jsons.append(reports_json) + + # save + with open(f"data/large_instruct_data/instruct_large_{task_name}.json", "w") as f: + json.dump(report_jsons, f, ensure_ascii=False, indent=4) + + +def create_cp_task_data(val_dataset, task_name): + prompts = pd.read_csv(f"data/instruct_prompts/{task_name}_prompts.csv")["instruction"].tolist() + data_loader = DataLoader(val_dataset, batch_size=200, shuffle=False, num_workers=200) + report_jsons = [] + for _, batch in tqdm(enumerate(data_loader)): + + # Create prompts for every report + # sample batchsize questions from EL_prompts + batch_prompts = random.choices(prompts, k=len(batch["text_input"])) + + # iterate over batch elements + for i in range(len(batch["text_input"])): + text_target = batch["text_target"][i] # GT report + task_prompt = batch_prompts[i] + cp_indices = np.where(batch["chexpert_labels"][i] == 1.) + cp_findings = [val_dataset.dataset.dataset.chexpert_cols[i] for i in cp_indices[0]] + + if task_name == "CPbQA": # binary QA + if "No Finding" in cp_findings: + cp_findings.remove("No Finding") + # 50% sample finding from cp_findings, 50% sample finding from val_dataset.dataset.dataset.chexpert_cols - cp_findings + if random.random() < 0.6 and len(cp_findings) > 0: + finding = random.choice(cp_findings) # answer: yes + answer = 'yes' + else: + finding = random.choice(list(set(val_dataset.dataset.dataset.chexpert_cols[1:]) - set(cp_findings))) # answer: no + answer = 'no' + task_prompt = task_prompt.replace("<X>", finding) + + elif task_name == "CPaQA": # give all findings + answer = ', '.join(cp_findings) + + dicom = batch["dicom"][i] + + # sample random prompt for every report + reports_json = { + "gt_report": text_target, + "task": task_prompt, + "input": "", + "output": answer, + "dicom": dicom, + "task_type": task_name + } + report_jsons.append(reports_json) + + # save + with open(f"data/large_instruct_data/instruct_large_{task_name}.json", "w") as f: + json.dump(report_jsons, f, ensure_ascii=False, indent=4) + + +class CorrectionDataset(Dataset): + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + sample = self.data[idx] + + fp = sample["fp"] + fn = sample["fn"] + fp_str = ', '.join(fp) + fp_str = fp_str.rsplit(', ', 1) + fp_str = ' and '.join(fp_str) + fn_str = ', '.join(fn) + fn_str = fn_str.rsplit(', ', 1) + fn_str = ' and '.join(fn_str) + + gt_report = sample["gt_report"] + pred_report = sample["pred_report"] + dicom = sample["dicom"] + return {'gt_report': gt_report, 'pred_report': pred_report, 'fp': fp_str, 'fn': fn_str, 'dicom': dicom} + + +def create_correction_task_data(lang_model, tokenizer): + # load correction json + with open("data/instruct_prompts/instruct_task_correction_preds.json") as f: + correction_preds = json.load(f) + + # create pytorch dataset from json + correction_dataset = CorrectionDataset(correction_preds) + data_loader = DataLoader(correction_dataset, batch_size=12, shuffle=False, num_workers=12) + + prompts_both = pd.read_csv(f"data/instruct_prompts/CO_both_prompts.csv")["instruction"].tolist() + prompts_add = pd.read_csv(f"data/instruct_prompts/CO_add_prompts.csv")["instruction"].tolist() + prompts_rem = pd.read_csv(f"data/instruct_prompts/CO_rem_prompts.csv")["instruction"].tolist() + report_jsons = [] + for _, batch in tqdm(enumerate(data_loader)): + # use very clear, fixed prompt for data generation -> in training use random prompts + + fixed_batch_prompts = [] + for fp, fn in zip(batch["fp"], batch["fn"]): + fixed_corr_prompt = "Please provide an adapted report. " + if fp != "": + fixed_corr_prompt += f"Do not mention {fp}. " + if fn != "": + fixed_corr_prompt += f"Mention {fn}. " + + if fp == "" and fn == "": + fixed_corr_prompt = "NOCHANGE" + fixed_batch_prompts.append(fixed_corr_prompt.strip()) + + batch_prompts = [] + for fp, fn in zip(batch["fp"], batch["fn"]): + if fp == "" and fn == "": + batch_prompts.append("NOCHANGE") + elif fp == "": + batch_prompts.append(random.choice(prompts_add).replace("<add>", fn)) + elif fn == "": + batch_prompts.append(random.choice(prompts_rem).replace("<rem>", fp)) + else: + batch_prompts.append(random.choice(prompts_both).replace("<add>", fn).replace("<rem>", fp)) + + batch_instructions = [] + for pred_report, prompt in zip(batch["pred_report"], fixed_batch_prompts): + conv = create_conv() + conv.append_message(conv.roles[0], "Please write a radiology report for the given x-ray.") + conv.append_message(conv.roles[1], pred_report) + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], None) + batch_instructions.append(conv.get_prompt()) + + inputs = tokenizer.batch_encode_plus(batch_instructions, return_tensors="pt", padding=True) + input_ids = inputs["input_ids"].to(torch.device("cuda")) + + # generate answers with no-lora vicuna + generation_output = lang_model.generate( + input_ids=input_ids, + dicom=None, + return_dict_in_generate=True, + output_scores=True, + max_new_tokens=256 + ) + preds = tokenizer.batch_decode(generation_output.sequences, skip_special_tokens=True) + preds = [p.split("ASSISTANT:")[-1].strip() for idx, p in enumerate(preds)] + + # iterate over batch elements + for i in range(len(batch["pred_report"])): + gt_report = batch["gt_report"][i] # GT report + incorrect_report = batch["pred_report"][i] # predicted report that will be corrected + task_prompt = batch_prompts[i] + task_instruction = batch_instructions[i] + answer = preds[i] + dicom = batch["dicom"][i] + + if task_prompt == "NOCHANGE": + continue # we don't want to train for correction on already correct reports + # sample random prompt for every report + reports_json = { + "gt_report": gt_report, + "incorrect_report": incorrect_report, + "task": task_prompt, + "instruction": task_instruction, + "input": "", + "output": answer, + "dicom": dicom, + "task_type": 'CO' + } + report_jsons.append(reports_json) + + # save + with open(f"data/large_instruct_data/instruct_large_CO.json", "w") as f: + json.dump(report_jsons, f, ensure_ascii=False, indent=4) + + +def create_nle_task_data(): + MIMIC_DIAGNOSISLIST = ['Atelectasis', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Lung Lesion', 'Lung Opacity', 'Pleural Effusion', + 'Pleural Other', 'Pneumonia', 'Pneumothorax'] + # load mimic_nle json + mimic_nle = [] + with open(f'{PATH_TO_MIMIC_NLE}/mimic-nle/mimic-nle-train.json', 'r') as f: + for line in f: + obj = json.loads(line) + mimic_nle.append(obj) + + prompts = pd.read_csv(f"data/instruct_prompts/RE_prompts.csv")["instruction"].tolist() + report_jsons = [] + reports = pd.read_csv('mimic-cxr/reports_processed/mimic_cxr_sectioned.csv') + reports = reports.dropna(subset=['findings']) + reports['findings'] = reports['findings'].apply(lambda x: x.replace('\n', '')) + + for sample in tqdm(mimic_nle): + report_id = sample["report_ID"] + gt_report = reports[reports["Note_file"] == f"{report_id}.txt"]["findings"].tolist() + if len(gt_report) == 0: # report did have no findings section + continue + gt_report = gt_report[0] + + nle = sample['nle'] + if nle not in gt_report: # sort out samples that reference the impression instead of the findings section + continue + + dicom = reports[reports["Note_file"] == f"{report_id}.txt"]["dicom_id"].tolist()[0] + task_prompt = random.choice(prompts) + + diagnoses = [d for idx, d in enumerate(MIMIC_DIAGNOSISLIST) if sample["diagnosis_label"][idx] == 1] + diagnoses_string = ", ".join(diagnoses) + diagnoses_string = diagnoses_string.rsplit(', ', 1) + diagnoses_string = ' and '.join(diagnoses_string) + task_prompt = task_prompt.replace("<X>", diagnoses_string) + + # sample random prompt for every report + reports_json = { + "gt_report": gt_report, + "task": task_prompt, + "input": "", + "output": sample['nle'], + "dicom": dicom, + "task_type": 'RE' + } + report_jsons.append(reports_json) + + # save + print(len(report_jsons)) + with open(f"data/large_instruct_data/instruct_large_RE.json", "w") as f: + json.dump(report_jsons, f, ensure_ascii=False, indent=4)