Switch to unified view

a/README.md b/README.md
1
<p align="center">
1
2
    <img src="images/logo.png" width="40%">
2
3
</p>
3
4
5
6
# Instruction Tuning Large Language Models to Understand Electronic Health Records
4
# Instruction Tuning Large Language Models to Understand Electronic Health Records
7
5
8
**Authors:** Zhenbang Wu, Anant Dadu, Michael Nalls, Faraz Faghri, Jimeng Sun  
6
**Authors:** Zhenbang Wu, Anant Dadu, Michael Nalls, Faraz Faghri, Jimeng Sun  
9
7
10
**Published at:** NeurIPS 2024 Datasets and Benchmarks Track (Spotlight)
8
**Published at:** NeurIPS 2024 Datasets and Benchmarks Track (Spotlight)
11
9
12
[[📑Paper](https://openreview.net/pdf?id=Dgy5WVgPd2)] [[🪧Poster](./poster.pdf)] [[📽️Slides](./slides.pdf)]
10
[[📑Paper](https://openreview.net/pdf?id=Dgy5WVgPd2)] [[🪧Poster](./poster.pdf)] [[📽️Slides](./slides.pdf)]
13
11
14
12
15
## Release
13
## Release
16
- [19 Dec 2024] The trained model weights are released
14
- [19 Dec 2024] The trained model weights are released
17
- [11 Dec 2024] A sample dataset with ~100 patients is added
15
- [11 Dec 2024] A sample dataset with ~100 patients is added
18
- [11 Dec 2024] Code for dataset creation, model training, and response evaluation is released
16
- [11 Dec 2024] Code for dataset creation, model training, and response evaluation is released
19
17
20
18
21
## Contents
19
## Contents
22
- [Core Dependencies](#core-dependencies)
20
- [Core Dependencies](#core-dependencies)
23
- [Data Download](#data-download)
21
- [Data Download](#data-download)
24
- [Model Download](#model-download)
22
- [Model Download](#model-download)
25
- [Evaluate](#evaluate)
23
- [Evaluate](#evaluate)
26
- [Train](#train)
24
- [Train](#train)
27
- [Dataset Creation](#dataset-creation)
25
- [Dataset Creation](#dataset-creation)
28
- [Notes on Model Enhancements](#notes-on-model-enhancements)
26
- [Notes on Model Enhancements](#notes-on-model-enhancements)
29
- [Citation](#citation)
27
- [Citation](#citation)
30
28
31
29
32
## Core Dependencies
30
## Core Dependencies
33
```
31
```
34
python 3.9
32
python 3.9
35
torch 2.3.0
33
torch 2.3.0
36
transformers 4.44.0
34
transformers 4.44.0
37
peft 0.10.0
35
peft 0.10.0
38
```
36
```
39
37
40
## Data Download
38
## Data Download
41
39
42
<p align="center">
43
    <img src="images/dataset.png" width="100%">
44
</p>
45
46
The **MIMIC-Instr** dataset will be hosted on [PhysioNet](https://physionet.org/) once the preparation and review process is complete.
40
The **MIMIC-Instr** dataset will be hosted on [PhysioNet](https://physionet.org/) once the preparation and review process is complete.
47
41
48
A sample dataset generated from the [MIMIC-IV Demo](https://physionet.org/content/mimic-iv-demo/2.2/) database is available in the `sample_data` directory.
42
A sample dataset generated from the [MIMIC-IV Demo](https://physionet.org/content/mimic-iv-demo/2.2/) database is available in the `sample_data` directory.
49
43
50
For early access to the full dataset, please reach out to Zhenbang Wu (zw12@illinois.edu) with your CITI training report.
44
For early access to the full dataset, please reach out to Zhenbang Wu (zw12@illinois.edu) with your CITI training report.
51
45
52
46
53
## Model Download
47
## Model Download
54
48
55
<p align="center">
49
<p align="center">
56
    <img src="images/model.png" width="100%">
50
    <img src="images/model.png" width="100%">
57
</p>
51
</p>
58
52
59
53
60
The pre-trained model checkpoints can be found on the Hugging Face model hub: [zzachw12/llemr-v1](https://huggingface.co/zzachw12/llemr-v1).
54
The pre-trained model checkpoints can be found on the Hugging Face model hub: [zzachw12/llemr-v1](https://huggingface.co/zzachw12/llemr-v1).
61
55
62
You can load the model using the following code snippet:
56
You can load the model using the following code snippet:
63
57
64
```python
58
```python
65
from peft import PeftModel
59
from peft import PeftModel
66
from src.model.init_llemr import init_llemr
60
from src.model.init_llemr import init_llemr
67
61
68
# Define paths for the base model and LoRA weights
62
# Define paths for the base model and LoRA weights
69
llm_pretrained_model_name_or_path = "lmsys/vicuna-7b-v1.5"
63
llm_pretrained_model_name_or_path = "lmsys/vicuna-7b-v1.5"
70
lora_name_or_path = "zzachw12/llemr-v1"
64
lora_name_or_path = "zzachw12/llemr-v1"
71
65
72
# Initialize the base model and tokenizer
66
# Initialize the base model and tokenizer
73
model, tokenizer = init_llemr(llm_pretrained_model_name_or_path, hidden_size=1027)
67
model, tokenizer = init_llemr(llm_pretrained_model_name_or_path, hidden_size=1027)
74
68
75
# Integrate the LoRA weights into the model
69
# Integrate the LoRA weights into the model
76
model = PeftModel.from_pretrained(model, lora_name_or_path)
70
model = PeftModel.from_pretrained(model, lora_name_or_path)
77
```
71
```
78
72
79
**Note:** This model requires pre-computed event embeddings generated by BiomedBERT. Follow [Evaluate](#evaluate) to preprocess the data, generate the response, and evaluate the model.
73
**Note:** This model requires pre-computed event embeddings generated by BiomedBERT. Follow [Evaluate](#evaluate) to preprocess the data, generate the response, and evaluate the model.
80
74
81
75
82
## Evaluate
76
## Evaluate
83
77
84
1. Download the MIMIC-Instr dataset from PhysioNet
78
1. Download the MIMIC-Instr dataset from PhysioNet
85
79
86
2. Run steps 1, 4, 7, 8 in [Data Generation](#data-generation) to prepare the event sequence data and pre-compute the event embeddings
80
2. Run steps 1, 4, 7, 8 in [Data Generation](#data-generation) to prepare the event sequence data and pre-compute the event embeddings
87
81
88
3. Generate the model response with [query_llemr.ipynb](src/eval/query_llemr.ipynb)
82
3. Generate the model response with [query_llemr.ipynb](src/eval/query_llemr.ipynb)
89
83
90
4. Compare the model response with the GPT-4 reference answer with [eval.ipynb](src/eval/eval.ipynb) (need OpenAI Azure service)
84
4. Compare the model response with the GPT-4 reference answer with [eval.ipynb](src/eval/eval.ipynb) (need OpenAI Azure service)
91
85
92
5. Summarize the results with [summary_eval.ipynb](src/eval/summary_eval.ipynb)
86
5. Summarize the results with [summary_eval.ipynb](src/eval/summary_eval.ipynb)
93
87
94
88
95
## Train
89
## Train
96
90
97
1. Download the MIMIC-Instr dataset from PhysioNet
91
1. Download the MIMIC-Instr dataset from PhysioNet
98
92
99
2. Run steps 1, 4, 7, 8 in [Data Generation](#data-generation) to prepare the event sequence data and pre-compute the event embeddings
93
2. Run steps 1, 4, 7, 8 in [Data Generation](#data-generation) to prepare the event sequence data and pre-compute the event embeddings
100
94
101
3. Run the training script [train.py](src/train/train.py):
95
3. Run the training script [train.py](src/train/train.py):
102
   - CMD: `sh src/train/train.sh`
96
   - CMD: `sh src/train/train.sh`
103
97
104
   
98
   
105
## Dataset Creation
99
## Dataset Creation
106
100
107
1. Download the [MIMIC-IV](https://physionet.org/content/mimiciv/2.2/) in the `raw_data` directory
101
1. Download the [MIMIC-IV](https://physionet.org/content/mimiciv/2.2/) in the `raw_data` directory
108
102
109
2. Download the [MIMIC-IV-Note](https://physionet.org/content/mimic-iv-note/2.2/) dataset in the `raw_data` directory
103
2. Download the [MIMIC-IV-Note](https://physionet.org/content/mimic-iv-note/2.2/) dataset in the `raw_data` directory
110
104
111
3. Run the following jupyter notebook to select the patient cohort: [01_cohort_selection.ipynb](src/preprocess/01_cohort_selection.ipynb)
105
3. Run the following jupyter notebook to select the patient cohort: [01_cohort_selection.ipynb](src/preprocess/01_cohort_selection.ipynb)
112
106
113
4. Run the following jupyter notebooks to prepare the event sequence data:
107
4. Run the following jupyter notebooks to prepare the event sequence data:
114
   - 1. Extract events:
108
   - 1. Extract events:
115
     - [02_event_static.ipynb](src/preprocess/02_event_static.ipynb) 
109
     - [02_event_static.ipynb](src/preprocess/02_event_static.ipynb) 
116
     - [02_event_hosp_diagnoses_icd.ipynb](src/preprocess/02_event_hosp_diagnoses_icd.ipynb)
110
     - [02_event_hosp_diagnoses_icd.ipynb](src/preprocess/02_event_hosp_diagnoses_icd.ipynb)
117
     - [02_event_hosp_labevents.ipynb](src/preprocess/02_event_hosp_labevents.ipynb)
111
     - [02_event_hosp_labevents.ipynb](src/preprocess/02_event_hosp_labevents.ipynb)
118
     - [02_event_hosp_microbiologyevents.ipynb](src/preprocess/02_event_hosp_microbiologyevents.ipynb)
112
     - [02_event_hosp_microbiologyevents.ipynb](src/preprocess/02_event_hosp_microbiologyevents.ipynb)
119
     - [02_event_hosp_prescriptions.ipynb](src/preprocess/02_event_hosp_prescriptions.ipynb)
113
     - [02_event_hosp_prescriptions.ipynb](src/preprocess/02_event_hosp_prescriptions.ipynb)
120
     - [02_event_hosp_transfers.ipynb](src/preprocess/02_event_hosp_transfers.ipynb)
114
     - [02_event_hosp_transfers.ipynb](src/preprocess/02_event_hosp_transfers.ipynb)
121
     - [02_event_icu_chartevents.ipynb](src/preprocess/02_event_icu_chartevents.ipynb)
115
     - [02_event_icu_chartevents.ipynb](src/preprocess/02_event_icu_chartevents.ipynb)
122
     - [02_event_icu_inputevents.ipynb](src/preprocess/02_event_icu_inputevents.ipynb)
116
     - [02_event_icu_inputevents.ipynb](src/preprocess/02_event_icu_inputevents.ipynb)
123
     - [02_event_icu_outputevents.ipynb](src/preprocess/02_event_icu_outputevents.ipynb)
117
     - [02_event_icu_outputevents.ipynb](src/preprocess/02_event_icu_outputevents.ipynb)
124
     - [02_event_icu_procedureevents.ipynb](src/preprocess/02_event_icu_procedureevents.ipynb)
118
     - [02_event_icu_procedureevents.ipynb](src/preprocess/02_event_icu_procedureevents.ipynb)
125
   - 2. Merge events: [03_merge_events.ipynb](src/preprocess/03_merge_events.ipynb)
119
   - 2. Merge events: [03_merge_events.ipynb](src/preprocess/03_merge_events.ipynb)
126
120
127
5. Run the following jupyter notebooks to generate the instruction tuning data:
121
5. Run the following jupyter notebooks to generate the instruction tuning data:
128
   - Run this only if you want to generate the instruction tuning data on your own 
122
   - Run this only if you want to generate the instruction tuning data on your own 
129
   - 1. Generate the schema alignment subset:
123
   - 1. Generate the schema alignment subset:
130
     - [04_template_qa_event.ipynb](src/preprocess/04_template_qa_event.ipynb)
124
     - [04_template_qa_event.ipynb](src/preprocess/04_template_qa_event.ipynb)
131
     - [04_paraphrase_qa_event.ipynb](src/preprocess/04_paraphrase_qa_event.ipynb) (need OpenAI Azure service)
125
     - [04_paraphrase_qa_event.ipynb](src/preprocess/04_paraphrase_qa_event.ipynb) (need OpenAI Azure service)
132
   - 2. Generate the instruction following subset:
126
   - 2. Generate the instruction following subset:
133
     - [04_generate_qa_note.ipynb](src/preprocess/04_generate_qa_note.ipynb) (need OpenAI Azure service)
127
     - [04_generate_qa_note.ipynb](src/preprocess/04_generate_qa_note.ipynb) (need OpenAI Azure service)
134
128
135
6. Split the data into train, validation, and test sets:
129
6. Split the data into train, validation, and test sets:
136
   - [05_data_split.ipynb](src/preprocess/05_data_split.ipynb)
130
   - [05_data_split.ipynb](src/preprocess/05_data_split.ipynb)
137
131
138
7. Pre-compute the event embeddings with [06_precompute_event_embeddings.py](src/preprocess/06_precompute_event_embeddings.py):
132
7. Pre-compute the event embeddings with [06_precompute_event_embeddings.py](src/preprocess/06_precompute_event_embeddings.py):
139
    - CMD: `sh src/preprocess/precompute_event_embeddings.sh`
133
    - CMD: `sh src/preprocess/precompute_event_embeddings.sh`
140
134
141
8. Generate the GPT-4 reference answer with [query_gpt4.ipynb](src/eval/query_gpt4.ipynb)
135
8. Generate the GPT-4 reference answer with [query_gpt4.ipynb](src/eval/query_gpt4.ipynb)
142
136
143
137
144
## Notes on Model Enhancements
138
## Notes on Model Enhancements
145
139
146
This repository incorporates several minor improvements over the original implementation described in the paper:
140
This repository incorporates several minor improvements over the original implementation described in the paper:
147
141
148
1. **Enhanced Event Encoder:**
142
1. **Enhanced Event Encoder:**
149
   - Replaced ClinicalBERT (`emilyalsentzer/Bio_ClinicalBERT`) with BiomedBERT-large (`microsoft/BiomedNLP-BiomedBERT-large-uncased-abstract`), improving the quality of event embeddings
143
   - Replaced ClinicalBERT (`emilyalsentzer/Bio_ClinicalBERT`) with BiomedBERT-large (`microsoft/BiomedNLP-BiomedBERT-large-uncased-abstract`), improving the quality of event embeddings
150
144
151
2. **Improved Event Embedding:**
145
2. **Improved Event Embedding:**
152
   - Concatenated event timestamps and numeric values (where available) to the final event embeddings, resulting in better representation of time-sensitive and quantitative data
146
   - Concatenated event timestamps and numeric values (where available) to the final event embeddings, resulting in better representation of time-sensitive and quantitative data
153
147
154
3. **Expanded Dataset:**
148
3. **Expanded Dataset:**
155
   - Increased the size of the clinical reasoning subset to 100K examples, doubling the data from the original 50K subset for more comprehensive coverage.
149
   - Increased the size of the clinical reasoning subset to 100K examples, doubling the data from the original 50K subset for more comprehensive coverage.
156
150
157
4. **Unified Training Approach:**
151
4. **Unified Training Approach:**
158
   - Adopted a single-step training process that integrates schema alignment and clinical reasoning subsets simultaneously, streamlining the training pipeline
152
   - Adopted a single-step training process that integrates schema alignment and clinical reasoning subsets simultaneously, streamlining the training pipeline
159
153
160
These advancements collectively enhance the model's ability to interpret and reason with EHR data, delivering superior performance compared to its predecessor.
154
These advancements collectively enhance the model's ability to interpret and reason with EHR data, delivering superior performance compared to its predecessor.
161
155
162
156
163
## Citation
157
## Citation
164
158
165
If you find this work useful, please cite:
159
If you find this work useful, please cite:
166
```
160
```
167
@inproceedings{
161
@inproceedings{
168
    wu2024instruction,
162
    wu2024instruction,
169
    title={Instruction Tuning Large Language Models to Understand Electronic Health Records},
163
    title={Instruction Tuning Large Language Models to Understand Electronic Health Records},
170
    author={Zhenbang Wu and Anant Dadu and Michael Nalls and Faraz Faghri and Jimeng Sun},
164
    author={Zhenbang Wu and Anant Dadu and Michael Nalls and Faraz Faghri and Jimeng Sun},
171
    booktitle={The Thirty-eight Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
165
    booktitle={The Thirty-eight Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
172
    year={2024},
166
    year={2024},
173
    url={https://openreview.net/forum?id=Dgy5WVgPd2}
167
    url={https://openreview.net/forum?id=Dgy5WVgPd2}
174
}
168
}
175
```
169
```
176
170
177
\* Note: The teaser image above the title is generated by ChatGPT.
171
\* Note: The teaser image above the title is generated by ChatGPT.