|
a/README.md |
|
b/README.md |
1 |
<p align="center"> |
1 |
|
2 |
<img src="images/logo.png" width="40%"> |
2 |
|
3 |
</p> |
3 |
|
4 |
|
|
|
5 |
|
|
|
6 |
# Instruction Tuning Large Language Models to Understand Electronic Health Records |
4 |
# Instruction Tuning Large Language Models to Understand Electronic Health Records |
7 |
|
5 |
|
8 |
**Authors:** Zhenbang Wu, Anant Dadu, Michael Nalls, Faraz Faghri, Jimeng Sun |
6 |
**Authors:** Zhenbang Wu, Anant Dadu, Michael Nalls, Faraz Faghri, Jimeng Sun |
9 |
|
7 |
|
10 |
**Published at:** NeurIPS 2024 Datasets and Benchmarks Track (Spotlight) |
8 |
**Published at:** NeurIPS 2024 Datasets and Benchmarks Track (Spotlight) |
11 |
|
9 |
|
12 |
[[📑Paper](https://openreview.net/pdf?id=Dgy5WVgPd2)] [[🪧Poster](./poster.pdf)] [[📽️Slides](./slides.pdf)] |
10 |
[[📑Paper](https://openreview.net/pdf?id=Dgy5WVgPd2)] [[🪧Poster](./poster.pdf)] [[📽️Slides](./slides.pdf)] |
13 |
|
11 |
|
14 |
|
12 |
|
15 |
## Release |
13 |
## Release
|
16 |
- [19 Dec 2024] The trained model weights are released |
14 |
- [19 Dec 2024] The trained model weights are released
|
17 |
- [11 Dec 2024] A sample dataset with ~100 patients is added |
15 |
- [11 Dec 2024] A sample dataset with ~100 patients is added
|
18 |
- [11 Dec 2024] Code for dataset creation, model training, and response evaluation is released |
16 |
- [11 Dec 2024] Code for dataset creation, model training, and response evaluation is released |
19 |
|
17 |
|
20 |
|
18 |
|
21 |
## Contents |
19 |
## Contents
|
22 |
- [Core Dependencies](#core-dependencies) |
20 |
- [Core Dependencies](#core-dependencies)
|
23 |
- [Data Download](#data-download) |
21 |
- [Data Download](#data-download)
|
24 |
- [Model Download](#model-download) |
22 |
- [Model Download](#model-download)
|
25 |
- [Evaluate](#evaluate) |
23 |
- [Evaluate](#evaluate)
|
26 |
- [Train](#train) |
24 |
- [Train](#train)
|
27 |
- [Dataset Creation](#dataset-creation) |
25 |
- [Dataset Creation](#dataset-creation)
|
28 |
- [Notes on Model Enhancements](#notes-on-model-enhancements) |
26 |
- [Notes on Model Enhancements](#notes-on-model-enhancements)
|
29 |
- [Citation](#citation) |
27 |
- [Citation](#citation) |
30 |
|
28 |
|
31 |
|
29 |
|
32 |
## Core Dependencies |
30 |
## Core Dependencies
|
33 |
``` |
31 |
```
|
34 |
python 3.9 |
32 |
python 3.9
|
35 |
torch 2.3.0 |
33 |
torch 2.3.0
|
36 |
transformers 4.44.0 |
34 |
transformers 4.44.0
|
37 |
peft 0.10.0 |
35 |
peft 0.10.0
|
38 |
``` |
36 |
``` |
39 |
|
37 |
|
40 |
## Data Download |
38 |
## Data Download |
41 |
|
39 |
|
42 |
<p align="center"> |
|
|
43 |
<img src="images/dataset.png" width="100%"> |
|
|
44 |
</p> |
|
|
45 |
|
|
|
46 |
The **MIMIC-Instr** dataset will be hosted on [PhysioNet](https://physionet.org/) once the preparation and review process is complete. |
40 |
The **MIMIC-Instr** dataset will be hosted on [PhysioNet](https://physionet.org/) once the preparation and review process is complete. |
47 |
|
41 |
|
48 |
A sample dataset generated from the [MIMIC-IV Demo](https://physionet.org/content/mimic-iv-demo/2.2/) database is available in the `sample_data` directory. |
42 |
A sample dataset generated from the [MIMIC-IV Demo](https://physionet.org/content/mimic-iv-demo/2.2/) database is available in the `sample_data` directory. |
49 |
|
43 |
|
50 |
For early access to the full dataset, please reach out to Zhenbang Wu (zw12@illinois.edu) with your CITI training report. |
44 |
For early access to the full dataset, please reach out to Zhenbang Wu (zw12@illinois.edu) with your CITI training report. |
51 |
|
45 |
|
52 |
|
46 |
|
53 |
## Model Download |
47 |
## Model Download |
54 |
|
48 |
|
55 |
<p align="center"> |
49 |
<p align="center">
|
56 |
<img src="images/model.png" width="100%"> |
50 |
<img src="images/model.png" width="100%">
|
57 |
</p> |
51 |
</p> |
58 |
|
52 |
|
59 |
|
53 |
|
60 |
The pre-trained model checkpoints can be found on the Hugging Face model hub: [zzachw12/llemr-v1](https://huggingface.co/zzachw12/llemr-v1). |
54 |
The pre-trained model checkpoints can be found on the Hugging Face model hub: [zzachw12/llemr-v1](https://huggingface.co/zzachw12/llemr-v1). |
61 |
|
55 |
|
62 |
You can load the model using the following code snippet: |
56 |
You can load the model using the following code snippet: |
63 |
|
57 |
|
64 |
```python |
58 |
```python
|
65 |
from peft import PeftModel |
59 |
from peft import PeftModel
|
66 |
from src.model.init_llemr import init_llemr |
60 |
from src.model.init_llemr import init_llemr |
67 |
|
61 |
|
68 |
# Define paths for the base model and LoRA weights |
62 |
# Define paths for the base model and LoRA weights
|
69 |
llm_pretrained_model_name_or_path = "lmsys/vicuna-7b-v1.5" |
63 |
llm_pretrained_model_name_or_path = "lmsys/vicuna-7b-v1.5"
|
70 |
lora_name_or_path = "zzachw12/llemr-v1" |
64 |
lora_name_or_path = "zzachw12/llemr-v1" |
71 |
|
65 |
|
72 |
# Initialize the base model and tokenizer |
66 |
# Initialize the base model and tokenizer
|
73 |
model, tokenizer = init_llemr(llm_pretrained_model_name_or_path, hidden_size=1027) |
67 |
model, tokenizer = init_llemr(llm_pretrained_model_name_or_path, hidden_size=1027) |
74 |
|
68 |
|
75 |
# Integrate the LoRA weights into the model |
69 |
# Integrate the LoRA weights into the model
|
76 |
model = PeftModel.from_pretrained(model, lora_name_or_path) |
70 |
model = PeftModel.from_pretrained(model, lora_name_or_path)
|
77 |
``` |
71 |
``` |
78 |
|
72 |
|
79 |
**Note:** This model requires pre-computed event embeddings generated by BiomedBERT. Follow [Evaluate](#evaluate) to preprocess the data, generate the response, and evaluate the model. |
73 |
**Note:** This model requires pre-computed event embeddings generated by BiomedBERT. Follow [Evaluate](#evaluate) to preprocess the data, generate the response, and evaluate the model. |
80 |
|
74 |
|
81 |
|
75 |
|
82 |
## Evaluate |
76 |
## Evaluate |
83 |
|
77 |
|
84 |
1. Download the MIMIC-Instr dataset from PhysioNet |
78 |
1. Download the MIMIC-Instr dataset from PhysioNet |
85 |
|
79 |
|
86 |
2. Run steps 1, 4, 7, 8 in [Data Generation](#data-generation) to prepare the event sequence data and pre-compute the event embeddings |
80 |
2. Run steps 1, 4, 7, 8 in [Data Generation](#data-generation) to prepare the event sequence data and pre-compute the event embeddings |
87 |
|
81 |
|
88 |
3. Generate the model response with [query_llemr.ipynb](src/eval/query_llemr.ipynb) |
82 |
3. Generate the model response with [query_llemr.ipynb](src/eval/query_llemr.ipynb) |
89 |
|
83 |
|
90 |
4. Compare the model response with the GPT-4 reference answer with [eval.ipynb](src/eval/eval.ipynb) (need OpenAI Azure service) |
84 |
4. Compare the model response with the GPT-4 reference answer with [eval.ipynb](src/eval/eval.ipynb) (need OpenAI Azure service) |
91 |
|
85 |
|
92 |
5. Summarize the results with [summary_eval.ipynb](src/eval/summary_eval.ipynb) |
86 |
5. Summarize the results with [summary_eval.ipynb](src/eval/summary_eval.ipynb) |
93 |
|
87 |
|
94 |
|
88 |
|
95 |
## Train |
89 |
## Train |
96 |
|
90 |
|
97 |
1. Download the MIMIC-Instr dataset from PhysioNet |
91 |
1. Download the MIMIC-Instr dataset from PhysioNet |
98 |
|
92 |
|
99 |
2. Run steps 1, 4, 7, 8 in [Data Generation](#data-generation) to prepare the event sequence data and pre-compute the event embeddings |
93 |
2. Run steps 1, 4, 7, 8 in [Data Generation](#data-generation) to prepare the event sequence data and pre-compute the event embeddings |
100 |
|
94 |
|
101 |
3. Run the training script [train.py](src/train/train.py): |
95 |
3. Run the training script [train.py](src/train/train.py):
|
102 |
- CMD: `sh src/train/train.sh` |
96 |
- CMD: `sh src/train/train.sh` |
103 |
|
97 |
|
104 |
|
98 |
|
105 |
## Dataset Creation |
99 |
## Dataset Creation |
106 |
|
100 |
|
107 |
1. Download the [MIMIC-IV](https://physionet.org/content/mimiciv/2.2/) in the `raw_data` directory |
101 |
1. Download the [MIMIC-IV](https://physionet.org/content/mimiciv/2.2/) in the `raw_data` directory |
108 |
|
102 |
|
109 |
2. Download the [MIMIC-IV-Note](https://physionet.org/content/mimic-iv-note/2.2/) dataset in the `raw_data` directory |
103 |
2. Download the [MIMIC-IV-Note](https://physionet.org/content/mimic-iv-note/2.2/) dataset in the `raw_data` directory |
110 |
|
104 |
|
111 |
3. Run the following jupyter notebook to select the patient cohort: [01_cohort_selection.ipynb](src/preprocess/01_cohort_selection.ipynb) |
105 |
3. Run the following jupyter notebook to select the patient cohort: [01_cohort_selection.ipynb](src/preprocess/01_cohort_selection.ipynb) |
112 |
|
106 |
|
113 |
4. Run the following jupyter notebooks to prepare the event sequence data: |
107 |
4. Run the following jupyter notebooks to prepare the event sequence data:
|
114 |
- 1. Extract events: |
108 |
- 1. Extract events:
|
115 |
- [02_event_static.ipynb](src/preprocess/02_event_static.ipynb) |
109 |
- [02_event_static.ipynb](src/preprocess/02_event_static.ipynb)
|
116 |
- [02_event_hosp_diagnoses_icd.ipynb](src/preprocess/02_event_hosp_diagnoses_icd.ipynb) |
110 |
- [02_event_hosp_diagnoses_icd.ipynb](src/preprocess/02_event_hosp_diagnoses_icd.ipynb)
|
117 |
- [02_event_hosp_labevents.ipynb](src/preprocess/02_event_hosp_labevents.ipynb) |
111 |
- [02_event_hosp_labevents.ipynb](src/preprocess/02_event_hosp_labevents.ipynb)
|
118 |
- [02_event_hosp_microbiologyevents.ipynb](src/preprocess/02_event_hosp_microbiologyevents.ipynb) |
112 |
- [02_event_hosp_microbiologyevents.ipynb](src/preprocess/02_event_hosp_microbiologyevents.ipynb)
|
119 |
- [02_event_hosp_prescriptions.ipynb](src/preprocess/02_event_hosp_prescriptions.ipynb) |
113 |
- [02_event_hosp_prescriptions.ipynb](src/preprocess/02_event_hosp_prescriptions.ipynb)
|
120 |
- [02_event_hosp_transfers.ipynb](src/preprocess/02_event_hosp_transfers.ipynb) |
114 |
- [02_event_hosp_transfers.ipynb](src/preprocess/02_event_hosp_transfers.ipynb)
|
121 |
- [02_event_icu_chartevents.ipynb](src/preprocess/02_event_icu_chartevents.ipynb) |
115 |
- [02_event_icu_chartevents.ipynb](src/preprocess/02_event_icu_chartevents.ipynb)
|
122 |
- [02_event_icu_inputevents.ipynb](src/preprocess/02_event_icu_inputevents.ipynb) |
116 |
- [02_event_icu_inputevents.ipynb](src/preprocess/02_event_icu_inputevents.ipynb)
|
123 |
- [02_event_icu_outputevents.ipynb](src/preprocess/02_event_icu_outputevents.ipynb) |
117 |
- [02_event_icu_outputevents.ipynb](src/preprocess/02_event_icu_outputevents.ipynb)
|
124 |
- [02_event_icu_procedureevents.ipynb](src/preprocess/02_event_icu_procedureevents.ipynb) |
118 |
- [02_event_icu_procedureevents.ipynb](src/preprocess/02_event_icu_procedureevents.ipynb)
|
125 |
- 2. Merge events: [03_merge_events.ipynb](src/preprocess/03_merge_events.ipynb) |
119 |
- 2. Merge events: [03_merge_events.ipynb](src/preprocess/03_merge_events.ipynb) |
126 |
|
120 |
|
127 |
5. Run the following jupyter notebooks to generate the instruction tuning data: |
121 |
5. Run the following jupyter notebooks to generate the instruction tuning data:
|
128 |
- Run this only if you want to generate the instruction tuning data on your own |
122 |
- Run this only if you want to generate the instruction tuning data on your own
|
129 |
- 1. Generate the schema alignment subset: |
123 |
- 1. Generate the schema alignment subset:
|
130 |
- [04_template_qa_event.ipynb](src/preprocess/04_template_qa_event.ipynb) |
124 |
- [04_template_qa_event.ipynb](src/preprocess/04_template_qa_event.ipynb)
|
131 |
- [04_paraphrase_qa_event.ipynb](src/preprocess/04_paraphrase_qa_event.ipynb) (need OpenAI Azure service) |
125 |
- [04_paraphrase_qa_event.ipynb](src/preprocess/04_paraphrase_qa_event.ipynb) (need OpenAI Azure service)
|
132 |
- 2. Generate the instruction following subset: |
126 |
- 2. Generate the instruction following subset:
|
133 |
- [04_generate_qa_note.ipynb](src/preprocess/04_generate_qa_note.ipynb) (need OpenAI Azure service) |
127 |
- [04_generate_qa_note.ipynb](src/preprocess/04_generate_qa_note.ipynb) (need OpenAI Azure service) |
134 |
|
128 |
|
135 |
6. Split the data into train, validation, and test sets: |
129 |
6. Split the data into train, validation, and test sets:
|
136 |
- [05_data_split.ipynb](src/preprocess/05_data_split.ipynb) |
130 |
- [05_data_split.ipynb](src/preprocess/05_data_split.ipynb) |
137 |
|
131 |
|
138 |
7. Pre-compute the event embeddings with [06_precompute_event_embeddings.py](src/preprocess/06_precompute_event_embeddings.py): |
132 |
7. Pre-compute the event embeddings with [06_precompute_event_embeddings.py](src/preprocess/06_precompute_event_embeddings.py):
|
139 |
- CMD: `sh src/preprocess/precompute_event_embeddings.sh` |
133 |
- CMD: `sh src/preprocess/precompute_event_embeddings.sh` |
140 |
|
134 |
|
141 |
8. Generate the GPT-4 reference answer with [query_gpt4.ipynb](src/eval/query_gpt4.ipynb) |
135 |
8. Generate the GPT-4 reference answer with [query_gpt4.ipynb](src/eval/query_gpt4.ipynb) |
142 |
|
136 |
|
143 |
|
137 |
|
144 |
## Notes on Model Enhancements |
138 |
## Notes on Model Enhancements |
145 |
|
139 |
|
146 |
This repository incorporates several minor improvements over the original implementation described in the paper: |
140 |
This repository incorporates several minor improvements over the original implementation described in the paper: |
147 |
|
141 |
|
148 |
1. **Enhanced Event Encoder:** |
142 |
1. **Enhanced Event Encoder:**
|
149 |
- Replaced ClinicalBERT (`emilyalsentzer/Bio_ClinicalBERT`) with BiomedBERT-large (`microsoft/BiomedNLP-BiomedBERT-large-uncased-abstract`), improving the quality of event embeddings |
143 |
- Replaced ClinicalBERT (`emilyalsentzer/Bio_ClinicalBERT`) with BiomedBERT-large (`microsoft/BiomedNLP-BiomedBERT-large-uncased-abstract`), improving the quality of event embeddings |
150 |
|
144 |
|
151 |
2. **Improved Event Embedding:** |
145 |
2. **Improved Event Embedding:**
|
152 |
- Concatenated event timestamps and numeric values (where available) to the final event embeddings, resulting in better representation of time-sensitive and quantitative data |
146 |
- Concatenated event timestamps and numeric values (where available) to the final event embeddings, resulting in better representation of time-sensitive and quantitative data |
153 |
|
147 |
|
154 |
3. **Expanded Dataset:** |
148 |
3. **Expanded Dataset:**
|
155 |
- Increased the size of the clinical reasoning subset to 100K examples, doubling the data from the original 50K subset for more comprehensive coverage. |
149 |
- Increased the size of the clinical reasoning subset to 100K examples, doubling the data from the original 50K subset for more comprehensive coverage. |
156 |
|
150 |
|
157 |
4. **Unified Training Approach:** |
151 |
4. **Unified Training Approach:**
|
158 |
- Adopted a single-step training process that integrates schema alignment and clinical reasoning subsets simultaneously, streamlining the training pipeline |
152 |
- Adopted a single-step training process that integrates schema alignment and clinical reasoning subsets simultaneously, streamlining the training pipeline |
159 |
|
153 |
|
160 |
These advancements collectively enhance the model's ability to interpret and reason with EHR data, delivering superior performance compared to its predecessor. |
154 |
These advancements collectively enhance the model's ability to interpret and reason with EHR data, delivering superior performance compared to its predecessor. |
161 |
|
155 |
|
162 |
|
156 |
|
163 |
## Citation |
157 |
## Citation |
164 |
|
158 |
|
165 |
If you find this work useful, please cite: |
159 |
If you find this work useful, please cite:
|
166 |
``` |
160 |
```
|
167 |
@inproceedings{ |
161 |
@inproceedings{
|
168 |
wu2024instruction, |
162 |
wu2024instruction,
|
169 |
title={Instruction Tuning Large Language Models to Understand Electronic Health Records}, |
163 |
title={Instruction Tuning Large Language Models to Understand Electronic Health Records},
|
170 |
author={Zhenbang Wu and Anant Dadu and Michael Nalls and Faraz Faghri and Jimeng Sun}, |
164 |
author={Zhenbang Wu and Anant Dadu and Michael Nalls and Faraz Faghri and Jimeng Sun},
|
171 |
booktitle={The Thirty-eight Conference on Neural Information Processing Systems Datasets and Benchmarks Track}, |
165 |
booktitle={The Thirty-eight Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
|
172 |
year={2024}, |
166 |
year={2024},
|
173 |
url={https://openreview.net/forum?id=Dgy5WVgPd2} |
167 |
url={https://openreview.net/forum?id=Dgy5WVgPd2}
|
174 |
} |
168 |
}
|
175 |
``` |
169 |
``` |
176 |
|
170 |
|
177 |
\* Note: The teaser image above the title is generated by ChatGPT. |
171 |
\* Note: The teaser image above the title is generated by ChatGPT.
|