Diff of /data/create_data.py [000000] .. [4abb48]

Switch to unified view

a b/data/create_data.py
1
import argparse
2
import dataclasses
3
import json
4
import os
5
from enum import auto, Enum
6
from pathlib import Path
7
from typing import List, Any
8
import random
9
10
import numpy as np
11
import pandas as pd
12
import torch
13
from omegaconf import OmegaConf
14
from torch.utils.data import Dataset, DataLoader
15
from tqdm import tqdm
16
from transformers import AutoTokenizer
17
from torch.utils.data.sampler import Sampler
18
19
from data.instruct_tasks import create_direct_task_data, create_cp_task_data, create_correction_task_data, create_nle_task_data
20
from local_config import VIS_ROOT, PATH_TO_MIMIC_CXR
21
from model.lavis.models.blip2_models.modeling_llama_imgemb import LlamaForCausalLM
22
23
24
class MyReportProcessor():
25
    def __init__(self, prompt="", max_words=50, prompt_neg=""):
26
        self.prompt = prompt
27
        self.max_words = max_words
28
        self.prompt_neg = prompt_neg
29
30
    def __call__(self, findings, no_labels=False):
31
        prompt = self.prompt
32
33
        if no_labels:
34
            findings = "no common findings"  # cannot write which findings as we don't no them
35
        prompt = prompt.format(findings=findings)
36
37
        return prompt
38
39
    @classmethod
40
    def from_config(cls, cfg=None):
41
        if cfg is None:
42
            cfg = OmegaConf.create()
43
44
        prompt = cfg.get("prompt", "")
45
        max_words = cfg.get("max_words", 50)
46
47
        return cls(prompt=prompt, max_words=max_words)
48
49
50
class SeparatorStyle(Enum):
51
    """Different separator style."""
52
    SINGLE = auto()
53
    TWO = auto()
54
55
56
@dataclasses.dataclass
57
class Conversation:
58
    """A class that keeps all conversation history."""
59
    system: str
60
    roles: List[str]
61
    messages: List[List[str]]
62
    offset: int
63
    sep_style: SeparatorStyle = SeparatorStyle.SINGLE
64
    sep: str = "###"
65
    sep2: str = None
66
67
    # Used for gradio server
68
    skip_next: bool = False
69
    conv_id: Any = None
70
71
    def get_prompt(self):
72
        if self.sep_style == SeparatorStyle.SINGLE:
73
            ret = self.system
74
            for role, message in self.messages:
75
                if message:
76
                    ret += self.sep + " " + role + ": " + message
77
                else:
78
                    ret += self.sep + " " + role + ":"
79
            return ret
80
        elif self.sep_style == SeparatorStyle.TWO:
81
            seps = [self.sep, self.sep2]
82
            ret = self.system + seps[0]
83
            for i, (role, message) in enumerate(self.messages):
84
                if message:
85
                    ret += role + ": " + message + seps[i % 2]
86
                else:
87
                    ret += role + ":"
88
            return ret
89
        else:
90
            raise ValueError(f"Invalid style: {self.sep_style}")
91
92
    def append_message(self, role, message):
93
        self.messages.append([role, message])
94
95
    def dict(self):
96
        return {
97
            "system": self.system,
98
            "roles": self.roles,
99
            "messages": self.messages,
100
            "offset": self.offset,
101
            "sep": self.sep,
102
            "sep2": self.sep2,
103
            "conv_id": self.conv_id,
104
        }
105
106
107
def create_conv():
108
    conv = Conversation(
109
        system="A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. "
110
               "The assistant gives professional, detailed, and polite answers to the user's questions.",
111
        roles=["USER", "ASSISTANT"],
112
        messages=[],
113
        offset=0,
114
        sep_style=SeparatorStyle.TWO,
115
        sep=" ",
116
        sep2="</s>",
117
    )
118
    return conv
119
120
121
class MIMIC_Text_Dataset(Dataset):
122
    def __init__(self, split, truncate=None, prompt_type="basic", use_indication=False):
123
        super().__init__()
124
125
        # load csv file
126
        self.split = pd.read_csv(
127
            f'{PATH_TO_MIMIC_CXR}/mimic-cxr-jpg/2.0.0/mimic-cxr-2.0.0-split.csv')
128
        self.reports = pd.read_csv('mimic-cxr/reports_processed/mimic_cxr_sectioned.csv')
129
        # drop reports where findings are nan
130
        self.reports = self.reports.dropna(subset=['findings'])
131
132
        self.img_ids = {img_id: i for i, img_id in enumerate(self.reports['dicom_id'])}
133
        self.chexpert = pd.read_csv(f'data/data_files/finding_chexbert_labels.csv')
134
        self.chexpert_cols = ["No Finding", "Enlarged Cardiomediastinum",
135
                              "Cardiomegaly", "Lung Opacity",
136
                              "Lung Lesion", "Edema",
137
                              "Consolidation", "Pneumonia",
138
                              "Atelectasis", "Pneumothorax",
139
                              "Pleural Effusion", "Pleural Other",
140
                              "Fracture", "Support Devices"]
141
142
        self.use_indication = use_indication
143
144
        self.vis_root = VIS_ROOT
145
146
        self.prompt_type = prompt_type
147
148
        self.split_ids = set(self.split.loc[self.split['split'] == split]['dicom_id'])
149
        self.train_ids = set(self.split.loc[self.split['split'] == 'train']['dicom_id'])
150
151
        # get all dicom_ids where "split" is split
152
        self.annotation = self.reports.loc[self.reports['dicom_id'].isin(self.split_ids)]
153
        if truncate is not None:
154
            self.annotation = self.annotation[:truncate]
155
156
        self.annotation['findings'] = self.annotation['findings'].apply(lambda x: x.replace('\n', ''))
157
158
        # Extract patient_id from Img_Folder (3rd part) and study_id is the name of the notefile without the pre-pending 's'
159
        self.annotation['subject_id'] = self.annotation['Img_Folder'].apply(lambda x: int(x.split('/')[2].lstrip('p')))
160
        self.annotation['study_id'] = self.annotation['Note_file'].apply(lambda x: int(x.lstrip('s').rstrip('.txt')))
161
162
        # Merge chexpert labels with annotation dataframe
163
        self.annotation = pd.merge(self.annotation, self.chexpert, how='left', left_on=['dicom_id'],
164
                                   right_on=['dicom_id'])
165
166
        # for every row add a string of comma-separated positive labels
167
        self.annotation['positive_labels'] = self.annotation.apply(lambda x: self.convert_to_finding_labels(x[self.chexpert_cols].values,
168
                                                                                                            self.chexpert_cols), axis=1)
169
170
        # maybe use transforms from here: ResNet50_Weights.IMAGENET1K_V2.transforms
171
        # read prompt from json
172
        prompts = json.loads(Path(f"vicuna_prompts.json").read_text(encoding="UTF-8"))
173
        self.text_processor = MyReportProcessor(
174
            prompt=prompts[prompt_type], max_words=1000,
175
            prompt_neg=prompts[prompt_type.replace("matching_examples", "neg_matching_examples")])
176
177
    def convert_to_finding_labels(self, chexpert_labels, columns, label=1):
178
        # Get indices where value is 1
179
        indices = np.where(chexpert_labels == label)
180
        # Get the corresponding column names and join them into a string
181
        labels = ", ".join([columns[i] for i in indices[0]])
182
        return labels
183
184
    def __getitem__(self, index):
185
        ann = self.annotation.iloc[index]
186
        # if self.use_indication:
187
        #     indication = self.indications[study_id]
188
        #     if indication == "":
189
        #         indication = "Indication not given."
190
        caption = ann["findings"].strip()
191
        chexpert_labels = ann[self.chexpert_cols].astype(float).values
192
        chexpert_label_str = ann["positive_labels"]
193
        dicom_id = ann["dicom_id"]
194
195
        # check if all columns are in (nan, 0) -> no labels
196
        no_labels = np.all((np.isnan(chexpert_labels)) | (chexpert_labels == 0) | (chexpert_labels == -1.))
197
        finding_string = chexpert_label_str.lower().strip()
198
199
        input_text = self.text_processor(findings=finding_string, no_labels=no_labels)
200
201
        # if self.use_indication:
202
        #     input_text = "Indication: " + indication + " " + input_text
203
204
        # template for vicuna v1.3
205
        conv = Conversation(
206
            system="A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. "
207
                   "The assistant gives professional, detailed, and polite answers to the user's questions.",
208
            roles=["USER", "ASSISTANT"],
209
            messages=[],
210
            offset=0,
211
            sep_style=SeparatorStyle.TWO,
212
            sep=" ",
213
            sep2="</s>",
214
        )
215
        conv.append_message(conv.roles[0], input_text)
216
        conv.append_message(conv.roles[1], None)
217
        prompt = conv.get_prompt()
218
219
        return {
220
            "text_input": prompt,
221
            "text_target": caption,
222
            "ig_label_string": finding_string,
223
            "chexpert_labels": chexpert_labels,
224
            "chexpert_cols": self.chexpert_cols,
225
            "dicom": dicom_id,
226
            "img_path": ann["Img_Folder"] + "/" + ann["Img_Filename"],
227
        }
228
229
    def __len__(self):
230
        return len(self.annotation)
231
232
233
class SubsetSampler(Sampler):
234
    def __init__(self, indices):
235
        self.indices = indices
236
237
    def __iter__(self):
238
        return (self.indices[i] for i in range(len(self.indices)))
239
240
    def __len__(self):
241
        return len(self.indices)
242
243
244
def stratified_sample(df, simulated_epochs=1):
245
    # We want to reduce the number of examples with no finding to 1/14th of the dataset. We achieve this easily by first seperating the dataset into 2 groups: no finding and finding.
246
    # either no finding, or nothing is considered a no finding
247
    no_findings_indices = df.annotation[((df.annotation['No Finding'] == 1) | ((df.annotation[df.chexpert_cols] == 1).sum(1) == 0) == 1)].index
248
    finding_indices = df.annotation.index.difference(no_findings_indices)
249
    no_findings_indices = no_findings_indices.tolist()
250
    finding_indices = finding_indices.tolist()
251
252
    # we are striving to lose as little no_finding data as possible. So instead of just reducing the number of no_finding examples, we will increase the number of finding examples. Just clone and extend dataset
253
    finding_indices = finding_indices * simulated_epochs
254
    # subsample the no finding examples to be 1/14th of the new dataset
255
    new_dataset_size = len(finding_indices) * 14 / 13
256
    new_no_finding_count = int(new_dataset_size / 14)
257
    # merge considering the new dataset size
258
    all_indices = finding_indices + random.sample(no_findings_indices, new_no_finding_count)
259
    return all_indices
260
261
262
def create_report_data_vicuna_specific_stratified(prompt_type):
263
    val_dataset = MIMIC_Text_Dataset(split="train", truncate=None, prompt_type=prompt_type)
264
    stratified_indices = stratified_sample(val_dataset, simulated_epochs=2)
265
    sampler = SubsetSampler(stratified_indices)
266
    data_loader = DataLoader(val_dataset, batch_size=200, num_workers=200, sampler=sampler)
267
268
    report_jsons = []
269
    for _, batch in tqdm(enumerate(data_loader)):
270
        # iterate over batch elements
271
        for i in range(len(batch["text_input"])):
272
            text_input = batch["text_input"][i]
273
            text_target = batch["text_target"][i]
274
            dicom = batch["dicom"][i]
275
276
            # sample random prompt for every report
277
            reports_json = {
278
                "instruction": text_input,
279
                "input": "",
280
                "output": text_target,
281
                "dicom": dicom,
282
            }
283
            report_jsons.append(reports_json)
284
285
    # Save the JSON data to a file
286
    with open("data/data_files/mimic_cxr_reports_stratified.json", "w") as f:
287
        json.dump(report_jsons, f, ensure_ascii=False, indent=4)
288
289
290
'''
291
this method saves instruct data jsons for all the different tasks we defined:
292
- easy language: EL DONE
293
- correction: CO DONE
294
- summerization: SU DONE
295
- reasoning: RE (based on MIMIC-NLE) DONE
296
- region QA: RQA DONE
297
- CP binary QA: CPbQA DONE
298
- CP all QA: CPaQA DONE
299
300
for every report we sample one task and one prompt and save the report, the question (task) and the answer generated by vicuna (or from dataset groundtruth)
301
'''
302
303
304
def create_report_data_vicuna_instruct_large():
305
    lang_model = LlamaForCausalLM.from_pretrained("lmsys/vicuna-13b-v1.3", torch_dtype=torch.float16, device_map='auto', load_in_8bit=False)
306
    tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-13b-v1.3", use_fast=False, truncation_side="left", padding_side="left")
307
    tokenizer.pad_token = tokenizer.unk_token
308
309
    val_dataset = MIMIC_Text_Dataset(split="train", truncate=None, prompt_type="img_matching_examples_ig2_noexamples")
310
    # split in 6 portions of 1/6th each, randomly
311
    split_size = len(val_dataset) // 6
312
    remainder = len(val_dataset) % 6
313
314
    val_dataset_EL, _, val_dataset_SU, val_dataset_EX, val_dataset_RQA, val_dataset_CPQA = torch.utils.data.random_split(val_dataset,
315
                                                                                                                         [split_size + (i < remainder)
316
                                                                                                                          for i in range(
317
                                                                                                                             6)])  # correction is samples somewhere else
318
319
    # split val_dataset_CPQA in 2
320
    split_size = len(val_dataset_CPQA) // 2
321
    remainder = len(val_dataset_CPQA) % 2
322
    val_dataset_CPbQA, val_dataset_CPaQA = torch.utils.data.random_split(val_dataset_CPQA, [split_size + (i < remainder) for i in range(2)])
323
324
    # create directory
325
    if not os.path.exists("data/large_instruct_data"):
326
        os.makedirs("data/large_instruct_data")
327
328
    # create data
329
    create_direct_task_data(lang_model, tokenizer, val_dataset_EL, task_name="EL")
330
    create_direct_task_data(lang_model, tokenizer, val_dataset_SU, task_name="SU")
331
    create_direct_task_data(lang_model, tokenizer, val_dataset_RQA, task_name="RQA")
332
    create_cp_task_data(val_dataset_CPbQA, task_name="CPbQA")
333
    create_cp_task_data(val_dataset_CPaQA, task_name="CPaQA")
334
335
    create_correction_task_data(lang_model, tokenizer)
336
    create_nle_task_data()
337
338
339
'''
340
fuse instruct data with report generation task into one dataset json
341
'''
342
343
344
def fuse_instruct_dataset(prompt_type="img_matching_examples_ig2_noexamples_IMG_findings"):
345
    # get report generation data
346
    val_dataset = MIMIC_Text_Dataset(split="train", truncate=None, prompt_type=prompt_type)
347
    stratified_indices = stratified_sample(val_dataset, simulated_epochs=2)
348
    sampler = SubsetSampler(stratified_indices)
349
    data_loader = DataLoader(val_dataset, batch_size=200, sampler=sampler, num_workers=200)
350
    report_jsons = []
351
    for _, batch in tqdm(enumerate(data_loader)):
352
        # iterate over batch elements
353
        for i in range(len(batch["text_input"])):
354
            text_input = batch["text_input"][i]
355
            text_target = batch["text_target"][i]
356
            dicom = batch["dicom"][i]
357
358
            # sample random prompt for every report
359
            reports_json = {
360
                "instruction": text_input,
361
                "input": "",
362
                "output": text_target,
363
                "dicom": dicom,
364
            }
365
            report_jsons.append(reports_json)
366
367
    task_jsons = []
368
    with open(f"vicuna_prompts.json", "r") as f:
369
        prompts = json.load(f)
370
    report_prompt = prompts[prompt_type]
371
372
    # get instruct data
373
    for task in ["EL", "RE", "CO", "SU", "RQA", "CPbQA", "CPaQA"]:
374
        print("Creating data for " + task)
375
        with open(f"data/large_instruct_data/instruct_large_{task}.json", "r") as f:
376
            task_data = json.load(f)
377
378
        for elem in tqdm(task_data):
379
            report = elem["gt_report"] if task != "CO" else elem["incorrect_report"]
380
381
            conv = create_conv()
382
            conv.append_message(conv.roles[0], report_prompt)
383
            conv.append_message(conv.roles[1], report)
384
            conv.append_message(conv.roles[0], elem["task"])
385
            conv.append_message(conv.roles[1], None)
386
387
            instruction = conv.get_prompt()
388
389
            # get elem directly from val_dataset.train_annotation with same dicom
390
            orig_elem = val_dataset.annotation[val_dataset.annotation["dicom_id"] == elem["dicom"]].iloc[0]
391
392
            if type(orig_elem['positive_labels']) == float and np.isnan(orig_elem['positive_labels']):
393
                finding_str = "no common findings"
394
            else:
395
                finding_str = orig_elem['positive_labels'].lower().strip()
396
            instruction = instruction.format(findings=finding_str)
397
398
            task_json = {
399
                "instruction": instruction,
400
                "input": "",
401
                "output": elem["output"].lower().strip() if task == "CPaQA" else elem["output"].strip(),
402
                "dicom": elem["dicom"],
403
            }
404
            task_jsons.append(task_json)
405
406
    # combine and shuffle report and task jsons
407
    combined_jsons = report_jsons + task_jsons
408
    random.shuffle(combined_jsons)
409
410
    # save to json
411
    with open(f"data/data_files/mimic_cxr_instruct_stratified.json", "w") as f:
412
        json.dump(combined_jsons, f, indent=4)
413
414
415
if __name__ == '__main__':
416
    # args parser
417
    parser = argparse.ArgumentParser()
418
    parser.add_argument('--mode', type=str, default='RG', help='RG or INS')
419
    args = parser.parse_args()
420
421
    ''' Create data to train RaDialog-RG model'''
422
    if args.mode == 'RG':
423
        create_report_data_vicuna_specific_stratified(prompt_type="img_matching_examples_ig2_noexamples_IMG_findings")
424
425
    ''' Create data to train RaDialog-INS model'''
426
    if args.mode == 'INS':
427
        create_report_data_vicuna_instruct_large()
428
        fuse_instruct_dataset()
429
430
    # This code is meant for understanding how our instruct dataset is created.
431
    # Due to randomness in the sampling and model predictions, a newly generated dataset could be slightly different.
432
    # To exactly reproduce our results, please use the instruct dataset we published and use 'fuse_instruct_dataset' to merge with your MIMIC data.