|
a |
|
b/data/instruct_tasks.py |
|
|
1 |
import dataclasses |
|
|
2 |
import json |
|
|
3 |
import random |
|
|
4 |
from enum import Enum, auto |
|
|
5 |
from pathlib import Path |
|
|
6 |
from typing import List, Any |
|
|
7 |
|
|
|
8 |
import numpy as np |
|
|
9 |
import pandas as pd |
|
|
10 |
import torch |
|
|
11 |
from torch.utils.data import DataLoader, Dataset |
|
|
12 |
from tqdm import tqdm |
|
|
13 |
|
|
|
14 |
from local_config import PATH_TO_MIMIC_NLE |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
class SeparatorStyle(Enum): |
|
|
18 |
"""Different separator style.""" |
|
|
19 |
SINGLE = auto() |
|
|
20 |
TWO = auto() |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
@dataclasses.dataclass |
|
|
24 |
class Conversation: |
|
|
25 |
"""A class that keeps all conversation history.""" |
|
|
26 |
system: str |
|
|
27 |
roles: List[str] |
|
|
28 |
messages: List[List[str]] |
|
|
29 |
offset: int |
|
|
30 |
sep_style: SeparatorStyle = SeparatorStyle.SINGLE |
|
|
31 |
sep: str = "###" |
|
|
32 |
sep2: str = None |
|
|
33 |
|
|
|
34 |
# Used for gradio server |
|
|
35 |
skip_next: bool = False |
|
|
36 |
conv_id: Any = None |
|
|
37 |
|
|
|
38 |
def get_prompt(self): |
|
|
39 |
if self.sep_style == SeparatorStyle.SINGLE: |
|
|
40 |
ret = self.system |
|
|
41 |
for role, message in self.messages: |
|
|
42 |
if message: |
|
|
43 |
ret += self.sep + " " + role + ": " + message |
|
|
44 |
else: |
|
|
45 |
ret += self.sep + " " + role + ":" |
|
|
46 |
return ret |
|
|
47 |
elif self.sep_style == SeparatorStyle.TWO: |
|
|
48 |
seps = [self.sep, self.sep2] |
|
|
49 |
ret = self.system + seps[0] |
|
|
50 |
for i, (role, message) in enumerate(self.messages): |
|
|
51 |
if message: |
|
|
52 |
ret += role + ": " + message + seps[i % 2] |
|
|
53 |
else: |
|
|
54 |
ret += role + ":" |
|
|
55 |
return ret |
|
|
56 |
else: |
|
|
57 |
raise ValueError(f"Invalid style: {self.sep_style}") |
|
|
58 |
|
|
|
59 |
def append_message(self, role, message): |
|
|
60 |
self.messages.append([role, message]) |
|
|
61 |
|
|
|
62 |
def dict(self): |
|
|
63 |
return { |
|
|
64 |
"system": self.system, |
|
|
65 |
"roles": self.roles, |
|
|
66 |
"messages": self.messages, |
|
|
67 |
"offset": self.offset, |
|
|
68 |
"sep": self.sep, |
|
|
69 |
"sep2": self.sep2, |
|
|
70 |
"conv_id": self.conv_id, |
|
|
71 |
} |
|
|
72 |
|
|
|
73 |
|
|
|
74 |
def create_conv(): |
|
|
75 |
conv = Conversation( |
|
|
76 |
system="A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. " |
|
|
77 |
"The assistant gives professional, detailed, and polite answers to the user's questions.", |
|
|
78 |
roles=["USER", "ASSISTANT"], |
|
|
79 |
messages=[], |
|
|
80 |
offset=0, |
|
|
81 |
sep_style=SeparatorStyle.TWO, |
|
|
82 |
sep=" ", |
|
|
83 |
sep2="</s>", |
|
|
84 |
) |
|
|
85 |
return conv |
|
|
86 |
|
|
|
87 |
|
|
|
88 |
def create_direct_task_data(lang_model, tokenizer, val_dataset, task_name): |
|
|
89 |
prompts = pd.read_csv(f"data/instruct_prompts/{task_name}_prompts.csv")["instruction"].tolist() |
|
|
90 |
data_loader = DataLoader(val_dataset, batch_size=12, shuffle=False, num_workers=0) |
|
|
91 |
report_jsons = [] |
|
|
92 |
print("Dataloader len: ", len(data_loader)) |
|
|
93 |
for _, batch in tqdm(enumerate(data_loader)): |
|
|
94 |
|
|
|
95 |
# Create prompts for every report |
|
|
96 |
# sample batchsize questions from EL_prompts |
|
|
97 |
batch_prompts = random.choices(prompts, k=len(batch["text_input"])) |
|
|
98 |
batch_instructions = [] |
|
|
99 |
for text_target, prompt in zip(batch["text_target"], batch_prompts): |
|
|
100 |
conv = create_conv() |
|
|
101 |
conv.append_message(conv.roles[0], "Report: " + text_target + "\n" + prompt) |
|
|
102 |
conv.append_message(conv.roles[1], None) |
|
|
103 |
batch_instructions.append(conv.get_prompt()) |
|
|
104 |
|
|
|
105 |
inputs = tokenizer.batch_encode_plus(batch_instructions, return_tensors="pt", padding=True) |
|
|
106 |
input_ids = inputs["input_ids"].to(torch.device("cuda")) |
|
|
107 |
|
|
|
108 |
# generate answers with no-lora vicuna |
|
|
109 |
generation_output = lang_model.generate( |
|
|
110 |
input_ids=input_ids, |
|
|
111 |
dicom=None, |
|
|
112 |
return_dict_in_generate=True, |
|
|
113 |
output_scores=True, |
|
|
114 |
max_new_tokens=256 |
|
|
115 |
) |
|
|
116 |
preds = tokenizer.batch_decode(generation_output.sequences, skip_special_tokens=True) |
|
|
117 |
preds = [p.split("ASSISTANT:")[1] for idx, p in enumerate(preds)] |
|
|
118 |
|
|
|
119 |
# iterate over batch elements |
|
|
120 |
for i in range(len(batch["text_input"])): |
|
|
121 |
text_target = batch["text_target"][i] # GT report |
|
|
122 |
task_prompt = batch_prompts[i] |
|
|
123 |
task_instruction = batch_instructions[i] |
|
|
124 |
answer = preds[i] |
|
|
125 |
dicom = batch["dicom"][i] |
|
|
126 |
|
|
|
127 |
# sample random prompt for every report |
|
|
128 |
reports_json = { |
|
|
129 |
"gt_report": text_target, |
|
|
130 |
"task": task_prompt, |
|
|
131 |
"instruction": task_instruction, |
|
|
132 |
"input": "", |
|
|
133 |
"output": answer, |
|
|
134 |
"dicom": dicom, |
|
|
135 |
"task_type": task_name |
|
|
136 |
} |
|
|
137 |
report_jsons.append(reports_json) |
|
|
138 |
|
|
|
139 |
# save |
|
|
140 |
with open(f"data/large_instruct_data/instruct_large_{task_name}.json", "w") as f: |
|
|
141 |
json.dump(report_jsons, f, ensure_ascii=False, indent=4) |
|
|
142 |
|
|
|
143 |
|
|
|
144 |
def create_cp_task_data(val_dataset, task_name): |
|
|
145 |
prompts = pd.read_csv(f"data/instruct_prompts/{task_name}_prompts.csv")["instruction"].tolist() |
|
|
146 |
data_loader = DataLoader(val_dataset, batch_size=200, shuffle=False, num_workers=200) |
|
|
147 |
report_jsons = [] |
|
|
148 |
for _, batch in tqdm(enumerate(data_loader)): |
|
|
149 |
|
|
|
150 |
# Create prompts for every report |
|
|
151 |
# sample batchsize questions from EL_prompts |
|
|
152 |
batch_prompts = random.choices(prompts, k=len(batch["text_input"])) |
|
|
153 |
|
|
|
154 |
# iterate over batch elements |
|
|
155 |
for i in range(len(batch["text_input"])): |
|
|
156 |
text_target = batch["text_target"][i] # GT report |
|
|
157 |
task_prompt = batch_prompts[i] |
|
|
158 |
cp_indices = np.where(batch["chexpert_labels"][i] == 1.) |
|
|
159 |
cp_findings = [val_dataset.dataset.dataset.chexpert_cols[i] for i in cp_indices[0]] |
|
|
160 |
|
|
|
161 |
if task_name == "CPbQA": # binary QA |
|
|
162 |
if "No Finding" in cp_findings: |
|
|
163 |
cp_findings.remove("No Finding") |
|
|
164 |
# 50% sample finding from cp_findings, 50% sample finding from val_dataset.dataset.dataset.chexpert_cols - cp_findings |
|
|
165 |
if random.random() < 0.6 and len(cp_findings) > 0: |
|
|
166 |
finding = random.choice(cp_findings) # answer: yes |
|
|
167 |
answer = 'yes' |
|
|
168 |
else: |
|
|
169 |
finding = random.choice(list(set(val_dataset.dataset.dataset.chexpert_cols[1:]) - set(cp_findings))) # answer: no |
|
|
170 |
answer = 'no' |
|
|
171 |
task_prompt = task_prompt.replace("<X>", finding) |
|
|
172 |
|
|
|
173 |
elif task_name == "CPaQA": # give all findings |
|
|
174 |
answer = ', '.join(cp_findings) |
|
|
175 |
|
|
|
176 |
dicom = batch["dicom"][i] |
|
|
177 |
|
|
|
178 |
# sample random prompt for every report |
|
|
179 |
reports_json = { |
|
|
180 |
"gt_report": text_target, |
|
|
181 |
"task": task_prompt, |
|
|
182 |
"input": "", |
|
|
183 |
"output": answer, |
|
|
184 |
"dicom": dicom, |
|
|
185 |
"task_type": task_name |
|
|
186 |
} |
|
|
187 |
report_jsons.append(reports_json) |
|
|
188 |
|
|
|
189 |
# save |
|
|
190 |
with open(f"data/large_instruct_data/instruct_large_{task_name}.json", "w") as f: |
|
|
191 |
json.dump(report_jsons, f, ensure_ascii=False, indent=4) |
|
|
192 |
|
|
|
193 |
|
|
|
194 |
class CorrectionDataset(Dataset): |
|
|
195 |
def __init__(self, data): |
|
|
196 |
self.data = data |
|
|
197 |
|
|
|
198 |
def __len__(self): |
|
|
199 |
return len(self.data) |
|
|
200 |
|
|
|
201 |
def __getitem__(self, idx): |
|
|
202 |
sample = self.data[idx] |
|
|
203 |
|
|
|
204 |
fp = sample["fp"] |
|
|
205 |
fn = sample["fn"] |
|
|
206 |
fp_str = ', '.join(fp) |
|
|
207 |
fp_str = fp_str.rsplit(', ', 1) |
|
|
208 |
fp_str = ' and '.join(fp_str) |
|
|
209 |
fn_str = ', '.join(fn) |
|
|
210 |
fn_str = fn_str.rsplit(', ', 1) |
|
|
211 |
fn_str = ' and '.join(fn_str) |
|
|
212 |
|
|
|
213 |
gt_report = sample["gt_report"] |
|
|
214 |
pred_report = sample["pred_report"] |
|
|
215 |
dicom = sample["dicom"] |
|
|
216 |
return {'gt_report': gt_report, 'pred_report': pred_report, 'fp': fp_str, 'fn': fn_str, 'dicom': dicom} |
|
|
217 |
|
|
|
218 |
|
|
|
219 |
def create_correction_task_data(lang_model, tokenizer): |
|
|
220 |
# load correction json |
|
|
221 |
with open("data/instruct_prompts/instruct_task_correction_preds.json") as f: |
|
|
222 |
correction_preds = json.load(f) |
|
|
223 |
|
|
|
224 |
# create pytorch dataset from json |
|
|
225 |
correction_dataset = CorrectionDataset(correction_preds) |
|
|
226 |
data_loader = DataLoader(correction_dataset, batch_size=12, shuffle=False, num_workers=12) |
|
|
227 |
|
|
|
228 |
prompts_both = pd.read_csv(f"data/instruct_prompts/CO_both_prompts.csv")["instruction"].tolist() |
|
|
229 |
prompts_add = pd.read_csv(f"data/instruct_prompts/CO_add_prompts.csv")["instruction"].tolist() |
|
|
230 |
prompts_rem = pd.read_csv(f"data/instruct_prompts/CO_rem_prompts.csv")["instruction"].tolist() |
|
|
231 |
report_jsons = [] |
|
|
232 |
for _, batch in tqdm(enumerate(data_loader)): |
|
|
233 |
# use very clear, fixed prompt for data generation -> in training use random prompts |
|
|
234 |
|
|
|
235 |
fixed_batch_prompts = [] |
|
|
236 |
for fp, fn in zip(batch["fp"], batch["fn"]): |
|
|
237 |
fixed_corr_prompt = "Please provide an adapted report. " |
|
|
238 |
if fp != "": |
|
|
239 |
fixed_corr_prompt += f"Do not mention {fp}. " |
|
|
240 |
if fn != "": |
|
|
241 |
fixed_corr_prompt += f"Mention {fn}. " |
|
|
242 |
|
|
|
243 |
if fp == "" and fn == "": |
|
|
244 |
fixed_corr_prompt = "NOCHANGE" |
|
|
245 |
fixed_batch_prompts.append(fixed_corr_prompt.strip()) |
|
|
246 |
|
|
|
247 |
batch_prompts = [] |
|
|
248 |
for fp, fn in zip(batch["fp"], batch["fn"]): |
|
|
249 |
if fp == "" and fn == "": |
|
|
250 |
batch_prompts.append("NOCHANGE") |
|
|
251 |
elif fp == "": |
|
|
252 |
batch_prompts.append(random.choice(prompts_add).replace("<add>", fn)) |
|
|
253 |
elif fn == "": |
|
|
254 |
batch_prompts.append(random.choice(prompts_rem).replace("<rem>", fp)) |
|
|
255 |
else: |
|
|
256 |
batch_prompts.append(random.choice(prompts_both).replace("<add>", fn).replace("<rem>", fp)) |
|
|
257 |
|
|
|
258 |
batch_instructions = [] |
|
|
259 |
for pred_report, prompt in zip(batch["pred_report"], fixed_batch_prompts): |
|
|
260 |
conv = create_conv() |
|
|
261 |
conv.append_message(conv.roles[0], "Please write a radiology report for the given x-ray.") |
|
|
262 |
conv.append_message(conv.roles[1], pred_report) |
|
|
263 |
conv.append_message(conv.roles[0], prompt) |
|
|
264 |
conv.append_message(conv.roles[1], None) |
|
|
265 |
batch_instructions.append(conv.get_prompt()) |
|
|
266 |
|
|
|
267 |
inputs = tokenizer.batch_encode_plus(batch_instructions, return_tensors="pt", padding=True) |
|
|
268 |
input_ids = inputs["input_ids"].to(torch.device("cuda")) |
|
|
269 |
|
|
|
270 |
# generate answers with no-lora vicuna |
|
|
271 |
generation_output = lang_model.generate( |
|
|
272 |
input_ids=input_ids, |
|
|
273 |
dicom=None, |
|
|
274 |
return_dict_in_generate=True, |
|
|
275 |
output_scores=True, |
|
|
276 |
max_new_tokens=256 |
|
|
277 |
) |
|
|
278 |
preds = tokenizer.batch_decode(generation_output.sequences, skip_special_tokens=True) |
|
|
279 |
preds = [p.split("ASSISTANT:")[-1].strip() for idx, p in enumerate(preds)] |
|
|
280 |
|
|
|
281 |
# iterate over batch elements |
|
|
282 |
for i in range(len(batch["pred_report"])): |
|
|
283 |
gt_report = batch["gt_report"][i] # GT report |
|
|
284 |
incorrect_report = batch["pred_report"][i] # predicted report that will be corrected |
|
|
285 |
task_prompt = batch_prompts[i] |
|
|
286 |
task_instruction = batch_instructions[i] |
|
|
287 |
answer = preds[i] |
|
|
288 |
dicom = batch["dicom"][i] |
|
|
289 |
|
|
|
290 |
if task_prompt == "NOCHANGE": |
|
|
291 |
continue # we don't want to train for correction on already correct reports |
|
|
292 |
# sample random prompt for every report |
|
|
293 |
reports_json = { |
|
|
294 |
"gt_report": gt_report, |
|
|
295 |
"incorrect_report": incorrect_report, |
|
|
296 |
"task": task_prompt, |
|
|
297 |
"instruction": task_instruction, |
|
|
298 |
"input": "", |
|
|
299 |
"output": answer, |
|
|
300 |
"dicom": dicom, |
|
|
301 |
"task_type": 'CO' |
|
|
302 |
} |
|
|
303 |
report_jsons.append(reports_json) |
|
|
304 |
|
|
|
305 |
# save |
|
|
306 |
with open(f"data/large_instruct_data/instruct_large_CO.json", "w") as f: |
|
|
307 |
json.dump(report_jsons, f, ensure_ascii=False, indent=4) |
|
|
308 |
|
|
|
309 |
|
|
|
310 |
def create_nle_task_data(): |
|
|
311 |
MIMIC_DIAGNOSISLIST = ['Atelectasis', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Lung Lesion', 'Lung Opacity', 'Pleural Effusion', |
|
|
312 |
'Pleural Other', 'Pneumonia', 'Pneumothorax'] |
|
|
313 |
# load mimic_nle json |
|
|
314 |
mimic_nle = [] |
|
|
315 |
with open(f'{PATH_TO_MIMIC_NLE}/mimic-nle/mimic-nle-train.json', 'r') as f: |
|
|
316 |
for line in f: |
|
|
317 |
obj = json.loads(line) |
|
|
318 |
mimic_nle.append(obj) |
|
|
319 |
|
|
|
320 |
prompts = pd.read_csv(f"data/instruct_prompts/RE_prompts.csv")["instruction"].tolist() |
|
|
321 |
report_jsons = [] |
|
|
322 |
reports = pd.read_csv('mimic-cxr/reports_processed/mimic_cxr_sectioned.csv') |
|
|
323 |
reports = reports.dropna(subset=['findings']) |
|
|
324 |
reports['findings'] = reports['findings'].apply(lambda x: x.replace('\n', '')) |
|
|
325 |
|
|
|
326 |
for sample in tqdm(mimic_nle): |
|
|
327 |
report_id = sample["report_ID"] |
|
|
328 |
gt_report = reports[reports["Note_file"] == f"{report_id}.txt"]["findings"].tolist() |
|
|
329 |
if len(gt_report) == 0: # report did have no findings section |
|
|
330 |
continue |
|
|
331 |
gt_report = gt_report[0] |
|
|
332 |
|
|
|
333 |
nle = sample['nle'] |
|
|
334 |
if nle not in gt_report: # sort out samples that reference the impression instead of the findings section |
|
|
335 |
continue |
|
|
336 |
|
|
|
337 |
dicom = reports[reports["Note_file"] == f"{report_id}.txt"]["dicom_id"].tolist()[0] |
|
|
338 |
task_prompt = random.choice(prompts) |
|
|
339 |
|
|
|
340 |
diagnoses = [d for idx, d in enumerate(MIMIC_DIAGNOSISLIST) if sample["diagnosis_label"][idx] == 1] |
|
|
341 |
diagnoses_string = ", ".join(diagnoses) |
|
|
342 |
diagnoses_string = diagnoses_string.rsplit(', ', 1) |
|
|
343 |
diagnoses_string = ' and '.join(diagnoses_string) |
|
|
344 |
task_prompt = task_prompt.replace("<X>", diagnoses_string) |
|
|
345 |
|
|
|
346 |
# sample random prompt for every report |
|
|
347 |
reports_json = { |
|
|
348 |
"gt_report": gt_report, |
|
|
349 |
"task": task_prompt, |
|
|
350 |
"input": "", |
|
|
351 |
"output": sample['nle'], |
|
|
352 |
"dicom": dicom, |
|
|
353 |
"task_type": 'RE' |
|
|
354 |
} |
|
|
355 |
report_jsons.append(reports_json) |
|
|
356 |
|
|
|
357 |
# save |
|
|
358 |
print(len(report_jsons)) |
|
|
359 |
with open(f"data/large_instruct_data/instruct_large_RE.json", "w") as f: |
|
|
360 |
json.dump(report_jsons, f, ensure_ascii=False, indent=4) |