|
a |
|
b/data/create_data.py |
|
|
1 |
import argparse |
|
|
2 |
import dataclasses |
|
|
3 |
import json |
|
|
4 |
import os |
|
|
5 |
from enum import auto, Enum |
|
|
6 |
from pathlib import Path |
|
|
7 |
from typing import List, Any |
|
|
8 |
import random |
|
|
9 |
|
|
|
10 |
import numpy as np |
|
|
11 |
import pandas as pd |
|
|
12 |
import torch |
|
|
13 |
from omegaconf import OmegaConf |
|
|
14 |
from torch.utils.data import Dataset, DataLoader |
|
|
15 |
from tqdm import tqdm |
|
|
16 |
from transformers import AutoTokenizer |
|
|
17 |
from torch.utils.data.sampler import Sampler |
|
|
18 |
|
|
|
19 |
from data.instruct_tasks import create_direct_task_data, create_cp_task_data, create_correction_task_data, create_nle_task_data |
|
|
20 |
from local_config import VIS_ROOT, PATH_TO_MIMIC_CXR |
|
|
21 |
from model.lavis.models.blip2_models.modeling_llama_imgemb import LlamaForCausalLM |
|
|
22 |
|
|
|
23 |
|
|
|
24 |
class MyReportProcessor(): |
|
|
25 |
def __init__(self, prompt="", max_words=50, prompt_neg=""): |
|
|
26 |
self.prompt = prompt |
|
|
27 |
self.max_words = max_words |
|
|
28 |
self.prompt_neg = prompt_neg |
|
|
29 |
|
|
|
30 |
def __call__(self, findings, no_labels=False): |
|
|
31 |
prompt = self.prompt |
|
|
32 |
|
|
|
33 |
if no_labels: |
|
|
34 |
findings = "no common findings" # cannot write which findings as we don't no them |
|
|
35 |
prompt = prompt.format(findings=findings) |
|
|
36 |
|
|
|
37 |
return prompt |
|
|
38 |
|
|
|
39 |
@classmethod |
|
|
40 |
def from_config(cls, cfg=None): |
|
|
41 |
if cfg is None: |
|
|
42 |
cfg = OmegaConf.create() |
|
|
43 |
|
|
|
44 |
prompt = cfg.get("prompt", "") |
|
|
45 |
max_words = cfg.get("max_words", 50) |
|
|
46 |
|
|
|
47 |
return cls(prompt=prompt, max_words=max_words) |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
class SeparatorStyle(Enum): |
|
|
51 |
"""Different separator style.""" |
|
|
52 |
SINGLE = auto() |
|
|
53 |
TWO = auto() |
|
|
54 |
|
|
|
55 |
|
|
|
56 |
@dataclasses.dataclass |
|
|
57 |
class Conversation: |
|
|
58 |
"""A class that keeps all conversation history.""" |
|
|
59 |
system: str |
|
|
60 |
roles: List[str] |
|
|
61 |
messages: List[List[str]] |
|
|
62 |
offset: int |
|
|
63 |
sep_style: SeparatorStyle = SeparatorStyle.SINGLE |
|
|
64 |
sep: str = "###" |
|
|
65 |
sep2: str = None |
|
|
66 |
|
|
|
67 |
# Used for gradio server |
|
|
68 |
skip_next: bool = False |
|
|
69 |
conv_id: Any = None |
|
|
70 |
|
|
|
71 |
def get_prompt(self): |
|
|
72 |
if self.sep_style == SeparatorStyle.SINGLE: |
|
|
73 |
ret = self.system |
|
|
74 |
for role, message in self.messages: |
|
|
75 |
if message: |
|
|
76 |
ret += self.sep + " " + role + ": " + message |
|
|
77 |
else: |
|
|
78 |
ret += self.sep + " " + role + ":" |
|
|
79 |
return ret |
|
|
80 |
elif self.sep_style == SeparatorStyle.TWO: |
|
|
81 |
seps = [self.sep, self.sep2] |
|
|
82 |
ret = self.system + seps[0] |
|
|
83 |
for i, (role, message) in enumerate(self.messages): |
|
|
84 |
if message: |
|
|
85 |
ret += role + ": " + message + seps[i % 2] |
|
|
86 |
else: |
|
|
87 |
ret += role + ":" |
|
|
88 |
return ret |
|
|
89 |
else: |
|
|
90 |
raise ValueError(f"Invalid style: {self.sep_style}") |
|
|
91 |
|
|
|
92 |
def append_message(self, role, message): |
|
|
93 |
self.messages.append([role, message]) |
|
|
94 |
|
|
|
95 |
def dict(self): |
|
|
96 |
return { |
|
|
97 |
"system": self.system, |
|
|
98 |
"roles": self.roles, |
|
|
99 |
"messages": self.messages, |
|
|
100 |
"offset": self.offset, |
|
|
101 |
"sep": self.sep, |
|
|
102 |
"sep2": self.sep2, |
|
|
103 |
"conv_id": self.conv_id, |
|
|
104 |
} |
|
|
105 |
|
|
|
106 |
|
|
|
107 |
def create_conv(): |
|
|
108 |
conv = Conversation( |
|
|
109 |
system="A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. " |
|
|
110 |
"The assistant gives professional, detailed, and polite answers to the user's questions.", |
|
|
111 |
roles=["USER", "ASSISTANT"], |
|
|
112 |
messages=[], |
|
|
113 |
offset=0, |
|
|
114 |
sep_style=SeparatorStyle.TWO, |
|
|
115 |
sep=" ", |
|
|
116 |
sep2="</s>", |
|
|
117 |
) |
|
|
118 |
return conv |
|
|
119 |
|
|
|
120 |
|
|
|
121 |
class MIMIC_Text_Dataset(Dataset): |
|
|
122 |
def __init__(self, split, truncate=None, prompt_type="basic", use_indication=False): |
|
|
123 |
super().__init__() |
|
|
124 |
|
|
|
125 |
# load csv file |
|
|
126 |
self.split = pd.read_csv( |
|
|
127 |
f'{PATH_TO_MIMIC_CXR}/mimic-cxr-jpg/2.0.0/mimic-cxr-2.0.0-split.csv') |
|
|
128 |
self.reports = pd.read_csv('mimic-cxr/reports_processed/mimic_cxr_sectioned.csv') |
|
|
129 |
# drop reports where findings are nan |
|
|
130 |
self.reports = self.reports.dropna(subset=['findings']) |
|
|
131 |
|
|
|
132 |
self.img_ids = {img_id: i for i, img_id in enumerate(self.reports['dicom_id'])} |
|
|
133 |
self.chexpert = pd.read_csv(f'data/data_files/finding_chexbert_labels.csv') |
|
|
134 |
self.chexpert_cols = ["No Finding", "Enlarged Cardiomediastinum", |
|
|
135 |
"Cardiomegaly", "Lung Opacity", |
|
|
136 |
"Lung Lesion", "Edema", |
|
|
137 |
"Consolidation", "Pneumonia", |
|
|
138 |
"Atelectasis", "Pneumothorax", |
|
|
139 |
"Pleural Effusion", "Pleural Other", |
|
|
140 |
"Fracture", "Support Devices"] |
|
|
141 |
|
|
|
142 |
self.use_indication = use_indication |
|
|
143 |
|
|
|
144 |
self.vis_root = VIS_ROOT |
|
|
145 |
|
|
|
146 |
self.prompt_type = prompt_type |
|
|
147 |
|
|
|
148 |
self.split_ids = set(self.split.loc[self.split['split'] == split]['dicom_id']) |
|
|
149 |
self.train_ids = set(self.split.loc[self.split['split'] == 'train']['dicom_id']) |
|
|
150 |
|
|
|
151 |
# get all dicom_ids where "split" is split |
|
|
152 |
self.annotation = self.reports.loc[self.reports['dicom_id'].isin(self.split_ids)] |
|
|
153 |
if truncate is not None: |
|
|
154 |
self.annotation = self.annotation[:truncate] |
|
|
155 |
|
|
|
156 |
self.annotation['findings'] = self.annotation['findings'].apply(lambda x: x.replace('\n', '')) |
|
|
157 |
|
|
|
158 |
# Extract patient_id from Img_Folder (3rd part) and study_id is the name of the notefile without the pre-pending 's' |
|
|
159 |
self.annotation['subject_id'] = self.annotation['Img_Folder'].apply(lambda x: int(x.split('/')[2].lstrip('p'))) |
|
|
160 |
self.annotation['study_id'] = self.annotation['Note_file'].apply(lambda x: int(x.lstrip('s').rstrip('.txt'))) |
|
|
161 |
|
|
|
162 |
# Merge chexpert labels with annotation dataframe |
|
|
163 |
self.annotation = pd.merge(self.annotation, self.chexpert, how='left', left_on=['dicom_id'], |
|
|
164 |
right_on=['dicom_id']) |
|
|
165 |
|
|
|
166 |
# for every row add a string of comma-separated positive labels |
|
|
167 |
self.annotation['positive_labels'] = self.annotation.apply(lambda x: self.convert_to_finding_labels(x[self.chexpert_cols].values, |
|
|
168 |
self.chexpert_cols), axis=1) |
|
|
169 |
|
|
|
170 |
# maybe use transforms from here: ResNet50_Weights.IMAGENET1K_V2.transforms |
|
|
171 |
# read prompt from json |
|
|
172 |
prompts = json.loads(Path(f"vicuna_prompts.json").read_text(encoding="UTF-8")) |
|
|
173 |
self.text_processor = MyReportProcessor( |
|
|
174 |
prompt=prompts[prompt_type], max_words=1000, |
|
|
175 |
prompt_neg=prompts[prompt_type.replace("matching_examples", "neg_matching_examples")]) |
|
|
176 |
|
|
|
177 |
def convert_to_finding_labels(self, chexpert_labels, columns, label=1): |
|
|
178 |
# Get indices where value is 1 |
|
|
179 |
indices = np.where(chexpert_labels == label) |
|
|
180 |
# Get the corresponding column names and join them into a string |
|
|
181 |
labels = ", ".join([columns[i] for i in indices[0]]) |
|
|
182 |
return labels |
|
|
183 |
|
|
|
184 |
def __getitem__(self, index): |
|
|
185 |
ann = self.annotation.iloc[index] |
|
|
186 |
# if self.use_indication: |
|
|
187 |
# indication = self.indications[study_id] |
|
|
188 |
# if indication == "": |
|
|
189 |
# indication = "Indication not given." |
|
|
190 |
caption = ann["findings"].strip() |
|
|
191 |
chexpert_labels = ann[self.chexpert_cols].astype(float).values |
|
|
192 |
chexpert_label_str = ann["positive_labels"] |
|
|
193 |
dicom_id = ann["dicom_id"] |
|
|
194 |
|
|
|
195 |
# check if all columns are in (nan, 0) -> no labels |
|
|
196 |
no_labels = np.all((np.isnan(chexpert_labels)) | (chexpert_labels == 0) | (chexpert_labels == -1.)) |
|
|
197 |
finding_string = chexpert_label_str.lower().strip() |
|
|
198 |
|
|
|
199 |
input_text = self.text_processor(findings=finding_string, no_labels=no_labels) |
|
|
200 |
|
|
|
201 |
# if self.use_indication: |
|
|
202 |
# input_text = "Indication: " + indication + " " + input_text |
|
|
203 |
|
|
|
204 |
# template for vicuna v1.3 |
|
|
205 |
conv = Conversation( |
|
|
206 |
system="A chat between a curious user and an artificial intelligence assistant acting as an experienced radiologist. " |
|
|
207 |
"The assistant gives professional, detailed, and polite answers to the user's questions.", |
|
|
208 |
roles=["USER", "ASSISTANT"], |
|
|
209 |
messages=[], |
|
|
210 |
offset=0, |
|
|
211 |
sep_style=SeparatorStyle.TWO, |
|
|
212 |
sep=" ", |
|
|
213 |
sep2="</s>", |
|
|
214 |
) |
|
|
215 |
conv.append_message(conv.roles[0], input_text) |
|
|
216 |
conv.append_message(conv.roles[1], None) |
|
|
217 |
prompt = conv.get_prompt() |
|
|
218 |
|
|
|
219 |
return { |
|
|
220 |
"text_input": prompt, |
|
|
221 |
"text_target": caption, |
|
|
222 |
"ig_label_string": finding_string, |
|
|
223 |
"chexpert_labels": chexpert_labels, |
|
|
224 |
"chexpert_cols": self.chexpert_cols, |
|
|
225 |
"dicom": dicom_id, |
|
|
226 |
"img_path": ann["Img_Folder"] + "/" + ann["Img_Filename"], |
|
|
227 |
} |
|
|
228 |
|
|
|
229 |
def __len__(self): |
|
|
230 |
return len(self.annotation) |
|
|
231 |
|
|
|
232 |
|
|
|
233 |
class SubsetSampler(Sampler): |
|
|
234 |
def __init__(self, indices): |
|
|
235 |
self.indices = indices |
|
|
236 |
|
|
|
237 |
def __iter__(self): |
|
|
238 |
return (self.indices[i] for i in range(len(self.indices))) |
|
|
239 |
|
|
|
240 |
def __len__(self): |
|
|
241 |
return len(self.indices) |
|
|
242 |
|
|
|
243 |
|
|
|
244 |
def stratified_sample(df, simulated_epochs=1): |
|
|
245 |
# We want to reduce the number of examples with no finding to 1/14th of the dataset. We achieve this easily by first seperating the dataset into 2 groups: no finding and finding. |
|
|
246 |
# either no finding, or nothing is considered a no finding |
|
|
247 |
no_findings_indices = df.annotation[((df.annotation['No Finding'] == 1) | ((df.annotation[df.chexpert_cols] == 1).sum(1) == 0) == 1)].index |
|
|
248 |
finding_indices = df.annotation.index.difference(no_findings_indices) |
|
|
249 |
no_findings_indices = no_findings_indices.tolist() |
|
|
250 |
finding_indices = finding_indices.tolist() |
|
|
251 |
|
|
|
252 |
# we are striving to lose as little no_finding data as possible. So instead of just reducing the number of no_finding examples, we will increase the number of finding examples. Just clone and extend dataset |
|
|
253 |
finding_indices = finding_indices * simulated_epochs |
|
|
254 |
# subsample the no finding examples to be 1/14th of the new dataset |
|
|
255 |
new_dataset_size = len(finding_indices) * 14 / 13 |
|
|
256 |
new_no_finding_count = int(new_dataset_size / 14) |
|
|
257 |
# merge considering the new dataset size |
|
|
258 |
all_indices = finding_indices + random.sample(no_findings_indices, new_no_finding_count) |
|
|
259 |
return all_indices |
|
|
260 |
|
|
|
261 |
|
|
|
262 |
def create_report_data_vicuna_specific_stratified(prompt_type): |
|
|
263 |
val_dataset = MIMIC_Text_Dataset(split="train", truncate=None, prompt_type=prompt_type) |
|
|
264 |
stratified_indices = stratified_sample(val_dataset, simulated_epochs=2) |
|
|
265 |
sampler = SubsetSampler(stratified_indices) |
|
|
266 |
data_loader = DataLoader(val_dataset, batch_size=200, num_workers=200, sampler=sampler) |
|
|
267 |
|
|
|
268 |
report_jsons = [] |
|
|
269 |
for _, batch in tqdm(enumerate(data_loader)): |
|
|
270 |
# iterate over batch elements |
|
|
271 |
for i in range(len(batch["text_input"])): |
|
|
272 |
text_input = batch["text_input"][i] |
|
|
273 |
text_target = batch["text_target"][i] |
|
|
274 |
dicom = batch["dicom"][i] |
|
|
275 |
|
|
|
276 |
# sample random prompt for every report |
|
|
277 |
reports_json = { |
|
|
278 |
"instruction": text_input, |
|
|
279 |
"input": "", |
|
|
280 |
"output": text_target, |
|
|
281 |
"dicom": dicom, |
|
|
282 |
} |
|
|
283 |
report_jsons.append(reports_json) |
|
|
284 |
|
|
|
285 |
# Save the JSON data to a file |
|
|
286 |
with open("data/data_files/mimic_cxr_reports_stratified.json", "w") as f: |
|
|
287 |
json.dump(report_jsons, f, ensure_ascii=False, indent=4) |
|
|
288 |
|
|
|
289 |
|
|
|
290 |
''' |
|
|
291 |
this method saves instruct data jsons for all the different tasks we defined: |
|
|
292 |
- easy language: EL DONE |
|
|
293 |
- correction: CO DONE |
|
|
294 |
- summerization: SU DONE |
|
|
295 |
- reasoning: RE (based on MIMIC-NLE) DONE |
|
|
296 |
- region QA: RQA DONE |
|
|
297 |
- CP binary QA: CPbQA DONE |
|
|
298 |
- CP all QA: CPaQA DONE |
|
|
299 |
|
|
|
300 |
for every report we sample one task and one prompt and save the report, the question (task) and the answer generated by vicuna (or from dataset groundtruth) |
|
|
301 |
''' |
|
|
302 |
|
|
|
303 |
|
|
|
304 |
def create_report_data_vicuna_instruct_large(): |
|
|
305 |
lang_model = LlamaForCausalLM.from_pretrained("lmsys/vicuna-13b-v1.3", torch_dtype=torch.float16, device_map='auto', load_in_8bit=False) |
|
|
306 |
tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-13b-v1.3", use_fast=False, truncation_side="left", padding_side="left") |
|
|
307 |
tokenizer.pad_token = tokenizer.unk_token |
|
|
308 |
|
|
|
309 |
val_dataset = MIMIC_Text_Dataset(split="train", truncate=None, prompt_type="img_matching_examples_ig2_noexamples") |
|
|
310 |
# split in 6 portions of 1/6th each, randomly |
|
|
311 |
split_size = len(val_dataset) // 6 |
|
|
312 |
remainder = len(val_dataset) % 6 |
|
|
313 |
|
|
|
314 |
val_dataset_EL, _, val_dataset_SU, val_dataset_EX, val_dataset_RQA, val_dataset_CPQA = torch.utils.data.random_split(val_dataset, |
|
|
315 |
[split_size + (i < remainder) |
|
|
316 |
for i in range( |
|
|
317 |
6)]) # correction is samples somewhere else |
|
|
318 |
|
|
|
319 |
# split val_dataset_CPQA in 2 |
|
|
320 |
split_size = len(val_dataset_CPQA) // 2 |
|
|
321 |
remainder = len(val_dataset_CPQA) % 2 |
|
|
322 |
val_dataset_CPbQA, val_dataset_CPaQA = torch.utils.data.random_split(val_dataset_CPQA, [split_size + (i < remainder) for i in range(2)]) |
|
|
323 |
|
|
|
324 |
# create directory |
|
|
325 |
if not os.path.exists("data/large_instruct_data"): |
|
|
326 |
os.makedirs("data/large_instruct_data") |
|
|
327 |
|
|
|
328 |
# create data |
|
|
329 |
create_direct_task_data(lang_model, tokenizer, val_dataset_EL, task_name="EL") |
|
|
330 |
create_direct_task_data(lang_model, tokenizer, val_dataset_SU, task_name="SU") |
|
|
331 |
create_direct_task_data(lang_model, tokenizer, val_dataset_RQA, task_name="RQA") |
|
|
332 |
create_cp_task_data(val_dataset_CPbQA, task_name="CPbQA") |
|
|
333 |
create_cp_task_data(val_dataset_CPaQA, task_name="CPaQA") |
|
|
334 |
|
|
|
335 |
create_correction_task_data(lang_model, tokenizer) |
|
|
336 |
create_nle_task_data() |
|
|
337 |
|
|
|
338 |
|
|
|
339 |
''' |
|
|
340 |
fuse instruct data with report generation task into one dataset json |
|
|
341 |
''' |
|
|
342 |
|
|
|
343 |
|
|
|
344 |
def fuse_instruct_dataset(prompt_type="img_matching_examples_ig2_noexamples_IMG_findings"): |
|
|
345 |
# get report generation data |
|
|
346 |
val_dataset = MIMIC_Text_Dataset(split="train", truncate=None, prompt_type=prompt_type) |
|
|
347 |
stratified_indices = stratified_sample(val_dataset, simulated_epochs=2) |
|
|
348 |
sampler = SubsetSampler(stratified_indices) |
|
|
349 |
data_loader = DataLoader(val_dataset, batch_size=200, sampler=sampler, num_workers=200) |
|
|
350 |
report_jsons = [] |
|
|
351 |
for _, batch in tqdm(enumerate(data_loader)): |
|
|
352 |
# iterate over batch elements |
|
|
353 |
for i in range(len(batch["text_input"])): |
|
|
354 |
text_input = batch["text_input"][i] |
|
|
355 |
text_target = batch["text_target"][i] |
|
|
356 |
dicom = batch["dicom"][i] |
|
|
357 |
|
|
|
358 |
# sample random prompt for every report |
|
|
359 |
reports_json = { |
|
|
360 |
"instruction": text_input, |
|
|
361 |
"input": "", |
|
|
362 |
"output": text_target, |
|
|
363 |
"dicom": dicom, |
|
|
364 |
} |
|
|
365 |
report_jsons.append(reports_json) |
|
|
366 |
|
|
|
367 |
task_jsons = [] |
|
|
368 |
with open(f"vicuna_prompts.json", "r") as f: |
|
|
369 |
prompts = json.load(f) |
|
|
370 |
report_prompt = prompts[prompt_type] |
|
|
371 |
|
|
|
372 |
# get instruct data |
|
|
373 |
for task in ["EL", "RE", "CO", "SU", "RQA", "CPbQA", "CPaQA"]: |
|
|
374 |
print("Creating data for " + task) |
|
|
375 |
with open(f"data/large_instruct_data/instruct_large_{task}.json", "r") as f: |
|
|
376 |
task_data = json.load(f) |
|
|
377 |
|
|
|
378 |
for elem in tqdm(task_data): |
|
|
379 |
report = elem["gt_report"] if task != "CO" else elem["incorrect_report"] |
|
|
380 |
|
|
|
381 |
conv = create_conv() |
|
|
382 |
conv.append_message(conv.roles[0], report_prompt) |
|
|
383 |
conv.append_message(conv.roles[1], report) |
|
|
384 |
conv.append_message(conv.roles[0], elem["task"]) |
|
|
385 |
conv.append_message(conv.roles[1], None) |
|
|
386 |
|
|
|
387 |
instruction = conv.get_prompt() |
|
|
388 |
|
|
|
389 |
# get elem directly from val_dataset.train_annotation with same dicom |
|
|
390 |
orig_elem = val_dataset.annotation[val_dataset.annotation["dicom_id"] == elem["dicom"]].iloc[0] |
|
|
391 |
|
|
|
392 |
if type(orig_elem['positive_labels']) == float and np.isnan(orig_elem['positive_labels']): |
|
|
393 |
finding_str = "no common findings" |
|
|
394 |
else: |
|
|
395 |
finding_str = orig_elem['positive_labels'].lower().strip() |
|
|
396 |
instruction = instruction.format(findings=finding_str) |
|
|
397 |
|
|
|
398 |
task_json = { |
|
|
399 |
"instruction": instruction, |
|
|
400 |
"input": "", |
|
|
401 |
"output": elem["output"].lower().strip() if task == "CPaQA" else elem["output"].strip(), |
|
|
402 |
"dicom": elem["dicom"], |
|
|
403 |
} |
|
|
404 |
task_jsons.append(task_json) |
|
|
405 |
|
|
|
406 |
# combine and shuffle report and task jsons |
|
|
407 |
combined_jsons = report_jsons + task_jsons |
|
|
408 |
random.shuffle(combined_jsons) |
|
|
409 |
|
|
|
410 |
# save to json |
|
|
411 |
with open(f"data/data_files/mimic_cxr_instruct_stratified.json", "w") as f: |
|
|
412 |
json.dump(combined_jsons, f, indent=4) |
|
|
413 |
|
|
|
414 |
|
|
|
415 |
if __name__ == '__main__': |
|
|
416 |
# args parser |
|
|
417 |
parser = argparse.ArgumentParser() |
|
|
418 |
parser.add_argument('--mode', type=str, default='RG', help='RG or INS') |
|
|
419 |
args = parser.parse_args() |
|
|
420 |
|
|
|
421 |
''' Create data to train RaDialog-RG model''' |
|
|
422 |
if args.mode == 'RG': |
|
|
423 |
create_report_data_vicuna_specific_stratified(prompt_type="img_matching_examples_ig2_noexamples_IMG_findings") |
|
|
424 |
|
|
|
425 |
''' Create data to train RaDialog-INS model''' |
|
|
426 |
if args.mode == 'INS': |
|
|
427 |
create_report_data_vicuna_instruct_large() |
|
|
428 |
fuse_instruct_dataset() |
|
|
429 |
|
|
|
430 |
# This code is meant for understanding how our instruct dataset is created. |
|
|
431 |
# Due to randomness in the sampling and model predictions, a newly generated dataset could be slightly different. |
|
|
432 |
# To exactly reproduce our results, please use the instruct dataset we published and use 'fuse_instruct_dataset' to merge with your MIMIC data. |