Diff of /data/instruct_tasks.py [000000] .. [4abb48]

Switch to unified view

a b/data/instruct_tasks.py
1
import dataclasses
2
import json
3
import random
4
from enum import Enum, auto
5
from pathlib import Path
6
from typing import List, Any
7
8
import numpy as np
9
import pandas as pd
10
import torch
11
from torch.utils.data import DataLoader, Dataset
12
from tqdm import tqdm
13
14
from local_config import PATH_TO_MIMIC_NLE
15
16
17
class SeparatorStyle(Enum):
18
    """Different separator style."""
19
    SINGLE = auto()
20
    TWO = auto()
21
22
23
@dataclasses.dataclass
24
class Conversation:
25
    """A class that keeps all conversation history."""
26
    system: str
27
    roles: List[str]
28
    messages: List[List[str]]
29
    offset: int
30
    sep_style: SeparatorStyle = SeparatorStyle.SINGLE
31
    sep: str = "###"
32
    sep2: str = None
33
34
    # Used for gradio server
35
    skip_next: bool = False
36
    conv_id: Any = None
37
38
    def get_prompt(self):
39
        if self.sep_style == SeparatorStyle.SINGLE:
40
            ret = self.system
41
            for role, message in self.messages:
42
                if message:
43
                    ret += self.sep + " " + role + ": " + message
44
                else:
45
                    ret += self.sep + " " + role + ":"
46
            return ret
47
        elif self.sep_style == SeparatorStyle.TWO:
48
            seps = [self.sep, self.sep2]
49
            ret = self.system + seps[0]
50
            for i, (role, message) in enumerate(self.messages):
51
                if message:
52
                    ret += role + ": " + message + seps[i % 2]
53
                else:
54
                    ret += role + ":"
55
            return ret
56
        else:
57
            raise ValueError(f"Invalid style: {self.sep_style}")
58
59
    def append_message(self, role, message):
60
        self.messages.append([role, message])
61
62
    def dict(self):
63
        return {
64
            "system": self.system,
65
            "roles": self.roles,
66
            "messages": self.messages,
67
            "offset": self.offset,
68
            "sep": self.sep,
69
            "sep2": self.sep2,
70
            "conv_id": self.conv_id,
71
        }
72
73
74
def create_conv():
75
    conv = Conversation(
76
        system="A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. "
77
               "The assistant gives professional, detailed, and polite answers to the user's questions.",
78
        roles=["USER", "ASSISTANT"],
79
        messages=[],
80
        offset=0,
81
        sep_style=SeparatorStyle.TWO,
82
        sep=" ",
83
        sep2="</s>",
84
    )
85
    return conv
86
87
88
def create_direct_task_data(lang_model, tokenizer, val_dataset, task_name):
89
    prompts = pd.read_csv(f"data/instruct_prompts/{task_name}_prompts.csv")["instruction"].tolist()
90
    data_loader = DataLoader(val_dataset, batch_size=12, shuffle=False, num_workers=0)
91
    report_jsons = []
92
    print("Dataloader len: ", len(data_loader))
93
    for _, batch in tqdm(enumerate(data_loader)):
94
95
        # Create prompts for every report
96
        # sample batchsize questions from EL_prompts
97
        batch_prompts = random.choices(prompts, k=len(batch["text_input"]))
98
        batch_instructions = []
99
        for text_target, prompt in zip(batch["text_target"], batch_prompts):
100
            conv = create_conv()
101
            conv.append_message(conv.roles[0], "Report: " + text_target + "\n" + prompt)
102
            conv.append_message(conv.roles[1], None)
103
            batch_instructions.append(conv.get_prompt())
104
105
        inputs = tokenizer.batch_encode_plus(batch_instructions, return_tensors="pt", padding=True)
106
        input_ids = inputs["input_ids"].to(torch.device("cuda"))
107
108
        # generate answers with no-lora vicuna
109
        generation_output = lang_model.generate(
110
            input_ids=input_ids,
111
            dicom=None,
112
            return_dict_in_generate=True,
113
            output_scores=True,
114
            max_new_tokens=256
115
        )
116
        preds = tokenizer.batch_decode(generation_output.sequences, skip_special_tokens=True)
117
        preds = [p.split("ASSISTANT:")[1] for idx, p in enumerate(preds)]
118
119
        # iterate over batch elements
120
        for i in range(len(batch["text_input"])):
121
            text_target = batch["text_target"][i]  # GT report
122
            task_prompt = batch_prompts[i]
123
            task_instruction = batch_instructions[i]
124
            answer = preds[i]
125
            dicom = batch["dicom"][i]
126
127
            # sample random prompt for every report
128
            reports_json = {
129
                "gt_report": text_target,
130
                "task": task_prompt,
131
                "instruction": task_instruction,
132
                "input": "",
133
                "output": answer,
134
                "dicom": dicom,
135
                "task_type": task_name
136
            }
137
            report_jsons.append(reports_json)
138
139
    # save
140
    with open(f"data/large_instruct_data/instruct_large_{task_name}.json", "w") as f:
141
        json.dump(report_jsons, f, ensure_ascii=False, indent=4)
142
143
144
def create_cp_task_data(val_dataset, task_name):
145
    prompts = pd.read_csv(f"data/instruct_prompts/{task_name}_prompts.csv")["instruction"].tolist()
146
    data_loader = DataLoader(val_dataset, batch_size=200, shuffle=False, num_workers=200)
147
    report_jsons = []
148
    for _, batch in tqdm(enumerate(data_loader)):
149
150
        # Create prompts for every report
151
        # sample batchsize questions from EL_prompts
152
        batch_prompts = random.choices(prompts, k=len(batch["text_input"]))
153
154
        # iterate over batch elements
155
        for i in range(len(batch["text_input"])):
156
            text_target = batch["text_target"][i]  # GT report
157
            task_prompt = batch_prompts[i]
158
            cp_indices = np.where(batch["chexpert_labels"][i] == 1.)
159
            cp_findings = [val_dataset.dataset.dataset.chexpert_cols[i] for i in cp_indices[0]]
160
161
            if task_name == "CPbQA":  # binary QA
162
                if "No Finding" in cp_findings:
163
                    cp_findings.remove("No Finding")
164
                # 50% sample finding from cp_findings, 50% sample finding from val_dataset.dataset.dataset.chexpert_cols - cp_findings
165
                if random.random() < 0.6 and len(cp_findings) > 0:
166
                    finding = random.choice(cp_findings)  # answer: yes
167
                    answer = 'yes'
168
                else:
169
                    finding = random.choice(list(set(val_dataset.dataset.dataset.chexpert_cols[1:]) - set(cp_findings)))  # answer: no
170
                    answer = 'no'
171
                task_prompt = task_prompt.replace("<X>", finding)
172
173
            elif task_name == "CPaQA":  # give all findings
174
                answer = ', '.join(cp_findings)
175
176
            dicom = batch["dicom"][i]
177
178
            # sample random prompt for every report
179
            reports_json = {
180
                "gt_report": text_target,
181
                "task": task_prompt,
182
                "input": "",
183
                "output": answer,
184
                "dicom": dicom,
185
                "task_type": task_name
186
            }
187
            report_jsons.append(reports_json)
188
189
    # save
190
    with open(f"data/large_instruct_data/instruct_large_{task_name}.json", "w") as f:
191
        json.dump(report_jsons, f, ensure_ascii=False, indent=4)
192
193
194
class CorrectionDataset(Dataset):
195
    def __init__(self, data):
196
        self.data = data
197
198
    def __len__(self):
199
        return len(self.data)
200
201
    def __getitem__(self, idx):
202
        sample = self.data[idx]
203
204
        fp = sample["fp"]
205
        fn = sample["fn"]
206
        fp_str = ', '.join(fp)
207
        fp_str = fp_str.rsplit(', ', 1)
208
        fp_str = ' and '.join(fp_str)
209
        fn_str = ', '.join(fn)
210
        fn_str = fn_str.rsplit(', ', 1)
211
        fn_str = ' and '.join(fn_str)
212
213
        gt_report = sample["gt_report"]
214
        pred_report = sample["pred_report"]
215
        dicom = sample["dicom"]
216
        return {'gt_report': gt_report, 'pred_report': pred_report, 'fp': fp_str, 'fn': fn_str, 'dicom': dicom}
217
218
219
def create_correction_task_data(lang_model, tokenizer):
220
    # load correction json
221
    with open("data/instruct_prompts/instruct_task_correction_preds.json") as f:
222
        correction_preds = json.load(f)
223
224
    # create pytorch dataset from json
225
    correction_dataset = CorrectionDataset(correction_preds)
226
    data_loader = DataLoader(correction_dataset, batch_size=12, shuffle=False, num_workers=12)
227
228
    prompts_both = pd.read_csv(f"data/instruct_prompts/CO_both_prompts.csv")["instruction"].tolist()
229
    prompts_add = pd.read_csv(f"data/instruct_prompts/CO_add_prompts.csv")["instruction"].tolist()
230
    prompts_rem = pd.read_csv(f"data/instruct_prompts/CO_rem_prompts.csv")["instruction"].tolist()
231
    report_jsons = []
232
    for _, batch in tqdm(enumerate(data_loader)):
233
        # use very clear, fixed prompt for data generation -> in training use random prompts
234
235
        fixed_batch_prompts = []
236
        for fp, fn in zip(batch["fp"], batch["fn"]):
237
            fixed_corr_prompt = "Please provide an adapted report. "
238
            if fp != "":
239
                fixed_corr_prompt += f"Do not mention {fp}. "
240
            if fn != "":
241
                fixed_corr_prompt += f"Mention {fn}. "
242
243
            if fp == "" and fn == "":
244
                fixed_corr_prompt = "NOCHANGE"
245
            fixed_batch_prompts.append(fixed_corr_prompt.strip())
246
247
        batch_prompts = []
248
        for fp, fn in zip(batch["fp"], batch["fn"]):
249
            if fp == "" and fn == "":
250
                batch_prompts.append("NOCHANGE")
251
            elif fp == "":
252
                batch_prompts.append(random.choice(prompts_add).replace("<add>", fn))
253
            elif fn == "":
254
                batch_prompts.append(random.choice(prompts_rem).replace("<rem>", fp))
255
            else:
256
                batch_prompts.append(random.choice(prompts_both).replace("<add>", fn).replace("<rem>", fp))
257
258
        batch_instructions = []
259
        for pred_report, prompt in zip(batch["pred_report"], fixed_batch_prompts):
260
            conv = create_conv()
261
            conv.append_message(conv.roles[0], "Please write a radiology report for the given x-ray.")
262
            conv.append_message(conv.roles[1], pred_report)
263
            conv.append_message(conv.roles[0], prompt)
264
            conv.append_message(conv.roles[1], None)
265
            batch_instructions.append(conv.get_prompt())
266
267
        inputs = tokenizer.batch_encode_plus(batch_instructions, return_tensors="pt", padding=True)
268
        input_ids = inputs["input_ids"].to(torch.device("cuda"))
269
270
        # generate answers with no-lora vicuna
271
        generation_output = lang_model.generate(
272
            input_ids=input_ids,
273
            dicom=None,
274
            return_dict_in_generate=True,
275
            output_scores=True,
276
            max_new_tokens=256
277
        )
278
        preds = tokenizer.batch_decode(generation_output.sequences, skip_special_tokens=True)
279
        preds = [p.split("ASSISTANT:")[-1].strip() for idx, p in enumerate(preds)]
280
281
        # iterate over batch elements
282
        for i in range(len(batch["pred_report"])):
283
            gt_report = batch["gt_report"][i]  # GT report
284
            incorrect_report = batch["pred_report"][i]  # predicted report that will be corrected
285
            task_prompt = batch_prompts[i]
286
            task_instruction = batch_instructions[i]
287
            answer = preds[i]
288
            dicom = batch["dicom"][i]
289
290
            if task_prompt == "NOCHANGE":
291
                continue  # we don't want to train for correction on already correct reports
292
            # sample random prompt for every report
293
            reports_json = {
294
                "gt_report": gt_report,
295
                "incorrect_report": incorrect_report,
296
                "task": task_prompt,
297
                "instruction": task_instruction,
298
                "input": "",
299
                "output": answer,
300
                "dicom": dicom,
301
                "task_type": 'CO'
302
            }
303
            report_jsons.append(reports_json)
304
305
    # save
306
    with open(f"data/large_instruct_data/instruct_large_CO.json", "w") as f:
307
        json.dump(report_jsons, f, ensure_ascii=False, indent=4)
308
309
310
def create_nle_task_data():
311
    MIMIC_DIAGNOSISLIST = ['Atelectasis', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Lung Lesion', 'Lung Opacity', 'Pleural Effusion',
312
                           'Pleural Other', 'Pneumonia', 'Pneumothorax']
313
    # load mimic_nle json
314
    mimic_nle = []
315
    with open(f'{PATH_TO_MIMIC_NLE}/mimic-nle/mimic-nle-train.json', 'r') as f:
316
        for line in f:
317
            obj = json.loads(line)
318
            mimic_nle.append(obj)
319
320
    prompts = pd.read_csv(f"data/instruct_prompts/RE_prompts.csv")["instruction"].tolist()
321
    report_jsons = []
322
    reports = pd.read_csv('mimic-cxr/reports_processed/mimic_cxr_sectioned.csv')
323
    reports = reports.dropna(subset=['findings'])
324
    reports['findings'] = reports['findings'].apply(lambda x: x.replace('\n', ''))
325
326
    for sample in tqdm(mimic_nle):
327
        report_id = sample["report_ID"]
328
        gt_report = reports[reports["Note_file"] == f"{report_id}.txt"]["findings"].tolist()
329
        if len(gt_report) == 0:  # report did have no findings section
330
            continue
331
        gt_report = gt_report[0]
332
333
        nle = sample['nle']
334
        if nle not in gt_report:  # sort out samples that reference the impression instead of the findings section
335
            continue
336
337
        dicom = reports[reports["Note_file"] == f"{report_id}.txt"]["dicom_id"].tolist()[0]
338
        task_prompt = random.choice(prompts)
339
340
        diagnoses = [d for idx, d in enumerate(MIMIC_DIAGNOSISLIST) if sample["diagnosis_label"][idx] == 1]
341
        diagnoses_string = ", ".join(diagnoses)
342
        diagnoses_string = diagnoses_string.rsplit(', ', 1)
343
        diagnoses_string = ' and '.join(diagnoses_string)
344
        task_prompt = task_prompt.replace("<X>", diagnoses_string)
345
346
        # sample random prompt for every report
347
        reports_json = {
348
            "gt_report": gt_report,
349
            "task": task_prompt,
350
            "input": "",
351
            "output": sample['nle'],
352
            "dicom": dicom,
353
            "task_type": 'RE'
354
        }
355
        report_jsons.append(reports_json)
356
357
    # save
358
    print(len(report_jsons))
359
    with open(f"data/large_instruct_data/instruct_large_RE.json", "w") as f:
360
        json.dump(report_jsons, f, ensure_ascii=False, indent=4)